mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-11 10:38:31 -05:00
Merge vscode source through release 1.79.2 (#23482)
* log when an editor action doesn't run because of enablement * notebooks create/dispose editors. this means controllers must be created eagerly (😢) and that notebooks need a custom way of plugging comparision keys for session. works unless creating another session for the same cell of a duplicated editor * Set offSide to sql lang configuration to true (#183461) * Fixes #181764 (#183550) * fix typo * Always scroll down and focus the input (#183557) * Fixes #180386 (#183561) * cli: ensure ordering of rpc server messages (#183558) * cli: ensure ordering of rpc server messages Sending lots of messages to a stream would block them around the async tokio mutex, which is "fair" so doesn't preserve ordering. Instead, use the write_loop approach I introduced to the server_multiplexer for the same reason some time ago. * fix clippy * update for May endgame * testing: allow invalidateTestResults to take an array (#183569) * Document `ShareProvider` API proposal (#183568) * Document `ShareProvider` API proposal * Remove mention of VS Code from JSDoc * Add support for rendering svg and md in welcome message (#183580) * Remove toggle setting more eagerly (#183584) * rm message abt macOS * Change text (#183589) * Change text * Accidentally changed the wrong file * cli: improve output for code tunnel status (#183571) * testing: allow invalidateTestResults to take an array * cli: improve output for code tunnel status Fixes #183570 * [json/css/html] update services (#183595) * Add experimental setting to enable this dialog * Fix exporting chat model to JSON before it is initialized (#183597) * minimum scrolling to reveal the next cell on shift+enter (#183600) do minimum scrolling to reveal the next cell on Execute cell and select next * Fixing Jupyter notebook issue 13263 (#183527) fix for the issue, still need to understand why there is strange focusing * Tweak proposed API JSDoc (#183590) * Tweak proposed API JSDoc * workbench -> workspace * fix ? operator * Use active editor and show progress when sharing (#183603) Use active editor and show progress * use scroll setting variable correctly * Schedule welcome widget to show once between typing. (#183606) * Schedule dialog to show once between typing * Don't re-render if already displayed once * Add F10 keybinding for debugger step, even on Web. (#183510) Fixes #181792. Previously, for Web the keyboard shortcut was Alt-F10, because it was believed that F10 could not be bound on browsers. This turned out to be incorrect, so we make the shortcut consistent (F10) with desktop VSCode which is also what many other debuggers use. We keep Alt-F10 on web as a secondary keybinding to keep the experience some web users may have gotten used to by now. * Also pass process.env * Restore missing chat clear commands (#183651) * chore: update electron@22.5.4 (#183716) * Show remote indicator in web when remoteAuthority is set (#183728) * feat: .vuerc as json file (#153017) Co-authored-by: Martin Aeschlimann <martinae@microsoft.com> * Delete --compatibility=1.63 code from the server (#183738) * Copy vscode.dev link to tunnel generates an invalid link when an untitled workspace is open (#183739) * Recent devcontainer display string corrupted on Get Started page (#183740) * Improve "next codeblock" navigation (#183744) * Improve "next codeblock" navigation Operate on the current focused response, or the last one, and scroll to the selected item * Normalize command title * Git - run git status if similarityThreshold changes (#183762) * fix aria-label issue in kb editor fixes A11y_GradeB_VSCode_Keyboard shortcut reads words together - Blind: Arrow key navigation to row Find the binding keys and "when" cell data are read together resulting in a word " CTRL + FeditorFocus instead of CTRL + F editorFocus" #182490 * Status - fix compact padding (#183768) * Remove angle brackets from VB brackets (#183782) Fixes #183359 * Update language config schema with more details about brackets. (#183779) * fix comment (#183812) * Support for `Notebook` CodeAction Kind (#183457) * nb kind support -- wip * allow notebook codeactions around single cell edit check * move notebook code action type out of editor --------- Co-authored-by: rebornix <penn.lv@gmail.com> * cli: fix connection default being applied (#183827) * cli: bump to openssl 1.1.1u (#183828) * Implement "delete" action for chat history (#183609) * Use desired file name when generating new md pasted file paths (#183861) Fixes #183851 * Default to filename for markdown new file if empty (#183864) Fixes #183848 * Fix small typo (#183865) Fixes #183819 * Noop when moving a symbol into the file it is already in (#183866) Fixes #183793 * Adjust codeAction validation to account for notebook kind (#183859) * Make JS/TS `go to configuration` commands work on non-`file:` file systems (#183688) Make `go to project` commands work on non-`file:` file systems Fixes #183685 * Can't do regex search after opening notebook (#183884) Fixes #183858 * Default to current dir for `move to file` select (#183875) Fixes #183870 `showOpenDialog` seems to ignore `defaultUri` if the file doesn't exist * Use `<...>` style markdown links when needed (#183876) Fixes #183849 * Remove check for context keys * Update xterm package * Enable updating a chat model without triggering incremental typing (#183894) * Enable chat "move" commands on empty sessions (#183895) * Enable chat "move" commands on empty sessions and also imported sessions * Fix command name * Fix some chat keybindings on windows (#183896) * "Revert File" on inactive editors are ignored (fix #177557) (#183903) * Empty reason while switching profile (fix #183775) (#183904) * fix https://github.com/microsoft/vscode-internalbacklog/issues/4278 (#183910) * fix https://github.com/microsoft/vscode/issues/183770 (#183914) * code --status displays a lot of errors before actual status output (fix #183787) (#183915) * joh/icy manatee (#183917) * Use idle value for widget of interactive editor controller https://github.com/microsoft/vscode/issues/183820 * also make preview editors idle values https://github.com/microsoft/vscode/issues/183820 * Fix #183777 (#183929) * Fix #182309 (#183925) * Tree checkbox item -> items (#183931) Fixes #183826 * Fixes #183909 (#183940) * Fix #183837 (#183943) fix #183837 * Git - fix #183941 (#183944) * Update xterm.css Fixes #181242 * chore: add @ulugbekna and @aiday-mar to my-endgame notebook (#183946) * Revert "When snippet mode is active, make `Tab` not accept suggestion but advance placeholder" This reverts commit 50a80cdb61511343996ff1d41d0b676c3d329f48. * revert not focusing completion list when quick suggest happens during snippet * change `snippetsPreventQuickSuggestions` default to false * Fix #181446 (#183956) * fix https://github.com/microsoft/vscode-internalbacklog/issues/4298 (#183957) * fix: remove extraneous incorrect context keys (#183959) These were actually getting added in getTestItemContextOverlay, and the test ID was using the extended ID which extensions do not know about. Fixes #183612 * Fixes https://github.com/microsoft/monaco-editor/issues/3920 (#183960) * fix https://github.com/microsoft/vscode-internalbacklog/issues/4324 (#183961) * fix #183030 * fix #180826 (#183962) * make message more generic for interactive editor help * . * fix #183968 * Keep codeblock toolbar visible when focused * Fix when clause on "Run in terminal" command * add important info to help menu * fix #183970 * Set `isRefactoring` for all TS refactoring edits (#183982) * consolidate * Disable move to file in TS versions < 5.2 (#183992) There are still a few key bugs with refactoring. We will ship this as a preview for TS 5.2+ instead of for 5.1 * Polish query accepting (#183995) We shouldn't send the same request to Copilot if the query hasn't changed. So if the query is the same, we short circut. Fixes https://github.com/microsoft/vscode-internalbacklog/issues/4286 Also, when we open in chat, we should use the last accepted query, not what's in the input box. Fixes https://github.com/microsoft/vscode-internalbacklog/issues/4280 * Allow widget to have focus (#184000) So that selecting non-code text works. Fixes https://github.com/microsoft/vscode-internalbacklog/issues/4294 * Fix microsoft/vscode-internalbacklog#4257. Mitigate zindex for zone widgets. (#184001) * Change welcome dialog contribution to Eventually * Misc fixes * Workspace folder picker entry descriptions are suboptimal for some filesystems (fix #183418) (#184018) * cli - ignore std error unless verbose (#183787) (#184031) * joh/inquisitive meerkat (#184034) * only stash sessions that are none empty https://github.com/microsoft/vscode-internalbacklog/issues/4281 * only unstash a session once - unless new exchanges are made, https://github.com/microsoft/vscode-internalbacklog/issues/4281 * account for all exchange types * Improve declared components (#184039) * make sure to read setting (#184040) d'oh, related to https://github.com/microsoft/vscode/issues/173387#issuecomment-1571696644 * [html] update service (#184049) [html] update service. FIxes #181176 * reset context keys on reset/hide (#184042) fixes https://github.com/microsoft/vscode-internalbacklog/issues/4330 * use `Lazy`, not `IdleValue` for the IE widget held by the eager controller (#184048) https://github.com/microsoft/vscode/issues/183820 * fix https://github.com/microsoft/vscode-internalbacklog/issues/4333 (#184067) * use undo-loop instead of undo-edit when discarding chat session (#184063) * use undo-loop instead of undo-edit when discarding chat session fixes https://github.com/microsoft/vscode-internalbacklog/issues/4118 * fix tests, wait for correct state * Add logging to node download (#184070) Add logging to node download. For #182951 * re-enable default zone widget revealing when showing (#184072) fixes https://github.com/microsoft/vscode-internalbacklog/issues/4332, also fixes https://github.com/microsoft/vscode-internalbacklog/issues/3784 * fix #178202 * Allow APIs in stable (#184062) * Fix microsoft/vscode-internalbacklog#4206. Override List view whitespace css for monaco editor (#184087) * Fix JSDoc grammatical error (#184090) * Pick up TS 5.1.3 (#184091) Fixes #182931 * Misc fixes * update distro (#184097) * chore: update electron@22.5.5 (#184116) * Extension host veto is registered multiple times on restart (fix #183778) (#184127) Extension host veto is registered multiple times on restart (#183778) * Do not auto start the local web worker extension host (#184137) * Allow embedders to intercept trustedTypes.createPolicy calls (#184136) Allow embedders to intercept trustedTypes.createPolicy calls (#184100) * fix: reading from console output for --status on windows and linux (#184138) fix: reading from console output for --status on windows and linux (#184118) * Misc fixes * code --status displays a lot of errors before actual status output (fix #183787) (#184200) fix 183787 * (cherry-pick to 1.79 from main) Handle galleryExtension failure in featuredExtensionService (#184205) Handle galleryExtension failure in featuredExtensionService (#184198) Handle galleryExtension failure * Fix #184183. Multiple output height updates are skipped. (#184188) * Post merge init fixes * Misc build issues * disable toggle inline diff of `alt` down https://github.com/microsoft/vscode-internalbacklog/issues/4342 * Take into account already activated extensions when computing running locations (#184303) Take into account already activated extensions when computing running locations (fixes #184180) * Avoid `extensionService.getExtension` and use `ActivationKind.Immediate` to allow that URI handling works while resolving (#184310) Avoid `extensionService.getExtension` and use `ActivationKind.Immediate` to allow that URI handling works while resolving (fixes #182217) * WIP * rm fish auto injection * More breaks * Fix Port Attributes constructor (#184412) * WIP * WIP * Allow extensions to get at the exports of other extensions during resolving (#184487) Allow extensions to get at the exports of other extensions during resolving (fixes #184472) * do not auto finish session when inline chat widgets have focus re https://github.com/microsoft/vscode-internalbacklog/issues/4354 * fix compile errors caused by new base method * WIP * WIP * WIP * WIP * Build errors * unc - fix path traversal bypass * Bump version * cherry-pick prod changes from main * Disable sandbox * Build break from merge * bump version * Merge pull request #184739 from max06/max06/issue184659 Restore ShellIntegration for fish (#184659) * Git - only add --find-renames if the value is not the default one (#185053) Git - only add --find-renames if the value is not the default one (#184992) * Cherry-pick: Revert changes to render featured extensions when available (#184747) Revert changes to render featured extensions when available. (#184573) * Lower timeouts for experimentation and gallery service * Revert changes to render extensions when available * Add audio cues * fix: disable app sandbox when --no-sandbox is present (#184913) * fix: disable app sandbox when --no-sandbox is present (#184897) * fix: loading minimist in packaged builds * Runtime errors * UNC allow list checks cannot be disabled in extension host (fix #184989) (#185085) * UNC allow list checks cannot be disabled in extension host (#184989) * Update src/vs/base/node/unc.js Co-authored-by: Robo <hop2deep@gmail.com> --------- Co-authored-by: Robo <hop2deep@gmail.com> * Add notebook extension * Fix mangling issues * Fix mangling issues * npm install * npm install * Issues blocking bundle * Fix build folder compile errors * Fix windows bundle build * Linting fixes * Fix sqllint issues * Update yarn.lock files * Fix unit tests * Fix a couple breaks from test fixes * Bump distro * redo the checkbox style * Update linux build container dockerfile * Bump build image tag * Bump native watch dog package * Bump node-pty * Bump distro * Fix documnetation error * Update distro * redo the button styles * Update datasource TS * Add missing yarn.lock files * Windows setup fix * Turn off extension unit tests while investigating * color box style * Remove appx * Turn off test log upload * update dropdownlist style * fix universal app build error (#23488) * Skip flaky bufferContext vscode test --------- Co-authored-by: Johannes <johannes.rieken@gmail.com> Co-authored-by: Henning Dieterichs <hdieterichs@microsoft.com> Co-authored-by: Julien Richard <jairbubbles@hotmail.com> Co-authored-by: Charles Gagnon <chgagnon@microsoft.com> Co-authored-by: Megan Rogge <merogge@microsoft.com> Co-authored-by: meganrogge <megan.rogge@microsoft.com> Co-authored-by: Rob Lourens <roblourens@gmail.com> Co-authored-by: Connor Peet <connor@peet.io> Co-authored-by: Joyce Er <joyce.er@microsoft.com> Co-authored-by: Bhavya U <bhavyau@microsoft.com> Co-authored-by: Raymond Zhao <7199958+rzhao271@users.noreply.github.com> Co-authored-by: Martin Aeschlimann <martinae@microsoft.com> Co-authored-by: Aaron Munger <aamunger@microsoft.com> Co-authored-by: Aiday Marlen Kyzy <amarlenkyzy@microsoft.com> Co-authored-by: rebornix <penn.lv@gmail.com> Co-authored-by: Ole <oler@google.com> Co-authored-by: Jean Pierre <jeanp413@hotmail.com> Co-authored-by: Robo <hop2deep@gmail.com> Co-authored-by: Yash Singh <saiansh2525@gmail.com> Co-authored-by: Ladislau Szomoru <3372902+lszomoru@users.noreply.github.com> Co-authored-by: Ulugbek Abdullaev <ulugbekna@gmail.com> Co-authored-by: Alex Ross <alros@microsoft.com> Co-authored-by: Michael Lively <milively@microsoft.com> Co-authored-by: Matt Bierner <matb@microsoft.com> Co-authored-by: Andrea Mah <31675041+andreamah@users.noreply.github.com> Co-authored-by: Benjamin Pasero <benjamin.pasero@microsoft.com> Co-authored-by: Sandeep Somavarapu <sasomava@microsoft.com> Co-authored-by: Daniel Imms <2193314+Tyriar@users.noreply.github.com> Co-authored-by: Tyler James Leonhardt <me@tylerleonhardt.com> Co-authored-by: Alexandru Dima <alexdima@microsoft.com> Co-authored-by: Joao Moreno <Joao.Moreno@microsoft.com> Co-authored-by: Alan Ren <alanren@microsoft.com>
This commit is contained in:
183
cli/src/async_pipe.rs
Normal file
183
cli/src/async_pipe.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
|
||||
use std::path::{Path, PathBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
// todo: we could probably abstract this into some crate, if one doesn't already exist
|
||||
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
pub type AsyncPipe = tokio::net::UnixStream;
|
||||
pub type AsyncPipeWriteHalf = tokio::net::unix::OwnedWriteHalf;
|
||||
pub type AsyncPipeReadHalf = tokio::net::unix::OwnedReadHalf;
|
||||
|
||||
pub async fn get_socket_rw_stream(path: &Path) -> Result<AsyncPipe, CodeError> {
|
||||
tokio::net::UnixStream::connect(path)
|
||||
.await
|
||||
.map_err(CodeError::AsyncPipeFailed)
|
||||
}
|
||||
|
||||
pub async fn listen_socket_rw_stream(path: &Path) -> Result<AsyncPipeListener, CodeError> {
|
||||
tokio::net::UnixListener::bind(path)
|
||||
.map(AsyncPipeListener)
|
||||
.map_err(CodeError::AsyncPipeListenerFailed)
|
||||
}
|
||||
|
||||
pub struct AsyncPipeListener(tokio::net::UnixListener);
|
||||
|
||||
impl AsyncPipeListener {
|
||||
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
|
||||
self.0.accept().await.map_err(CodeError::AsyncPipeListenerFailed).map(|(s, _)| s)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn socket_stream_split(pipe: AsyncPipe) -> (AsyncPipeReadHalf, AsyncPipeWriteHalf) {
|
||||
pipe.into_split()
|
||||
}
|
||||
} else {
|
||||
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
|
||||
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
|
||||
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
|
||||
use pin_project::pin_project;
|
||||
|
||||
#[pin_project(project = AsyncPipeProj)]
|
||||
pub enum AsyncPipe {
|
||||
PipeClient(#[pin] NamedPipeClient),
|
||||
PipeServer(#[pin] NamedPipeServer),
|
||||
}
|
||||
|
||||
impl AsyncRead for AsyncPipe {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match self.project() {
|
||||
AsyncPipeProj::PipeClient(c) => c.poll_read(cx, buf),
|
||||
AsyncPipeProj::PipeServer(c) => c.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for AsyncPipe {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match self.project() {
|
||||
AsyncPipeProj::PipeClient(c) => c.poll_write(cx, buf),
|
||||
AsyncPipeProj::PipeServer(c) => c.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
match self.project() {
|
||||
AsyncPipeProj::PipeClient(c) => c.poll_write_vectored(cx, bufs),
|
||||
AsyncPipeProj::PipeServer(c) => c.poll_write_vectored(cx, bufs),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.project() {
|
||||
AsyncPipeProj::PipeClient(c) => c.poll_flush(cx),
|
||||
AsyncPipeProj::PipeServer(c) => c.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
match self {
|
||||
AsyncPipe::PipeClient(c) => c.is_write_vectored(),
|
||||
AsyncPipe::PipeServer(c) => c.is_write_vectored(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
match self.project() {
|
||||
AsyncPipeProj::PipeClient(c) => c.poll_shutdown(cx),
|
||||
AsyncPipeProj::PipeServer(c) => c.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type AsyncPipeWriteHalf = tokio::io::WriteHalf<AsyncPipe>;
|
||||
pub type AsyncPipeReadHalf = tokio::io::ReadHalf<AsyncPipe>;
|
||||
|
||||
pub async fn get_socket_rw_stream(path: &Path) -> Result<AsyncPipe, CodeError> {
|
||||
// Tokio says we can need to try in a loop. Do so.
|
||||
// https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
|
||||
let client = loop {
|
||||
match ClientOptions::new().open(path) {
|
||||
Ok(client) => break client,
|
||||
// ERROR_PIPE_BUSY https://docs.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499-
|
||||
Err(e) if e.raw_os_error() == Some(231) => sleep(Duration::from_millis(100)).await,
|
||||
Err(e) => return Err(CodeError::AsyncPipeFailed(e)),
|
||||
}
|
||||
};
|
||||
|
||||
Ok(AsyncPipe::PipeClient(client))
|
||||
}
|
||||
|
||||
pub struct AsyncPipeListener {
|
||||
path: PathBuf,
|
||||
server: NamedPipeServer
|
||||
}
|
||||
|
||||
impl AsyncPipeListener {
|
||||
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
|
||||
// see https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeServer.html
|
||||
// this is a bit weird in that the server becomes the client once
|
||||
// they get a connection, and we create a new client.
|
||||
|
||||
self.server
|
||||
.connect()
|
||||
.await
|
||||
.map_err(CodeError::AsyncPipeListenerFailed)?;
|
||||
|
||||
// Construct the next server to be connected before sending the one
|
||||
// we already have of onto a task. This ensures that the server
|
||||
// isn't closed (after it's done in the task) before a new one is
|
||||
// available. Otherwise the client might error with
|
||||
// `io::ErrorKind::NotFound`.
|
||||
let next_server = ServerOptions::new()
|
||||
.create(&self.path)
|
||||
.map_err(CodeError::AsyncPipeListenerFailed)?;
|
||||
|
||||
|
||||
Ok(AsyncPipe::PipeServer(std::mem::replace(&mut self.server, next_server)))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn listen_socket_rw_stream(path: &Path) -> Result<AsyncPipeListener, CodeError> {
|
||||
let server = ServerOptions::new()
|
||||
.first_pipe_instance(true)
|
||||
.create(path)
|
||||
.map_err(CodeError::AsyncPipeListenerFailed)?;
|
||||
|
||||
Ok(AsyncPipeListener { path: path.to_owned(), server })
|
||||
}
|
||||
|
||||
pub fn socket_stream_split(pipe: AsyncPipe) -> (AsyncPipeReadHalf, AsyncPipeWriteHalf) {
|
||||
tokio::io::split(pipe)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a random name for a pipe/socket on the paltform
|
||||
pub fn get_socket_name() -> PathBuf {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(unix)] {
|
||||
std::env::temp_dir().join(format!("{}-{}", APPLICATION_NAME, Uuid::new_v4()))
|
||||
} else {
|
||||
PathBuf::from(format!(r"\\.\pipe\{}-{}", APPLICATION_NAME, Uuid::new_v4()))
|
||||
}
|
||||
}
|
||||
}
|
||||
639
cli/src/auth.rs
Normal file
639
cli/src/auth.rs
Normal file
@@ -0,0 +1,639 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::{
|
||||
constants::{get_default_user_agent, PRODUCT_NAME_LONG},
|
||||
debug, info, log,
|
||||
state::{LauncherPaths, PersistedState},
|
||||
trace,
|
||||
util::{
|
||||
errors::{
|
||||
wrap, AnyError, OAuthError, RefreshTokenNotAvailableError, StatusError, WrappedError,
|
||||
},
|
||||
input::prompt_options,
|
||||
},
|
||||
warning,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use gethostname::gethostname;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::{cell::Cell, fmt::Display, path::PathBuf, sync::Arc};
|
||||
use tokio::time::sleep;
|
||||
use tunnels::{
|
||||
contracts::PROD_FIRST_PARTY_APP_ID,
|
||||
management::{Authorization, AuthorizationProvider, HttpError},
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
message: Option<String>,
|
||||
verification_uri: String,
|
||||
expires_in: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthenticationResponse {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_in: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthenticationError {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(clap::ArgEnum, Serialize, Deserialize, Debug, Clone, Copy)]
|
||||
pub enum AuthProvider {
|
||||
Microsoft,
|
||||
Github,
|
||||
}
|
||||
|
||||
impl Display for AuthProvider {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AuthProvider::Microsoft => write!(f, "Microsoft Account"),
|
||||
AuthProvider::Github => write!(f, "Github Account"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthProvider {
|
||||
pub fn client_id(&self) -> &'static str {
|
||||
match self {
|
||||
AuthProvider::Microsoft => "aebc6443-996d-45c2-90f0-388ff96faa56",
|
||||
AuthProvider::Github => "01ab8ac9400c4e429b23",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn code_uri(&self) -> &'static str {
|
||||
match self {
|
||||
AuthProvider::Microsoft => {
|
||||
"https://login.microsoftonline.com/common/oauth2/v2.0/devicecode"
|
||||
}
|
||||
AuthProvider::Github => "https://github.com/login/device/code",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn grant_uri(&self) -> &'static str {
|
||||
match self {
|
||||
AuthProvider::Microsoft => "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
AuthProvider::Github => "https://github.com/login/oauth/access_token",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_default_scopes(&self) -> String {
|
||||
match self {
|
||||
AuthProvider::Microsoft => format!(
|
||||
"{}/.default+offline_access+profile+openid",
|
||||
PROD_FIRST_PARTY_APP_ID
|
||||
),
|
||||
AuthProvider::Github => "read:user+read:org".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct StoredCredential {
|
||||
#[serde(rename = "p")]
|
||||
provider: AuthProvider,
|
||||
#[serde(rename = "a")]
|
||||
access_token: String,
|
||||
#[serde(rename = "r")]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(rename = "e")]
|
||||
expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl StoredCredential {
|
||||
pub async fn is_expired(&self, log: &log::Logger, client: &reqwest::Client) -> bool {
|
||||
match self.provider {
|
||||
AuthProvider::Microsoft => self
|
||||
.expires_at
|
||||
.map(|e| Utc::now() + chrono::Duration::minutes(5) > e)
|
||||
.unwrap_or(false),
|
||||
|
||||
// Make an auth request to Github. Mark the credential as expired
|
||||
// only on a verifiable 4xx code. We don't error on any failed
|
||||
// request since then a drop in connection could "require" a refresh
|
||||
AuthProvider::Github => {
|
||||
let res = client
|
||||
.get("https://api.github.com/user")
|
||||
.header("Authorization", format!("token {}", self.access_token))
|
||||
.header("User-Agent", get_default_user_agent())
|
||||
.send()
|
||||
.await;
|
||||
let res = match res {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warning!(log, "failed to check Github token: {}", e);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
if res.status().is_success() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let err = StatusError::from_res(res).await;
|
||||
debug!(log, "github token looks expired: {:?}", err);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_response(auth: AuthenticationResponse, provider: AuthProvider) -> Self {
|
||||
StoredCredential {
|
||||
provider,
|
||||
access_token: auth.access_token,
|
||||
refresh_token: auth.refresh_token,
|
||||
expires_at: auth.expires_in.map(|e| Utc::now() + Duration::seconds(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct StorageWithLastRead {
|
||||
storage: Box<dyn StorageImplementation>,
|
||||
last_read: Cell<Result<Option<StoredCredential>, WrappedError>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Auth {
|
||||
client: reqwest::Client,
|
||||
log: log::Logger,
|
||||
file_storage_path: PathBuf,
|
||||
storage: Arc<std::sync::Mutex<Option<StorageWithLastRead>>>,
|
||||
}
|
||||
|
||||
trait StorageImplementation: Send + Sync {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError>;
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError>;
|
||||
fn clear(&mut self) -> Result<(), WrappedError>;
|
||||
}
|
||||
|
||||
// unseal decrypts and deserializes the value
|
||||
fn seal<T>(value: &T) -> String
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
let dec = serde_json::to_string(value).expect("expected to serialize");
|
||||
if std::env::var("VSCODE_CLI_DISABLE_KEYCHAIN_ENCRYPT").is_ok() {
|
||||
return dec;
|
||||
}
|
||||
encrypt(&dec)
|
||||
}
|
||||
|
||||
// unseal decrypts and deserializes the value
|
||||
fn unseal<T>(value: &str) -> Option<T>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
// small back-compat for old unencrypted values, or if VSCODE_CLI_DISABLE_KEYCHAIN_ENCRYPT set
|
||||
if let Ok(v) = serde_json::from_str::<T>(value) {
|
||||
return Some(v);
|
||||
}
|
||||
|
||||
let dec = decrypt(value)?;
|
||||
serde_json::from_str::<T>(&dec).ok()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
const KEYCHAIN_ENTRY_LIMIT: usize = 1024;
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
const KEYCHAIN_ENTRY_LIMIT: usize = 128 * 1024;
|
||||
|
||||
const CONTINUE_MARKER: &str = "<MORE>";
|
||||
|
||||
#[derive(Default)]
|
||||
struct KeyringStorage {
|
||||
// keywring storage can be split into multiple entries due to entry length limits
|
||||
// on Windows https://github.com/microsoft/vscode-cli/issues/358
|
||||
entries: Vec<keyring::Entry>,
|
||||
}
|
||||
|
||||
macro_rules! get_next_entry {
|
||||
($self: expr, $i: expr) => {
|
||||
match $self.entries.get($i) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
let e = keyring::Entry::new("vscode-cli", &format!("vscode-cli-{}", $i));
|
||||
$self.entries.push(e);
|
||||
$self.entries.last().unwrap()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl StorageImplementation for KeyringStorage {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
|
||||
let mut str = String::new();
|
||||
|
||||
for i in 0.. {
|
||||
let entry = get_next_entry!(self, i);
|
||||
let next_chunk = match entry.get_password() {
|
||||
Ok(value) => value,
|
||||
Err(keyring::Error::NoEntry) => return Ok(None), // missing entries?
|
||||
Err(e) => return Err(wrap(e, "error reading keyring")),
|
||||
};
|
||||
|
||||
if next_chunk.ends_with(CONTINUE_MARKER) {
|
||||
str.push_str(&next_chunk[..next_chunk.len() - CONTINUE_MARKER.len()]);
|
||||
} else {
|
||||
str.push_str(&next_chunk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(unseal(&str))
|
||||
}
|
||||
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
|
||||
let sealed = seal(&value);
|
||||
let step_size = KEYCHAIN_ENTRY_LIMIT - CONTINUE_MARKER.len();
|
||||
|
||||
for i in (0..sealed.len()).step_by(step_size) {
|
||||
let entry = get_next_entry!(self, i / step_size);
|
||||
|
||||
let cutoff = i + step_size;
|
||||
let stored = if cutoff <= sealed.len() {
|
||||
let mut part = sealed[i..cutoff].to_string();
|
||||
part.push_str(CONTINUE_MARKER);
|
||||
entry.set_password(&part)
|
||||
} else {
|
||||
entry.set_password(&sealed[i..])
|
||||
};
|
||||
|
||||
if let Err(e) = stored {
|
||||
return Err(wrap(e, "error updating keyring"));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear(&mut self) -> Result<(), WrappedError> {
|
||||
self.read().ok(); // make sure component parts are available
|
||||
for entry in self.entries.iter() {
|
||||
entry
|
||||
.delete_password()
|
||||
.map_err(|e| wrap(e, "error updating keyring"))?;
|
||||
}
|
||||
self.entries.clear();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct FileStorage(PersistedState<Option<String>>);
|
||||
|
||||
impl StorageImplementation for FileStorage {
|
||||
fn read(&mut self) -> Result<Option<StoredCredential>, WrappedError> {
|
||||
Ok(self.0.load().and_then(|s| unseal(&s)))
|
||||
}
|
||||
|
||||
fn store(&mut self, value: StoredCredential) -> Result<(), WrappedError> {
|
||||
self.0.save(Some(seal(&value)))
|
||||
}
|
||||
|
||||
fn clear(&mut self) -> Result<(), WrappedError> {
|
||||
self.0.save(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl Auth {
|
||||
pub fn new(paths: &LauncherPaths, log: log::Logger) -> Auth {
|
||||
Auth {
|
||||
log,
|
||||
client: reqwest::Client::new(),
|
||||
file_storage_path: paths.root().join("token.json"),
|
||||
storage: Arc::new(std::sync::Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_storage<T, F>(&self, op: F) -> T
|
||||
where
|
||||
F: FnOnce(&mut StorageWithLastRead) -> T,
|
||||
{
|
||||
let mut opt = self.storage.lock().unwrap();
|
||||
if let Some(s) = opt.as_mut() {
|
||||
return op(s);
|
||||
}
|
||||
|
||||
let mut keyring_storage = KeyringStorage::default();
|
||||
let mut file_storage = FileStorage(PersistedState::new(self.file_storage_path.clone()));
|
||||
|
||||
let keyring_storage_result = match std::env::var("VSCODE_CLI_USE_FILE_KEYCHAIN") {
|
||||
Ok(_) => Err(wrap("", "user prefers file storage")),
|
||||
_ => keyring_storage.read(),
|
||||
};
|
||||
|
||||
let mut storage = match keyring_storage_result {
|
||||
Ok(v) => StorageWithLastRead {
|
||||
last_read: Cell::new(Ok(v)),
|
||||
storage: Box::new(keyring_storage),
|
||||
},
|
||||
Err(_) => StorageWithLastRead {
|
||||
last_read: Cell::new(file_storage.read()),
|
||||
storage: Box::new(file_storage),
|
||||
},
|
||||
};
|
||||
|
||||
let out = op(&mut storage);
|
||||
*opt = Some(storage);
|
||||
out
|
||||
}
|
||||
|
||||
/// Gets a tunnel Authentication for use in the tunnel management API.
|
||||
pub async fn get_tunnel_authentication(&self) -> Result<Authorization, AnyError> {
|
||||
let cred = self.get_credential().await?;
|
||||
let auth = match cred.provider {
|
||||
AuthProvider::Microsoft => Authorization::Bearer(cred.access_token),
|
||||
AuthProvider::Github => Authorization::Github(format!(
|
||||
"client_id={} {}",
|
||||
cred.provider.client_id(),
|
||||
cred.access_token
|
||||
)),
|
||||
};
|
||||
|
||||
Ok(auth)
|
||||
}
|
||||
|
||||
/// Reads the current details from the keyring.
|
||||
pub fn get_current_credential(&self) -> Result<Option<StoredCredential>, WrappedError> {
|
||||
self.with_storage(|storage| {
|
||||
let value = storage.last_read.replace(Ok(None));
|
||||
storage.last_read.set(value.clone());
|
||||
value
|
||||
})
|
||||
}
|
||||
|
||||
/// Clears login info from the keyring.
|
||||
pub fn clear_credentials(&self) -> Result<(), WrappedError> {
|
||||
self.with_storage(|storage| {
|
||||
storage.storage.clear()?;
|
||||
storage.last_read.set(Ok(None));
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs the login flow, optionally pre-filling a provider and/or access token.
|
||||
pub async fn login(
|
||||
&self,
|
||||
provider: Option<AuthProvider>,
|
||||
access_token: Option<String>,
|
||||
) -> Result<StoredCredential, AnyError> {
|
||||
let provider = match provider {
|
||||
Some(p) => p,
|
||||
None => self.prompt_for_provider().await?,
|
||||
};
|
||||
|
||||
let credentials = match access_token {
|
||||
Some(t) => StoredCredential {
|
||||
provider,
|
||||
access_token: t,
|
||||
refresh_token: None,
|
||||
expires_at: None,
|
||||
},
|
||||
None => self.do_device_code_flow_with_provider(provider).await?,
|
||||
};
|
||||
|
||||
self.store_credentials(credentials.clone());
|
||||
Ok(credentials)
|
||||
}
|
||||
|
||||
/// Gets the currently stored credentials, or asks the user to log in.
|
||||
pub async fn get_credential(&self) -> Result<StoredCredential, AnyError> {
|
||||
let entry = match self.get_current_credential() {
|
||||
Ok(Some(old_creds)) => {
|
||||
trace!(self.log, "Found token in keyring");
|
||||
match self.get_refreshed_token(&old_creds).await {
|
||||
Ok(Some(new_creds)) => {
|
||||
self.store_credentials(new_creds.clone());
|
||||
new_creds
|
||||
}
|
||||
Ok(None) => old_creds,
|
||||
Err(e) => {
|
||||
info!(self.log, "error refreshing token: {}", e);
|
||||
let new_creds = self
|
||||
.do_device_code_flow_with_provider(old_creds.provider)
|
||||
.await?;
|
||||
self.store_credentials(new_creds.clone());
|
||||
new_creds
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None) => {
|
||||
trace!(self.log, "No token in keyring, getting a new one");
|
||||
let creds = self.do_device_code_flow().await?;
|
||||
self.store_credentials(creds.clone());
|
||||
creds
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
warning!(
|
||||
self.log,
|
||||
"Error reading token from keyring, getting a new one: {}",
|
||||
e
|
||||
);
|
||||
let creds = self.do_device_code_flow().await?;
|
||||
self.store_credentials(creds.clone());
|
||||
creds
|
||||
}
|
||||
};
|
||||
|
||||
Ok(entry)
|
||||
}
|
||||
|
||||
/// Stores credentials, logging a warning if it fails.
|
||||
fn store_credentials(&self, creds: StoredCredential) {
|
||||
self.with_storage(|storage| {
|
||||
if let Err(e) = storage.storage.store(creds.clone()) {
|
||||
warning!(
|
||||
self.log,
|
||||
"Failed to update keyring with new credentials: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
storage.last_read.set(Ok(Some(creds)));
|
||||
})
|
||||
}
|
||||
|
||||
/// Refreshes the token in the credentials if necessary. Returns None if
|
||||
/// the token is up to date, or Some new token otherwise.
|
||||
async fn get_refreshed_token(
|
||||
&self,
|
||||
creds: &StoredCredential,
|
||||
) -> Result<Option<StoredCredential>, AnyError> {
|
||||
if !creds.is_expired(&self.log, &self.client).await {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let refresh_token = match &creds.refresh_token {
|
||||
Some(t) => t,
|
||||
None => return Err(AnyError::from(RefreshTokenNotAvailableError())),
|
||||
};
|
||||
|
||||
self.do_grant(
|
||||
creds.provider,
|
||||
format!(
|
||||
"client_id={}&grant_type=refresh_token&refresh_token={}",
|
||||
creds.provider.client_id(),
|
||||
refresh_token
|
||||
),
|
||||
)
|
||||
.await
|
||||
.map(Some)
|
||||
}
|
||||
|
||||
/// Does a "grant token" request.
|
||||
async fn do_grant(
|
||||
&self,
|
||||
provider: AuthProvider,
|
||||
body: String,
|
||||
) -> Result<StoredCredential, AnyError> {
|
||||
let response = self
|
||||
.client
|
||||
.post(provider.grant_uri())
|
||||
.body(body)
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status_code = response.status().as_u16();
|
||||
let body = response.bytes().await?;
|
||||
if let Ok(body) = serde_json::from_slice::<AuthenticationResponse>(&body) {
|
||||
return Ok(StoredCredential::from_response(body, provider));
|
||||
}
|
||||
|
||||
if let Ok(res) = serde_json::from_slice::<AuthenticationError>(&body) {
|
||||
return Err(OAuthError {
|
||||
error: res.error,
|
||||
error_description: res.error_description,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
return Err(StatusError {
|
||||
body: String::from_utf8_lossy(&body).to_string(),
|
||||
status_code,
|
||||
url: provider.grant_uri().to_string(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
/// Implements the device code flow, returning the credentials upon success.
|
||||
async fn do_device_code_flow(&self) -> Result<StoredCredential, AnyError> {
|
||||
let provider = self.prompt_for_provider().await?;
|
||||
self.do_device_code_flow_with_provider(provider).await
|
||||
}
|
||||
|
||||
async fn prompt_for_provider(&self) -> Result<AuthProvider, AnyError> {
|
||||
if std::env::var("VSCODE_CLI_ALLOW_MS_AUTH").is_err() {
|
||||
return Ok(AuthProvider::Github);
|
||||
}
|
||||
|
||||
let provider = prompt_options(
|
||||
format!("How would you like to log in to {}?", PRODUCT_NAME_LONG),
|
||||
&[AuthProvider::Microsoft, AuthProvider::Github],
|
||||
)?;
|
||||
|
||||
Ok(provider)
|
||||
}
|
||||
|
||||
async fn do_device_code_flow_with_provider(
|
||||
&self,
|
||||
provider: AuthProvider,
|
||||
) -> Result<StoredCredential, AnyError> {
|
||||
loop {
|
||||
let init_code = self
|
||||
.client
|
||||
.post(provider.code_uri())
|
||||
.header("Accept", "application/json")
|
||||
.body(format!(
|
||||
"client_id={}&scope={}",
|
||||
provider.client_id(),
|
||||
provider.get_default_scopes(),
|
||||
))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !init_code.status().is_success() {
|
||||
return Err(StatusError::from_res(init_code).await?.into());
|
||||
}
|
||||
|
||||
let init_code_json = init_code.json::<DeviceCodeResponse>().await?;
|
||||
let expires_at = Utc::now() + chrono::Duration::seconds(init_code_json.expires_in);
|
||||
|
||||
match &init_code_json.message {
|
||||
Some(m) => self.log.result(m),
|
||||
None => self.log.result(&format!(
|
||||
"To grant access to the server, please log into {} and use code {}",
|
||||
init_code_json.verification_uri, init_code_json.user_code
|
||||
)),
|
||||
};
|
||||
|
||||
let body = format!(
|
||||
"client_id={}&grant_type=urn:ietf:params:oauth:grant-type:device_code&device_code={}",
|
||||
provider.client_id(),
|
||||
init_code_json.device_code
|
||||
);
|
||||
|
||||
let mut interval_s = 5;
|
||||
while Utc::now() < expires_at {
|
||||
sleep(std::time::Duration::from_secs(interval_s)).await;
|
||||
|
||||
match self.do_grant(provider, body.clone()).await {
|
||||
Ok(creds) => return Ok(creds),
|
||||
Err(AnyError::OAuthError(e)) if e.error == "slow_down" => {
|
||||
interval_s += 5; // https://www.rfc-editor.org/rfc/rfc8628#section-3.5
|
||||
trace!(self.log, "refresh poll failed, slowing down");
|
||||
}
|
||||
Err(e) => {
|
||||
trace!(self.log, "refresh poll failed, retrying: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthorizationProvider for Auth {
|
||||
async fn get_authorization(&self) -> Result<Authorization, HttpError> {
|
||||
self.get_tunnel_authentication()
|
||||
.await
|
||||
.map_err(|e| HttpError::AuthorizationError(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref HOSTNAME: Vec<u8> = gethostname().to_string_lossy().bytes().collect();
|
||||
}
|
||||
|
||||
#[cfg(feature = "vscode-encrypt")]
|
||||
fn encrypt(value: &str) -> String {
|
||||
vscode_encrypt::encrypt(&HOSTNAME, value.as_bytes()).expect("expected to encrypt")
|
||||
}
|
||||
|
||||
#[cfg(feature = "vscode-encrypt")]
|
||||
fn decrypt(value: &str) -> Option<String> {
|
||||
let b = vscode_encrypt::decrypt(&HOSTNAME, value).ok()?;
|
||||
String::from_utf8(b).ok()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vscode-encrypt"))]
|
||||
fn encrypt(value: &str) -> String {
|
||||
value.to_owned()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vscode-encrypt"))]
|
||||
fn decrypt(value: &str) -> Option<String> {
|
||||
Some(value.to_owned())
|
||||
}
|
||||
234
cli/src/bin/code/legacy_args.rs
Normal file
234
cli/src/bin/code/legacy_args.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use cli::commands::args::{
|
||||
CliCore, Commands, DesktopCodeOptions, ExtensionArgs, ExtensionSubcommand,
|
||||
InstallExtensionArgs, ListExtensionArgs, UninstallExtensionArgs,
|
||||
};
|
||||
|
||||
/// Tries to parse the argv using the legacy CLI interface, looking for its
|
||||
/// flags and generating a CLI with subcommands if those don't exist.
|
||||
pub fn try_parse_legacy(
|
||||
iter: impl IntoIterator<Item = impl Into<std::ffi::OsString>>,
|
||||
) -> Option<CliCore> {
|
||||
let raw = clap_lex::RawArgs::new(iter);
|
||||
let mut cursor = raw.cursor();
|
||||
raw.next(&mut cursor); // Skip the bin
|
||||
|
||||
// First make a hashmap of all flags and capture positional arguments.
|
||||
let mut args: HashMap<String, Vec<String>> = HashMap::new();
|
||||
let mut last_arg = None;
|
||||
while let Some(arg) = raw.next(&mut cursor) {
|
||||
if let Some((long, value)) = arg.to_long() {
|
||||
if let Ok(long) = long {
|
||||
last_arg = Some(long.to_string());
|
||||
match args.get_mut(long) {
|
||||
Some(prev) => {
|
||||
if let Some(v) = value {
|
||||
prev.push(v.to_str_lossy().to_string());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
if let Some(v) = value {
|
||||
args.insert(long.to_string(), vec![v.to_str_lossy().to_string()]);
|
||||
} else {
|
||||
args.insert(long.to_string(), vec![]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Ok(value) = arg.to_value() {
|
||||
if let Some(last_arg) = &last_arg {
|
||||
args.get_mut(last_arg)
|
||||
.expect("expected to have last arg")
|
||||
.push(value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let get_first_arg_value =
|
||||
|key: &str| args.get(key).and_then(|v| v.first()).map(|s| s.to_string());
|
||||
let desktop_code_options = DesktopCodeOptions {
|
||||
extensions_dir: get_first_arg_value("extensions-dir"),
|
||||
user_data_dir: get_first_arg_value("user-data-dir"),
|
||||
use_version: None,
|
||||
};
|
||||
|
||||
// Now translate them to subcommands.
|
||||
// --list-extensions -> ext list
|
||||
// --install-extension=id -> ext install <id>
|
||||
// --uninstall-extension=id -> ext uninstall <id>
|
||||
// --status -> status
|
||||
|
||||
if args.contains_key("list-extensions") {
|
||||
Some(CliCore {
|
||||
subcommand: Some(Commands::Extension(ExtensionArgs {
|
||||
subcommand: ExtensionSubcommand::List(ListExtensionArgs {
|
||||
category: get_first_arg_value("category"),
|
||||
show_versions: args.contains_key("show-versions"),
|
||||
}),
|
||||
desktop_code_options,
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Some(exts) = args.remove("install-extension") {
|
||||
Some(CliCore {
|
||||
subcommand: Some(Commands::Extension(ExtensionArgs {
|
||||
subcommand: ExtensionSubcommand::Install(InstallExtensionArgs {
|
||||
id_or_path: exts,
|
||||
pre_release: args.contains_key("pre-release"),
|
||||
force: args.contains_key("force"),
|
||||
}),
|
||||
desktop_code_options,
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Some(exts) = args.remove("uninstall-extension") {
|
||||
Some(CliCore {
|
||||
subcommand: Some(Commands::Extension(ExtensionArgs {
|
||||
subcommand: ExtensionSubcommand::Uninstall(UninstallExtensionArgs { id: exts }),
|
||||
desktop_code_options,
|
||||
})),
|
||||
..Default::default()
|
||||
})
|
||||
} else if args.contains_key("status") {
|
||||
Some(CliCore {
|
||||
subcommand: Some(Commands::Status),
|
||||
..Default::default()
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parses_list_extensions() {
|
||||
let args = vec![
|
||||
"code",
|
||||
"--list-extensions",
|
||||
"--category",
|
||||
"themes",
|
||||
"--show-versions",
|
||||
];
|
||||
let cli = try_parse_legacy(args.into_iter()).unwrap();
|
||||
|
||||
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
|
||||
if let ExtensionSubcommand::List(list_args) = extension_args.subcommand {
|
||||
assert_eq!(list_args.category, Some("themes".to_string()));
|
||||
assert!(list_args.show_versions);
|
||||
} else {
|
||||
panic!(
|
||||
"Expected list subcommand, got {:?}",
|
||||
extension_args.subcommand
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parses_install_extension() {
|
||||
let args = vec![
|
||||
"code",
|
||||
"--install-extension",
|
||||
"connor4312.codesong",
|
||||
"connor4312.hello-world",
|
||||
"--pre-release",
|
||||
"--force",
|
||||
];
|
||||
let cli = try_parse_legacy(args.into_iter()).unwrap();
|
||||
|
||||
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
|
||||
if let ExtensionSubcommand::Install(install_args) = extension_args.subcommand {
|
||||
assert_eq!(
|
||||
install_args.id_or_path,
|
||||
vec!["connor4312.codesong", "connor4312.hello-world"]
|
||||
);
|
||||
assert!(install_args.pre_release);
|
||||
assert!(install_args.force);
|
||||
} else {
|
||||
panic!(
|
||||
"Expected install subcommand, got {:?}",
|
||||
extension_args.subcommand
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parses_uninstall_extension() {
|
||||
let args = vec!["code", "--uninstall-extension", "connor4312.codesong"];
|
||||
let cli = try_parse_legacy(args.into_iter()).unwrap();
|
||||
|
||||
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
|
||||
if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand {
|
||||
assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]);
|
||||
} else {
|
||||
panic!(
|
||||
"Expected uninstall subcommand, got {:?}",
|
||||
extension_args.subcommand
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parses_user_data_dir_and_extensions_dir() {
|
||||
let args = vec![
|
||||
"code",
|
||||
"--uninstall-extension",
|
||||
"connor4312.codesong",
|
||||
"--user-data-dir",
|
||||
"foo",
|
||||
"--extensions-dir",
|
||||
"bar",
|
||||
];
|
||||
let cli = try_parse_legacy(args.into_iter()).unwrap();
|
||||
|
||||
if let Some(Commands::Extension(extension_args)) = cli.subcommand {
|
||||
assert_eq!(
|
||||
extension_args.desktop_code_options.user_data_dir,
|
||||
Some("foo".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
extension_args.desktop_code_options.extensions_dir,
|
||||
Some("bar".to_string())
|
||||
);
|
||||
if let ExtensionSubcommand::Uninstall(uninstall_args) = extension_args.subcommand {
|
||||
assert_eq!(uninstall_args.id, vec!["connor4312.codesong"]);
|
||||
} else {
|
||||
panic!(
|
||||
"Expected uninstall subcommand, got {:?}",
|
||||
extension_args.subcommand
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status() {
|
||||
let args = vec!["code", "--status"];
|
||||
let cli = try_parse_legacy(args.into_iter()).unwrap();
|
||||
|
||||
if let Some(Commands::Status) = cli.subcommand {
|
||||
// no-op
|
||||
} else {
|
||||
panic!("Expected extension subcommand, got {:?}", cli.subcommand);
|
||||
}
|
||||
}
|
||||
}
|
||||
177
cli/src/bin/code/main.rs
Normal file
177
cli/src/bin/code/main.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
mod legacy_args;
|
||||
|
||||
use std::process::Command;
|
||||
|
||||
use clap::Parser;
|
||||
use cli::{
|
||||
commands::{args, tunnels, update, version, CommandContext},
|
||||
constants::get_default_user_agent,
|
||||
desktop, log,
|
||||
state::LauncherPaths,
|
||||
util::{
|
||||
errors::{wrap, AnyError},
|
||||
is_integrated_cli,
|
||||
prereqs::PreReqChecker,
|
||||
},
|
||||
};
|
||||
use legacy_args::try_parse_legacy;
|
||||
use opentelemetry::sdk::trace::TracerProvider as SdkTracerProvider;
|
||||
use opentelemetry::trace::TracerProvider;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), std::convert::Infallible> {
|
||||
let raw_args = std::env::args_os().collect::<Vec<_>>();
|
||||
let parsed = try_parse_legacy(&raw_args)
|
||||
.map(|core| args::AnyCli::Integrated(args::IntegratedCli { core }))
|
||||
.unwrap_or_else(|| {
|
||||
if let Ok(true) = is_integrated_cli() {
|
||||
args::AnyCli::Integrated(args::IntegratedCli::parse_from(&raw_args))
|
||||
} else {
|
||||
args::AnyCli::Standalone(args::StandaloneCli::parse_from(&raw_args))
|
||||
}
|
||||
});
|
||||
|
||||
let core = parsed.core();
|
||||
let context_paths = LauncherPaths::migrate(core.global_options.cli_data_dir.clone()).unwrap();
|
||||
let context_args = core.clone();
|
||||
|
||||
// gets a command context without installing the global logger
|
||||
let context_no_logger = || CommandContext {
|
||||
http: reqwest::ClientBuilder::new()
|
||||
.user_agent(get_default_user_agent())
|
||||
.build()
|
||||
.unwrap(),
|
||||
paths: context_paths,
|
||||
log: make_logger(&context_args),
|
||||
args: context_args,
|
||||
};
|
||||
|
||||
// gets a command context with the global logger installer. Usually what most commands want.
|
||||
macro_rules! context {
|
||||
() => {{
|
||||
let context = context_no_logger();
|
||||
log::install_global_logger(context.log.clone());
|
||||
context
|
||||
}};
|
||||
}
|
||||
|
||||
let result = match parsed {
|
||||
args::AnyCli::Standalone(args::StandaloneCli {
|
||||
subcommand: Some(cmd),
|
||||
..
|
||||
}) => match cmd {
|
||||
args::StandaloneCommands::Update(args) => update::update(context!(), args).await,
|
||||
},
|
||||
args::AnyCli::Standalone(args::StandaloneCli { core: c, .. })
|
||||
| args::AnyCli::Integrated(args::IntegratedCli { core: c, .. }) => match c.subcommand {
|
||||
None => {
|
||||
let context = context!();
|
||||
let ca = context.args.get_base_code_args();
|
||||
start_code(context, ca).await
|
||||
}
|
||||
|
||||
Some(args::Commands::Extension(extension_args)) => {
|
||||
let context = context!();
|
||||
let mut ca = context.args.get_base_code_args();
|
||||
extension_args.add_code_args(&mut ca);
|
||||
start_code(context, ca).await
|
||||
}
|
||||
|
||||
Some(args::Commands::Status) => {
|
||||
let context = context!();
|
||||
let mut ca = context.args.get_base_code_args();
|
||||
ca.push("--status".to_string());
|
||||
start_code(context, ca).await
|
||||
}
|
||||
|
||||
Some(args::Commands::Version(version_args)) => match version_args.subcommand {
|
||||
args::VersionSubcommand::Use(use_version_args) => {
|
||||
version::switch_to(context!(), use_version_args).await
|
||||
}
|
||||
args::VersionSubcommand::Show => version::show(context!()).await,
|
||||
},
|
||||
|
||||
Some(args::Commands::CommandShell) => tunnels::command_shell(context!()).await,
|
||||
|
||||
Some(args::Commands::Tunnel(tunnel_args)) => match tunnel_args.subcommand {
|
||||
Some(args::TunnelSubcommand::Prune) => tunnels::prune(context!()).await,
|
||||
Some(args::TunnelSubcommand::Unregister) => tunnels::unregister(context!()).await,
|
||||
Some(args::TunnelSubcommand::Kill) => tunnels::kill(context!()).await,
|
||||
Some(args::TunnelSubcommand::Restart) => tunnels::restart(context!()).await,
|
||||
Some(args::TunnelSubcommand::Status) => tunnels::status(context!()).await,
|
||||
Some(args::TunnelSubcommand::Rename(rename_args)) => {
|
||||
tunnels::rename(context!(), rename_args).await
|
||||
}
|
||||
Some(args::TunnelSubcommand::User(user_command)) => {
|
||||
tunnels::user(context!(), user_command).await
|
||||
}
|
||||
Some(args::TunnelSubcommand::Service(service_args)) => {
|
||||
tunnels::service(context_no_logger(), service_args).await
|
||||
}
|
||||
None => tunnels::serve(context_no_logger(), tunnel_args.serve_args).await,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
match result {
|
||||
Err(e) => print_and_exit(e),
|
||||
Ok(code) => std::process::exit(code),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_logger(core: &args::CliCore) -> log::Logger {
|
||||
let log_level = if core.global_options.verbose {
|
||||
log::Level::Trace
|
||||
} else {
|
||||
core.global_options.log.unwrap_or(log::Level::Info)
|
||||
};
|
||||
|
||||
let tracer = SdkTracerProvider::builder().build().tracer("codecli");
|
||||
let mut log = log::Logger::new(tracer, log_level);
|
||||
if let Some(f) = &core.global_options.log_to_file {
|
||||
log = log.tee(log::FileLogSink::new(log_level, f).expect("expected to make file logger"))
|
||||
}
|
||||
|
||||
log
|
||||
}
|
||||
|
||||
fn print_and_exit<E>(err: E) -> !
|
||||
where
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
log::emit(log::Level::Error, "", &format!("{}", err));
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
async fn start_code(context: CommandContext, args: Vec<String>) -> Result<i32, AnyError> {
|
||||
// todo: once the integrated CLI takes the place of the Node.js CLI, this should
|
||||
// redirect to the current installation without using the CodeVersionManager.
|
||||
|
||||
let platform = PreReqChecker::new().verify().await?;
|
||||
let version_manager =
|
||||
desktop::CodeVersionManager::new(context.log.clone(), &context.paths, platform);
|
||||
let version = match &context.args.editor_options.code_options.use_version {
|
||||
Some(v) => desktop::RequestedVersion::try_from(v.as_str())?,
|
||||
None => version_manager.get_preferred_version(),
|
||||
};
|
||||
|
||||
let binary = match version_manager.try_get_entrypoint(&version).await {
|
||||
Some(ep) => ep,
|
||||
None => {
|
||||
desktop::prompt_to_install(&version);
|
||||
return Ok(1);
|
||||
}
|
||||
};
|
||||
|
||||
let code = Command::new(&binary)
|
||||
.args(args)
|
||||
.status()
|
||||
.map(|s| s.code().unwrap_or(1))
|
||||
.map_err(|e| wrap(e, format!("error running editor from {}", binary.display())))?;
|
||||
|
||||
Ok(code)
|
||||
}
|
||||
12
cli/src/commands.rs
Normal file
12
cli/src/commands.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
mod context;
|
||||
|
||||
pub mod args;
|
||||
pub mod tunnels;
|
||||
pub mod update;
|
||||
pub mod version;
|
||||
pub use context::CommandContext;
|
||||
688
cli/src/commands/args.rs
Normal file
688
cli/src/commands/args.rs
Normal file
@@ -0,0 +1,688 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{fmt, path::PathBuf};
|
||||
|
||||
use crate::{constants, log, options, tunnels::code_server::CodeServerArgs};
|
||||
use clap::{ArgEnum, Args, Parser, Subcommand};
|
||||
use const_format::concatcp;
|
||||
|
||||
const CLI_NAME: &str = concatcp!(constants::PRODUCT_NAME_LONG, " CLI");
|
||||
const HELP_COMMANDS: &str = "Usage: {name} [options][paths...]
|
||||
|
||||
To read output from another program, append '-' (e.g. 'echo Hello World | {name} -')";
|
||||
|
||||
const STANDALONE_TEMPLATE: &str = concatcp!(
|
||||
CLI_NAME,
|
||||
" Standalone - {version}
|
||||
|
||||
",
|
||||
HELP_COMMANDS,
|
||||
"
|
||||
Running editor commands requires installing ",
|
||||
constants::QUALITYLESS_PRODUCT_NAME,
|
||||
", and may differ slightly.
|
||||
|
||||
{all-args}"
|
||||
);
|
||||
const INTEGRATED_TEMPLATE: &str = concatcp!(
|
||||
CLI_NAME,
|
||||
" - {version}
|
||||
|
||||
",
|
||||
HELP_COMMANDS,
|
||||
"
|
||||
|
||||
{all-args}"
|
||||
);
|
||||
|
||||
const COMMIT_IN_VERSION: &str = match constants::VSCODE_CLI_COMMIT {
|
||||
Some(c) => c,
|
||||
None => "unknown",
|
||||
};
|
||||
const NUMBER_IN_VERSION: &str = match constants::VSCODE_CLI_VERSION {
|
||||
Some(c) => c,
|
||||
None => "dev",
|
||||
};
|
||||
const VERSION: &str = concatcp!(NUMBER_IN_VERSION, " (commit ", COMMIT_IN_VERSION, ")");
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
#[clap(
|
||||
help_template = INTEGRATED_TEMPLATE,
|
||||
long_about = None,
|
||||
version = VERSION,
|
||||
)]
|
||||
pub struct IntegratedCli {
|
||||
#[clap(flatten)]
|
||||
pub core: CliCore,
|
||||
}
|
||||
|
||||
/// Common CLI shared between intergated and standalone interfaces.
|
||||
#[derive(Args, Debug, Default, Clone)]
|
||||
pub struct CliCore {
|
||||
/// One or more files, folders, or URIs to open.
|
||||
#[clap(name = "paths")]
|
||||
pub open_paths: Vec<String>,
|
||||
|
||||
#[clap(flatten, next_help_heading = Some("EDITOR OPTIONS"))]
|
||||
pub editor_options: EditorOptions,
|
||||
|
||||
#[clap(flatten, next_help_heading = Some("EDITOR TROUBLESHOOTING"))]
|
||||
pub troubleshooting: EditorTroubleshooting,
|
||||
|
||||
#[clap(flatten, next_help_heading = Some("GLOBAL OPTIONS"))]
|
||||
pub global_options: GlobalOptions,
|
||||
|
||||
#[clap(subcommand)]
|
||||
pub subcommand: Option<Commands>,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Default)]
|
||||
#[clap(
|
||||
help_template = STANDALONE_TEMPLATE,
|
||||
long_about = None,
|
||||
version = VERSION,
|
||||
)]
|
||||
pub struct StandaloneCli {
|
||||
#[clap(flatten)]
|
||||
pub core: CliCore,
|
||||
|
||||
#[clap(subcommand)]
|
||||
pub subcommand: Option<StandaloneCommands>,
|
||||
}
|
||||
|
||||
pub enum AnyCli {
|
||||
Integrated(IntegratedCli),
|
||||
Standalone(StandaloneCli),
|
||||
}
|
||||
|
||||
impl AnyCli {
|
||||
pub fn core(&self) -> &CliCore {
|
||||
match self {
|
||||
AnyCli::Integrated(cli) => &cli.core,
|
||||
AnyCli::Standalone(cli) => &cli.core,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CliCore {
|
||||
pub fn get_base_code_args(&self) -> Vec<String> {
|
||||
let mut args = self.open_paths.clone();
|
||||
self.editor_options.add_code_args(&mut args);
|
||||
self.troubleshooting.add_code_args(&mut args);
|
||||
self.global_options.add_code_args(&mut args);
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a CliCore> for CodeServerArgs {
|
||||
fn from(cli: &'a CliCore) -> Self {
|
||||
let mut args = CodeServerArgs {
|
||||
log: cli.global_options.log,
|
||||
accept_server_license_terms: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
args.log = cli.global_options.log;
|
||||
args.accept_server_license_terms = true;
|
||||
|
||||
if cli.global_options.verbose {
|
||||
args.verbose = true;
|
||||
}
|
||||
|
||||
if cli.global_options.disable_telemetry {
|
||||
args.telemetry_level = Some(options::TelemetryLevel::Off);
|
||||
} else if cli.global_options.telemetry_level.is_some() {
|
||||
args.telemetry_level = cli.global_options.telemetry_level;
|
||||
}
|
||||
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum StandaloneCommands {
|
||||
/// Updates the CLI.
|
||||
Update(StandaloneUpdateArgs),
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct StandaloneUpdateArgs {
|
||||
/// Only check for updates, without actually updating the CLI.
|
||||
#[clap(long)]
|
||||
pub check: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
|
||||
pub enum Commands {
|
||||
/// Create a tunnel that's accessible on vscode.dev from anywhere.
|
||||
/// Run `code tunnel --help` for more usage info.
|
||||
Tunnel(TunnelArgs),
|
||||
|
||||
/// Manage editor extensions.
|
||||
#[clap(name = "ext")]
|
||||
Extension(ExtensionArgs),
|
||||
|
||||
/// Print process usage and diagnostics information.
|
||||
Status,
|
||||
|
||||
/// Changes the version of the editor you're using.
|
||||
Version(VersionArgs),
|
||||
|
||||
/// Runs the control server on process stdin/stdout
|
||||
#[clap(hide = true)]
|
||||
CommandShell,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct ExtensionArgs {
|
||||
#[clap(subcommand)]
|
||||
pub subcommand: ExtensionSubcommand,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub desktop_code_options: DesktopCodeOptions,
|
||||
}
|
||||
|
||||
impl ExtensionArgs {
|
||||
pub fn add_code_args(&self, target: &mut Vec<String>) {
|
||||
self.desktop_code_options.add_code_args(target);
|
||||
self.subcommand.add_code_args(target);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum ExtensionSubcommand {
|
||||
/// List installed extensions.
|
||||
List(ListExtensionArgs),
|
||||
/// Install an extension.
|
||||
Install(InstallExtensionArgs),
|
||||
/// Uninstall an extension.
|
||||
Uninstall(UninstallExtensionArgs),
|
||||
}
|
||||
|
||||
impl ExtensionSubcommand {
|
||||
pub fn add_code_args(&self, target: &mut Vec<String>) {
|
||||
match self {
|
||||
ExtensionSubcommand::List(args) => {
|
||||
target.push("--list-extensions".to_string());
|
||||
if args.show_versions {
|
||||
target.push("--show-versions".to_string());
|
||||
}
|
||||
if let Some(category) = &args.category {
|
||||
target.push(format!("--category={}", category));
|
||||
}
|
||||
}
|
||||
ExtensionSubcommand::Install(args) => {
|
||||
for id in args.id_or_path.iter() {
|
||||
target.push(format!("--install-extension={}", id));
|
||||
}
|
||||
if args.pre_release {
|
||||
target.push("--pre-release".to_string());
|
||||
}
|
||||
if args.force {
|
||||
target.push("--force".to_string());
|
||||
}
|
||||
}
|
||||
ExtensionSubcommand::Uninstall(args) => {
|
||||
for id in args.id.iter() {
|
||||
target.push(format!("--uninstall-extension={}", id));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct ListExtensionArgs {
|
||||
/// Filters installed extensions by provided category, when using --list-extensions.
|
||||
#[clap(long, value_name = "category")]
|
||||
pub category: Option<String>,
|
||||
|
||||
/// Show versions of installed extensions, when using --list-extensions.
|
||||
#[clap(long)]
|
||||
pub show_versions: bool,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct InstallExtensionArgs {
|
||||
/// Either an extension id or a path to a VSIX. The identifier of an
|
||||
/// extension is '${publisher}.${name}'. Use '--force' argument to update
|
||||
/// to latest version. To install a specific version provide '@${version}'.
|
||||
/// For example: 'vscode.csharp@1.2.3'.
|
||||
#[clap(name = "ext-id | id")]
|
||||
pub id_or_path: Vec<String>,
|
||||
|
||||
/// Installs the pre-release version of the extension
|
||||
#[clap(long)]
|
||||
pub pre_release: bool,
|
||||
|
||||
/// Update to the latest version of the extension if it's already installed.
|
||||
#[clap(long)]
|
||||
pub force: bool,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct UninstallExtensionArgs {
|
||||
/// One or more extension identifiers to uninstall. The identifier of an
|
||||
/// extension is '${publisher}.${name}'. Use '--force' argument to update
|
||||
/// to latest version.
|
||||
#[clap(name = "ext-id")]
|
||||
pub id: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct VersionArgs {
|
||||
#[clap(subcommand)]
|
||||
pub subcommand: VersionSubcommand,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum VersionSubcommand {
|
||||
/// Switches the version of the editor in use.
|
||||
Use(UseVersionArgs),
|
||||
|
||||
/// Shows the currently configured editor version.
|
||||
Show,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct UseVersionArgs {
|
||||
/// The version of the editor you want to use. Can be "stable", "insiders",
|
||||
/// a version number, or an absolute path to an existing install.
|
||||
#[clap(value_name = "stable | insiders | x.y.z | path")]
|
||||
pub name: String,
|
||||
|
||||
/// The directory where the version can be found.
|
||||
#[clap(long, value_name = "path")]
|
||||
pub install_dir: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Default, Clone)]
|
||||
pub struct EditorOptions {
|
||||
/// Compare two files with each other.
|
||||
#[clap(short, long, value_names = &["file", "file"])]
|
||||
pub diff: Vec<String>,
|
||||
|
||||
/// Add folder(s) to the last active window.
|
||||
#[clap(short, long, value_name = "folder")]
|
||||
pub add: Option<String>,
|
||||
|
||||
/// Open a file at the path on the specified line and character position.
|
||||
#[clap(short, long, value_name = "file:line[:character]")]
|
||||
pub goto: Option<String>,
|
||||
|
||||
/// Force to open a new window.
|
||||
#[clap(short, long)]
|
||||
pub new_window: bool,
|
||||
|
||||
/// Force to open a file or folder in an
|
||||
#[clap(short, long)]
|
||||
pub reuse_window: bool,
|
||||
|
||||
/// Wait for the files to be closed before returning.
|
||||
#[clap(short, long)]
|
||||
pub wait: bool,
|
||||
|
||||
/// The locale to use (e.g. en-US or zh-TW).
|
||||
#[clap(long, value_name = "locale")]
|
||||
pub locale: Option<String>,
|
||||
|
||||
/// Enables proposed API features for extensions. Can receive one or
|
||||
/// more extension IDs to enable individually.
|
||||
#[clap(long, value_name = "ext-id")]
|
||||
pub enable_proposed_api: Vec<String>,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub code_options: DesktopCodeOptions,
|
||||
}
|
||||
|
||||
impl EditorOptions {
|
||||
pub fn add_code_args(&self, target: &mut Vec<String>) {
|
||||
if !self.diff.is_empty() {
|
||||
target.push("--diff".to_string());
|
||||
for file in self.diff.iter() {
|
||||
target.push(file.clone());
|
||||
}
|
||||
}
|
||||
if let Some(add) = &self.add {
|
||||
target.push("--add".to_string());
|
||||
target.push(add.clone());
|
||||
}
|
||||
if let Some(goto) = &self.goto {
|
||||
target.push("--goto".to_string());
|
||||
target.push(goto.clone());
|
||||
}
|
||||
if self.new_window {
|
||||
target.push("--new-window".to_string());
|
||||
}
|
||||
if self.reuse_window {
|
||||
target.push("--reuse-window".to_string());
|
||||
}
|
||||
if self.wait {
|
||||
target.push("--wait".to_string());
|
||||
}
|
||||
if let Some(locale) = &self.locale {
|
||||
target.push(format!("--locale={}", locale));
|
||||
}
|
||||
if !self.enable_proposed_api.is_empty() {
|
||||
for id in self.enable_proposed_api.iter() {
|
||||
target.push(format!("--enable-proposed-api={}", id));
|
||||
}
|
||||
}
|
||||
self.code_options.add_code_args(target);
|
||||
}
|
||||
}
|
||||
|
||||
/// Arguments applicable whenever the desktop editor is launched
|
||||
#[derive(Args, Debug, Default, Clone)]
|
||||
pub struct DesktopCodeOptions {
|
||||
/// Set the root path for extensions.
|
||||
#[clap(long, value_name = "dir")]
|
||||
pub extensions_dir: Option<String>,
|
||||
|
||||
/// Specifies the directory that user data is kept in. Can be used to
|
||||
/// open multiple distinct instances of the editor.
|
||||
#[clap(long, value_name = "dir")]
|
||||
pub user_data_dir: Option<String>,
|
||||
|
||||
/// Sets the editor version to use for this command. The preferred version
|
||||
/// can be persisted with `code version use <version>`. Can be "stable",
|
||||
/// "insiders", a version number, or an absolute path to an existing install.
|
||||
#[clap(long, value_name = "stable | insiders | x.y.z | path")]
|
||||
pub use_version: Option<String>,
|
||||
}
|
||||
|
||||
/// Argument specifying the output format.
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct OutputFormatOptions {
|
||||
/// Set the data output formats.
|
||||
#[clap(arg_enum, long, value_name = "format", default_value_t = OutputFormat::Text)]
|
||||
pub format: OutputFormat,
|
||||
}
|
||||
|
||||
impl DesktopCodeOptions {
|
||||
pub fn add_code_args(&self, target: &mut Vec<String>) {
|
||||
if let Some(extensions_dir) = &self.extensions_dir {
|
||||
target.push(format!("--extensions-dir={}", extensions_dir));
|
||||
}
|
||||
if let Some(user_data_dir) = &self.user_data_dir {
|
||||
target.push(format!("--user-data-dir={}", user_data_dir));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Default, Clone)]
|
||||
pub struct GlobalOptions {
|
||||
/// Directory where CLI metadata should be stored.
|
||||
#[clap(long, env = "VSCODE_CLI_DATA_DIR", global = true)]
|
||||
pub cli_data_dir: Option<String>,
|
||||
|
||||
/// Print verbose output (implies --wait).
|
||||
#[clap(long, global = true)]
|
||||
pub verbose: bool,
|
||||
|
||||
/// Log to a file in addition to stdout. Used when running as a service.
|
||||
#[clap(long, global = true, hide = true)]
|
||||
pub log_to_file: Option<PathBuf>,
|
||||
|
||||
/// Log level to use.
|
||||
#[clap(long, arg_enum, value_name = "level", global = true)]
|
||||
pub log: Option<log::Level>,
|
||||
|
||||
/// Disable telemetry for the current command, even if it was previously
|
||||
/// accepted as part of the license prompt or specified in '--telemetry-level'
|
||||
#[clap(long, global = true, hide = true)]
|
||||
pub disable_telemetry: bool,
|
||||
|
||||
/// Sets the initial telemetry level
|
||||
#[clap(arg_enum, long, global = true, hide = true)]
|
||||
pub telemetry_level: Option<options::TelemetryLevel>,
|
||||
}
|
||||
|
||||
impl GlobalOptions {
|
||||
pub fn add_code_args(&self, target: &mut Vec<String>) {
|
||||
if self.verbose {
|
||||
target.push("--verbose".to_string());
|
||||
}
|
||||
if let Some(log) = self.log {
|
||||
target.push(format!("--log={}", log));
|
||||
}
|
||||
if self.disable_telemetry {
|
||||
target.push("--disable-telemetry".to_string());
|
||||
}
|
||||
if let Some(telemetry_level) = &self.telemetry_level {
|
||||
target.push(format!("--telemetry-level={}", telemetry_level));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Default, Clone)]
|
||||
pub struct EditorTroubleshooting {
|
||||
/// Run CPU profiler during startup.
|
||||
#[clap(long)]
|
||||
pub prof_startup: bool,
|
||||
|
||||
/// Disable all installed extensions.
|
||||
#[clap(long)]
|
||||
pub disable_extensions: bool,
|
||||
|
||||
/// Disable an extension.
|
||||
#[clap(long, value_name = "ext-id")]
|
||||
pub disable_extension: Vec<String>,
|
||||
|
||||
/// Turn sync on or off.
|
||||
#[clap(arg_enum, long, value_name = "on | off")]
|
||||
pub sync: Option<SyncState>,
|
||||
|
||||
/// Allow debugging and profiling of extensions. Check the developer tools for the connection URI.
|
||||
#[clap(long, value_name = "port")]
|
||||
pub inspect_extensions: Option<u16>,
|
||||
|
||||
/// Allow debugging and profiling of extensions with the extension host
|
||||
/// being paused after start. Check the developer tools for the connection URI.
|
||||
#[clap(long, value_name = "port")]
|
||||
pub inspect_brk_extensions: Option<u16>,
|
||||
|
||||
/// Disable GPU hardware acceleration.
|
||||
#[clap(long)]
|
||||
pub disable_gpu: bool,
|
||||
|
||||
/// Shows all telemetry events which the editor collects.
|
||||
#[clap(long)]
|
||||
pub telemetry: bool,
|
||||
}
|
||||
|
||||
impl EditorTroubleshooting {
|
||||
pub fn add_code_args(&self, target: &mut Vec<String>) {
|
||||
if self.prof_startup {
|
||||
target.push("--prof-startup".to_string());
|
||||
}
|
||||
if self.disable_extensions {
|
||||
target.push("--disable-extensions".to_string());
|
||||
}
|
||||
for id in self.disable_extension.iter() {
|
||||
target.push(format!("--disable-extension={}", id));
|
||||
}
|
||||
if let Some(sync) = &self.sync {
|
||||
target.push(format!("--sync={}", sync));
|
||||
}
|
||||
if let Some(port) = &self.inspect_extensions {
|
||||
target.push(format!("--inspect-extensions={}", port));
|
||||
}
|
||||
if let Some(port) = &self.inspect_brk_extensions {
|
||||
target.push(format!("--inspect-brk-extensions={}", port));
|
||||
}
|
||||
if self.disable_gpu {
|
||||
target.push("--disable-gpu".to_string());
|
||||
}
|
||||
if self.telemetry {
|
||||
target.push("--telemetry".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ArgEnum, Clone, Copy, Debug)]
|
||||
pub enum SyncState {
|
||||
On,
|
||||
Off,
|
||||
}
|
||||
|
||||
impl fmt::Display for SyncState {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
SyncState::Off => write!(f, "off"),
|
||||
SyncState::On => write!(f, "on"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ArgEnum, Clone, Copy, Debug)]
|
||||
pub enum OutputFormat {
|
||||
Json,
|
||||
Text,
|
||||
}
|
||||
|
||||
#[derive(Args, Clone, Debug, Default)]
|
||||
pub struct ExistingTunnelArgs {
|
||||
/// Name you'd like to assign preexisting tunnel to use to connect the tunnel
|
||||
#[clap(long, hide = true)]
|
||||
pub tunnel_name: Option<String>,
|
||||
|
||||
/// Token to authenticate and use preexisting tunnel
|
||||
#[clap(long, hide = true)]
|
||||
pub host_token: Option<String>,
|
||||
|
||||
/// ID of preexisting tunnel to use to connect the tunnel
|
||||
#[clap(long, hide = true)]
|
||||
pub tunnel_id: Option<String>,
|
||||
|
||||
/// Cluster of preexisting tunnel to use to connect the tunnel
|
||||
#[clap(long, hide = true)]
|
||||
pub cluster: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone, Default)]
|
||||
pub struct TunnelServeArgs {
|
||||
/// Optional details to connect to an existing tunnel
|
||||
#[clap(flatten, next_help_heading = Some("ADVANCED OPTIONS"))]
|
||||
pub tunnel: ExistingTunnelArgs,
|
||||
|
||||
/// Randomly name machine for port forwarding service
|
||||
#[clap(long)]
|
||||
pub random_name: bool,
|
||||
|
||||
/// Prevents the machine going to sleep while this command runs.
|
||||
#[clap(long)]
|
||||
pub no_sleep: bool,
|
||||
|
||||
/// Sets the machine name for port forwarding service
|
||||
#[clap(long)]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Optional parent process id. If provided, the server will be stopped when the process of the given pid no longer exists
|
||||
#[clap(long, hide = true)]
|
||||
pub parent_process_id: Option<String>,
|
||||
|
||||
/// If set, the user accepts the server license terms and the server will be started without a user prompt.
|
||||
#[clap(long)]
|
||||
pub accept_server_license_terms: bool,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct TunnelArgs {
|
||||
#[clap(subcommand)]
|
||||
pub subcommand: Option<TunnelSubcommand>,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub serve_args: TunnelServeArgs,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum TunnelSubcommand {
|
||||
/// Delete all servers which are currently not running.
|
||||
Prune,
|
||||
|
||||
/// Stops any running tunnel on the system.
|
||||
Kill,
|
||||
|
||||
/// Restarts any running tunnel on the system.
|
||||
Restart,
|
||||
|
||||
/// Gets whether there is a tunnel running on the current machineiou.
|
||||
Status,
|
||||
|
||||
/// Rename the name of this machine associated with port forwarding service.
|
||||
Rename(TunnelRenameArgs),
|
||||
|
||||
/// Remove this machine's association with the port forwarding service.
|
||||
Unregister,
|
||||
|
||||
#[clap(subcommand)]
|
||||
User(TunnelUserSubCommands),
|
||||
|
||||
/// (Preview) Manages the tunnel when installed as a system service,
|
||||
#[clap(subcommand)]
|
||||
Service(TunnelServiceSubCommands),
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum TunnelServiceSubCommands {
|
||||
/// Installs or re-installs the tunnel service on the machine.
|
||||
Install(TunnelServiceInstallArgs),
|
||||
|
||||
/// Uninstalls and stops the tunnel service.
|
||||
Uninstall,
|
||||
|
||||
/// Shows logs for the running service.
|
||||
Log,
|
||||
|
||||
/// Internal command for running the service
|
||||
#[clap(hide = true)]
|
||||
InternalRun,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct TunnelServiceInstallArgs {
|
||||
/// If set, the user accepts the server license terms and the server will be started without a user prompt.
|
||||
#[clap(long)]
|
||||
pub accept_server_license_terms: bool,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct TunnelRenameArgs {
|
||||
/// The name you'd like to rename your machine to.
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum TunnelUserSubCommands {
|
||||
/// Log in to port forwarding service
|
||||
Login(LoginArgs),
|
||||
|
||||
/// Log out of port forwarding service
|
||||
Logout,
|
||||
|
||||
/// Show the account that's logged into port forwarding service
|
||||
Show,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct LoginArgs {
|
||||
/// An access token to store for authentication. Note: this will not be
|
||||
/// refreshed if it expires!
|
||||
#[clap(long, requires = "provider")]
|
||||
pub access_token: Option<String>,
|
||||
|
||||
/// The auth provider to use. If not provided, a prompt will be shown.
|
||||
#[clap(arg_enum, long)]
|
||||
pub provider: Option<AuthProvider>,
|
||||
}
|
||||
|
||||
#[derive(clap::ArgEnum, Debug, Clone, Copy)]
|
||||
pub enum AuthProvider {
|
||||
Microsoft,
|
||||
Github,
|
||||
}
|
||||
15
cli/src/commands/context.rs
Normal file
15
cli/src/commands/context.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::{log, state::LauncherPaths};
|
||||
|
||||
use super::args::CliCore;
|
||||
|
||||
pub struct CommandContext {
|
||||
pub log: log::Logger,
|
||||
pub paths: LauncherPaths,
|
||||
pub args: CliCore,
|
||||
pub http: reqwest::Client,
|
||||
}
|
||||
135
cli/src/commands/output.rs
Normal file
135
cli/src/commands/output.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use std::io::{BufWriter, Write};
|
||||
|
||||
use super::args::OutputFormat;
|
||||
|
||||
pub struct Column {
|
||||
max_width: usize,
|
||||
heading: &'static str,
|
||||
data: Vec<String>,
|
||||
}
|
||||
|
||||
impl Column {
|
||||
pub fn new(heading: &'static str) -> Self {
|
||||
Column {
|
||||
max_width: heading.len(),
|
||||
heading,
|
||||
data: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_row(&mut self, row: String) {
|
||||
self.max_width = std::cmp::max(self.max_width, row.len());
|
||||
self.data.push(row);
|
||||
}
|
||||
}
|
||||
|
||||
impl OutputFormat {
|
||||
pub fn print_table(&self, table: OutputTable) -> Result<(), std::io::Error> {
|
||||
match *self {
|
||||
OutputFormat::Json => JsonTablePrinter().print(table, &mut std::io::stdout()),
|
||||
OutputFormat::Text => TextTablePrinter().print(table, &mut std::io::stdout()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OutputTable {
|
||||
cols: Vec<Column>,
|
||||
}
|
||||
|
||||
impl OutputTable {
|
||||
pub fn new(cols: Vec<Column>) -> Self {
|
||||
OutputTable { cols }
|
||||
}
|
||||
}
|
||||
|
||||
trait TablePrinter {
|
||||
fn print(&self, table: OutputTable, out: &mut dyn std::io::Write)
|
||||
-> Result<(), std::io::Error>;
|
||||
}
|
||||
|
||||
pub struct JsonTablePrinter();
|
||||
|
||||
impl TablePrinter for JsonTablePrinter {
|
||||
fn print(
|
||||
&self,
|
||||
table: OutputTable,
|
||||
out: &mut dyn std::io::Write,
|
||||
) -> Result<(), std::io::Error> {
|
||||
let mut bw = BufWriter::new(out);
|
||||
bw.write_all(b"[")?;
|
||||
|
||||
if !table.cols.is_empty() {
|
||||
let data_len = table.cols[0].data.len();
|
||||
for i in 0..data_len {
|
||||
if i > 0 {
|
||||
bw.write_all(b",{")?;
|
||||
} else {
|
||||
bw.write_all(b"{")?;
|
||||
}
|
||||
for col in &table.cols {
|
||||
serde_json::to_writer(&mut bw, col.heading)?;
|
||||
bw.write_all(b":")?;
|
||||
serde_json::to_writer(&mut bw, &col.data[i])?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bw.write_all(b"]")?;
|
||||
bw.flush()
|
||||
}
|
||||
}
|
||||
|
||||
/// Type that prints the output as an ASCII, markdown-style table.
|
||||
pub struct TextTablePrinter();
|
||||
|
||||
impl TablePrinter for TextTablePrinter {
|
||||
fn print(
|
||||
&self,
|
||||
table: OutputTable,
|
||||
out: &mut dyn std::io::Write,
|
||||
) -> Result<(), std::io::Error> {
|
||||
let mut bw = BufWriter::new(out);
|
||||
|
||||
let sizes = table.cols.iter().map(|c| c.max_width).collect::<Vec<_>>();
|
||||
|
||||
// print headers
|
||||
write_columns(&mut bw, table.cols.iter().map(|c| c.heading), &sizes)?;
|
||||
// print --- separators
|
||||
write_columns(
|
||||
&mut bw,
|
||||
table.cols.iter().map(|c| "-".repeat(c.max_width)),
|
||||
&sizes,
|
||||
)?;
|
||||
// print each column
|
||||
if !table.cols.is_empty() {
|
||||
let data_len = table.cols[0].data.len();
|
||||
for i in 0..data_len {
|
||||
write_columns(&mut bw, table.cols.iter().map(|c| &c.data[i]), &sizes)?;
|
||||
}
|
||||
}
|
||||
|
||||
bw.flush()
|
||||
}
|
||||
}
|
||||
|
||||
fn write_columns<T>(
|
||||
mut w: impl Write,
|
||||
cols: impl Iterator<Item = T>,
|
||||
sizes: &[usize],
|
||||
) -> Result<(), std::io::Error>
|
||||
where
|
||||
T: Display,
|
||||
{
|
||||
w.write_all(b"|")?;
|
||||
for (i, col) in cols.enumerate() {
|
||||
write!(w, " {:width$} |", col, width = sizes[i])?;
|
||||
}
|
||||
w.write_all(b"\r\n")
|
||||
}
|
||||
427
cli/src/commands/tunnels.rs
Normal file
427
cli/src/commands/tunnels.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{str::FromStr, time::Duration};
|
||||
use sysinfo::Pid;
|
||||
|
||||
use super::{
|
||||
args::{
|
||||
AuthProvider, CliCore, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs,
|
||||
TunnelServiceSubCommands, TunnelUserSubCommands,
|
||||
},
|
||||
CommandContext,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auth::Auth,
|
||||
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
tunnels::{
|
||||
code_server::CodeServerArgs,
|
||||
create_service_manager, dev_tunnels, legal,
|
||||
paths::get_all_servers,
|
||||
protocol, serve_stream,
|
||||
shutdown_signal::ShutdownRequest,
|
||||
singleton_client::do_single_rpc_call,
|
||||
singleton_server::{
|
||||
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
|
||||
},
|
||||
Next, ServeStreamParams, ServiceContainer, ServiceManager,
|
||||
},
|
||||
util::{
|
||||
app_lock::AppMutex,
|
||||
errors::{wrap, AnyError, CodeError},
|
||||
prereqs::PreReqChecker,
|
||||
},
|
||||
};
|
||||
use crate::{
|
||||
singleton::{acquire_singleton, SingletonConnection},
|
||||
tunnels::{
|
||||
dev_tunnels::ActiveTunnel,
|
||||
singleton_client::{start_singleton_client, SingletonClientArgs},
|
||||
SleepInhibitor,
|
||||
},
|
||||
};
|
||||
|
||||
impl From<AuthProvider> for crate::auth::AuthProvider {
|
||||
fn from(auth_provider: AuthProvider) -> Self {
|
||||
match auth_provider {
|
||||
AuthProvider::Github => crate::auth::AuthProvider::Github,
|
||||
AuthProvider::Microsoft => crate::auth::AuthProvider::Microsoft,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ExistingTunnelArgs> for Option<dev_tunnels::ExistingTunnel> {
|
||||
fn from(d: ExistingTunnelArgs) -> Option<dev_tunnels::ExistingTunnel> {
|
||||
if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) =
|
||||
(d.tunnel_id, d.tunnel_name, d.cluster, d.host_token)
|
||||
{
|
||||
Some(dev_tunnels::ExistingTunnel {
|
||||
tunnel_id,
|
||||
tunnel_name,
|
||||
host_token,
|
||||
cluster,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TunnelServiceContainer {
|
||||
args: CliCore,
|
||||
}
|
||||
|
||||
impl TunnelServiceContainer {
|
||||
fn new(args: CliCore) -> Self {
|
||||
Self { args }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ServiceContainer for TunnelServiceContainer {
|
||||
async fn run_service(
|
||||
&mut self,
|
||||
log: log::Logger,
|
||||
launcher_paths: LauncherPaths,
|
||||
) -> Result<(), AnyError> {
|
||||
let csa = (&self.args).into();
|
||||
serve_with_csa(
|
||||
launcher_paths,
|
||||
log,
|
||||
TunnelServeArgs {
|
||||
random_name: true, // avoid prompting
|
||||
..Default::default()
|
||||
},
|
||||
csa,
|
||||
TUNNEL_SERVICE_LOCK_NAME,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let platform = PreReqChecker::new().verify().await?;
|
||||
serve_stream(
|
||||
tokio::io::stdin(),
|
||||
tokio::io::stderr(),
|
||||
ServeStreamParams {
|
||||
log: ctx.log,
|
||||
launcher_paths: ctx.paths,
|
||||
platform,
|
||||
requires_auth: true,
|
||||
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
code_server_args: (&ctx.args).into(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
pub async fn service(
|
||||
ctx: CommandContext,
|
||||
service_args: TunnelServiceSubCommands,
|
||||
) -> Result<i32, AnyError> {
|
||||
let manager = create_service_manager(ctx.log.clone(), &ctx.paths);
|
||||
match service_args {
|
||||
TunnelServiceSubCommands::Install(args) => {
|
||||
// ensure logged in, otherwise subsequent serving will fail
|
||||
Auth::new(&ctx.paths, ctx.log.clone())
|
||||
.get_credential()
|
||||
.await?;
|
||||
|
||||
// likewise for license consent
|
||||
legal::require_consent(&ctx.paths, args.accept_server_license_terms)?;
|
||||
|
||||
let current_exe =
|
||||
std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?;
|
||||
|
||||
manager
|
||||
.register(
|
||||
current_exe,
|
||||
&[
|
||||
"--verbose",
|
||||
"--cli-data-dir",
|
||||
ctx.paths.root().as_os_str().to_string_lossy().as_ref(),
|
||||
"tunnel",
|
||||
"service",
|
||||
"internal-run",
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
ctx.log.result(format!("Service successfully installed! You can use `{} tunnel service log` to monitor it, and `{} tunnel service uninstall` to remove it.", APPLICATION_NAME, APPLICATION_NAME));
|
||||
}
|
||||
TunnelServiceSubCommands::Uninstall => {
|
||||
manager.unregister().await?;
|
||||
}
|
||||
TunnelServiceSubCommands::Log => {
|
||||
manager.show_logs().await?;
|
||||
}
|
||||
TunnelServiceSubCommands::InternalRun => {
|
||||
manager
|
||||
.run(ctx.paths.clone(), TunnelServiceContainer::new(ctx.args))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
pub async fn user(ctx: CommandContext, user_args: TunnelUserSubCommands) -> Result<i32, AnyError> {
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
match user_args {
|
||||
TunnelUserSubCommands::Login(login_args) => {
|
||||
auth.login(
|
||||
login_args.provider.map(|p| p.into()),
|
||||
login_args.access_token.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
TunnelUserSubCommands::Logout => {
|
||||
auth.clear_credentials()?;
|
||||
}
|
||||
TunnelUserSubCommands::Show => {
|
||||
if let Ok(Some(_)) = auth.get_current_credential() {
|
||||
ctx.log.result("logged in");
|
||||
} else {
|
||||
ctx.log.result("not logged in");
|
||||
return Ok(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Remove the tunnel used by this gateway, if any.
|
||||
pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Result<i32, AnyError> {
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
|
||||
dt.rename_tunnel(&rename_args.name).await?;
|
||||
ctx.log.result(format!(
|
||||
"Successfully renamed this gateway to {}",
|
||||
&rename_args.name
|
||||
));
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Remove the tunnel used by this gateway, if any.
|
||||
pub async fn unregister(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let auth = Auth::new(&ctx.paths, ctx.log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
|
||||
dt.remove_tunnel().await?;
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
pub async fn restart(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
do_single_rpc_call::<_, ()>(
|
||||
&ctx.paths.tunnel_lockfile(),
|
||||
ctx.log,
|
||||
protocol::singleton::METHOD_RESTART,
|
||||
protocol::EmptyObject {},
|
||||
)
|
||||
.await
|
||||
.map(|_| 0)
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub async fn kill(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
do_single_rpc_call::<_, ()>(
|
||||
&ctx.paths.tunnel_lockfile(),
|
||||
ctx.log,
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
protocol::EmptyObject {},
|
||||
)
|
||||
.await
|
||||
.map(|_| 0)
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
pub async fn status(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let status = do_single_rpc_call::<_, protocol::singleton::Status>(
|
||||
&ctx.paths.tunnel_lockfile(),
|
||||
ctx.log.clone(),
|
||||
protocol::singleton::METHOD_STATUS,
|
||||
protocol::EmptyObject {},
|
||||
)
|
||||
.await;
|
||||
|
||||
match status {
|
||||
Err(CodeError::NoRunningTunnel) => {
|
||||
ctx.log.result(CodeError::NoRunningTunnel.to_string());
|
||||
Ok(1)
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
Ok(s) => {
|
||||
ctx.log.result(serde_json::to_string(&s).unwrap());
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes unused servers.
|
||||
pub async fn prune(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
get_all_servers(&ctx.paths)
|
||||
.into_iter()
|
||||
.map(|s| s.server_paths(&ctx.paths))
|
||||
.filter(|s| s.get_running_pid().is_none())
|
||||
.try_for_each(|s| {
|
||||
ctx.log
|
||||
.result(format!("Deleted {}", s.server_dir.display()));
|
||||
s.delete()
|
||||
})
|
||||
.map_err(AnyError::from)?;
|
||||
|
||||
ctx.log.result("Successfully removed all unused servers");
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
/// Starts the gateway server.
|
||||
pub async fn serve(ctx: CommandContext, gateway_args: TunnelServeArgs) -> Result<i32, AnyError> {
|
||||
let CommandContext {
|
||||
log, paths, args, ..
|
||||
} = ctx;
|
||||
|
||||
let no_sleep = match gateway_args.no_sleep.then(SleepInhibitor::new) {
|
||||
Some(i) => match i.await {
|
||||
Ok(i) => Some(i),
|
||||
Err(e) => {
|
||||
warning!(log, "Could not inhibit sleep: {}", e);
|
||||
None
|
||||
}
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
|
||||
legal::require_consent(&paths, gateway_args.accept_server_license_terms)?;
|
||||
|
||||
let csa = (&args).into();
|
||||
let result = serve_with_csa(paths, log, gateway_args, csa, TUNNEL_CLI_LOCK_NAME).await;
|
||||
drop(no_sleep);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn get_connection_token(tunnel: &ActiveTunnel) -> String {
|
||||
let mut hash = Sha256::new();
|
||||
hash.update(tunnel.id.as_bytes());
|
||||
let result = hash.finalize();
|
||||
base64::encode_config(result, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
async fn serve_with_csa(
|
||||
paths: LauncherPaths,
|
||||
mut log: log::Logger,
|
||||
gateway_args: TunnelServeArgs,
|
||||
mut csa: CodeServerArgs,
|
||||
app_mutex_name: Option<&'static str>,
|
||||
) -> Result<i32, AnyError> {
|
||||
let log_broadcast = BroadcastLogSink::new();
|
||||
log = log.tee(log_broadcast.clone());
|
||||
log::install_global_logger(log.clone()); // re-install so that library logs are captured
|
||||
|
||||
// Intentionally read before starting the server. If the server updated and
|
||||
// respawn is requested, the old binary will get renamed, and then
|
||||
// current_exe will point to the wrong path.
|
||||
let current_exe = std::env::current_exe().unwrap();
|
||||
|
||||
let mut vec = vec![
|
||||
ShutdownRequest::CtrlC,
|
||||
ShutdownRequest::ExeUninstalled(current_exe.to_owned()),
|
||||
];
|
||||
if let Some(p) = gateway_args
|
||||
.parent_process_id
|
||||
.and_then(|p| Pid::from_str(&p).ok())
|
||||
{
|
||||
vec.push(ShutdownRequest::ParentProcessKilled(p));
|
||||
}
|
||||
let shutdown = ShutdownRequest::create_rx(vec);
|
||||
|
||||
let server = loop {
|
||||
if shutdown.is_open() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
match acquire_singleton(paths.tunnel_lockfile()).await {
|
||||
Ok(SingletonConnection::Client(stream)) => {
|
||||
debug!(log, "starting as client to singleton");
|
||||
let should_exit = start_singleton_client(SingletonClientArgs {
|
||||
log: log.clone(),
|
||||
shutdown: shutdown.clone(),
|
||||
stream,
|
||||
})
|
||||
.await;
|
||||
if should_exit {
|
||||
return Ok(0);
|
||||
}
|
||||
}
|
||||
Ok(SingletonConnection::Singleton(server)) => break server,
|
||||
Err(e) => {
|
||||
warning!(log, "error access singleton, retrying: {}", e);
|
||||
tokio::time::sleep(Duration::from_secs(2)).await
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!(log, "starting as new singleton");
|
||||
|
||||
let mut server =
|
||||
make_singleton_server(log_broadcast.clone(), log.clone(), server, shutdown.clone());
|
||||
let platform = spanf!(log, log.span("prereq"), PreReqChecker::new().verify())?;
|
||||
let _lock = app_mutex_name.map(AppMutex::new);
|
||||
|
||||
let auth = Auth::new(&paths, log.clone());
|
||||
let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths);
|
||||
loop {
|
||||
let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() {
|
||||
dt.start_existing_tunnel(d).await
|
||||
} else {
|
||||
dt.start_new_launcher_tunnel(gateway_args.name.as_deref(), gateway_args.random_name)
|
||||
.await
|
||||
}?;
|
||||
|
||||
csa.connection_token = Some(get_connection_token(&tunnel));
|
||||
|
||||
let mut r = start_singleton_server(SingletonServerArgs {
|
||||
log: log.clone(),
|
||||
tunnel,
|
||||
paths: &paths,
|
||||
code_server_args: &csa,
|
||||
platform,
|
||||
log_broadcast: &log_broadcast,
|
||||
shutdown: shutdown.clone(),
|
||||
server: &mut server,
|
||||
})
|
||||
.await?;
|
||||
r.tunnel.close().await.ok();
|
||||
|
||||
match r.next {
|
||||
Next::Respawn => {
|
||||
warning!(log, "respawn requested, starting new server");
|
||||
// reuse current args, but specify no-forward since tunnels will
|
||||
// already be running in this process, and we cannot do a login
|
||||
let args = std::env::args().skip(1).collect::<Vec<String>>();
|
||||
let exit = std::process::Command::new(current_exe)
|
||||
.args(args)
|
||||
.spawn()
|
||||
.map_err(|e| wrap(e, "error respawning after update"))?
|
||||
.wait()
|
||||
.map_err(|e| wrap(e, "error waiting for child"))?;
|
||||
|
||||
return Ok(exit.code().unwrap_or(1));
|
||||
}
|
||||
Next::Exit => return Ok(0),
|
||||
Next::Restart => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
50
cli/src/commands/update.rs
Normal file
50
cli/src/commands/update.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use indicatif::ProgressBar;
|
||||
|
||||
use crate::{
|
||||
constants::PRODUCT_NAME_LONG,
|
||||
self_update::SelfUpdate,
|
||||
update_service::UpdateService,
|
||||
util::{errors::AnyError, http::ReqwestSimpleHttp, input::ProgressBarReporter},
|
||||
};
|
||||
|
||||
use super::{args::StandaloneUpdateArgs, CommandContext};
|
||||
|
||||
pub async fn update(ctx: CommandContext, args: StandaloneUpdateArgs) -> Result<i32, AnyError> {
|
||||
let update_service = UpdateService::new(
|
||||
ctx.log.clone(),
|
||||
Arc::new(ReqwestSimpleHttp::with_client(ctx.http.clone())),
|
||||
);
|
||||
let update_service = SelfUpdate::new(&update_service)?;
|
||||
|
||||
let current_version = update_service.get_current_release().await?;
|
||||
if update_service.is_up_to_date_with(¤t_version) {
|
||||
ctx.log.result(format!(
|
||||
"{} is already to to date ({})",
|
||||
PRODUCT_NAME_LONG, current_version.commit
|
||||
));
|
||||
return Ok(1);
|
||||
}
|
||||
|
||||
if args.check {
|
||||
ctx.log
|
||||
.result(format!("Update to {} is available", current_version));
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let pb = ProgressBar::new(1);
|
||||
pb.set_message("Downloading...");
|
||||
update_service
|
||||
.do_update(¤t_version, ProgressBarReporter::from(pb))
|
||||
.await?;
|
||||
ctx.log
|
||||
.result(format!("Successfully updated to {}", current_version));
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
62
cli/src/commands/version.rs
Normal file
62
cli/src/commands/version.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::{
|
||||
desktop::{prompt_to_install, CodeVersionManager, RequestedVersion},
|
||||
log,
|
||||
util::{
|
||||
errors::{AnyError, NoInstallInUserProvidedPath},
|
||||
prereqs::PreReqChecker,
|
||||
},
|
||||
};
|
||||
|
||||
use super::{args::UseVersionArgs, CommandContext};
|
||||
|
||||
pub async fn switch_to(ctx: CommandContext, args: UseVersionArgs) -> Result<i32, AnyError> {
|
||||
let platform = PreReqChecker::new().verify().await?;
|
||||
let vm = CodeVersionManager::new(ctx.log.clone(), &ctx.paths, platform);
|
||||
let version = RequestedVersion::try_from(args.name.as_str())?;
|
||||
|
||||
let maybe_path = match args.install_dir {
|
||||
Some(d) => Some(
|
||||
CodeVersionManager::get_entrypoint_for_install_dir(&PathBuf::from(&d))
|
||||
.await
|
||||
.ok_or(NoInstallInUserProvidedPath(d))?,
|
||||
),
|
||||
None => vm.try_get_entrypoint(&version).await,
|
||||
};
|
||||
|
||||
match maybe_path {
|
||||
Some(p) => {
|
||||
vm.set_preferred_version(version.clone(), p.clone()).await?;
|
||||
print_now_using(&ctx.log, &version, &p);
|
||||
Ok(0)
|
||||
}
|
||||
None => {
|
||||
prompt_to_install(&version);
|
||||
Ok(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn show(ctx: CommandContext) -> Result<i32, AnyError> {
|
||||
let platform = PreReqChecker::new().verify().await?;
|
||||
let vm = CodeVersionManager::new(ctx.log.clone(), &ctx.paths, platform);
|
||||
|
||||
let version = vm.get_preferred_version();
|
||||
println!("Current quality: {}", version);
|
||||
match vm.try_get_entrypoint(&version).await {
|
||||
Some(p) => println!("Installation path: {}", p.display()),
|
||||
None => println!("No existing installation found"),
|
||||
}
|
||||
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn print_now_using(log: &log::Logger, version: &RequestedVersion, path: &Path) {
|
||||
log.result(format!("Now using {} from {}", version, path.display()));
|
||||
}
|
||||
129
cli/src/constants.rs
Normal file
129
cli/src/constants.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use const_format::concatcp;
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::options::Quality;
|
||||
|
||||
pub const CONTROL_PORT: u16 = 31545;
|
||||
|
||||
/// Protocol version sent to clients. This can be used to indiciate new or
|
||||
/// changed capabilities that clients may wish to leverage.
|
||||
/// 1 - Initial protocol version
|
||||
/// 2 - Addition of `serve.compressed` property to control whether servermsg's
|
||||
/// are compressed bidirectionally.
|
||||
/// 3 - The server's connection token is set to a SHA256 hash of the tunnel ID
|
||||
/// 4 - The server's msgpack messages are no longer length-prefixed
|
||||
pub const PROTOCOL_VERSION: u32 = 4;
|
||||
|
||||
/// Prefix for the tunnel tag that includes the version.
|
||||
pub const PROTOCOL_VERSION_TAG_PREFIX: &str = "protocolv";
|
||||
/// Tag for the current protocol version, which is included in dev tunnels.
|
||||
pub const PROTOCOL_VERSION_TAG: &str = concatcp!("protocolv", PROTOCOL_VERSION);
|
||||
|
||||
pub const VSCODE_CLI_VERSION: Option<&'static str> = option_env!("VSCODE_CLI_VERSION");
|
||||
pub const VSCODE_CLI_AI_KEY: Option<&'static str> = option_env!("VSCODE_CLI_AI_KEY");
|
||||
pub const VSCODE_CLI_AI_ENDPOINT: Option<&'static str> = option_env!("VSCODE_CLI_AI_ENDPOINT");
|
||||
pub const VSCODE_CLI_QUALITY: Option<&'static str> = option_env!("VSCODE_CLI_QUALITY");
|
||||
pub const DOCUMENTATION_URL: Option<&'static str> = option_env!("VSCODE_CLI_DOCUMENTATION_URL");
|
||||
pub const VSCODE_CLI_COMMIT: Option<&'static str> = option_env!("VSCODE_CLI_COMMIT");
|
||||
pub const VSCODE_CLI_UPDATE_ENDPOINT: Option<&'static str> =
|
||||
option_env!("VSCODE_CLI_UPDATE_ENDPOINT");
|
||||
|
||||
/// Windows lock name for the running tunnel service. Used by the setup script
|
||||
/// to detect a tunnel process. See #179265.
|
||||
pub const TUNNEL_SERVICE_LOCK_NAME: Option<&'static str> =
|
||||
option_env!("VSCODE_CLI_TUNNEL_SERVICE_MUTEX");
|
||||
|
||||
/// Windows lock name for the running tunnel without a service. Used by the setup
|
||||
/// script to detect a tunnel process. See #179265.
|
||||
pub const TUNNEL_CLI_LOCK_NAME: Option<&'static str> = option_env!("VSCODE_CLI_TUNNEL_CLI_MUTEX");
|
||||
|
||||
pub const TUNNEL_SERVICE_USER_AGENT_ENV_VAR: &str = "TUNNEL_SERVICE_USER_AGENT";
|
||||
|
||||
/// Application name as it appears on the CLI.
|
||||
pub const APPLICATION_NAME: &str = match option_env!("VSCODE_CLI_APPLICATION_NAME") {
|
||||
Some(n) => n,
|
||||
None => "code",
|
||||
};
|
||||
|
||||
/// Full name of the product with its version.
|
||||
pub const PRODUCT_NAME_LONG: &str = match option_env!("VSCODE_CLI_NAME_LONG") {
|
||||
Some(n) => n,
|
||||
None => "Code - OSS",
|
||||
};
|
||||
|
||||
/// Name of the application without quality information.
|
||||
pub const QUALITYLESS_PRODUCT_NAME: &str = match option_env!("VSCODE_CLI_QUALITYLESS_PRODUCT_NAME")
|
||||
{
|
||||
Some(n) => n,
|
||||
None => "Code",
|
||||
};
|
||||
|
||||
/// Name of the application without quality information.
|
||||
pub const QUALITYLESS_SERVER_NAME: &str = concatcp!(QUALITYLESS_PRODUCT_NAME, " Server");
|
||||
|
||||
/// Web URL the editor is hosted at. For VS Code, this is vscode.dev.
|
||||
pub const EDITOR_WEB_URL: Option<&'static str> = option_env!("VSCODE_CLI_EDITOR_WEB_URL");
|
||||
|
||||
/// Name shown in places where we need to tell a user what a process is, e.g. in sleep inhibition.
|
||||
pub const TUNNEL_ACTIVITY_NAME: &str = concatcp!(PRODUCT_NAME_LONG, " Tunnel");
|
||||
|
||||
const NONINTERACTIVE_VAR: &str = "VSCODE_CLI_NONINTERACTIVE";
|
||||
|
||||
/// Default data CLI data directory.
|
||||
pub const DEFAULT_DATA_PARENT_DIR: &str = match option_env!("VSCODE_CLI_DEFAULT_PARENT_DATA_DIR") {
|
||||
Some(n) => n,
|
||||
None => ".vscode-oss",
|
||||
};
|
||||
|
||||
pub fn get_default_user_agent() -> String {
|
||||
format!(
|
||||
"vscode-server-launcher/{}",
|
||||
VSCODE_CLI_VERSION.unwrap_or("dev")
|
||||
)
|
||||
}
|
||||
|
||||
const NO_COLOR_ENV: &str = "NO_COLOR";
|
||||
|
||||
lazy_static! {
|
||||
pub static ref TUNNEL_SERVICE_USER_AGENT: String =
|
||||
match std::env::var(TUNNEL_SERVICE_USER_AGENT_ENV_VAR) {
|
||||
Ok(ua) if !ua.is_empty() => format!("{} {}", ua, get_default_user_agent()),
|
||||
_ => get_default_user_agent(),
|
||||
};
|
||||
|
||||
/// Map of quality names to arrays of app IDs used for them, for example, `{"stable":["ABC123"]}`
|
||||
pub static ref WIN32_APP_IDS: Option<HashMap<Quality, Vec<String>>> =
|
||||
option_env!("VSCODE_CLI_WIN32_APP_IDS").and_then(|s| serde_json::from_str(s).unwrap());
|
||||
|
||||
/// Map of quality names to desktop download URIs
|
||||
pub static ref QUALITY_DOWNLOAD_URIS: Option<HashMap<Quality, String>> =
|
||||
option_env!("VSCODE_CLI_QUALITY_DOWNLOAD_URIS").and_then(|s| serde_json::from_str(s).unwrap());
|
||||
|
||||
/// Map of qualities to the long name of the app in that quality
|
||||
pub static ref PRODUCT_NAME_LONG_MAP: Option<HashMap<Quality, String>> =
|
||||
option_env!("VSCODE_CLI_NAME_LONG_MAP").and_then(|s| serde_json::from_str(s).unwrap());
|
||||
|
||||
/// Map of qualities to the application name
|
||||
pub static ref APPLICATION_NAME_MAP: Option<HashMap<Quality, String>> =
|
||||
option_env!("VSCODE_CLI_APPLICATION_NAME_MAP").and_then(|s| serde_json::from_str(s).unwrap());
|
||||
|
||||
/// Map of qualities to the server name
|
||||
pub static ref SERVER_NAME_MAP: Option<HashMap<Quality, String>> =
|
||||
option_env!("VSCODE_CLI_SERVER_NAME_MAP").and_then(|s| serde_json::from_str(s).unwrap());
|
||||
|
||||
/// Whether i/o interactions are allowed in the current CLI.
|
||||
pub static ref IS_A_TTY: bool = atty::is(atty::Stream::Stdin);
|
||||
|
||||
/// Whether i/o interactions are allowed in the current CLI.
|
||||
pub static ref COLORS_ENABLED: bool = *IS_A_TTY && std::env::var(NO_COLOR_ENV).is_err();
|
||||
|
||||
/// Whether i/o interactions are allowed in the current CLI.
|
||||
pub static ref IS_INTERACTIVE_CLI: bool = *IS_A_TTY && std::env::var(NONINTERACTIVE_VAR).is_err();
|
||||
}
|
||||
8
cli/src/desktop.rs
Normal file
8
cli/src/desktop.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
mod version_manager;
|
||||
|
||||
pub use version_manager::{prompt_to_install, CodeVersionManager, RequestedVersion};
|
||||
617
cli/src/desktop/version_manager.rs
Normal file
617
cli/src/desktop/version_manager.rs
Normal file
@@ -0,0 +1,617 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
ffi::OsString,
|
||||
fmt, io,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
constants::{QUALITYLESS_PRODUCT_NAME, QUALITY_DOWNLOAD_URIS},
|
||||
log,
|
||||
options::{self, Quality},
|
||||
state::{LauncherPaths, PersistedState},
|
||||
update_service::Platform,
|
||||
util::errors::{AnyError, InvalidRequestedVersion},
|
||||
};
|
||||
|
||||
/// Parsed instance that a user can request.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(tag = "t", content = "c")]
|
||||
pub enum RequestedVersion {
|
||||
Quality(options::Quality),
|
||||
Version {
|
||||
version: String,
|
||||
quality: options::Quality,
|
||||
},
|
||||
Commit {
|
||||
commit: String,
|
||||
quality: options::Quality,
|
||||
},
|
||||
Path(String),
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref SEMVER_RE: Regex = Regex::new(r"^\d+\.\d+\.\d+(-insider)?$").unwrap();
|
||||
static ref COMMIT_RE: Regex = Regex::new(r"^[a-z]+/[a-e0-f]{40}$").unwrap();
|
||||
}
|
||||
|
||||
impl RequestedVersion {
|
||||
pub fn get_command(&self) -> String {
|
||||
match self {
|
||||
RequestedVersion::Quality(quality) => {
|
||||
format!("code version use {}", quality.get_machine_name())
|
||||
}
|
||||
RequestedVersion::Version { version, .. } => {
|
||||
format!("code version use {}", version)
|
||||
}
|
||||
RequestedVersion::Commit { commit, quality } => {
|
||||
format!("code version use {}/{}", quality.get_machine_name(), commit)
|
||||
}
|
||||
RequestedVersion::Path(path) => {
|
||||
format!("code version use {}", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RequestedVersion {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
RequestedVersion::Quality(quality) => write!(f, "{}", quality.get_capitalized_name()),
|
||||
RequestedVersion::Version { version, .. } => {
|
||||
write!(f, "{}", version)
|
||||
}
|
||||
RequestedVersion::Commit { commit, quality } => {
|
||||
write!(f, "{}/{}", quality, commit)
|
||||
}
|
||||
RequestedVersion::Path(path) => write!(f, "{}", path),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for RequestedVersion {
|
||||
type Error = InvalidRequestedVersion;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
if let Ok(quality) = options::Quality::try_from(s) {
|
||||
return Ok(RequestedVersion::Quality(quality));
|
||||
}
|
||||
|
||||
if SEMVER_RE.is_match(s) {
|
||||
return Ok(RequestedVersion::Version {
|
||||
quality: if s.ends_with("-insider") {
|
||||
options::Quality::Insiders
|
||||
} else {
|
||||
options::Quality::Stable
|
||||
},
|
||||
version: s.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if Path::is_absolute(&PathBuf::from(s)) {
|
||||
return Ok(RequestedVersion::Path(s.to_string()));
|
||||
}
|
||||
|
||||
if COMMIT_RE.is_match(s) {
|
||||
let idx = s.find('/').expect("expected a /");
|
||||
if let Ok(quality) = options::Quality::try_from(&s[0..idx]) {
|
||||
return Ok(RequestedVersion::Commit {
|
||||
commit: s[idx + 1..].to_string(),
|
||||
quality,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Err(InvalidRequestedVersion())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
struct Stored {
|
||||
/// Map of requested versions to locations where those versions are installed.
|
||||
versions: Vec<(RequestedVersion, OsString)>,
|
||||
current: usize,
|
||||
}
|
||||
|
||||
pub struct CodeVersionManager {
|
||||
state: PersistedState<Stored>,
|
||||
log: log::Logger,
|
||||
}
|
||||
|
||||
impl CodeVersionManager {
|
||||
pub fn new(log: log::Logger, lp: &LauncherPaths, _platform: Platform) -> Self {
|
||||
CodeVersionManager {
|
||||
log,
|
||||
state: PersistedState::new(lp.root().join("versions.json")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries to find the binary entrypoint for VS Code installed in the path.
|
||||
pub async fn get_entrypoint_for_install_dir(path: &Path) -> Option<PathBuf> {
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
// Check whether the user is supplying a path to the CLI directly (e.g. #164622)
|
||||
if let Ok(true) = path.metadata().map(|m| m.is_file()) {
|
||||
let result = std::process::Command::new(path)
|
||||
.args(["--version"])
|
||||
.output()
|
||||
.map(|o| o.status.success());
|
||||
|
||||
if let Ok(true) = result {
|
||||
return Some(path.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, mut rx) = mpsc::channel(1);
|
||||
|
||||
// Look for all the possible paths in parallel
|
||||
for entry in DESKTOP_CLI_RELATIVE_PATH.split(',') {
|
||||
let my_path = path.join(entry);
|
||||
let my_tx = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if tokio::fs::metadata(&my_path).await.is_ok() {
|
||||
my_tx.send(my_path).await.ok();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
drop(tx); // drop so rx gets None if no sender emits
|
||||
|
||||
rx.recv().await
|
||||
}
|
||||
|
||||
/// Sets the "version" as the persisted one for the user.
|
||||
pub async fn set_preferred_version(
|
||||
&self,
|
||||
version: RequestedVersion,
|
||||
path: PathBuf,
|
||||
) -> Result<(), AnyError> {
|
||||
let mut stored = self.state.load();
|
||||
stored.current = self.store_version_path(&mut stored, version, path);
|
||||
self.state.save(stored)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stores or updates the path used for the given version. Returns the index
|
||||
/// that the path exists at.
|
||||
fn store_version_path(
|
||||
&self,
|
||||
state: &mut Stored,
|
||||
version: RequestedVersion,
|
||||
path: PathBuf,
|
||||
) -> usize {
|
||||
if let Some(i) = state.versions.iter().position(|(v, _)| v == &version) {
|
||||
state.versions[i].1 = path.into_os_string();
|
||||
i
|
||||
} else {
|
||||
state
|
||||
.versions
|
||||
.push((version.clone(), path.into_os_string()));
|
||||
state.versions.len() - 1
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the currently preferred version based on set_preferred_version.
|
||||
pub fn get_preferred_version(&self) -> RequestedVersion {
|
||||
let stored = self.state.load();
|
||||
stored
|
||||
.versions
|
||||
.get(stored.current)
|
||||
.map(|(v, _)| v.clone())
|
||||
.unwrap_or(RequestedVersion::Quality(options::Quality::Stable))
|
||||
}
|
||||
|
||||
/// Tries to get the entrypoint for the version, if one can be found.
|
||||
pub async fn try_get_entrypoint(&self, version: &RequestedVersion) -> Option<PathBuf> {
|
||||
let mut state = self.state.load();
|
||||
if let Some((_, install_path)) = state.versions.iter().find(|(v, _)| v == version) {
|
||||
let p = PathBuf::from(install_path);
|
||||
if p.exists() {
|
||||
return Some(p);
|
||||
}
|
||||
}
|
||||
|
||||
// For simple quality requests, see if that's installed already on the system
|
||||
let candidates = match &version {
|
||||
RequestedVersion::Quality(q) => match detect_installed_program(&self.log, *q) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warning!(self.log, "error looking up installed applications: {}", e);
|
||||
return None;
|
||||
}
|
||||
},
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let found = match candidates.into_iter().next() {
|
||||
Some(p) => p,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
// stash the found path for faster lookup
|
||||
self.store_version_path(&mut state, version.clone(), found.clone());
|
||||
if let Err(e) = self.state.save(state) {
|
||||
debug!(self.log, "error caching version path: {}", e);
|
||||
}
|
||||
|
||||
Some(found)
|
||||
}
|
||||
}
|
||||
|
||||
/// Shows a nice UI prompt to users asking them if they want to install the
|
||||
/// requested version.
|
||||
pub fn prompt_to_install(version: &RequestedVersion) {
|
||||
println!(
|
||||
"No installation of {} {} was found.",
|
||||
QUALITYLESS_PRODUCT_NAME, version
|
||||
);
|
||||
|
||||
if let RequestedVersion::Quality(quality) = version {
|
||||
if let Some(uri) = QUALITY_DOWNLOAD_URIS.as_ref().and_then(|m| m.get(quality)) {
|
||||
// todo: on some platforms, we may be able to help automate installation. For example,
|
||||
// we can unzip the app ourselves on macOS and on windows we can download and spawn the GUI installer
|
||||
#[cfg(target_os = "linux")]
|
||||
println!("Install it from your system's package manager or {}, restart your shell, and try again.", uri);
|
||||
#[cfg(target_os = "macos")]
|
||||
println!("Download and unzip it from {} and try again.", uri);
|
||||
#[cfg(target_os = "windows")]
|
||||
println!("Install it from {} and try again.", uri);
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
println!("If you already installed {} and we didn't detect it, run `{} --install-dir /path/to/installation`", QUALITYLESS_PRODUCT_NAME, version.get_command());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
fn detect_installed_program(log: &log::Logger, quality: Quality) -> io::Result<Vec<PathBuf>> {
|
||||
// easy, fast detection for where apps are usually installed
|
||||
let mut probable = PathBuf::from("/Applications");
|
||||
let app_name = quality.get_long_name();
|
||||
probable.push(format!("{}.app", app_name));
|
||||
if probable.exists() {
|
||||
probable.extend(["Contents/Resources", "app", "bin", "code"]);
|
||||
return Ok(vec![probable]);
|
||||
}
|
||||
|
||||
// _Much_ slower detection using the system_profiler (~10s for me). While the
|
||||
// profiler can output nicely structure plist xml, pulling in an xml parser
|
||||
// just for this is overkill. The default output looks something like...
|
||||
//
|
||||
// Visual Studio Code - Exploration 2:
|
||||
//
|
||||
// Version: 1.73.0-exploration
|
||||
// Obtained from: Identified Developer
|
||||
// Last Modified: 9/23/22, 10:16 AM
|
||||
// Kind: Intel
|
||||
// Signed by: Developer ID Application: Microsoft Corporation (UBF8T346G9), Developer ID Certification Authority, Apple Root CA
|
||||
// Location: /Users/connor/Downloads/Visual Studio Code - Exploration 2.app
|
||||
//
|
||||
// So, use a simple state machine that looks for the first line, and then for
|
||||
// the `Location:` line for the path.
|
||||
info!(log, "Searching for installations on your machine, this is done once and will take about 10 seconds...");
|
||||
|
||||
let stdout = std::process::Command::new("system_profiler")
|
||||
.args(["SPApplicationsDataType", "-detailLevel", "mini"])
|
||||
.output()?
|
||||
.stdout;
|
||||
|
||||
enum State {
|
||||
LookingForName,
|
||||
LookingForLocation,
|
||||
}
|
||||
|
||||
let mut state = State::LookingForName;
|
||||
let mut output: Vec<PathBuf> = vec![];
|
||||
const LOCATION_PREFIX: &str = "Location:";
|
||||
for mut line in String::from_utf8_lossy(&stdout).lines() {
|
||||
line = line.trim();
|
||||
match state {
|
||||
State::LookingForName => {
|
||||
if line.starts_with(app_name) && line.ends_with(':') {
|
||||
state = State::LookingForLocation;
|
||||
}
|
||||
}
|
||||
State::LookingForLocation => {
|
||||
if let Some(suffix) = line.strip_prefix(LOCATION_PREFIX) {
|
||||
output.push(
|
||||
[suffix.trim(), "Contents/Resources", "app", "bin", "code"]
|
||||
.iter()
|
||||
.collect(),
|
||||
);
|
||||
state = State::LookingForName;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort shorter paths to the front, preferring "more global" installs, and
|
||||
// incidentally preferring local installs over Parallels 'installs'.
|
||||
output.sort_by_key(|a| a.as_os_str().len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn detect_installed_program(_log: &log::Logger, quality: Quality) -> io::Result<Vec<PathBuf>> {
|
||||
use crate::constants::WIN32_APP_IDS;
|
||||
use winreg::enums::{HKEY_CURRENT_USER, HKEY_LOCAL_MACHINE};
|
||||
use winreg::RegKey;
|
||||
|
||||
let mut output: Vec<PathBuf> = vec![];
|
||||
let app_ids = match WIN32_APP_IDS.as_ref().and_then(|m| m.get(&quality)) {
|
||||
Some(ids) => ids,
|
||||
None => return Ok(output),
|
||||
};
|
||||
|
||||
let scopes = [
|
||||
(
|
||||
HKEY_LOCAL_MACHINE,
|
||||
"SOFTWARE\\Wow6432Node\\Microsoft\\Windows\\CurrentVersion\\Uninstall",
|
||||
),
|
||||
(
|
||||
HKEY_LOCAL_MACHINE,
|
||||
"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Uninstall",
|
||||
),
|
||||
(
|
||||
HKEY_CURRENT_USER,
|
||||
"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Uninstall",
|
||||
),
|
||||
];
|
||||
|
||||
for (scope, key) in scopes {
|
||||
let cur_ver = match RegKey::predef(scope).open_subkey(key) {
|
||||
Ok(k) => k,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
for key in cur_ver.enum_keys().flatten() {
|
||||
if app_ids.iter().any(|id| key.contains(id)) {
|
||||
let sk = cur_ver.open_subkey(&key)?;
|
||||
if let Ok(location) = sk.get_value::<String, _>("InstallLocation") {
|
||||
output.push(
|
||||
[
|
||||
location.as_str(),
|
||||
"bin",
|
||||
&format!("{}.cmd", quality.get_application_name()),
|
||||
]
|
||||
.iter()
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
// Looks for the given binary name in the PATH, returning all candidate matches.
|
||||
// Based on https://github.dev/microsoft/vscode-js-debug/blob/7594d05518df6700df51771895fcad0ddc7f92f9/src/common/pathUtils.ts#L15
|
||||
#[cfg(target_os = "linux")]
|
||||
fn detect_installed_program(log: &log::Logger, quality: Quality) -> io::Result<Vec<PathBuf>> {
|
||||
let path = match std::env::var("PATH") {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
info!(log, "PATH is empty ({}), skipping detection", e);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
};
|
||||
|
||||
let name = quality.get_application_name();
|
||||
let current_exe = std::env::current_exe().expect("expected to read current exe");
|
||||
let mut output = vec![];
|
||||
for dir in path.split(':') {
|
||||
let target: PathBuf = [dir, name].iter().collect();
|
||||
match std::fs::canonicalize(&target) {
|
||||
Ok(m) if m == current_exe => continue,
|
||||
Ok(_) => {}
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// note: intentionally store the non-canonicalized version, since if it's a
|
||||
// symlink, (1) it's probably desired to use it and (2) resolving the link
|
||||
// breaks snap installations.
|
||||
output.push(target);
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
const DESKTOP_CLI_RELATIVE_PATH: &str = if cfg!(target_os = "macos") {
|
||||
"Contents/Resources/app/bin/code"
|
||||
} else if cfg!(target_os = "windows") {
|
||||
"bin/code.cmd,bin/code-insiders.cmd,bin/code-exploration.cmd"
|
||||
} else {
|
||||
"bin/code,bin/code-insiders,bin/code-exploration"
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
fs::{create_dir_all, File},
|
||||
io::Write,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn make_fake_vscode_install(path: &Path) {
|
||||
let bin = DESKTOP_CLI_RELATIVE_PATH
|
||||
.split(',')
|
||||
.next()
|
||||
.expect("expected exe path");
|
||||
|
||||
let binary_file_path = path.join(bin);
|
||||
let parent_dir_path = binary_file_path.parent().expect("expected parent path");
|
||||
|
||||
create_dir_all(parent_dir_path).expect("expected to create parent dir");
|
||||
|
||||
let mut binary_file = File::create(binary_file_path).expect("expected to make file");
|
||||
binary_file
|
||||
.write_all(b"")
|
||||
.expect("expected to write binary");
|
||||
}
|
||||
|
||||
fn make_multiple_vscode_install() -> tempfile::TempDir {
|
||||
let dir = tempfile::tempdir().expect("expected to make temp dir");
|
||||
make_fake_vscode_install(&dir.path().join("desktop/stable"));
|
||||
make_fake_vscode_install(&dir.path().join("desktop/1.68.2"));
|
||||
dir
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_installed_program() {
|
||||
// developers can run this test and debug output manually; VS Code will not
|
||||
// be installed in CI, so the test only makes sure it doesn't error out
|
||||
let result = detect_installed_program(&log::Logger::test(), Quality::Insiders);
|
||||
println!("result: {:?}", result);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_requested_version_parses() {
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from("1.2.3").unwrap(),
|
||||
RequestedVersion::Version {
|
||||
quality: options::Quality::Stable,
|
||||
version: "1.2.3".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from("1.2.3-insider").unwrap(),
|
||||
RequestedVersion::Version {
|
||||
quality: options::Quality::Insiders,
|
||||
version: "1.2.3-insider".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from("stable").unwrap(),
|
||||
RequestedVersion::Quality(options::Quality::Stable)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from("insiders").unwrap(),
|
||||
RequestedVersion::Quality(options::Quality::Insiders)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from("insiders/92fd228156aafeb326b23f6604028d342152313b")
|
||||
.unwrap(),
|
||||
RequestedVersion::Commit {
|
||||
commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(),
|
||||
quality: options::Quality::Insiders
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from("stable/92fd228156aafeb326b23f6604028d342152313b").unwrap(),
|
||||
RequestedVersion::Commit {
|
||||
commit: "92fd228156aafeb326b23f6604028d342152313b".to_string(),
|
||||
quality: options::Quality::Stable
|
||||
}
|
||||
);
|
||||
|
||||
let exe = std::env::current_exe()
|
||||
.expect("expected to get exe")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
assert_eq!(
|
||||
RequestedVersion::try_from(exe.as_str()).unwrap(),
|
||||
RequestedVersion::Path(exe),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_preferred_version() {
|
||||
let dir = make_multiple_vscode_install();
|
||||
let lp = LauncherPaths::new_without_replacements(dir.path().to_owned());
|
||||
let vm1 = CodeVersionManager::new(log::Logger::test(), &lp, Platform::LinuxARM64);
|
||||
|
||||
assert_eq!(
|
||||
vm1.get_preferred_version(),
|
||||
RequestedVersion::Quality(options::Quality::Stable)
|
||||
);
|
||||
vm1.set_preferred_version(
|
||||
RequestedVersion::Quality(options::Quality::Exploration),
|
||||
dir.path().join("desktop/stable"),
|
||||
)
|
||||
.await
|
||||
.expect("expected to store");
|
||||
vm1.set_preferred_version(
|
||||
RequestedVersion::Quality(options::Quality::Insiders),
|
||||
dir.path().join("desktop/stable"),
|
||||
)
|
||||
.await
|
||||
.expect("expected to store");
|
||||
assert_eq!(
|
||||
vm1.get_preferred_version(),
|
||||
RequestedVersion::Quality(options::Quality::Insiders)
|
||||
);
|
||||
|
||||
let vm2 = CodeVersionManager::new(log::Logger::test(), &lp, Platform::LinuxARM64);
|
||||
assert_eq!(
|
||||
vm2.get_preferred_version(),
|
||||
RequestedVersion::Quality(options::Quality::Insiders)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gets_entrypoint() {
|
||||
let dir = make_multiple_vscode_install();
|
||||
|
||||
assert!(CodeVersionManager::get_entrypoint_for_install_dir(
|
||||
&dir.path().join("desktop").join("stable")
|
||||
)
|
||||
.await
|
||||
.is_some());
|
||||
|
||||
assert!(
|
||||
CodeVersionManager::get_entrypoint_for_install_dir(&dir.path().join("invalid"))
|
||||
.await
|
||||
.is_none()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gets_entrypoint_as_binary() {
|
||||
let dir = tempfile::tempdir().expect("expected to make temp dir");
|
||||
|
||||
#[cfg(windows)]
|
||||
let binary_file_path = {
|
||||
let path = dir.path().join("code.cmd");
|
||||
File::create(&path).expect("expected to create file");
|
||||
path
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let binary_file_path = {
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let path = dir.path().join("code");
|
||||
{
|
||||
let mut f = File::create(&path).expect("expected to create file");
|
||||
f.write_all(b"#!/bin/sh")
|
||||
.expect("expected to write to file");
|
||||
}
|
||||
fs::set_permissions(&path, fs::Permissions::from_mode(0o777))
|
||||
.expect("expected to set permissions");
|
||||
path
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
CodeVersionManager::get_entrypoint_for_install_dir(&binary_file_path).await,
|
||||
Some(binary_file_path)
|
||||
);
|
||||
}
|
||||
}
|
||||
119
cli/src/download_cache.rs
Normal file
119
cli/src/download_cache.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::create_dir_all,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use futures::Future;
|
||||
use tokio::fs::remove_dir_all;
|
||||
|
||||
use crate::{
|
||||
state::PersistedState,
|
||||
util::errors::{wrap, AnyError, WrappedError},
|
||||
};
|
||||
|
||||
const KEEP_LRU: usize = 5;
|
||||
const STAGING_SUFFIX: &str = ".staging";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DownloadCache {
|
||||
path: PathBuf,
|
||||
state: PersistedState<Vec<String>>,
|
||||
}
|
||||
|
||||
impl DownloadCache {
|
||||
pub fn new(path: PathBuf) -> DownloadCache {
|
||||
DownloadCache {
|
||||
state: PersistedState::new(path.join("lru.json")),
|
||||
path,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the download cache path. Names of cache entries can be formed by
|
||||
/// joining them to the path.
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Gets whether a cache exists with the name already. Marks it as recently
|
||||
/// used if it does exist.
|
||||
pub fn exists(&self, name: &str) -> Option<PathBuf> {
|
||||
let p = self.path.join(name);
|
||||
if !p.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let _ = self.touch(name.to_string());
|
||||
Some(p)
|
||||
}
|
||||
|
||||
/// Removes the item from the cache, if it exists
|
||||
pub fn delete(&self, name: &str) -> Result<(), WrappedError> {
|
||||
let f = self.path.join(name);
|
||||
if f.exists() {
|
||||
std::fs::remove_dir_all(f).map_err(|e| wrap(e, "error removing cached folder"))?;
|
||||
}
|
||||
|
||||
self.state.update(|l| {
|
||||
l.retain(|n| n != name);
|
||||
})
|
||||
}
|
||||
|
||||
/// Calls the function to create the cached folder if it doesn't exist,
|
||||
/// returning the path where the folder is. Note that the path passed to
|
||||
/// the `do_create` method is a staging path and will not be the same as the
|
||||
/// final returned path.
|
||||
pub async fn create<F, T>(
|
||||
&self,
|
||||
name: impl AsRef<str>,
|
||||
do_create: F,
|
||||
) -> Result<PathBuf, AnyError>
|
||||
where
|
||||
F: FnOnce(PathBuf) -> T,
|
||||
T: Future<Output = Result<(), AnyError>> + Send,
|
||||
{
|
||||
let name = name.as_ref();
|
||||
let target_dir = self.path.join(name);
|
||||
if target_dir.exists() {
|
||||
return Ok(target_dir);
|
||||
}
|
||||
|
||||
let temp_dir = self.path.join(format!("{}{}", name, STAGING_SUFFIX));
|
||||
let _ = remove_dir_all(&temp_dir).await; // cleanup any existing
|
||||
|
||||
create_dir_all(&temp_dir).map_err(|e| wrap(e, "error creating server directory"))?;
|
||||
do_create(temp_dir.clone()).await?;
|
||||
|
||||
let _ = self.touch(name.to_string());
|
||||
std::fs::rename(&temp_dir, &target_dir)
|
||||
.map_err(|e| wrap(e, "error renaming downloaded server"))?;
|
||||
|
||||
Ok(target_dir)
|
||||
}
|
||||
|
||||
fn touch(&self, name: String) -> Result<(), AnyError> {
|
||||
self.state.update(|l| {
|
||||
if let Some(index) = l.iter().position(|s| s == &name) {
|
||||
l.remove(index);
|
||||
}
|
||||
l.insert(0, name);
|
||||
|
||||
if l.len() <= KEEP_LRU {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(f) = l.last() {
|
||||
let f = self.path.join(f);
|
||||
if !f.exists() || std::fs::remove_dir_all(f).is_ok() {
|
||||
l.pop();
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
106
cli/src/json_rpc.rs
Normal file
106
cli/src/json_rpc.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader},
|
||||
pin,
|
||||
sync::mpsc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
rpc::{self, MaybeSync, Serialization},
|
||||
util::{
|
||||
errors::InvalidRpcDataError,
|
||||
sync::{Barrier, Receivable},
|
||||
},
|
||||
};
|
||||
use std::io;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JsonRpcSerializer {}
|
||||
|
||||
impl Serialization for JsonRpcSerializer {
|
||||
fn serialize(&self, value: impl serde::Serialize) -> Vec<u8> {
|
||||
let mut v = serde_json::to_vec(&value).unwrap();
|
||||
v.push(b'\n');
|
||||
v
|
||||
}
|
||||
|
||||
fn deserialize<P: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
b: &[u8],
|
||||
) -> Result<P, crate::util::errors::AnyError> {
|
||||
serde_json::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new RPC Builder that serializes to JSON.
|
||||
#[allow(dead_code)]
|
||||
pub fn new_json_rpc() -> rpc::RpcBuilder<JsonRpcSerializer> {
|
||||
rpc::RpcBuilder::new(JsonRpcSerializer {})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
|
||||
dispatcher: rpc::RpcDispatcher<JsonRpcSerializer, C>,
|
||||
read: impl AsyncRead + Unpin,
|
||||
mut write: impl AsyncWrite + Unpin,
|
||||
mut msg_rx: impl Receivable<Vec<u8>>,
|
||||
mut shutdown_rx: Barrier<S>,
|
||||
) -> io::Result<Option<S>> {
|
||||
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
|
||||
let mut read = BufReader::new(read);
|
||||
|
||||
let mut read_buf = String::new();
|
||||
let shutdown_fut = shutdown_rx.wait();
|
||||
pin!(shutdown_fut);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
r = &mut shutdown_fut => return Ok(r.ok()),
|
||||
Some(w) = write_rx.recv() => {
|
||||
write.write_all(&w).await?;
|
||||
},
|
||||
Some(w) = msg_rx.recv_msg() => {
|
||||
write.write_all(&w).await?;
|
||||
},
|
||||
n = read.read_line(&mut read_buf) => {
|
||||
let r = match n {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => dispatcher.dispatch(read_buf[..n].as_bytes()),
|
||||
Err(e) => return Err(e)
|
||||
};
|
||||
|
||||
read_buf.truncate(0);
|
||||
|
||||
match r {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
write.write_all(&v).await?;
|
||||
},
|
||||
MaybeSync::Sync(None) => continue,
|
||||
MaybeSync::Future(fut) => {
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
},
|
||||
MaybeSync::Stream((dto, fut)) => {
|
||||
if let Some(dto) = dto {
|
||||
dispatcher.register_stream(write_tx.clone(), dto).await;
|
||||
}
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
26
cli/src/lib.rs
Normal file
26
cli/src/lib.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
// todo: we should reduce the exported surface area over time as things are
|
||||
// moved into a common CLI
|
||||
pub mod auth;
|
||||
pub mod constants;
|
||||
#[macro_use]
|
||||
pub mod log;
|
||||
pub mod commands;
|
||||
pub mod desktop;
|
||||
pub mod options;
|
||||
pub mod self_update;
|
||||
pub mod state;
|
||||
pub mod tunnels;
|
||||
pub mod update_service;
|
||||
pub mod util;
|
||||
|
||||
mod download_cache;
|
||||
mod async_pipe;
|
||||
mod json_rpc;
|
||||
mod msgpack_rpc;
|
||||
mod rpc;
|
||||
mod singleton;
|
||||
455
cli/src/log.rs
Normal file
455
cli/src/log.rs
Normal file
@@ -0,0 +1,455 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use chrono::Local;
|
||||
use opentelemetry::{
|
||||
sdk::trace::{Tracer, TracerProvider},
|
||||
trace::{SpanBuilder, Tracer as TraitTracer, TracerProvider as TracerProviderTrait},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::{
|
||||
io::Write,
|
||||
sync::atomic::{AtomicU32, Ordering},
|
||||
};
|
||||
use std::{path::Path, sync::Arc};
|
||||
|
||||
use crate::constants::COLORS_ENABLED;
|
||||
|
||||
static INSTANCE_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
// Gets a next incrementing number that can be used in logs
|
||||
pub fn next_counter() -> u32 {
|
||||
INSTANCE_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
// Log level
|
||||
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub enum Level {
|
||||
Trace = 0,
|
||||
Debug,
|
||||
#[default]
|
||||
Info,
|
||||
Warn,
|
||||
Error,
|
||||
Critical,
|
||||
Off,
|
||||
}
|
||||
|
||||
|
||||
|
||||
impl fmt::Display for Level {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Level::Critical => write!(f, "critical"),
|
||||
Level::Debug => write!(f, "debug"),
|
||||
Level::Error => write!(f, "error"),
|
||||
Level::Info => write!(f, "info"),
|
||||
Level::Off => write!(f, "off"),
|
||||
Level::Trace => write!(f, "trace"),
|
||||
Level::Warn => write!(f, "warn"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Level {
|
||||
pub fn name(&self) -> Option<&str> {
|
||||
match self {
|
||||
Level::Trace => Some("trace"),
|
||||
Level::Debug => Some("debug"),
|
||||
Level::Info => Some("info"),
|
||||
Level::Warn => Some("warn"),
|
||||
Level::Error => Some("error"),
|
||||
Level::Critical => Some("critical"),
|
||||
Level::Off => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn color_code(&self) -> Option<&str> {
|
||||
if !*COLORS_ENABLED {
|
||||
return None;
|
||||
}
|
||||
|
||||
match self {
|
||||
Level::Trace => None,
|
||||
Level::Debug => Some("\x1b[36m"),
|
||||
Level::Info => Some("\x1b[35m"),
|
||||
Level::Warn => Some("\x1b[33m"),
|
||||
Level::Error => Some("\x1b[31m"),
|
||||
Level::Critical => Some("\x1b[31m"),
|
||||
Level::Off => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_u8(self) -> u8 {
|
||||
self as u8
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tunnel_prefix() -> String {
|
||||
format!("[tunnel.{}]", next_counter())
|
||||
}
|
||||
|
||||
pub fn new_code_server_prefix() -> String {
|
||||
format!("[codeserver.{}]", next_counter())
|
||||
}
|
||||
|
||||
pub fn new_rpc_prefix() -> String {
|
||||
format!("[rpc.{}]", next_counter())
|
||||
}
|
||||
|
||||
// Base logger implementation
|
||||
#[derive(Clone)]
|
||||
pub struct Logger {
|
||||
tracer: Arc<Tracer>,
|
||||
sink: Vec<Box<dyn LogSink>>,
|
||||
prefix: Option<String>,
|
||||
}
|
||||
|
||||
// Copy trick from https://stackoverflow.com/a/30353928
|
||||
pub trait LogSinkClone {
|
||||
fn clone_box(&self) -> Box<dyn LogSink>;
|
||||
}
|
||||
|
||||
impl<T> LogSinkClone for T
|
||||
where
|
||||
T: 'static + LogSink + Clone,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn LogSink> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait LogSink: LogSinkClone + Sync + Send {
|
||||
fn write_log(&self, level: Level, prefix: &str, message: &str);
|
||||
fn write_result(&self, message: &str);
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn LogSink> {
|
||||
fn clone(&self) -> Box<dyn LogSink> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
/// The basic log sink that writes output to stdout, with colors when relevant.
|
||||
#[derive(Clone)]
|
||||
pub struct StdioLogSink {
|
||||
level: Level,
|
||||
}
|
||||
|
||||
impl LogSink for StdioLogSink {
|
||||
fn write_log(&self, level: Level, prefix: &str, message: &str) {
|
||||
if level < self.level {
|
||||
return;
|
||||
}
|
||||
|
||||
emit(level, prefix, message);
|
||||
}
|
||||
|
||||
fn write_result(&self, message: &str) {
|
||||
println!("{}", message);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FileLogSink {
|
||||
level: Level,
|
||||
file: Arc<std::sync::Mutex<std::fs::File>>,
|
||||
}
|
||||
|
||||
impl FileLogSink {
|
||||
pub fn new(level: Level, path: &Path) -> std::io::Result<Self> {
|
||||
let file = std::fs::File::create(path)?;
|
||||
Ok(Self {
|
||||
level,
|
||||
file: Arc::new(std::sync::Mutex::new(file)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LogSink for FileLogSink {
|
||||
fn write_log(&self, level: Level, prefix: &str, message: &str) {
|
||||
if level < self.level {
|
||||
return;
|
||||
}
|
||||
|
||||
let line = format(level, prefix, message, false);
|
||||
|
||||
// ignore any errors, not much we can do if logging fails...
|
||||
self.file.lock().unwrap().write_all(line.as_bytes()).ok();
|
||||
}
|
||||
|
||||
fn write_result(&self, _message: &str) {}
|
||||
}
|
||||
|
||||
impl Logger {
|
||||
pub fn test() -> Self {
|
||||
Self {
|
||||
tracer: Arc::new(TracerProvider::builder().build().tracer("codeclitest")),
|
||||
sink: vec![],
|
||||
prefix: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(tracer: Tracer, level: Level) -> Self {
|
||||
Self {
|
||||
tracer: Arc::new(tracer),
|
||||
sink: vec![Box::new(StdioLogSink { level })],
|
||||
prefix: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn span(&self, name: &str) -> SpanBuilder {
|
||||
self.tracer.span_builder(format!("serverlauncher/{}", name))
|
||||
}
|
||||
|
||||
pub fn tracer(&self) -> &Tracer {
|
||||
&self.tracer
|
||||
}
|
||||
|
||||
pub fn emit(&self, level: Level, message: &str) {
|
||||
let prefix = self.prefix.as_deref().unwrap_or("");
|
||||
for sink in &self.sink {
|
||||
sink.write_log(level, prefix, message);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn result(&self, message: impl AsRef<str>) {
|
||||
for sink in &self.sink {
|
||||
sink.write_result(message.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prefixed(&self, prefix: &str) -> Logger {
|
||||
Logger {
|
||||
prefix: Some(match &self.prefix {
|
||||
Some(p) => format!("{}{} ", p, prefix),
|
||||
None => format!("{} ", prefix),
|
||||
}),
|
||||
..self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new logger with the additional log sink added.
|
||||
pub fn tee<T>(&self, sink: T) -> Logger
|
||||
where
|
||||
T: LogSink + 'static,
|
||||
{
|
||||
let mut new_sinks = self.sink.clone();
|
||||
new_sinks.push(Box::new(sink));
|
||||
|
||||
Logger {
|
||||
sink: new_sinks,
|
||||
..self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new logger with the sink replace with the given sink.
|
||||
pub fn with_sink<T>(&self, sink: T) -> Logger
|
||||
where
|
||||
T: LogSink + 'static,
|
||||
{
|
||||
Logger {
|
||||
sink: vec![Box::new(sink)],
|
||||
..self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_download_logger<'a>(&'a self, prefix: &'static str) -> DownloadLogger<'a> {
|
||||
DownloadLogger {
|
||||
prefix,
|
||||
logger: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DownloadLogger<'a> {
|
||||
prefix: &'static str,
|
||||
logger: &'a Logger,
|
||||
}
|
||||
|
||||
impl<'a> crate::util::io::ReportCopyProgress for DownloadLogger<'a> {
|
||||
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) {
|
||||
if total_bytes > 0 {
|
||||
self.logger.emit(
|
||||
Level::Trace,
|
||||
&format!(
|
||||
"{} {}/{} ({:.0}%)",
|
||||
self.prefix,
|
||||
bytes_so_far,
|
||||
total_bytes,
|
||||
(bytes_so_far as f64 / total_bytes as f64) * 100.0,
|
||||
),
|
||||
);
|
||||
} else {
|
||||
self.logger.emit(
|
||||
Level::Trace,
|
||||
&format!("{} {}/{}", self.prefix, bytes_so_far, total_bytes,),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format(level: Level, prefix: &str, message: &str, use_colors: bool) -> String {
|
||||
let current = Local::now();
|
||||
let timestamp = current.format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
|
||||
let name = level.name().unwrap();
|
||||
|
||||
if use_colors {
|
||||
if let Some(c) = level.color_code() {
|
||||
return format!(
|
||||
"\x1b[2m[{}]\x1b[0m {}{}\x1b[0m {}{}\n",
|
||||
timestamp, c, name, prefix, message
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
format!("[{}] {} {}{}\n", timestamp, name, prefix, message)
|
||||
}
|
||||
|
||||
pub fn emit(level: Level, prefix: &str, message: &str) {
|
||||
let line = format(level, prefix, message, true);
|
||||
if level == Level::Trace {
|
||||
print!("\x1b[2m{}\x1b[0m", line);
|
||||
} else {
|
||||
print!("{}", line);
|
||||
}
|
||||
}
|
||||
|
||||
/// Installs the logger instance as the global logger for the 'log' service.
|
||||
/// Replaces any existing registered logger. Note that the logger will be leaked/
|
||||
pub fn install_global_logger(log: Logger) {
|
||||
log::set_logger(Box::leak(Box::new(RustyLogger(log))))
|
||||
.map(|()| log::set_max_level(log::LevelFilter::Debug))
|
||||
.expect("expected to make logger");
|
||||
}
|
||||
|
||||
/// Logger that uses the common rust "log" crate and directs back to one of
|
||||
/// our managed loggers.
|
||||
struct RustyLogger(Logger);
|
||||
|
||||
impl log::Log for RustyLogger {
|
||||
fn enabled(&self, metadata: &log::Metadata) -> bool {
|
||||
metadata.level() <= log::Level::Debug
|
||||
}
|
||||
|
||||
fn log(&self, record: &log::Record) {
|
||||
if !self.enabled(record.metadata()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// exclude noisy log modules:
|
||||
let src = match record.module_path() {
|
||||
Some("russh::cipher") => return,
|
||||
Some("russh::negotiation") => return,
|
||||
Some(s) => s,
|
||||
None => "<unknown>",
|
||||
};
|
||||
|
||||
self.0.emit(
|
||||
match record.level() {
|
||||
log::Level::Debug => Level::Debug,
|
||||
log::Level::Error => Level::Error,
|
||||
log::Level::Info => Level::Info,
|
||||
log::Level::Trace => Level::Trace,
|
||||
log::Level::Warn => Level::Warn,
|
||||
},
|
||||
&format!("[{}] {}", src, record.args()),
|
||||
);
|
||||
}
|
||||
|
||||
fn flush(&self) {}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! error {
|
||||
($logger:expr, $str:expr) => {
|
||||
$logger.emit(log::Level::Error, $str)
|
||||
};
|
||||
($logger:expr, $($fmt:expr),+) => {
|
||||
$logger.emit(log::Level::Error, &format!($($fmt),+))
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! trace {
|
||||
($logger:expr, $str:expr) => {
|
||||
$logger.emit(log::Level::Trace, $str)
|
||||
};
|
||||
($logger:expr, $($fmt:expr),+) => {
|
||||
$logger.emit(log::Level::Trace, &format!($($fmt),+))
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! debug {
|
||||
($logger:expr, $str:expr) => {
|
||||
$logger.emit(log::Level::Debug, $str)
|
||||
};
|
||||
($logger:expr, $($fmt:expr),+) => {
|
||||
$logger.emit(log::Level::Debug, &format!($($fmt),+))
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! info {
|
||||
($logger:expr, $str:expr) => {
|
||||
$logger.emit(log::Level::Info, $str)
|
||||
};
|
||||
($logger:expr, $($fmt:expr),+) => {
|
||||
$logger.emit(log::Level::Info, &format!($($fmt),+))
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! warning {
|
||||
($logger:expr, $str:expr) => {
|
||||
$logger.emit(log::Level::Warn, $str)
|
||||
};
|
||||
($logger:expr, $($fmt:expr),+) => {
|
||||
$logger.emit(log::Level::Warn, &format!($($fmt),+))
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! span {
|
||||
($logger:expr, $span:expr, $func:expr) => {{
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
|
||||
let span = $span.start($logger.tracer());
|
||||
let cx = opentelemetry::Context::current_with_span(span);
|
||||
let guard = cx.clone().attach();
|
||||
let t = $func;
|
||||
|
||||
if let Err(e) = &t {
|
||||
cx.span().record_error(e);
|
||||
}
|
||||
|
||||
std::mem::drop(guard);
|
||||
|
||||
t
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! spanf {
|
||||
($logger:expr, $span:expr, $func:expr) => {{
|
||||
use opentelemetry::trace::{FutureExt, TraceContextExt};
|
||||
|
||||
let span = $span.start($logger.tracer());
|
||||
let cx = opentelemetry::Context::current_with_span(span);
|
||||
let t = $func.with_context(cx.clone()).await;
|
||||
|
||||
if let Err(e) = &t {
|
||||
cx.span().record_error(e);
|
||||
}
|
||||
|
||||
cx.span().end();
|
||||
|
||||
t
|
||||
}};
|
||||
}
|
||||
195
cli/src/msgpack_rpc.rs
Normal file
195
cli/src/msgpack_rpc.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use bytes::Buf;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
pin,
|
||||
sync::mpsc,
|
||||
};
|
||||
use tokio_util::codec::Decoder;
|
||||
|
||||
use crate::{
|
||||
rpc::{self, MaybeSync, Serialization},
|
||||
util::{
|
||||
errors::{AnyError, InvalidRpcDataError},
|
||||
sync::{Barrier, Receivable},
|
||||
},
|
||||
};
|
||||
use std::io::{self, Cursor, ErrorKind};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct MsgPackSerializer {}
|
||||
|
||||
impl Serialization for MsgPackSerializer {
|
||||
fn serialize(&self, value: impl serde::Serialize) -> Vec<u8> {
|
||||
rmp_serde::to_vec_named(&value).expect("expected to serialize")
|
||||
}
|
||||
|
||||
fn deserialize<P: serde::de::DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError> {
|
||||
rmp_serde::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into())
|
||||
}
|
||||
}
|
||||
|
||||
pub type MsgPackCaller = rpc::RpcCaller<MsgPackSerializer>;
|
||||
|
||||
/// Creates a new RPC Builder that serializes to msgpack.
|
||||
pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
|
||||
rpc::RpcBuilder::new(MsgPackSerializer {})
|
||||
}
|
||||
|
||||
/// Starting processing msgpack rpc over the given i/o. It's recommended that
|
||||
/// the reader be passed in as a BufReader for efficiency.
|
||||
pub async fn start_msgpack_rpc<
|
||||
C: Send + Sync + 'static,
|
||||
X: Clone,
|
||||
S: Send + Sync + Serialization,
|
||||
Read: AsyncRead + Unpin,
|
||||
Write: AsyncWrite + Unpin,
|
||||
>(
|
||||
dispatcher: rpc::RpcDispatcher<S, C>,
|
||||
mut read: Read,
|
||||
mut write: Write,
|
||||
mut msg_rx: impl Receivable<Vec<u8>>,
|
||||
mut shutdown_rx: Barrier<X>,
|
||||
) -> io::Result<(Option<X>, Read, Write)> {
|
||||
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
|
||||
let mut decoder = MsgPackCodec::new();
|
||||
let mut decoder_buf = bytes::BytesMut::new();
|
||||
|
||||
let shutdown_fut = shutdown_rx.wait();
|
||||
pin!(shutdown_fut);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
r = read.read_buf(&mut decoder_buf) => {
|
||||
r?;
|
||||
|
||||
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
|
||||
match dispatcher.dispatch_with_partial(&frame.vec, frame.obj) {
|
||||
MaybeSync::Sync(Some(v)) => {
|
||||
let _ = write_tx.send(v).await;
|
||||
},
|
||||
MaybeSync::Sync(None) => continue,
|
||||
MaybeSync::Future(fut) => {
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
MaybeSync::Stream((stream, fut)) => {
|
||||
if let Some(stream) = stream {
|
||||
dispatcher.register_stream(write_tx.clone(), stream).await;
|
||||
}
|
||||
let write_tx = write_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Some(v) = fut.await {
|
||||
let _ = write_tx.send(v).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
},
|
||||
Some(m) = write_rx.recv() => {
|
||||
write.write_all(&m).await?;
|
||||
},
|
||||
Some(m) = msg_rx.recv_msg() => {
|
||||
write.write_all(&m).await?;
|
||||
},
|
||||
r = &mut shutdown_fut => return Ok((r.ok(), read, write)),
|
||||
}
|
||||
|
||||
write.flush().await?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reader that reads msgpack object messages in a cancellation-safe way using Tokio's codecs.
|
||||
///
|
||||
/// rmp_serde does not support async reads, and does not plan to. But we know every
|
||||
/// type in protocol is some kind of object, so by asking to deserialize the
|
||||
/// requested object from a reader (repeatedly, if incomplete) we can
|
||||
/// accomplish streaming.
|
||||
pub struct MsgPackCodec<T> {
|
||||
_marker: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> MsgPackCodec<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_marker: std::marker::PhantomData::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MsgPackDecoded<T> {
|
||||
pub obj: T,
|
||||
pub vec: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<T: DeserializeOwned> tokio_util::codec::Decoder for MsgPackCodec<T> {
|
||||
type Item = MsgPackDecoded<T>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
let bytes_ref = src.as_ref();
|
||||
let mut cursor = Cursor::new(bytes_ref);
|
||||
|
||||
match rmp_serde::decode::from_read::<_, T>(&mut cursor) {
|
||||
Err(
|
||||
rmp_serde::decode::Error::InvalidDataRead(e)
|
||||
| rmp_serde::decode::Error::InvalidMarkerRead(e),
|
||||
) if e.kind() == ErrorKind::UnexpectedEof => {
|
||||
src.reserve(1024);
|
||||
Ok(None)
|
||||
}
|
||||
Err(e) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
e.to_string(),
|
||||
)),
|
||||
Ok(obj) => {
|
||||
let len = cursor.position() as usize;
|
||||
let vec = src[..len].to_vec();
|
||||
src.advance(len);
|
||||
Ok(Some(MsgPackDecoded { obj, vec }))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let msg = src[U32_SIZE..].to_vec();
|
||||
mod tests {
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
|
||||
pub struct Msg {
|
||||
pub x: i32,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol() {
|
||||
let mut c = MsgPackCodec::<Msg>::new();
|
||||
let mut buf = bytes::BytesMut::new();
|
||||
|
||||
assert!(c.decode(&mut buf).unwrap().is_none());
|
||||
|
||||
buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 1 }).unwrap().as_slice());
|
||||
buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 2 }).unwrap().as_slice());
|
||||
|
||||
src.resize(0, 0);
|
||||
c.decode(&mut buf).unwrap().expect("expected msg1").obj,
|
||||
Msg { x: 1 }
|
||||
);
|
||||
assert_eq!(
|
||||
c.decode(&mut buf).unwrap().expect("expected msg1").obj,
|
||||
Msg { x: 2 }
|
||||
);
|
||||
}
|
||||
}
|
||||
115
cli/src/options.rs
Normal file
115
cli/src/options.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::constants::{APPLICATION_NAME_MAP, PRODUCT_NAME_LONG_MAP, SERVER_NAME_MAP};
|
||||
|
||||
#[derive(clap::ArgEnum, Copy, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Quality {
|
||||
#[serde(rename = "stable")]
|
||||
Stable,
|
||||
#[serde(rename = "exploration")]
|
||||
Exploration,
|
||||
#[serde(other)]
|
||||
Insiders,
|
||||
}
|
||||
|
||||
impl Quality {
|
||||
/// Lowercased quality name in paths and protocol
|
||||
pub fn get_machine_name(&self) -> &'static str {
|
||||
match self {
|
||||
Quality::Insiders => "insiders",
|
||||
Quality::Exploration => "exploration",
|
||||
Quality::Stable => "stable",
|
||||
}
|
||||
}
|
||||
|
||||
/// Uppercased quality display name for humans
|
||||
pub fn get_capitalized_name(&self) -> &'static str {
|
||||
match self {
|
||||
Quality::Insiders => "Insiders",
|
||||
Quality::Exploration => "Exploration",
|
||||
Quality::Stable => "Stable",
|
||||
}
|
||||
}
|
||||
|
||||
/// Product long name
|
||||
pub fn get_long_name(&self) -> &'static str {
|
||||
PRODUCT_NAME_LONG_MAP
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(self))
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("Code - OSS")
|
||||
}
|
||||
|
||||
/// Product application name
|
||||
pub fn get_application_name(&self) -> &'static str {
|
||||
APPLICATION_NAME_MAP
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(self))
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("code")
|
||||
}
|
||||
|
||||
/// Server application name
|
||||
pub fn server_entrypoint(&self) -> String {
|
||||
let mut server_name = SERVER_NAME_MAP
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(self))
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("code-server-oss")
|
||||
.to_string();
|
||||
|
||||
if cfg!(windows) {
|
||||
server_name.push_str(".cmd");
|
||||
}
|
||||
|
||||
server_name
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Quality {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.get_capitalized_name())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Quality {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||
match s {
|
||||
"stable" => Ok(Quality::Stable),
|
||||
"insiders" | "insider" => Ok(Quality::Insiders),
|
||||
"exploration" => Ok(Quality::Exploration),
|
||||
_ => Err(format!(
|
||||
"Unknown quality: {}. Must be one of stable, insiders, or exploration.",
|
||||
s
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ArgEnum, Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TelemetryLevel {
|
||||
Off,
|
||||
Crash,
|
||||
Error,
|
||||
All,
|
||||
}
|
||||
|
||||
impl fmt::Display for TelemetryLevel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
TelemetryLevel::Off => write!(f, "off"),
|
||||
TelemetryLevel::Crash => write!(f, "crash"),
|
||||
TelemetryLevel::Error => write!(f, "error"),
|
||||
TelemetryLevel::All => write!(f, "all"),
|
||||
}
|
||||
}
|
||||
}
|
||||
693
cli/src/rpc.rs
Normal file
693
cli/src/rpc.rs
Normal file
@@ -0,0 +1,693 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
future,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::log;
|
||||
use futures::{future::BoxFuture, Future, FutureExt};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
|
||||
sync::{mpsc, oneshot},
|
||||
};
|
||||
|
||||
use crate::util::errors::AnyError;
|
||||
|
||||
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
|
||||
pub type AsyncMethod =
|
||||
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
|
||||
pub type Duplex = Arc<
|
||||
dyn Send
|
||||
+ Sync
|
||||
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
|
||||
>;
|
||||
|
||||
pub enum Method {
|
||||
Sync(SyncMethod),
|
||||
Async(AsyncMethod),
|
||||
Duplex(Duplex),
|
||||
}
|
||||
|
||||
/// Serialization is given to the RpcBuilder and defines how data gets serialized
|
||||
/// when callinth methods.
|
||||
pub trait Serialization: Send + Sync + 'static {
|
||||
fn serialize(&self, value: impl Serialize) -> Vec<u8>;
|
||||
fn deserialize<P: DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError>;
|
||||
}
|
||||
|
||||
/// RPC is a basic, transport-agnostic builder for RPC methods. You can
|
||||
/// register methods to it, then call `.build()` to get a "dispatcher" type.
|
||||
pub struct RpcBuilder<S> {
|
||||
serializer: Arc<S>,
|
||||
methods: HashMap<&'static str, Method>,
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
}
|
||||
|
||||
impl<S: Serialization> RpcBuilder<S> {
|
||||
/// Creates a new empty RPC builder.
|
||||
pub fn new(serializer: S) -> Self {
|
||||
Self {
|
||||
serializer: Arc::new(serializer),
|
||||
methods: HashMap::new(),
|
||||
calls: Arc::new(std::sync::Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a caller that will be connected to any eventual dispatchers,
|
||||
/// and that sends data to the "tx" channel.
|
||||
pub fn get_caller(&mut self, sender: mpsc::UnboundedSender<Vec<u8>>) -> RpcCaller<S> {
|
||||
RpcCaller {
|
||||
serializer: self.serializer.clone(),
|
||||
calls: self.calls.clone(),
|
||||
sender,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a method builder.
|
||||
pub fn methods<C: Send + Sync + 'static>(self, context: C) -> RpcMethodBuilder<S, C> {
|
||||
RpcMethodBuilder {
|
||||
context: Arc::new(context),
|
||||
serializer: self.serializer,
|
||||
methods: self.methods,
|
||||
calls: self.calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RpcMethodBuilder<S, C> {
|
||||
context: Arc<C>,
|
||||
serializer: Arc<S>,
|
||||
methods: HashMap<&'static str, Method>,
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct DuplexStreamStarted {
|
||||
pub for_request_id: u32,
|
||||
pub stream_ids: Vec<u32>,
|
||||
}
|
||||
|
||||
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
|
||||
/// Registers a synchronous rpc call that returns its result directly.
|
||||
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
|
||||
where
|
||||
P: DeserializeOwned,
|
||||
R: Serialize,
|
||||
F: Fn(P, &C) -> Result<R, AnyError> + Send + Sync + 'static,
|
||||
{
|
||||
if self.methods.contains_key(method_name) {
|
||||
panic!("Method already registered: {}", method_name);
|
||||
}
|
||||
|
||||
let serial = self.serializer.clone();
|
||||
let context = self.context.clone();
|
||||
self.methods.insert(
|
||||
method_name,
|
||||
Method::Sync(Arc::new(move |id, body| {
|
||||
let param = match serial.deserialize::<RequestParams<P>>(body) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
match callback(param.params, &context) {
|
||||
Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })),
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}),
|
||||
}
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
/// Registers an async rpc call that returns a Future.
|
||||
pub fn register_async<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
|
||||
where
|
||||
P: DeserializeOwned + Send + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
Fut: Future<Output = Result<R, AnyError>> + Send,
|
||||
F: (Fn(P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let serial = self.serializer.clone();
|
||||
let context = self.context.clone();
|
||||
self.methods.insert(
|
||||
method_name,
|
||||
Method::Async(Arc::new(move |id, body| {
|
||||
let param = match serial.deserialize::<RequestParams<P>>(body) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return future::ready(id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}))
|
||||
.boxed();
|
||||
}
|
||||
};
|
||||
|
||||
let callback = callback.clone();
|
||||
let serial = serial.clone();
|
||||
let context = context.clone();
|
||||
let fut = async move {
|
||||
match callback(param.params, context).await {
|
||||
Ok(result) => {
|
||||
id.map(|id| serial.serialize(&SuccessResponse { id, result }))
|
||||
}
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
fut.boxed()
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
/// Registers an async rpc call that returns a Future containing a duplex
|
||||
/// stream that should be handled by the client.
|
||||
pub fn register_duplex<P, R, Fut, F>(
|
||||
&mut self,
|
||||
method_name: &'static str,
|
||||
streams: usize,
|
||||
callback: F,
|
||||
) where
|
||||
P: DeserializeOwned + Send + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
Fut: Future<Output = Result<R, AnyError>> + Send,
|
||||
F: (Fn(Vec<DuplexStream>, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
|
||||
{
|
||||
let serial = self.serializer.clone();
|
||||
let context = self.context.clone();
|
||||
self.methods.insert(
|
||||
method_name,
|
||||
Method::Duplex(Arc::new(move |id, body| {
|
||||
let param = match serial.deserialize::<RequestParams<P>>(body) {
|
||||
Ok(p) => p,
|
||||
Err(err) => {
|
||||
return (
|
||||
None,
|
||||
future::ready(id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: 0,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}))
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let callback = callback.clone();
|
||||
let serial = serial.clone();
|
||||
let context = context.clone();
|
||||
|
||||
let mut dto = StreamDto {
|
||||
req_id: id.unwrap_or(0),
|
||||
streams: Vec::with_capacity(streams),
|
||||
};
|
||||
let mut servers = Vec::with_capacity(streams);
|
||||
|
||||
for _ in 0..streams {
|
||||
let (client, server) = tokio::io::duplex(8192);
|
||||
servers.push(server);
|
||||
dto.streams.push((next_message_id(), client));
|
||||
}
|
||||
|
||||
let fut = async move {
|
||||
match callback(servers, param.params, context).await {
|
||||
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
|
||||
Err(err) => id.map(|id| {
|
||||
serial.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
message: format!("{:?}", err),
|
||||
},
|
||||
})
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
(Some(dto), fut.boxed())
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
/// Builds into a usable, sync rpc dispatcher.
|
||||
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
|
||||
let streams = Streams::default();
|
||||
|
||||
let s1 = streams.clone();
|
||||
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
|
||||
let s1 = s1.clone();
|
||||
async move {
|
||||
s1.remove(m.stream).await;
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
let s2 = streams.clone();
|
||||
self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
|
||||
s2.write(m.stream, m.segment);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
RpcDispatcher {
|
||||
log,
|
||||
context: self.context,
|
||||
calls: self.calls,
|
||||
serializer: self.serializer,
|
||||
methods: Arc::new(self.methods),
|
||||
streams,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type DispatchMethod = Box<dyn Send + Sync + FnOnce(Outcome)>;
|
||||
|
||||
/// Dispatcher returned from a Builder that provides a transport-agnostic way to
|
||||
/// deserialize and dispatch RPC calls. This structure may get more advanced as
|
||||
/// time goes on...
|
||||
#[derive(Clone)]
|
||||
pub struct RpcCaller<S: Serialization> {
|
||||
serializer: Arc<S>,
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
sender: mpsc::UnboundedSender<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl<S: Serialization> RpcCaller<S> {
|
||||
pub fn serialize_notify<M, A>(serializer: &S, method: M, params: A) -> Vec<u8>
|
||||
where
|
||||
S: Serialization,
|
||||
M: AsRef<str> + serde::Serialize,
|
||||
A: Serialize,
|
||||
{
|
||||
serializer.serialize(&FullRequest {
|
||||
id: None,
|
||||
method,
|
||||
params,
|
||||
})
|
||||
}
|
||||
|
||||
/// Enqueues an outbound call. Returns whether the message was enqueued.
|
||||
pub fn notify<M, A>(&self, method: M, params: A) -> bool
|
||||
where
|
||||
M: AsRef<str> + serde::Serialize,
|
||||
A: Serialize,
|
||||
{
|
||||
self.sender
|
||||
.send(Self::serialize_notify(&self.serializer, method, params))
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
/// Enqueues an outbound call, returning its result.
|
||||
pub fn call<M, A, R>(&self, method: M, params: A) -> oneshot::Receiver<Result<R, ResponseError>>
|
||||
where
|
||||
M: AsRef<str> + serde::Serialize,
|
||||
A: Serialize,
|
||||
R: DeserializeOwned + Send + 'static,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let id = next_message_id();
|
||||
let body = self.serializer.serialize(&FullRequest {
|
||||
id: Some(id),
|
||||
method,
|
||||
params,
|
||||
});
|
||||
|
||||
if self.sender.send(body).is_err() {
|
||||
drop(tx);
|
||||
return rx;
|
||||
}
|
||||
|
||||
let serializer = self.serializer.clone();
|
||||
self.calls.lock().unwrap().insert(
|
||||
id,
|
||||
Box::new(move |body| {
|
||||
match body {
|
||||
Outcome::Error(e) => tx.send(Err(e)).ok(),
|
||||
Outcome::Success(r) => match serializer.deserialize::<SuccessResponse<R>>(&r) {
|
||||
Ok(r) => tx.send(Ok(r.result)).ok(),
|
||||
Err(err) => tx
|
||||
.send(Err(ResponseError {
|
||||
code: 0,
|
||||
message: err.to_string(),
|
||||
}))
|
||||
.ok(),
|
||||
},
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
/// Dispatcher returned from a Builder that provides a transport-agnostic way to
|
||||
/// deserialize and handle RPC calls. This structure may get more advanced as
|
||||
/// time goes on...
|
||||
#[derive(Clone)]
|
||||
pub struct RpcDispatcher<S, C> {
|
||||
log: log::Logger,
|
||||
context: Arc<C>,
|
||||
serializer: Arc<S>,
|
||||
methods: Arc<HashMap<&'static str, Method>>,
|
||||
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
|
||||
streams: Streams,
|
||||
}
|
||||
|
||||
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
|
||||
fn next_message_id() -> u32 {
|
||||
MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
|
||||
/// Runs the incoming request, returning the result of the call synchronously
|
||||
/// or in a future. (The caller can then decide whether to run the future
|
||||
/// sequentially in its receive loop, or not.)
|
||||
///
|
||||
/// The future or return result will be optional bytes that should be sent
|
||||
/// back to the socket.
|
||||
pub fn dispatch(&self, body: &[u8]) -> MaybeSync {
|
||||
match self.serializer.deserialize::<PartialIncoming>(body) {
|
||||
Ok(partial) => self.dispatch_with_partial(body, partial),
|
||||
Err(_err) => {
|
||||
warning!(self.log, "Failed to deserialize request, hex: {:X?}", body);
|
||||
MaybeSync::Sync(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Like dispatch, but allows passing an existing PartialIncoming.
|
||||
pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync {
|
||||
let id = partial.id;
|
||||
|
||||
if let Some(method_name) = partial.method {
|
||||
let method = self.methods.get(method_name.as_str());
|
||||
match method {
|
||||
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
|
||||
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
|
||||
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
|
||||
None => MaybeSync::Sync(id.map(|id| {
|
||||
self.serializer.serialize(&ErrorResponse {
|
||||
id,
|
||||
error: ResponseError {
|
||||
code: -1,
|
||||
message: format!("Method not found: {}", method_name),
|
||||
},
|
||||
})
|
||||
})),
|
||||
}
|
||||
} else if let Some(err) = partial.error {
|
||||
if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
|
||||
cb(Outcome::Error(err));
|
||||
}
|
||||
MaybeSync::Sync(None)
|
||||
} else {
|
||||
if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
|
||||
cb(Outcome::Success(body.to_vec()));
|
||||
}
|
||||
MaybeSync::Sync(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers a stream call returned from dispatch().
|
||||
pub async fn register_stream(
|
||||
&self,
|
||||
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
|
||||
dto: StreamDto,
|
||||
) {
|
||||
let r = write_tx
|
||||
.send(
|
||||
self.serializer
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAMS_STARTED,
|
||||
params: DuplexStreamStarted {
|
||||
stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(),
|
||||
for_request_id: dto.req_id,
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
for (stream_id, duplex) in dto.streams {
|
||||
let (mut read, write) = tokio::io::split(duplex);
|
||||
self.streams.insert(stream_id, write);
|
||||
|
||||
let write_tx = write_tx.clone();
|
||||
let serial = self.serializer.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0; 4096];
|
||||
loop {
|
||||
match read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => {
|
||||
let r = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_DATA,
|
||||
params: StreamDataParams {
|
||||
segment: &buf[..n],
|
||||
stream: stream_id,
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = write_tx
|
||||
.send(
|
||||
serial
|
||||
.serialize(&FullRequest {
|
||||
id: None,
|
||||
method: METHOD_STREAM_ENDED,
|
||||
params: StreamEndedParams { stream: stream_id },
|
||||
})
|
||||
.into(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn context(&self) -> Arc<C> {
|
||||
self.context.clone()
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamRec {
|
||||
write: Option<WriteHalf<DuplexStream>>,
|
||||
q: Vec<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct Streams {
|
||||
map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
|
||||
}
|
||||
|
||||
impl Streams {
|
||||
pub async fn remove(&self, id: u32) {
|
||||
let stream = self.map.lock().unwrap().remove(&id);
|
||||
if let Some(s) = stream {
|
||||
// if there's no 'write' right now, it'll shut down in the write_loop
|
||||
if let Some(mut w) = s.write {
|
||||
let _ = w.shutdown().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, id: u32, buf: Vec<u8>) {
|
||||
let mut map = self.map.lock().unwrap();
|
||||
if let Some(s) = map.get_mut(&id) {
|
||||
s.q.push(buf);
|
||||
|
||||
if let Some(w) = s.write.take() {
|
||||
tokio::spawn(write_loop(id, w, self.map.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) {
|
||||
self.map.lock().unwrap().insert(
|
||||
id,
|
||||
StreamRec {
|
||||
write: Some(stream),
|
||||
q: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Write loop started by `Streams.write`. It takes the WriteHalf, and
|
||||
/// runs until there's no more items in the 'write queue'. At that point, if the
|
||||
/// record still exists in the `streams` (i.e. we haven't shut down), it'll
|
||||
/// return the WriteHalf so that the next `write` call starts
|
||||
/// the loop again. Otherwise, it'll shut down the WriteHalf.
|
||||
///
|
||||
/// This is the equivalent of the same write_loop in the server_multiplexer.
|
||||
/// I couldn't figure out a nice way to abstract it without introducing
|
||||
/// performance overhead...
|
||||
async fn write_loop(
|
||||
id: u32,
|
||||
mut w: WriteHalf<DuplexStream>,
|
||||
streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
|
||||
) {
|
||||
let mut items_vec = vec![];
|
||||
loop {
|
||||
{
|
||||
let mut lock = streams.lock().unwrap();
|
||||
let stream_rec = match lock.get_mut(&id) {
|
||||
Some(b) => b,
|
||||
None => break,
|
||||
};
|
||||
|
||||
if stream_rec.q.is_empty() {
|
||||
stream_rec.write = Some(w);
|
||||
return;
|
||||
}
|
||||
|
||||
std::mem::swap(&mut stream_rec.q, &mut items_vec);
|
||||
}
|
||||
|
||||
for item in items_vec.drain(..) {
|
||||
if w.write_all(&item).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so
|
||||
}
|
||||
|
||||
const METHOD_STREAMS_STARTED: &str = "streams_started";
|
||||
const METHOD_STREAM_DATA: &str = "stream_data";
|
||||
const METHOD_STREAM_ENDED: &str = "stream_ended";
|
||||
|
||||
trait AssertIsSync: Sync {}
|
||||
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
|
||||
|
||||
/// Approximate shape that is used to determine what kind of data is incoming.
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PartialIncoming {
|
||||
pub id: Option<u32>,
|
||||
pub method: Option<String>,
|
||||
pub error: Option<ResponseError>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct StreamDataIncomingParams {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: Vec<u8>,
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StreamDataParams<'a> {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: &'a [u8],
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct StreamEndedParams {
|
||||
pub stream: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct FullRequest<M: AsRef<str>, P> {
|
||||
pub id: Option<u32>,
|
||||
pub method: M,
|
||||
pub params: P,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RequestParams<P> {
|
||||
pub params: P,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SuccessResponse<T> {
|
||||
pub id: u32,
|
||||
pub result: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
pub id: u32,
|
||||
pub error: ResponseError,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ResponseError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
enum Outcome {
|
||||
Success(Vec<u8>),
|
||||
Error(ResponseError),
|
||||
}
|
||||
|
||||
pub struct StreamDto {
|
||||
req_id: u32,
|
||||
streams: Vec<(u32, DuplexStream)>,
|
||||
}
|
||||
|
||||
pub enum MaybeSync {
|
||||
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
|
||||
Future(BoxFuture<'static, Option<Vec<u8>>>),
|
||||
Sync(Option<Vec<u8>>),
|
||||
}
|
||||
163
cli/src/self_update.rs
Normal file
163
cli/src/self_update.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{fs, path::Path, process::Command};
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{
|
||||
constants::{VSCODE_CLI_COMMIT, VSCODE_CLI_QUALITY},
|
||||
options::Quality,
|
||||
update_service::{unzip_downloaded_release, Platform, Release, TargetKind, UpdateService},
|
||||
util::{
|
||||
errors::{wrap, AnyError, CorruptDownload, UpdatesNotConfigured},
|
||||
http,
|
||||
io::{ReportCopyProgress, SilentCopyProgress},
|
||||
},
|
||||
};
|
||||
|
||||
pub struct SelfUpdate<'a> {
|
||||
commit: &'static str,
|
||||
quality: Quality,
|
||||
platform: Platform,
|
||||
update_service: &'a UpdateService,
|
||||
}
|
||||
|
||||
impl<'a> SelfUpdate<'a> {
|
||||
pub fn new(update_service: &'a UpdateService) -> Result<Self, AnyError> {
|
||||
let commit = VSCODE_CLI_COMMIT
|
||||
.ok_or_else(|| UpdatesNotConfigured("unknown build commit".to_string()))?;
|
||||
|
||||
let quality = VSCODE_CLI_QUALITY
|
||||
.ok_or_else(|| UpdatesNotConfigured("no configured quality".to_string()))
|
||||
.and_then(|q| Quality::try_from(q).map_err(UpdatesNotConfigured))?;
|
||||
|
||||
let platform = Platform::env_default().ok_or_else(|| {
|
||||
UpdatesNotConfigured("Unknown platform, please report this error".to_string())
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
commit,
|
||||
quality,
|
||||
platform,
|
||||
update_service,
|
||||
})
|
||||
}
|
||||
|
||||
/// Gets the current release
|
||||
pub async fn get_current_release(&self) -> Result<Release, AnyError> {
|
||||
self.update_service
|
||||
.get_latest_commit(self.platform, TargetKind::Cli, self.quality)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Gets whether the given release is what this CLI is built against
|
||||
pub fn is_up_to_date_with(&self, release: &Release) -> bool {
|
||||
release.commit == self.commit
|
||||
}
|
||||
|
||||
/// Updates the CLI to the given release.
|
||||
pub async fn do_update(
|
||||
&self,
|
||||
release: &Release,
|
||||
progress: impl ReportCopyProgress,
|
||||
) -> Result<(), AnyError> {
|
||||
// 1. Download the archive into a temporary directory
|
||||
let tempdir = tempdir().map_err(|e| wrap(e, "Failed to create temp dir"))?;
|
||||
let stream = self.update_service.get_download_stream(release).await?;
|
||||
let archive_path = tempdir.path().join(stream.url_path_basename().unwrap());
|
||||
http::download_into_file(&archive_path, progress, stream).await?;
|
||||
|
||||
// 2. Unzip the archive and get the binary
|
||||
let target_path =
|
||||
std::env::current_exe().map_err(|e| wrap(e, "could not get current exe"))?;
|
||||
let staging_path = target_path.with_extension(".update");
|
||||
let archive_contents_path = tempdir.path().join("content");
|
||||
// unzipping the single binary is pretty small and fast--don't bother with passing progress
|
||||
unzip_downloaded_release(&archive_path, &archive_contents_path, SilentCopyProgress())?;
|
||||
copy_updated_cli_to_path(&archive_contents_path, &staging_path)?;
|
||||
|
||||
// 3. Copy file metadata, make sure the new binary is executable\
|
||||
copy_file_metadata(&target_path, &staging_path)
|
||||
.map_err(|e| wrap(e, "failed to set file permissions"))?;
|
||||
validate_cli_is_good(&staging_path)?;
|
||||
|
||||
// Try to rename the old CLI to the tempdir, where it can get cleaned up by the
|
||||
// OS later. However, this can fail if the tempdir is on a different drive
|
||||
// than the installation dir. In this case just rename it to ".old".
|
||||
if fs::rename(&target_path, tempdir.path().join("old-code-cli")).is_err() {
|
||||
fs::rename(&target_path, target_path.with_extension(".old"))
|
||||
.map_err(|e| wrap(e, "failed to rename old CLI"))?;
|
||||
}
|
||||
|
||||
fs::rename(&staging_path, &target_path)
|
||||
.map_err(|e| wrap(e, "failed to rename newly installed CLI"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_cli_is_good(exe_path: &Path) -> Result<(), AnyError> {
|
||||
let o = Command::new(exe_path)
|
||||
.args(["--version"])
|
||||
.output()
|
||||
.map_err(|e| CorruptDownload(format!("could not execute new binary, aborting: {}", e)))?;
|
||||
|
||||
if !o.status.success() {
|
||||
let msg = format!(
|
||||
"could not execute new binary, aborting. Stdout:\n\n{}\n\nStderr:\n\n{}",
|
||||
String::from_utf8_lossy(&o.stdout),
|
||||
String::from_utf8_lossy(&o.stderr),
|
||||
);
|
||||
|
||||
return Err(CorruptDownload(msg).into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_updated_cli_to_path(unzipped_content: &Path, staging_path: &Path) -> Result<(), AnyError> {
|
||||
let unzipped_files = fs::read_dir(unzipped_content)
|
||||
.map_err(|e| wrap(e, "could not read update contents"))?
|
||||
.collect::<Vec<_>>();
|
||||
if unzipped_files.len() != 1 {
|
||||
let msg = format!(
|
||||
"expected exactly one file in update, got {}",
|
||||
unzipped_files.len()
|
||||
);
|
||||
return Err(CorruptDownload(msg).into());
|
||||
}
|
||||
|
||||
let archive_file = unzipped_files[0]
|
||||
.as_ref()
|
||||
.map_err(|e| wrap(e, "error listing update files"))?;
|
||||
fs::copy(archive_file.path(), staging_path)
|
||||
.map_err(|e| wrap(e, "error copying to staging file"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> {
|
||||
let permissions = from.metadata()?.permissions();
|
||||
fs::set_permissions(to, permissions)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn copy_file_metadata(from: &Path, to: &Path) -> Result<(), std::io::Error> {
|
||||
use std::os::unix::ffi::OsStrExt;
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
|
||||
let metadata = from.metadata()?;
|
||||
fs::set_permissions(to, metadata.permissions())?;
|
||||
|
||||
// based on coreutils' chown https://github.com/uutils/coreutils/blob/72b4629916abe0852ad27286f4e307fbca546b6e/src/chown/chown.rs#L266-L281
|
||||
let s = std::ffi::CString::new(to.as_os_str().as_bytes()).unwrap();
|
||||
let ret = unsafe { libc::chown(s.as_ptr(), metadata.uid(), metadata.gid()) };
|
||||
if ret != 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
193
cli/src/singleton.rs
Normal file
193
cli/src/singleton.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs::{File, OpenOptions},
|
||||
io::{Seek, SeekFrom, Write},
|
||||
path::{Path, PathBuf},
|
||||
time::Duration,
|
||||
};
|
||||
use sysinfo::{Pid, PidExt};
|
||||
|
||||
use crate::{
|
||||
async_pipe::{
|
||||
get_socket_name, get_socket_rw_stream, listen_socket_rw_stream, AsyncPipe,
|
||||
AsyncPipeListener,
|
||||
},
|
||||
util::{
|
||||
errors::CodeError,
|
||||
file_lock::{FileLock, Lock, PREFIX_LOCKED_BYTES},
|
||||
machine::wait_until_process_exits,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct SingletonServer {
|
||||
server: AsyncPipeListener,
|
||||
_lock: FileLock,
|
||||
}
|
||||
|
||||
impl SingletonServer {
|
||||
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
|
||||
self.server.accept().await
|
||||
}
|
||||
}
|
||||
|
||||
pub enum SingletonConnection {
|
||||
/// This instance got the singleton lock. It started listening on a socket
|
||||
/// and has the read/write pair. If this gets dropped, the lock is released.
|
||||
Singleton(SingletonServer),
|
||||
/// Another instance is a singleton, and this client connected to it.
|
||||
Client(AsyncPipe),
|
||||
}
|
||||
|
||||
/// Contents of the lock file; the listening socket ID and process ID
|
||||
/// doing the listening.
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct LockFileMatter {
|
||||
socket_path: String,
|
||||
pid: u32,
|
||||
}
|
||||
|
||||
/// Tries to acquire the singleton homed at the given lock file, either starting
|
||||
/// a new singleton if it doesn't exist, or connecting otherwise.
|
||||
pub async fn acquire_singleton(lock_file: PathBuf) -> Result<SingletonConnection, CodeError> {
|
||||
let file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(&lock_file)
|
||||
.map_err(CodeError::SingletonLockfileOpenFailed)?;
|
||||
|
||||
match FileLock::acquire(file) {
|
||||
Ok(Lock::AlreadyLocked(mut file)) => connect_as_client_with_file(&mut file)
|
||||
.await
|
||||
.map(SingletonConnection::Client),
|
||||
Ok(Lock::Acquired(lock)) => start_singleton_server(lock)
|
||||
.await
|
||||
.map(SingletonConnection::Singleton),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries to connect to the singleton homed at the given file as a client.
|
||||
pub async fn connect_as_client(lock_file: &Path) -> Result<AsyncPipe, CodeError> {
|
||||
let mut file = OpenOptions::new()
|
||||
.read(true)
|
||||
.open(lock_file)
|
||||
.map_err(CodeError::SingletonLockfileOpenFailed)?;
|
||||
|
||||
connect_as_client_with_file(&mut file).await
|
||||
}
|
||||
|
||||
async fn start_singleton_server(mut lock: FileLock) -> Result<SingletonServer, CodeError> {
|
||||
let socket_path = get_socket_name();
|
||||
|
||||
let mut vec = Vec::with_capacity(128);
|
||||
let _ = vec.write(&[0; PREFIX_LOCKED_BYTES]);
|
||||
let _ = rmp_serde::encode::write(
|
||||
&mut vec,
|
||||
&LockFileMatter {
|
||||
socket_path: socket_path.to_string_lossy().to_string(),
|
||||
pid: std::process::id(),
|
||||
},
|
||||
);
|
||||
|
||||
lock.file_mut()
|
||||
.write_all(&vec)
|
||||
.map_err(CodeError::SingletonLockfileOpenFailed)?;
|
||||
|
||||
let server = listen_socket_rw_stream(&socket_path).await?;
|
||||
Ok(SingletonServer {
|
||||
server,
|
||||
_lock: lock,
|
||||
})
|
||||
}
|
||||
|
||||
const MAX_CLIENT_ATTEMPTS: i32 = 10;
|
||||
|
||||
async fn connect_as_client_with_file(mut file: &mut File) -> Result<AsyncPipe, CodeError> {
|
||||
// retry, since someone else could get a lock and we could read it before
|
||||
// the JSON info was finished writing out
|
||||
let mut attempt = 0;
|
||||
loop {
|
||||
let _ = file.seek(SeekFrom::Start(PREFIX_LOCKED_BYTES as u64));
|
||||
let r = match rmp_serde::from_read::<_, LockFileMatter>(&mut file) {
|
||||
Ok(prev) => {
|
||||
let socket_path = PathBuf::from(prev.socket_path);
|
||||
|
||||
tokio::select! {
|
||||
p = retry_get_socket_rw_stream(&socket_path, 5, Duration::from_millis(500)) => p,
|
||||
_ = wait_until_process_exits(Pid::from_u32(prev.pid), 500) => return Err(CodeError::SingletonLockedProcessExited(prev.pid)),
|
||||
}
|
||||
}
|
||||
Err(e) => Err(CodeError::SingletonLockfileReadFailed(e)),
|
||||
};
|
||||
|
||||
if r.is_ok() || attempt == MAX_CLIENT_ATTEMPTS {
|
||||
return r;
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn retry_get_socket_rw_stream(
|
||||
path: &Path,
|
||||
max_tries: usize,
|
||||
interval: Duration,
|
||||
) -> Result<AsyncPipe, CodeError> {
|
||||
for i in 0.. {
|
||||
match get_socket_rw_stream(path).await {
|
||||
Ok(s) => return Ok(s),
|
||||
Err(e) if i == max_tries => return Err(e),
|
||||
Err(_) => tokio::time::sleep(interval).await,
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_acquires_singleton() {
|
||||
let dir = tempfile::tempdir().expect("expected to make temp dir");
|
||||
let s = acquire_singleton(dir.path().join("lock"))
|
||||
.await
|
||||
.expect("expected to acquire");
|
||||
|
||||
match s {
|
||||
SingletonConnection::Singleton(_) => {}
|
||||
_ => panic!("expected to be singleton"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_acquires_client() {
|
||||
let dir = tempfile::tempdir().expect("expected to make temp dir");
|
||||
let lockfile = dir.path().join("lock");
|
||||
let s1 = acquire_singleton(lockfile.clone())
|
||||
.await
|
||||
.expect("expected to acquire1");
|
||||
match s1 {
|
||||
SingletonConnection::Singleton(mut l) => tokio::spawn(async move {
|
||||
l.accept().await.expect("expected to accept");
|
||||
}),
|
||||
_ => panic!("expected to be singleton"),
|
||||
};
|
||||
|
||||
let s2 = acquire_singleton(lockfile)
|
||||
.await
|
||||
.expect("expected to acquire2");
|
||||
match s2 {
|
||||
SingletonConnection::Client(_) => {}
|
||||
_ => panic!("expected to be client"),
|
||||
}
|
||||
}
|
||||
}
|
||||
207
cli/src/state.rs
Normal file
207
cli/src/state.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
extern crate dirs;
|
||||
|
||||
use std::{
|
||||
fs::{create_dir_all, read_to_string, remove_dir_all, write},
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
use crate::{
|
||||
constants::{DEFAULT_DATA_PARENT_DIR, VSCODE_CLI_QUALITY},
|
||||
download_cache::DownloadCache,
|
||||
util::errors::{wrap, AnyError, NoHomeForLauncherError, WrappedError},
|
||||
};
|
||||
|
||||
const HOME_DIR_ALTS: [&str; 2] = ["$HOME", "~"];
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LauncherPaths {
|
||||
pub server_cache: DownloadCache,
|
||||
pub cli_cache: DownloadCache,
|
||||
root: PathBuf,
|
||||
}
|
||||
|
||||
struct PersistedStateContainer<T>
|
||||
where
|
||||
T: Clone + Serialize + DeserializeOwned + Default,
|
||||
{
|
||||
path: PathBuf,
|
||||
state: Option<T>,
|
||||
}
|
||||
|
||||
impl<T> PersistedStateContainer<T>
|
||||
where
|
||||
T: Clone + Serialize + DeserializeOwned + Default,
|
||||
{
|
||||
fn load_or_get(&mut self) -> T {
|
||||
if let Some(state) = &self.state {
|
||||
return state.clone();
|
||||
}
|
||||
|
||||
let state = if let Ok(s) = read_to_string(&self.path) {
|
||||
serde_json::from_str::<T>(&s).unwrap_or_default()
|
||||
} else {
|
||||
T::default()
|
||||
};
|
||||
|
||||
self.state = Some(state.clone());
|
||||
state
|
||||
}
|
||||
|
||||
fn save(&mut self, state: T) -> Result<(), WrappedError> {
|
||||
let s = serde_json::to_string(&state).unwrap();
|
||||
self.state = Some(state);
|
||||
write(&self.path, s).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error saving launcher state into {}", self.path.display()),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Container that holds some state value that is persisted to disk.
|
||||
#[derive(Clone)]
|
||||
pub struct PersistedState<T>
|
||||
where
|
||||
T: Clone + Serialize + DeserializeOwned + Default,
|
||||
{
|
||||
container: Arc<Mutex<PersistedStateContainer<T>>>,
|
||||
}
|
||||
|
||||
impl<T> PersistedState<T>
|
||||
where
|
||||
T: Clone + Serialize + DeserializeOwned + Default,
|
||||
{
|
||||
/// Creates a new state container that persists to the given path.
|
||||
pub fn new(path: PathBuf) -> PersistedState<T> {
|
||||
PersistedState {
|
||||
container: Arc::new(Mutex::new(PersistedStateContainer { path, state: None })),
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads persisted state.
|
||||
pub fn load(&self) -> T {
|
||||
self.container.lock().unwrap().load_or_get()
|
||||
}
|
||||
|
||||
/// Saves persisted state.
|
||||
pub fn save(&self, state: T) -> Result<(), WrappedError> {
|
||||
self.container.lock().unwrap().save(state)
|
||||
}
|
||||
|
||||
/// Mutates persisted state.
|
||||
pub fn update<R>(&self, mutator: impl FnOnce(&mut T) -> R) -> Result<R, WrappedError> {
|
||||
let mut container = self.container.lock().unwrap();
|
||||
let mut state = container.load_or_get();
|
||||
let r = mutator(&mut state);
|
||||
container.save(state).map(|_| r)
|
||||
}
|
||||
}
|
||||
|
||||
impl LauncherPaths {
|
||||
/// todo@conno4312: temporary migration from the old CLI data directory
|
||||
pub fn migrate(root: Option<String>) -> Result<LauncherPaths, AnyError> {
|
||||
if root.is_some() {
|
||||
return Self::new(root);
|
||||
}
|
||||
|
||||
let home_dir = match dirs::home_dir() {
|
||||
None => return Self::new(root),
|
||||
Some(d) => d,
|
||||
};
|
||||
|
||||
let old_dir = home_dir.join(".vscode-cli");
|
||||
let mut new_dir = home_dir;
|
||||
new_dir.push(DEFAULT_DATA_PARENT_DIR);
|
||||
new_dir.push("cli");
|
||||
if !old_dir.exists() || new_dir.exists() {
|
||||
return Self::new_for_path(new_dir);
|
||||
}
|
||||
|
||||
if let Err(e) = std::fs::rename(&old_dir, &new_dir) {
|
||||
// no logger exists at this point in the lifecycle, so just log to stderr
|
||||
eprintln!(
|
||||
"Failed to migrate old CLI data directory, will create a new one ({})",
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
Self::new_for_path(new_dir)
|
||||
}
|
||||
|
||||
pub fn new(root: Option<String>) -> Result<LauncherPaths, AnyError> {
|
||||
let root = root.unwrap_or_else(|| format!("~/{}/cli", DEFAULT_DATA_PARENT_DIR));
|
||||
let mut replaced = root.to_owned();
|
||||
for token in HOME_DIR_ALTS {
|
||||
if root.contains(token) {
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
replaced = root.replace(token, &home.to_string_lossy())
|
||||
} else {
|
||||
return Err(AnyError::from(NoHomeForLauncherError()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self::new_for_path(PathBuf::from(replaced))
|
||||
}
|
||||
|
||||
fn new_for_path(root: PathBuf) -> Result<LauncherPaths, AnyError> {
|
||||
if !root.exists() {
|
||||
create_dir_all(&root)
|
||||
.map_err(|e| wrap(e, format!("error creating directory {}", root.display())))?;
|
||||
}
|
||||
|
||||
Ok(LauncherPaths::new_without_replacements(root))
|
||||
}
|
||||
|
||||
pub fn new_without_replacements(root: PathBuf) -> LauncherPaths {
|
||||
// cleanup folders that existed before the new LRU strategy:
|
||||
let _ = std::fs::remove_dir_all(root.join("server-insiders"));
|
||||
let _ = std::fs::remove_dir_all(root.join("server-stable"));
|
||||
|
||||
LauncherPaths {
|
||||
server_cache: DownloadCache::new(root.join("servers")),
|
||||
cli_cache: DownloadCache::new(root.join("cli")),
|
||||
root,
|
||||
}
|
||||
}
|
||||
|
||||
/// Root directory for the server launcher
|
||||
pub fn root(&self) -> &Path {
|
||||
&self.root
|
||||
}
|
||||
|
||||
/// Lockfile for the running tunnel
|
||||
pub fn tunnel_lockfile(&self) -> PathBuf {
|
||||
self.root.join(format!(
|
||||
"tunnel-{}.lock",
|
||||
VSCODE_CLI_QUALITY.unwrap_or("oss")
|
||||
))
|
||||
}
|
||||
|
||||
/// Suggested path for tunnel service logs, when using file logs
|
||||
pub fn service_log_file(&self) -> PathBuf {
|
||||
self.root.join("tunnel-service.log")
|
||||
}
|
||||
|
||||
/// Removes the launcher data directory.
|
||||
pub fn remove(&self) -> Result<(), WrappedError> {
|
||||
remove_dir_all(&self.root).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!(
|
||||
"error removing launcher data directory {}",
|
||||
self.root.display()
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
40
cli/src/tunnels.rs
Normal file
40
cli/src/tunnels.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
pub mod code_server;
|
||||
pub mod dev_tunnels;
|
||||
pub mod legal;
|
||||
pub mod paths;
|
||||
pub mod protocol;
|
||||
pub mod shutdown_signal;
|
||||
pub mod singleton_client;
|
||||
pub mod singleton_server;
|
||||
|
||||
mod challenge;
|
||||
mod control_server;
|
||||
mod nosleep;
|
||||
#[cfg(target_os = "linux")]
|
||||
mod nosleep_linux;
|
||||
#[cfg(target_os = "macos")]
|
||||
mod nosleep_macos;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod nosleep_windows;
|
||||
mod port_forwarder;
|
||||
mod server_bridge;
|
||||
mod server_multiplexer;
|
||||
mod service;
|
||||
#[cfg(target_os = "linux")]
|
||||
mod service_linux;
|
||||
#[cfg(target_os = "macos")]
|
||||
mod service_macos;
|
||||
#[cfg(target_os = "windows")]
|
||||
mod service_windows;
|
||||
mod socket_signal;
|
||||
|
||||
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
|
||||
pub use nosleep::SleepInhibitor;
|
||||
pub use service::{
|
||||
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,
|
||||
};
|
||||
41
cli/src/tunnels/challenge.rs
Normal file
41
cli/src/tunnels/challenge.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn create_challenge() -> String {
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
Alphanumeric.sample_string(&mut rand::thread_rng(), 16)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn sign_challenge(challenge: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hash = Sha256::new();
|
||||
hash.update(challenge.as_bytes());
|
||||
let result = hash.finalize();
|
||||
base64::encode_config(result, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn verify_challenge(challenge: &str, response: &str) -> bool {
|
||||
sign_challenge(challenge) == response
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn create_challenge() -> String {
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
let str = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||
vsda::create_new_message(&str)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn sign_challenge(challenge: &str) -> String {
|
||||
vsda::sign(challenge)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn verify_challenge(challenge: &str, response: &str) -> bool {
|
||||
vsda::validate(challenge, response)
|
||||
}
|
||||
799
cli/src/tunnels/code_server.rs
Normal file
799
cli/src/tunnels/code_server.rs
Normal file
@@ -0,0 +1,799 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::paths::{InstalledServer, ServerPaths};
|
||||
use crate::async_pipe::get_socket_name;
|
||||
use crate::constants::{
|
||||
APPLICATION_NAME, EDITOR_WEB_URL, QUALITYLESS_PRODUCT_NAME, QUALITYLESS_SERVER_NAME,
|
||||
};
|
||||
use crate::download_cache::DownloadCache;
|
||||
use crate::options::{Quality, TelemetryLevel};
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::tunnels::paths::{get_server_folder_name, SERVER_FOLDER_NAME};
|
||||
use crate::update_service::{
|
||||
unzip_downloaded_release, Platform, Release, TargetKind, UpdateService,
|
||||
};
|
||||
use crate::util::command::{capture_command, kill_tree};
|
||||
use crate::util::errors::{wrap, AnyError, CodeError, ExtensionInstallFailed, WrappedError};
|
||||
use crate::util::http::{self, BoxedHttp};
|
||||
use crate::util::io::SilentCopyProgress;
|
||||
use crate::util::machine::process_exists;
|
||||
use crate::{debug, info, log, spanf, trace, warning};
|
||||
use lazy_static::lazy_static;
|
||||
use opentelemetry::KeyValue;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::fs::remove_file;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::oneshot::Receiver;
|
||||
use tokio::time::{interval, timeout};
|
||||
|
||||
lazy_static! {
|
||||
static ref LISTENING_PORT_RE: Regex =
|
||||
Regex::new(r"Extension host agent listening on (.+)").unwrap();
|
||||
static ref WEB_UI_RE: Regex = Regex::new(r"Web UI available at (.+)").unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct CodeServerArgs {
|
||||
pub host: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub socket_path: Option<String>,
|
||||
|
||||
// common argument
|
||||
pub telemetry_level: Option<TelemetryLevel>,
|
||||
pub log: Option<log::Level>,
|
||||
pub accept_server_license_terms: bool,
|
||||
pub verbose: bool,
|
||||
// extension management
|
||||
pub install_extensions: Vec<String>,
|
||||
pub uninstall_extensions: Vec<String>,
|
||||
pub list_extensions: bool,
|
||||
pub show_versions: bool,
|
||||
pub category: Option<String>,
|
||||
pub pre_release: bool,
|
||||
pub force: bool,
|
||||
pub start_server: bool,
|
||||
// connection tokens
|
||||
pub connection_token: Option<String>,
|
||||
pub connection_token_file: Option<String>,
|
||||
pub without_connection_token: bool,
|
||||
}
|
||||
|
||||
impl CodeServerArgs {
|
||||
pub fn log_level(&self) -> log::Level {
|
||||
if self.verbose {
|
||||
log::Level::Trace
|
||||
} else {
|
||||
self.log.unwrap_or(log::Level::Info)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn telemetry_disabled(&self) -> bool {
|
||||
self.telemetry_level == Some(TelemetryLevel::Off)
|
||||
}
|
||||
|
||||
pub fn command_arguments(&self) -> Vec<String> {
|
||||
let mut args = Vec::new();
|
||||
if let Some(i) = &self.socket_path {
|
||||
args.push(format!("--socket-path={}", i));
|
||||
} else {
|
||||
if let Some(i) = &self.host {
|
||||
args.push(format!("--host={}", i));
|
||||
}
|
||||
if let Some(i) = &self.port {
|
||||
args.push(format!("--port={}", i));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(i) = &self.connection_token {
|
||||
args.push(format!("--connection-token={}", i));
|
||||
}
|
||||
if let Some(i) = &self.connection_token_file {
|
||||
args.push(format!("--connection-token-file={}", i));
|
||||
}
|
||||
if self.without_connection_token {
|
||||
args.push(String::from("--without-connection-token"));
|
||||
}
|
||||
if self.accept_server_license_terms {
|
||||
args.push(String::from("--accept-server-license-terms"));
|
||||
}
|
||||
if let Some(i) = self.telemetry_level {
|
||||
args.push(format!("--telemetry-level={}", i));
|
||||
}
|
||||
if let Some(i) = self.log {
|
||||
args.push(format!("--log={}", i));
|
||||
}
|
||||
|
||||
for extension in &self.install_extensions {
|
||||
args.push(format!("--install-extension={}", extension));
|
||||
}
|
||||
if !&self.install_extensions.is_empty() {
|
||||
if self.pre_release {
|
||||
args.push(String::from("--pre-release"));
|
||||
}
|
||||
if self.force {
|
||||
args.push(String::from("--force"));
|
||||
}
|
||||
}
|
||||
for extension in &self.uninstall_extensions {
|
||||
args.push(format!("--uninstall-extension={}", extension));
|
||||
}
|
||||
if self.list_extensions {
|
||||
args.push(String::from("--list-extensions"));
|
||||
if self.show_versions {
|
||||
args.push(String::from("--show-versions"));
|
||||
}
|
||||
if let Some(i) = &self.category {
|
||||
args.push(format!("--category={}", i));
|
||||
}
|
||||
}
|
||||
if self.start_server {
|
||||
args.push(String::from("--start-server"));
|
||||
}
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
/// Base server params that can be `resolve()`d to a `ResolvedServerParams`.
|
||||
/// Doing so fetches additional information like a commit ID if previously
|
||||
/// unspecified.
|
||||
pub struct ServerParamsRaw {
|
||||
pub commit_id: Option<String>,
|
||||
pub quality: Quality,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
pub headless: bool,
|
||||
pub platform: Platform,
|
||||
}
|
||||
|
||||
/// Server params that can be used to start a VS Code server.
|
||||
pub struct ResolvedServerParams {
|
||||
pub release: Release,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
}
|
||||
|
||||
impl ResolvedServerParams {
|
||||
fn as_installed_server(&self) -> InstalledServer {
|
||||
InstalledServer {
|
||||
commit: self.release.commit.clone(),
|
||||
quality: self.release.quality,
|
||||
headless: self.release.target == TargetKind::Server,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParamsRaw {
|
||||
pub async fn resolve(
|
||||
self,
|
||||
log: &log::Logger,
|
||||
http: BoxedHttp,
|
||||
) -> Result<ResolvedServerParams, AnyError> {
|
||||
Ok(ResolvedServerParams {
|
||||
release: self.get_or_fetch_commit_id(log, http).await?,
|
||||
code_server_args: self.code_server_args,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_or_fetch_commit_id(
|
||||
&self,
|
||||
log: &log::Logger,
|
||||
http: BoxedHttp,
|
||||
) -> Result<Release, AnyError> {
|
||||
let target = match self.headless {
|
||||
true => TargetKind::Server,
|
||||
false => TargetKind::Web,
|
||||
};
|
||||
|
||||
if let Some(c) = &self.commit_id {
|
||||
return Ok(Release {
|
||||
commit: c.clone(),
|
||||
quality: self.quality,
|
||||
target,
|
||||
name: String::new(),
|
||||
platform: self.platform,
|
||||
});
|
||||
}
|
||||
|
||||
UpdateService::new(log.clone(), http)
|
||||
.get_latest_commit(self.platform, target, self.quality)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[allow(dead_code)]
|
||||
struct UpdateServerVersion {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub product_version: String,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
/// Code server listening on a port address.
|
||||
#[derive(Clone)]
|
||||
pub struct SocketCodeServer {
|
||||
pub commit_id: String,
|
||||
pub socket: PathBuf,
|
||||
pub origin: Arc<CodeServerOrigin>,
|
||||
}
|
||||
|
||||
/// Code server listening on a socket address.
|
||||
#[derive(Clone)]
|
||||
pub struct PortCodeServer {
|
||||
pub commit_id: String,
|
||||
pub port: u16,
|
||||
pub origin: Arc<CodeServerOrigin>,
|
||||
}
|
||||
|
||||
/// A server listening on any address/location.
|
||||
pub enum AnyCodeServer {
|
||||
Socket(SocketCodeServer),
|
||||
Port(PortCodeServer),
|
||||
}
|
||||
|
||||
pub enum CodeServerOrigin {
|
||||
/// A new code server, that opens the barrier when it exits.
|
||||
New(Box<Child>),
|
||||
/// An existing code server with a PID.
|
||||
Existing(u32),
|
||||
}
|
||||
|
||||
impl CodeServerOrigin {
|
||||
pub async fn wait_for_exit(&mut self) {
|
||||
match self {
|
||||
CodeServerOrigin::New(child) => {
|
||||
child.wait().await.ok();
|
||||
}
|
||||
CodeServerOrigin::Existing(pid) => {
|
||||
let mut interval = interval(Duration::from_secs(30));
|
||||
while process_exists(*pid) {
|
||||
interval.tick().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn kill(&mut self) {
|
||||
match self {
|
||||
CodeServerOrigin::New(child) => {
|
||||
child.kill().await.ok();
|
||||
}
|
||||
CodeServerOrigin::Existing(pid) => {
|
||||
kill_tree(*pid).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures the given list of extensions are installed on the running server.
|
||||
async fn do_extension_install_on_running_server(
|
||||
start_script_path: &Path,
|
||||
extensions: &[String],
|
||||
log: &log::Logger,
|
||||
) -> Result<(), AnyError> {
|
||||
if extensions.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
debug!(log, "Installing extensions...");
|
||||
let command = format!(
|
||||
"{} {}",
|
||||
start_script_path.display(),
|
||||
extensions
|
||||
.iter()
|
||||
.map(|s| get_extensions_flag(s))
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
);
|
||||
|
||||
let result = capture_command("bash", &["-c", &command]).await?;
|
||||
if !result.status.success() {
|
||||
Err(AnyError::from(ExtensionInstallFailed(
|
||||
String::from_utf8_lossy(&result.stderr).to_string(),
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerBuilder<'a> {
|
||||
logger: &'a log::Logger,
|
||||
server_params: &'a ResolvedServerParams,
|
||||
launcher_paths: &'a LauncherPaths,
|
||||
server_paths: ServerPaths,
|
||||
http: BoxedHttp,
|
||||
}
|
||||
|
||||
impl<'a> ServerBuilder<'a> {
|
||||
pub fn new(
|
||||
logger: &'a log::Logger,
|
||||
server_params: &'a ResolvedServerParams,
|
||||
launcher_paths: &'a LauncherPaths,
|
||||
http: BoxedHttp,
|
||||
) -> Self {
|
||||
Self {
|
||||
logger,
|
||||
server_params,
|
||||
launcher_paths,
|
||||
server_paths: server_params
|
||||
.as_installed_server()
|
||||
.server_paths(launcher_paths),
|
||||
http,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets any already-running server from this directory.
|
||||
pub async fn get_running(&self) -> Result<Option<AnyCodeServer>, AnyError> {
|
||||
info!(
|
||||
self.logger,
|
||||
"Checking {} and {} for a running server...",
|
||||
self.server_paths.logfile.display(),
|
||||
self.server_paths.pidfile.display()
|
||||
);
|
||||
|
||||
let pid = match self.server_paths.get_running_pid() {
|
||||
Some(pid) => pid,
|
||||
None => return Ok(None),
|
||||
};
|
||||
info!(self.logger, "Found running server (pid={})", pid);
|
||||
if !Path::new(&self.server_paths.logfile).exists() {
|
||||
warning!(self.logger, "{} Server is running but its logfile is missing. Don't delete the {} Server manually, run the command '{} prune'.", QUALITYLESS_PRODUCT_NAME, QUALITYLESS_PRODUCT_NAME, APPLICATION_NAME);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
do_extension_install_on_running_server(
|
||||
&self.server_paths.executable,
|
||||
&self.server_params.code_server_args.install_extensions,
|
||||
self.logger,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let origin = Arc::new(CodeServerOrigin::Existing(pid));
|
||||
let contents = fs::read_to_string(&self.server_paths.logfile)
|
||||
.expect("Something went wrong reading log file");
|
||||
|
||||
if let Some(port) = parse_port_from(&contents) {
|
||||
Ok(Some(AnyCodeServer::Port(PortCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
port,
|
||||
origin,
|
||||
})))
|
||||
} else if let Some(socket) = parse_socket_from(&contents) {
|
||||
Ok(Some(AnyCodeServer::Socket(SocketCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
socket,
|
||||
origin,
|
||||
})))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures the server is set up in the configured directory.
|
||||
pub async fn setup(&self) -> Result<(), AnyError> {
|
||||
debug!(
|
||||
self.logger,
|
||||
"Installing and setting up {}...", QUALITYLESS_SERVER_NAME
|
||||
);
|
||||
|
||||
let update_service = UpdateService::new(self.logger.clone(), self.http.clone());
|
||||
let name = get_server_folder_name(
|
||||
self.server_params.release.quality,
|
||||
&self.server_params.release.commit,
|
||||
);
|
||||
|
||||
self.launcher_paths
|
||||
.server_cache
|
||||
.create(name, |target_dir| async move {
|
||||
let tmpdir =
|
||||
tempfile::tempdir().map_err(|e| wrap(e, "error creating temp download dir"))?;
|
||||
|
||||
let response = update_service
|
||||
.get_download_stream(&self.server_params.release)
|
||||
.await?;
|
||||
let archive_path = tmpdir.path().join(response.url_path_basename().unwrap());
|
||||
|
||||
info!(
|
||||
self.logger,
|
||||
"Downloading {} server -> {}",
|
||||
QUALITYLESS_PRODUCT_NAME,
|
||||
archive_path.display()
|
||||
);
|
||||
|
||||
http::download_into_file(
|
||||
&archive_path,
|
||||
self.logger.get_download_logger("server download progress:"),
|
||||
response,
|
||||
)
|
||||
.await?;
|
||||
|
||||
unzip_downloaded_release(
|
||||
&archive_path,
|
||||
&target_dir.join(SERVER_FOLDER_NAME),
|
||||
SilentCopyProgress(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
debug!(self.logger, "Server setup complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn listen_on_port(&self, port: u16) -> Result<PortCodeServer, AnyError> {
|
||||
let mut cmd = self.get_base_command();
|
||||
cmd.arg("--start-server")
|
||||
.arg("--enable-remote-auto-shutdown")
|
||||
.arg(format!("--port={}", port));
|
||||
|
||||
let child = self.spawn_server_process(cmd)?;
|
||||
let log_file = self.get_logfile()?;
|
||||
let plog = self.logger.prefixed(&log::new_code_server_prefix());
|
||||
|
||||
let (mut origin, listen_rx) =
|
||||
monitor_server::<PortMatcher, u16>(child, Some(log_file), plog, false);
|
||||
|
||||
let port = match timeout(Duration::from_secs(8), listen_rx).await {
|
||||
Err(e) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "timed out looking for port"))
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "server exited without writing port"))
|
||||
}
|
||||
Ok(Ok(p)) => Ok(p),
|
||||
}?;
|
||||
|
||||
info!(self.logger, "Server started");
|
||||
|
||||
Ok(PortCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
port,
|
||||
origin: Arc::new(origin),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn listen_on_default_socket(&self) -> Result<SocketCodeServer, AnyError> {
|
||||
let requested_file = get_socket_name();
|
||||
self.listen_on_socket(&requested_file).await
|
||||
}
|
||||
|
||||
pub async fn listen_on_socket(&self, socket: &Path) -> Result<SocketCodeServer, AnyError> {
|
||||
Ok(spanf!(
|
||||
self.logger,
|
||||
self.logger.span("server.start").with_attributes(vec! {
|
||||
KeyValue::new("commit_id", self.server_params.release.commit.to_string()),
|
||||
KeyValue::new("quality", format!("{}", self.server_params.release.quality)),
|
||||
}),
|
||||
self._listen_on_socket(socket)
|
||||
)?)
|
||||
}
|
||||
|
||||
async fn _listen_on_socket(&self, socket: &Path) -> Result<SocketCodeServer, AnyError> {
|
||||
remove_file(&socket).await.ok(); // ignore any error if it doesn't exist
|
||||
|
||||
let mut cmd = self.get_base_command();
|
||||
cmd.arg("--start-server")
|
||||
.arg("--enable-remote-auto-shutdown")
|
||||
.arg(format!("--socket-path={}", socket.display()));
|
||||
|
||||
let child = self.spawn_server_process(cmd)?;
|
||||
let log_file = self.get_logfile()?;
|
||||
let plog = self.logger.prefixed(&log::new_code_server_prefix());
|
||||
|
||||
let (mut origin, listen_rx) =
|
||||
monitor_server::<SocketMatcher, PathBuf>(child, Some(log_file), plog, false);
|
||||
|
||||
let socket = match timeout(Duration::from_secs(8), listen_rx).await {
|
||||
Err(e) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "timed out looking for socket"))
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "server exited without writing socket"))
|
||||
}
|
||||
Ok(Ok(socket)) => Ok(socket),
|
||||
}?;
|
||||
|
||||
info!(self.logger, "Server started");
|
||||
|
||||
Ok(SocketCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
socket,
|
||||
origin: Arc::new(origin),
|
||||
})
|
||||
}
|
||||
|
||||
/// Starts with a given opaque set of args. Does not set up any port or
|
||||
/// socket, but does return one if present, in the form of a channel.
|
||||
pub async fn start_opaque_with_args<M, R>(
|
||||
&self,
|
||||
args: &[String],
|
||||
) -> Result<(CodeServerOrigin, Receiver<R>), AnyError>
|
||||
where
|
||||
M: ServerOutputMatcher<R>,
|
||||
R: 'static + Send + std::fmt::Debug,
|
||||
{
|
||||
let mut cmd = self.get_base_command();
|
||||
cmd.args(args);
|
||||
|
||||
let child = self.spawn_server_process(cmd)?;
|
||||
let plog = self.logger.prefixed(&log::new_code_server_prefix());
|
||||
|
||||
Ok(monitor_server::<M, R>(child, None, plog, true))
|
||||
}
|
||||
|
||||
fn spawn_server_process(&self, mut cmd: Command) -> Result<Child, AnyError> {
|
||||
info!(self.logger, "Starting server...");
|
||||
|
||||
debug!(self.logger, "Starting server with command... {:?}", cmd);
|
||||
|
||||
let child = cmd
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| wrap(e, "error spawning server"))?;
|
||||
|
||||
self.server_paths
|
||||
.write_pid(child.id().expect("expected server to have pid"))?;
|
||||
|
||||
Ok(child)
|
||||
}
|
||||
|
||||
fn get_logfile(&self) -> Result<File, WrappedError> {
|
||||
File::create(&self.server_paths.logfile).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!(
|
||||
"error creating log file {}",
|
||||
self.server_paths.logfile.display()
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn get_base_command(&self) -> Command {
|
||||
let mut cmd = Command::new(&self.server_paths.executable);
|
||||
cmd.stdin(std::process::Stdio::null())
|
||||
.args(self.server_params.code_server_args.command_arguments());
|
||||
cmd
|
||||
}
|
||||
}
|
||||
|
||||
fn monitor_server<M, R>(
|
||||
mut child: Child,
|
||||
log_file: Option<File>,
|
||||
plog: log::Logger,
|
||||
write_directly: bool,
|
||||
) -> (CodeServerOrigin, Receiver<R>)
|
||||
where
|
||||
M: ServerOutputMatcher<R>,
|
||||
R: 'static + Send + std::fmt::Debug,
|
||||
{
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.expect("child did not have a handle to stdout");
|
||||
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.expect("child did not have a handle to stdout");
|
||||
|
||||
let (listen_tx, listen_rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
// Handle stderr and stdout in a separate task. Initially scan lines looking
|
||||
// for the listening port. Afterwards, just scan and write out to the file.
|
||||
tokio::spawn(async move {
|
||||
let mut stdout_reader = BufReader::new(stdout).lines();
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
let write_line = |line: &str| -> std::io::Result<()> {
|
||||
if let Some(mut f) = log_file.as_ref() {
|
||||
f.write_all(line.as_bytes())?;
|
||||
f.write_all(&[b'\n'])?;
|
||||
}
|
||||
if write_directly {
|
||||
println!("{}", line);
|
||||
} else {
|
||||
trace!(plog, line);
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
loop {
|
||||
let line = tokio::select! {
|
||||
l = stderr_reader.next_line() => l,
|
||||
l = stdout_reader.next_line() => l,
|
||||
};
|
||||
|
||||
match line {
|
||||
Err(e) => {
|
||||
trace!(plog, "error reading from stdout/stderr: {}", e);
|
||||
return;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Ok(Some(l)) => {
|
||||
write_line(&l).ok();
|
||||
|
||||
if let Some(listen_on) = M::match_line(&l) {
|
||||
trace!(plog, "parsed location: {:?}", listen_on);
|
||||
listen_tx.send(listen_on).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let line = tokio::select! {
|
||||
l = stderr_reader.next_line() => l,
|
||||
l = stdout_reader.next_line() => l,
|
||||
};
|
||||
|
||||
match line {
|
||||
Err(e) => {
|
||||
trace!(plog, "error reading from stdout/stderr: {}", e);
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Ok(Some(l)) => {
|
||||
write_line(&l).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let origin = CodeServerOrigin::New(Box::new(child));
|
||||
(origin, listen_rx)
|
||||
}
|
||||
|
||||
fn get_extensions_flag(extension_id: &str) -> String {
|
||||
format!("--install-extension={}", extension_id)
|
||||
}
|
||||
|
||||
/// A type that can be used to scan stdout from the VS Code server. Returns
|
||||
/// some other type that, in turn, is returned from starting the server.
|
||||
pub trait ServerOutputMatcher<R>
|
||||
where
|
||||
R: Send,
|
||||
{
|
||||
fn match_line(line: &str) -> Option<R>;
|
||||
}
|
||||
|
||||
/// Parses a line like "Extension host agent listening on /tmp/foo.sock"
|
||||
struct SocketMatcher();
|
||||
|
||||
impl ServerOutputMatcher<PathBuf> for SocketMatcher {
|
||||
fn match_line(line: &str) -> Option<PathBuf> {
|
||||
parse_socket_from(line)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a line like "Extension host agent listening on 9000"
|
||||
pub struct PortMatcher();
|
||||
|
||||
impl ServerOutputMatcher<u16> for PortMatcher {
|
||||
fn match_line(line: &str) -> Option<u16> {
|
||||
parse_port_from(line)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a line like "Web UI available at http://localhost:9000/?tkn=..."
|
||||
pub struct WebUiMatcher();
|
||||
|
||||
impl ServerOutputMatcher<reqwest::Url> for WebUiMatcher {
|
||||
fn match_line(line: &str) -> Option<reqwest::Url> {
|
||||
WEB_UI_RE.captures(line).and_then(|cap| {
|
||||
cap.get(1)
|
||||
.and_then(|uri| reqwest::Url::parse(uri.as_str()).ok())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Does not do any parsing and just immediately returns an empty result.
|
||||
pub struct NoOpMatcher();
|
||||
|
||||
impl ServerOutputMatcher<()> for NoOpMatcher {
|
||||
fn match_line(_: &str) -> Option<()> {
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_socket_from(text: &str) -> Option<PathBuf> {
|
||||
LISTENING_PORT_RE
|
||||
.captures(text)
|
||||
.and_then(|cap| cap.get(1).map(|path| PathBuf::from(path.as_str())))
|
||||
}
|
||||
|
||||
fn parse_port_from(text: &str) -> Option<u16> {
|
||||
LISTENING_PORT_RE.captures(text).and_then(|cap| {
|
||||
cap.get(1)
|
||||
.and_then(|path| path.as_str().parse::<u16>().ok())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn print_listening(log: &log::Logger, tunnel_name: &str) {
|
||||
debug!(
|
||||
log,
|
||||
"{} is listening for incoming connections", QUALITYLESS_SERVER_NAME
|
||||
);
|
||||
|
||||
let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from(""));
|
||||
let current_dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(""));
|
||||
|
||||
let dir = if home_dir == current_dir {
|
||||
PathBuf::from("")
|
||||
} else {
|
||||
current_dir
|
||||
};
|
||||
|
||||
let base_web_url = match EDITOR_WEB_URL {
|
||||
Some(u) => u,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let mut addr = url::Url::parse(base_web_url).unwrap();
|
||||
{
|
||||
let mut ps = addr.path_segments_mut().unwrap();
|
||||
ps.push("tunnel");
|
||||
ps.push(tunnel_name);
|
||||
for segment in &dir {
|
||||
let as_str = segment.to_string_lossy();
|
||||
if !(as_str.len() == 1 && as_str.starts_with(std::path::MAIN_SEPARATOR)) {
|
||||
ps.push(as_str.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let message = &format!("\nOpen this link in your browser {}\n", addr);
|
||||
log.result(message);
|
||||
}
|
||||
|
||||
pub async fn download_cli_into_cache(
|
||||
cache: &DownloadCache,
|
||||
release: &Release,
|
||||
update_service: &UpdateService,
|
||||
) -> Result<PathBuf, AnyError> {
|
||||
let cache_name = format!(
|
||||
"{}-{}-{}",
|
||||
release.quality, release.commit, release.platform
|
||||
);
|
||||
let cli_dir = cache
|
||||
.create(&cache_name, |target_dir| async move {
|
||||
let tmpdir =
|
||||
tempfile::tempdir().map_err(|e| wrap(e, "error creating temp download dir"))?;
|
||||
let response = update_service.get_download_stream(release).await?;
|
||||
|
||||
let name = response.url_path_basename().unwrap();
|
||||
let archive_path = tmpdir.path().join(name);
|
||||
http::download_into_file(&archive_path, SilentCopyProgress(), response).await?;
|
||||
unzip_downloaded_release(&archive_path, &target_dir, SilentCopyProgress())?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
let cli = std::fs::read_dir(cli_dir)
|
||||
.map_err(|_| CodeError::CorruptDownload("could not read cli folder contents"))?
|
||||
.next();
|
||||
|
||||
match cli {
|
||||
Some(Ok(cli)) => Ok(cli.path()),
|
||||
_ => {
|
||||
let _ = cache.delete(&cache_name);
|
||||
Err(CodeError::CorruptDownload("cli directory is empty").into())
|
||||
}
|
||||
}
|
||||
}
|
||||
1168
cli/src/tunnels/control_server.rs
Normal file
1168
cli/src/tunnels/control_server.rs
Normal file
File diff suppressed because it is too large
Load Diff
997
cli/src/tunnels/dev_tunnels.rs
Normal file
997
cli/src/tunnels/dev_tunnels.rs
Normal file
@@ -0,0 +1,997 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::auth;
|
||||
use crate::constants::{
|
||||
CONTROL_PORT, IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, PROTOCOL_VERSION_TAG_PREFIX,
|
||||
TUNNEL_SERVICE_USER_AGENT,
|
||||
};
|
||||
use crate::state::{LauncherPaths, PersistedState};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError,
|
||||
};
|
||||
use crate::util::input::prompt_placeholder;
|
||||
use crate::{debug, info, log, spanf, trace, warning};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use lazy_static::lazy_static;
|
||||
use rand::prelude::IteratorRandom;
|
||||
use regex::Regex;
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tunnels::connections::{ForwardedPortConnection, RelayTunnelHost};
|
||||
use tunnels::contracts::{
|
||||
Tunnel, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN, TUNNEL_PROTOCOL_AUTO,
|
||||
};
|
||||
use tunnels::management::{
|
||||
new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
|
||||
NO_REQUEST_OPTIONS,
|
||||
};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PersistedTunnel {
|
||||
pub name: String,
|
||||
pub id: String,
|
||||
pub cluster: String,
|
||||
}
|
||||
|
||||
impl PersistedTunnel {
|
||||
pub fn into_locator(self) -> TunnelLocator {
|
||||
TunnelLocator::ID {
|
||||
cluster: self.cluster,
|
||||
id: self.id,
|
||||
}
|
||||
}
|
||||
pub fn locator(&self) -> TunnelLocator {
|
||||
TunnelLocator::ID {
|
||||
cluster: self.cluster.clone(),
|
||||
id: self.id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait AccessTokenProvider: Send + Sync {
|
||||
/// Gets the current access token.
|
||||
async fn refresh_token(&self) -> Result<String, WrappedError>;
|
||||
}
|
||||
|
||||
/// Access token provider that provides a fixed token without refreshing.
|
||||
struct StaticAccessTokenProvider(String);
|
||||
|
||||
impl StaticAccessTokenProvider {
|
||||
pub fn new(token: String) -> Self {
|
||||
Self(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AccessTokenProvider for StaticAccessTokenProvider {
|
||||
async fn refresh_token(&self) -> Result<String, WrappedError> {
|
||||
Ok(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Access token provider that looks up the token from the tunnels API.
|
||||
struct LookupAccessTokenProvider {
|
||||
client: TunnelManagementClient,
|
||||
locator: TunnelLocator,
|
||||
log: log::Logger,
|
||||
initial_token: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
impl LookupAccessTokenProvider {
|
||||
pub fn new(
|
||||
client: TunnelManagementClient,
|
||||
locator: TunnelLocator,
|
||||
log: log::Logger,
|
||||
initial_token: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
locator,
|
||||
log,
|
||||
initial_token: Arc::new(Mutex::new(initial_token)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AccessTokenProvider for LookupAccessTokenProvider {
|
||||
async fn refresh_token(&self) -> Result<String, WrappedError> {
|
||||
if let Some(token) = self.initial_token.lock().unwrap().take() {
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
let tunnel_lookup = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.get"),
|
||||
self.client.get_tunnel(
|
||||
&self.locator,
|
||||
&TunnelRequestOptions {
|
||||
token_scopes: vec!["host".to_string()],
|
||||
..Default::default()
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
trace!(self.log, "Successfully refreshed access token");
|
||||
|
||||
match tunnel_lookup {
|
||||
Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
|
||||
Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DevTunnels {
|
||||
log: log::Logger,
|
||||
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
|
||||
client: TunnelManagementClient,
|
||||
}
|
||||
|
||||
/// Representation of a tunnel returned from the `start` methods.
|
||||
pub struct ActiveTunnel {
|
||||
/// Name of the tunnel
|
||||
pub name: String,
|
||||
/// Underlying dev tunnels ID
|
||||
pub id: String,
|
||||
manager: ActiveTunnelManager,
|
||||
}
|
||||
|
||||
impl ActiveTunnel {
|
||||
/// Closes and unregisters the tunnel.
|
||||
pub async fn close(&mut self) -> Result<(), AnyError> {
|
||||
self.manager.kill().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forwards a port to local connections.
|
||||
pub async fn add_port_direct(
|
||||
&mut self,
|
||||
port_number: u16,
|
||||
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
|
||||
let port = self.manager.add_port_direct(port_number).await?;
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
/// Forwards a port over TCP.
|
||||
pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> {
|
||||
self.manager.add_port_tcp(port_number).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a forwarded port TCP.
|
||||
pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> {
|
||||
self.manager.remove_port(port_number).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Gets the public URI on which a forwarded port can be access in browser.
|
||||
pub async fn get_port_uri(&mut self, port: u16) -> Result<String, AnyError> {
|
||||
let endpoint = self.manager.get_endpoint().await?;
|
||||
let format = endpoint
|
||||
.base
|
||||
.port_uri_format
|
||||
.expect("expected to have port format");
|
||||
|
||||
Ok(format.replace(PORT_TOKEN, &port.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
|
||||
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
|
||||
|
||||
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
|
||||
tunnel
|
||||
.access_tokens
|
||||
.as_ref()
|
||||
.expect("expected to have access tokens")
|
||||
.get("host")
|
||||
.expect("expected to have host token")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
|
||||
if name.len() > MAX_TUNNEL_NAME_LENGTH {
|
||||
return Err(InvalidTunnelName(format!(
|
||||
"Names cannot be longer than {} characters. Please try a different name.",
|
||||
MAX_TUNNEL_NAME_LENGTH
|
||||
)));
|
||||
}
|
||||
|
||||
let re = Regex::new(r"^([\w-]+)$").unwrap();
|
||||
|
||||
if !re.is_match(name) {
|
||||
return Err(InvalidTunnelName(
|
||||
"Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref HOST_TUNNEL_REQUEST_OPTIONS: TunnelRequestOptions = TunnelRequestOptions {
|
||||
include_ports: true,
|
||||
token_scopes: vec!["host".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
}
|
||||
|
||||
/// Structure optionally passed into `start_existing_tunnel` to forward an existing tunnel.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExistingTunnel {
|
||||
/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
|
||||
pub tunnel_name: String,
|
||||
|
||||
/// Token to authenticate and use preexisting tunnel
|
||||
pub host_token: String,
|
||||
|
||||
/// Id of preexisting tunnel to use to connect to the VS Code Server
|
||||
pub tunnel_id: String,
|
||||
|
||||
/// Cluster of preexisting tunnel to use to connect to the VS Code Server
|
||||
pub cluster: String,
|
||||
}
|
||||
|
||||
impl DevTunnels {
|
||||
pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels {
|
||||
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
|
||||
client.authorization_provider(auth);
|
||||
|
||||
DevTunnels {
|
||||
log: log.clone(),
|
||||
client: client.into(),
|
||||
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
|
||||
let tunnel = match self.launcher_tunnel.load() {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.delete"),
|
||||
self.client
|
||||
.delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
|
||||
|
||||
self.launcher_tunnel.save(None)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Renames the current tunnel to the new name.
|
||||
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
|
||||
self.update_tunnel_name(None, name).await.map(|_| ())
|
||||
}
|
||||
|
||||
/// Updates the name of the existing persisted tunnel to the new name.
|
||||
/// Gracefully creates a new tunnel if the previous one was deleted.
|
||||
async fn update_tunnel_name(
|
||||
&mut self,
|
||||
persisted: Option<PersistedTunnel>,
|
||||
name: &str,
|
||||
) -> Result<(Tunnel, PersistedTunnel), AnyError> {
|
||||
let name = name.to_ascii_lowercase();
|
||||
self.check_is_name_free(&name).await?;
|
||||
|
||||
debug!(self.log, "Tunnel name changed, applying updates...");
|
||||
|
||||
let (mut full_tunnel, mut persisted, is_new) = match persisted {
|
||||
Some(persisted) => {
|
||||
self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
}
|
||||
None => self
|
||||
.create_tunnel(&name, NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
.map(|(pt, t)| (t, pt, true)),
|
||||
}?;
|
||||
|
||||
if is_new {
|
||||
return Ok((full_tunnel, persisted));
|
||||
}
|
||||
|
||||
full_tunnel.tags = vec![name.to_string(), VSCODE_CLI_TUNNEL_TAG.to_string()];
|
||||
|
||||
let new_tunnel = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.update"),
|
||||
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to rename tunnel"))?;
|
||||
|
||||
persisted.name = name;
|
||||
self.launcher_tunnel.save(Some(persisted.clone()))?;
|
||||
|
||||
Ok((new_tunnel, persisted))
|
||||
}
|
||||
|
||||
/// Gets the persisted tunnel from the service, or creates a new one.
|
||||
/// If `create_with_new_name` is given, the new tunnel has that name
|
||||
/// instead of the one previously persisted.
|
||||
async fn get_or_create_tunnel(
|
||||
&mut self,
|
||||
persisted: PersistedTunnel,
|
||||
create_with_new_name: Option<&str>,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<(Tunnel, PersistedTunnel, /* is_new */ bool), AnyError> {
|
||||
let tunnel_lookup = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.get"),
|
||||
self.client.get_tunnel(&persisted.locator(), options)
|
||||
);
|
||||
|
||||
match tunnel_lookup {
|
||||
Ok(ft) => Ok((ft, persisted, false)),
|
||||
Err(HttpError::ResponseError(e))
|
||||
if e.status_code == StatusCode::NOT_FOUND
|
||||
|| e.status_code == StatusCode::FORBIDDEN =>
|
||||
{
|
||||
let (persisted, tunnel) = self
|
||||
.create_tunnel(create_with_new_name.unwrap_or(&persisted.name), options)
|
||||
.await?;
|
||||
Ok((tunnel, persisted, true))
|
||||
}
|
||||
Err(e) => Err(wrap(e, "failed to lookup tunnel").into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`,
|
||||
/// this attempts to reuse or create a tunnel of a preferred name or of a generated friendly tunnel name.
|
||||
pub async fn start_new_launcher_tunnel(
|
||||
&mut self,
|
||||
preferred_name: Option<&str>,
|
||||
use_random_name: bool,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
|
||||
Some(mut persisted) => {
|
||||
if let Some(preferred_name) = preferred_name.map(|n| n.to_ascii_lowercase()) {
|
||||
if persisted.name.to_ascii_lowercase() != preferred_name {
|
||||
(_, persisted) = self
|
||||
.update_tunnel_name(Some(persisted), &preferred_name)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let (tunnel, persisted, _) = self
|
||||
.get_or_create_tunnel(persisted, None, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.await?;
|
||||
(tunnel, persisted)
|
||||
}
|
||||
None => {
|
||||
debug!(self.log, "No code server tunnel found, creating new one");
|
||||
let name = self
|
||||
.get_name_for_tunnel(preferred_name, use_random_name)
|
||||
.await?;
|
||||
let (persisted, full_tunnel) = self
|
||||
.create_tunnel(&name, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.await?;
|
||||
(full_tunnel, persisted)
|
||||
}
|
||||
};
|
||||
|
||||
if !tunnel.tags.iter().any(|t| t == PROTOCOL_VERSION_TAG) {
|
||||
tunnel = self
|
||||
.update_protocol_version_tag(tunnel, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let locator = TunnelLocator::try_from(&tunnel).unwrap();
|
||||
let host_token = get_host_token_from_tunnel(&tunnel);
|
||||
|
||||
for port_to_delete in tunnel
|
||||
.ports
|
||||
.iter()
|
||||
.filter(|p| p.port_number != CONTROL_PORT)
|
||||
{
|
||||
let output_fut = self.client.delete_tunnel_port(
|
||||
&locator,
|
||||
port_to_delete.port_number,
|
||||
NO_REQUEST_OPTIONS,
|
||||
);
|
||||
spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.port.delete"),
|
||||
output_fut
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to delete port"))?;
|
||||
}
|
||||
|
||||
// cleanup any old trailing tunnel endpoints
|
||||
for endpoint in tunnel.endpoints {
|
||||
let fut = self.client.delete_tunnel_endpoints(
|
||||
&locator,
|
||||
&endpoint.host_id,
|
||||
None,
|
||||
NO_REQUEST_OPTIONS,
|
||||
);
|
||||
|
||||
spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
|
||||
.map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
|
||||
}
|
||||
|
||||
self.start_tunnel(
|
||||
locator.clone(),
|
||||
&persisted,
|
||||
self.client.clone(),
|
||||
LookupAccessTokenProvider::new(
|
||||
self.client.clone(),
|
||||
locator,
|
||||
self.log.clone(),
|
||||
Some(host_token),
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn create_tunnel(
|
||||
&mut self,
|
||||
name: &str,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<(PersistedTunnel, Tunnel), AnyError> {
|
||||
info!(self.log, "Creating tunnel with the name: {}", name);
|
||||
|
||||
let mut tried_recycle = false;
|
||||
|
||||
let new_tunnel = Tunnel {
|
||||
tags: vec![
|
||||
name.to_string(),
|
||||
PROTOCOL_VERSION_TAG.to_string(),
|
||||
VSCODE_CLI_TUNNEL_TAG.to_string(),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
loop {
|
||||
let result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.create"),
|
||||
self.client.create_tunnel(&new_tunnel, options)
|
||||
);
|
||||
|
||||
match result {
|
||||
Err(HttpError::ResponseError(e))
|
||||
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
|
||||
{
|
||||
if !tried_recycle && self.try_recycle_tunnel().await? {
|
||||
tried_recycle = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
"You've exceeded the 10 machine limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
format!("{:?}", e),
|
||||
)))
|
||||
}
|
||||
Ok(t) => {
|
||||
let pt = PersistedTunnel {
|
||||
cluster: t.cluster_id.clone().unwrap(),
|
||||
id: t.tunnel_id.clone().unwrap(),
|
||||
name: name.to_string(),
|
||||
};
|
||||
|
||||
self.launcher_tunnel.save(Some(pt.clone()))?;
|
||||
return Ok((pt, t));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures the tunnel contains a tag for the current PROTCOL_VERSION, and no
|
||||
/// other version tags.
|
||||
async fn update_protocol_version_tag(
|
||||
&self,
|
||||
tunnel: Tunnel,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<Tunnel, AnyError> {
|
||||
debug!(
|
||||
self.log,
|
||||
"Updating tunnel protocol version tag to {}", PROTOCOL_VERSION_TAG
|
||||
);
|
||||
let mut new_tags: Vec<String> = tunnel
|
||||
.tags
|
||||
.into_iter()
|
||||
.filter(|t| !t.starts_with(PROTOCOL_VERSION_TAG_PREFIX))
|
||||
.collect();
|
||||
new_tags.push(PROTOCOL_VERSION_TAG.to_string());
|
||||
|
||||
let tunnel_update = Tunnel {
|
||||
tags: new_tags,
|
||||
tunnel_id: tunnel.tunnel_id.clone(),
|
||||
cluster_id: tunnel.cluster_id.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.protocol-tag-update"),
|
||||
self.client.update_tunnel(&tunnel_update, options)
|
||||
);
|
||||
|
||||
result.map_err(|e| wrap(e, "tunnel tag update failed").into())
|
||||
}
|
||||
|
||||
/// Tries to delete an unused tunnel, and then creates a tunnel with the
|
||||
/// given `new_name`.
|
||||
async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
|
||||
trace!(
|
||||
self.log,
|
||||
"Tunnel limit hit, trying to recycle an old tunnel"
|
||||
);
|
||||
|
||||
let existing_tunnels = self.list_all_server_tunnels().await?;
|
||||
|
||||
let recyclable = existing_tunnels
|
||||
.iter()
|
||||
.filter(|t| {
|
||||
t.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.host_connection_count.as_ref())
|
||||
.map(|c| c.get_count())
|
||||
.unwrap_or(0) == 0
|
||||
})
|
||||
.choose(&mut rand::thread_rng());
|
||||
|
||||
match recyclable {
|
||||
Some(tunnel) => {
|
||||
trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
|
||||
spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.delete"),
|
||||
self.client
|
||||
.delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
|
||||
Ok(true)
|
||||
}
|
||||
None => {
|
||||
trace!(self.log, "No tunnels available to recycle");
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_all_server_tunnels(&mut self) -> Result<Vec<Tunnel>, AnyError> {
|
||||
let tunnels = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.listall"),
|
||||
self.client.list_all_tunnels(&TunnelRequestOptions {
|
||||
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string()],
|
||||
require_all_tags: true,
|
||||
..Default::default()
|
||||
})
|
||||
)
|
||||
.map_err(|e| wrap(e, "error listing current tunnels"))?;
|
||||
|
||||
Ok(tunnels)
|
||||
}
|
||||
|
||||
async fn check_is_name_free(&mut self, name: &str) -> Result<(), AnyError> {
|
||||
let existing = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.rename.search"),
|
||||
self.client.list_all_tunnels(&TunnelRequestOptions {
|
||||
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string(), name.to_string()],
|
||||
require_all_tags: true,
|
||||
..Default::default()
|
||||
})
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to list existing tunnels"))?;
|
||||
if !existing.is_empty() {
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
"tunnel name already in use".to_string(),
|
||||
)));
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_name_for_tunnel(
|
||||
&mut self,
|
||||
preferred_name: Option<&str>,
|
||||
mut use_random_name: bool,
|
||||
) -> Result<String, AnyError> {
|
||||
let existing_tunnels = self.list_all_server_tunnels().await?;
|
||||
let is_name_free = |n: &str| {
|
||||
!existing_tunnels.iter().any(|v| {
|
||||
v.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.host_connection_count.as_ref().map(|c| c.get_count()))
|
||||
.unwrap_or(0) > 0 && v.tags.iter().any(|t| t == n)
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(machine_name) = preferred_name {
|
||||
let name = machine_name.to_ascii_lowercase();
|
||||
if let Err(e) = is_valid_name(&name) {
|
||||
info!(self.log, "{} is an invalid name", e);
|
||||
return Err(AnyError::from(wrap(e, "invalid name")));
|
||||
}
|
||||
if is_name_free(&name) {
|
||||
return Ok(name);
|
||||
}
|
||||
info!(
|
||||
self.log,
|
||||
"{} is already taken, using a random name instead", &name
|
||||
);
|
||||
use_random_name = true;
|
||||
}
|
||||
|
||||
let mut placeholder_name =
|
||||
clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
|
||||
placeholder_name.make_ascii_lowercase();
|
||||
|
||||
if !is_name_free(&placeholder_name) {
|
||||
for i in 2.. {
|
||||
let fixed_name = format!("{}{}", placeholder_name, i);
|
||||
if is_name_free(&fixed_name) {
|
||||
placeholder_name = fixed_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if use_random_name || !*IS_INTERACTIVE_CLI {
|
||||
return Ok(placeholder_name);
|
||||
}
|
||||
|
||||
loop {
|
||||
let mut name = prompt_placeholder(
|
||||
"What would you like to call this machine?",
|
||||
&placeholder_name,
|
||||
)?;
|
||||
|
||||
name.make_ascii_lowercase();
|
||||
|
||||
if let Err(e) = is_valid_name(&name) {
|
||||
info!(self.log, "{}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_name_free(&name) {
|
||||
return Ok(name);
|
||||
}
|
||||
|
||||
info!(self.log, "The name {} is already in use", name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Hosts an existing tunnel, where the tunnel ID and host token are given.
|
||||
pub async fn start_existing_tunnel(
|
||||
&mut self,
|
||||
tunnel: ExistingTunnel,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let tunnel_details = PersistedTunnel {
|
||||
name: tunnel.tunnel_name,
|
||||
id: tunnel.tunnel_id,
|
||||
cluster: tunnel.cluster,
|
||||
};
|
||||
|
||||
let mut mgmt = self.client.build();
|
||||
mgmt.authorization(tunnels::management::Authorization::Tunnel(
|
||||
tunnel.host_token.clone(),
|
||||
));
|
||||
|
||||
self.start_tunnel(
|
||||
tunnel_details.locator(),
|
||||
&tunnel_details,
|
||||
mgmt.into(),
|
||||
StaticAccessTokenProvider::new(tunnel.host_token),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn start_tunnel(
|
||||
&mut self,
|
||||
locator: TunnelLocator,
|
||||
tunnel_details: &PersistedTunnel,
|
||||
client: TunnelManagementClient,
|
||||
access_token: impl AccessTokenProvider + 'static,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
|
||||
|
||||
let endpoint_result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.serve.callback"),
|
||||
manager.get_endpoint()
|
||||
);
|
||||
|
||||
let endpoint = match endpoint_result {
|
||||
Ok(endpoint) => endpoint,
|
||||
Err(e) => {
|
||||
error!(self.log, "Error connecting to tunnel endpoint: {}", e);
|
||||
manager.kill().await.ok();
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
|
||||
|
||||
Ok(ActiveTunnel {
|
||||
name: tunnel_details.name.clone(),
|
||||
id: tunnel_details.id.clone(),
|
||||
manager,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ActiveTunnelManager {
|
||||
close_tx: Option<mpsc::Sender<()>>,
|
||||
endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
|
||||
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
|
||||
}
|
||||
|
||||
impl ActiveTunnelManager {
|
||||
pub fn new(
|
||||
log: log::Logger,
|
||||
mgmt: TunnelManagementClient,
|
||||
locator: TunnelLocator,
|
||||
access_token: impl AccessTokenProvider + 'static,
|
||||
) -> ActiveTunnelManager {
|
||||
let (endpoint_tx, endpoint_rx) = watch::channel(None);
|
||||
let (close_tx, close_rx) = mpsc::channel(1);
|
||||
|
||||
let relay = Arc::new(tokio::sync::Mutex::new(RelayTunnelHost::new(locator, mgmt)));
|
||||
let relay_spawned = relay.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
ActiveTunnelManager::spawn_tunnel(
|
||||
log,
|
||||
relay_spawned,
|
||||
close_rx,
|
||||
endpoint_tx,
|
||||
access_token,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
ActiveTunnelManager {
|
||||
endpoint_rx,
|
||||
relay,
|
||||
close_tx: Some(close_tx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a port for TCP/IP forwarding.
|
||||
#[allow(dead_code)] // todo: port forwarding
|
||||
pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), WrappedError> {
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.add_port(&TunnelPort {
|
||||
port_number,
|
||||
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error adding port to relay"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a port for TCP/IP forwarding.
|
||||
pub async fn add_port_direct(
|
||||
&self,
|
||||
port_number: u16,
|
||||
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.add_port_raw(&TunnelPort {
|
||||
port_number,
|
||||
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error adding port to relay"))
|
||||
}
|
||||
|
||||
/// Removes a port from TCP/IP forwarding.
|
||||
pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.remove_port(port_number)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error remove port from relay"))
|
||||
}
|
||||
|
||||
/// Gets the most recent details from the tunnel process. Returns None if
|
||||
/// the process exited before providing details.
|
||||
pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
|
||||
loop {
|
||||
if let Some(details) = &*self.endpoint_rx.borrow() {
|
||||
return details.clone().map_err(AnyError::from);
|
||||
}
|
||||
|
||||
if self.endpoint_rx.changed().await.is_err() {
|
||||
return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Kills the process, and waits for it to exit.
|
||||
/// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works
|
||||
pub async fn kill(&mut self) -> Result<(), AnyError> {
|
||||
if let Some(tx) = self.close_tx.take() {
|
||||
drop(tx);
|
||||
}
|
||||
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.unregister()
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error unregistering relay"))?;
|
||||
|
||||
while self.endpoint_rx.changed().await.is_ok() {}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn spawn_tunnel(
|
||||
log: log::Logger,
|
||||
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
|
||||
mut close_rx: mpsc::Receiver<()>,
|
||||
endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
|
||||
access_token_provider: impl AccessTokenProvider + 'static,
|
||||
) {
|
||||
let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
|
||||
|
||||
macro_rules! fail {
|
||||
($e: expr, $msg: expr) => {
|
||||
warning!(log, "{}: {}", $msg, $e);
|
||||
endpoint_tx.send(Some(Err($e))).ok();
|
||||
backoff.delay().await;
|
||||
};
|
||||
}
|
||||
|
||||
loop {
|
||||
debug!(log, "Starting tunnel to server...");
|
||||
|
||||
let access_token = match access_token_provider.refresh_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
fail!(e, "Error refreshing access token, will retry");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// we don't bother making a client that can refresh the token, since
|
||||
// the tunnel won't be able to host as soon as the access token expires.
|
||||
let handle_res = {
|
||||
let mut relay = relay.lock().await;
|
||||
relay
|
||||
.connect(&access_token)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error connecting to tunnel"))
|
||||
};
|
||||
|
||||
let mut handle = match handle_res {
|
||||
Ok(handle) => handle,
|
||||
Err(e) => {
|
||||
fail!(e, "Error connecting to relay, will retry");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff.reset();
|
||||
endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
|
||||
|
||||
tokio::select! {
|
||||
// error is mapped like this prevent it being used across an await,
|
||||
// which Rust dislikes since there's a non-sendable dyn Error in there
|
||||
res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
|
||||
if let Err(e) = res {
|
||||
fail!(e, "Tunnel exited unexpectedly, reconnecting");
|
||||
} else {
|
||||
warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
|
||||
backoff.delay().await;
|
||||
}
|
||||
},
|
||||
_ = close_rx.recv() => {
|
||||
trace!(log, "Tunnel closing gracefully");
|
||||
trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Backoff {
|
||||
failures: u32,
|
||||
base_duration: Duration,
|
||||
max_duration: Duration,
|
||||
}
|
||||
|
||||
impl Backoff {
|
||||
pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
|
||||
Self {
|
||||
failures: 0,
|
||||
base_duration,
|
||||
max_duration,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delay(&mut self) {
|
||||
tokio::time::sleep(self.next()).await
|
||||
}
|
||||
|
||||
pub fn next(&mut self) -> Duration {
|
||||
self.failures += 1;
|
||||
let duration = self
|
||||
.base_duration
|
||||
.checked_mul(self.failures)
|
||||
.unwrap_or(self.max_duration);
|
||||
std::cmp::min(duration, self.max_duration)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.failures = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleans up the hostname so it can be used as a tunnel name.
|
||||
/// See TUNNEL_NAME_PATTERN in the tunnels SDK for the rules we try to use.
|
||||
fn clean_hostname_for_tunnel(hostname: &str) -> String {
|
||||
let mut out = String::new();
|
||||
for char in hostname.chars().take(60) {
|
||||
match char {
|
||||
'-' | '_' | ' ' => {
|
||||
out.push('-');
|
||||
}
|
||||
'0'..='9' | 'a'..='z' | 'A'..='Z' => {
|
||||
out.push(char);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let trimmed = out.trim_matches('-');
|
||||
if trimmed.len() < 2 {
|
||||
"remote-machine".to_string() // placeholder if the result was empty
|
||||
} else {
|
||||
trimmed.to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_clean_hostname_for_tunnel() {
|
||||
assert_eq!(
|
||||
clean_hostname_for_tunnel("hello123"),
|
||||
"hello123".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
clean_hostname_for_tunnel("-cool-name-"),
|
||||
"cool-name".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
clean_hostname_for_tunnel("cool!name with_chars"),
|
||||
"coolname-with-chars".to_string()
|
||||
);
|
||||
assert_eq!(clean_hostname_for_tunnel("z"), "remote-machine".to_string());
|
||||
}
|
||||
}
|
||||
66
cli/src/tunnels/legal.rs
Normal file
66
cli/src/tunnels/legal.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::constants::{IS_INTERACTIVE_CLI, PRODUCT_NAME_LONG};
|
||||
use crate::state::{LauncherPaths, PersistedState};
|
||||
use crate::util::errors::{AnyError, MissingLegalConsent};
|
||||
use crate::util::input::prompt_yn;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const LICENSE_TEXT: Option<&'static str> = option_env!("VSCODE_CLI_REMOTE_LICENSE_TEXT");
|
||||
const LICENSE_PROMPT: Option<&'static str> = option_env!("VSCODE_CLI_REMOTE_LICENSE_PROMPT");
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
struct PersistedConsent {
|
||||
pub consented: Option<bool>,
|
||||
}
|
||||
|
||||
pub fn require_consent(
|
||||
paths: &LauncherPaths,
|
||||
accept_server_license_terms: bool,
|
||||
) -> Result<(), AnyError> {
|
||||
match LICENSE_TEXT {
|
||||
Some(t) => println!("{}", t.replace("\\n", "\r\n")),
|
||||
None => return Ok(()),
|
||||
}
|
||||
|
||||
let prompt = match LICENSE_PROMPT {
|
||||
Some(p) => p,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let license: PersistedState<PersistedConsent> =
|
||||
PersistedState::new(paths.root().join("license_consent.json"));
|
||||
|
||||
let mut load = license.load();
|
||||
if let Some(true) = load.consented {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if accept_server_license_terms {
|
||||
load.consented = Some(true);
|
||||
} else if !*IS_INTERACTIVE_CLI {
|
||||
return Err(MissingLegalConsent(
|
||||
"Run this command again with --accept-server-license-terms to indicate your agreement."
|
||||
.to_string(),
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
match prompt_yn(prompt) {
|
||||
Ok(true) => {
|
||||
load.consented = Some(true);
|
||||
}
|
||||
Ok(false) => {
|
||||
return Err(AnyError::from(MissingLegalConsent(format!(
|
||||
"Sorry you cannot use {} CLI without accepting the terms.",
|
||||
PRODUCT_NAME_LONG
|
||||
))))
|
||||
}
|
||||
Err(e) => return Err(AnyError::from(MissingLegalConsent(e.to_string()))),
|
||||
}
|
||||
}
|
||||
|
||||
license.save(load)?;
|
||||
Ok(())
|
||||
}
|
||||
13
cli/src/tunnels/nosleep.rs
Normal file
13
cli/src/tunnels/nosleep.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub type SleepInhibitor = super::nosleep_windows::SleepInhibitor;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub type SleepInhibitor = super::nosleep_linux::SleepInhibitor;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub type SleepInhibitor = super::nosleep_macos::SleepInhibitor;
|
||||
79
cli/src/tunnels/nosleep_linux.rs
Normal file
79
cli/src/tunnels/nosleep_linux.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use zbus::{dbus_proxy, Connection};
|
||||
|
||||
use crate::{
|
||||
constants::APPLICATION_NAME,
|
||||
util::errors::{wrap, AnyError},
|
||||
};
|
||||
|
||||
/// An basically undocumented API, but seems widely implemented, and is what
|
||||
/// browsers use for sleep inhibition. The downside is that it also *may*
|
||||
/// disable the screensaver. A much better and more granular API is available
|
||||
/// on `org.freedesktop.login1.Manager`, but this requires administrative
|
||||
/// permission to request inhibition, which is not possible here.
|
||||
///
|
||||
/// See https://source.chromium.org/chromium/chromium/src/+/main:services/device/wake_lock/power_save_blocker/power_save_blocker_linux.cc;l=54;drc=2e85357a8b76996981cc6f783853a49df2cedc3a
|
||||
#[dbus_proxy(
|
||||
interface = "org.freedesktop.PowerManagement.Inhibit",
|
||||
gen_blocking = false,
|
||||
default_service = "org.freedesktop.PowerManagement.Inhibit",
|
||||
default_path = "/org/freedesktop/PowerManagement/Inhibit"
|
||||
)]
|
||||
trait PMInhibitor {
|
||||
#[dbus_proxy(name = "Inhibit")]
|
||||
fn inhibit(&self, what: &str, why: &str) -> zbus::Result<u32>;
|
||||
}
|
||||
|
||||
/// A slightly better documented version which seems commonly used.
|
||||
#[dbus_proxy(
|
||||
interface = "org.freedesktop.ScreenSaver",
|
||||
gen_blocking = false,
|
||||
default_service = "org.freedesktop.ScreenSaver",
|
||||
default_path = "/org/freedesktop/ScreenSaver"
|
||||
)]
|
||||
trait ScreenSaver {
|
||||
#[dbus_proxy(name = "Inhibit")]
|
||||
fn inhibit(&self, what: &str, why: &str) -> zbus::Result<u32>;
|
||||
}
|
||||
|
||||
pub struct SleepInhibitor {
|
||||
_connection: Connection, // Inhibition is released when the connection is closed
|
||||
}
|
||||
|
||||
impl SleepInhibitor {
|
||||
pub async fn new() -> Result<Self, AnyError> {
|
||||
let connection = Connection::session()
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error creating dbus session"))?;
|
||||
|
||||
macro_rules! try_inhibit {
|
||||
($proxy:ident) => {
|
||||
match $proxy::new(&connection).await {
|
||||
Ok(proxy) => proxy.inhibit(APPLICATION_NAME, "running tunnel").await,
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if let Err(e1) = try_inhibit!(PMInhibitorProxy) {
|
||||
if let Err(e2) = try_inhibit!(ScreenSaverProxy) {
|
||||
return Err(wrap(
|
||||
e2,
|
||||
format!(
|
||||
"error requesting sleep inhibition, pminhibitor gave {}, screensaver gave",
|
||||
e1
|
||||
),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(SleepInhibitor {
|
||||
_connection: connection,
|
||||
})
|
||||
}
|
||||
}
|
||||
78
cli/src/tunnels/nosleep_macos.rs
Normal file
78
cli/src/tunnels/nosleep_macos.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::io;
|
||||
|
||||
use core_foundation::base::TCFType;
|
||||
use core_foundation::string::{CFString, CFStringRef};
|
||||
use libc::c_int;
|
||||
|
||||
use crate::constants::TUNNEL_ACTIVITY_NAME;
|
||||
|
||||
extern "C" {
|
||||
pub fn IOPMAssertionCreateWithName(
|
||||
assertion_type: CFStringRef,
|
||||
assertion_level: u32,
|
||||
assertion_name: CFStringRef,
|
||||
assertion_id: &mut u32,
|
||||
) -> c_int;
|
||||
|
||||
pub fn IOPMAssertionRelease(assertion_id: u32) -> c_int;
|
||||
}
|
||||
|
||||
const NUM_ASSERTIONS: usize = 2;
|
||||
|
||||
const ASSERTIONS: [&str; NUM_ASSERTIONS] = ["PreventUserIdleSystemSleep", "PreventSystemSleep"];
|
||||
|
||||
struct Assertion(u32);
|
||||
|
||||
impl Assertion {
|
||||
pub fn make(typ: &CFString, name: &CFString) -> io::Result<Self> {
|
||||
let mut assertion_id = 0;
|
||||
let result = unsafe {
|
||||
IOPMAssertionCreateWithName(
|
||||
typ.as_concrete_TypeRef(),
|
||||
255,
|
||||
name.as_concrete_TypeRef(),
|
||||
&mut assertion_id,
|
||||
)
|
||||
};
|
||||
|
||||
if result != 0 {
|
||||
Err(io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(Self(assertion_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Assertion {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
IOPMAssertionRelease(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SleepInhibitor {
|
||||
_assertions: Vec<Assertion>,
|
||||
}
|
||||
|
||||
impl SleepInhibitor {
|
||||
pub async fn new() -> io::Result<Self> {
|
||||
let mut assertions = Vec::with_capacity(NUM_ASSERTIONS);
|
||||
let assertion_name = CFString::from_static_string(TUNNEL_ACTIVITY_NAME);
|
||||
for typ in ASSERTIONS {
|
||||
assertions.push(Assertion::make(
|
||||
&CFString::from_static_string(typ),
|
||||
&assertion_name,
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
_assertions: assertions,
|
||||
})
|
||||
}
|
||||
}
|
||||
79
cli/src/tunnels/nosleep_windows.rs
Normal file
79
cli/src/tunnels/nosleep_windows.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::io;
|
||||
|
||||
use winapi::{
|
||||
ctypes::c_void,
|
||||
um::{
|
||||
handleapi::CloseHandle,
|
||||
minwinbase::REASON_CONTEXT,
|
||||
winbase::{PowerClearRequest, PowerCreateRequest, PowerSetRequest},
|
||||
winnt::{
|
||||
PowerRequestSystemRequired, POWER_REQUEST_CONTEXT_SIMPLE_STRING,
|
||||
POWER_REQUEST_CONTEXT_VERSION, POWER_REQUEST_TYPE,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::constants::TUNNEL_ACTIVITY_NAME;
|
||||
|
||||
struct Request(*mut c_void);
|
||||
|
||||
impl Request {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
let mut reason: Vec<u16> = TUNNEL_ACTIVITY_NAME.encode_utf16().collect();
|
||||
let mut context = REASON_CONTEXT {
|
||||
Version: POWER_REQUEST_CONTEXT_VERSION,
|
||||
Flags: POWER_REQUEST_CONTEXT_SIMPLE_STRING,
|
||||
..Default::default()
|
||||
};
|
||||
unsafe { *context.Reason.SimpleReasonString_mut() = reason.as_mut_ptr() };
|
||||
|
||||
let request = unsafe { PowerCreateRequest(&mut context) };
|
||||
if request.is_null() {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(Self(request))
|
||||
}
|
||||
|
||||
pub fn set(&self, request_type: POWER_REQUEST_TYPE) -> io::Result<()> {
|
||||
let result = unsafe { PowerSetRequest(self.0, request_type) };
|
||||
if result == 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Request {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
CloseHandle(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SleepInhibitor {
|
||||
request: Request,
|
||||
}
|
||||
|
||||
impl SleepInhibitor {
|
||||
pub async fn new() -> io::Result<Self> {
|
||||
let request = Request::new()?;
|
||||
request.set(PowerRequestSystemRequired)?;
|
||||
Ok(Self { request })
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SleepInhibitor {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
PowerClearRequest(self.request.0, PowerRequestSystemRequired);
|
||||
}
|
||||
}
|
||||
}
|
||||
154
cli/src/tunnels/paths.rs
Normal file
154
cli/src/tunnels/paths.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::{read_dir, read_to_string, remove_dir_all, write},
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
options::{self, Quality},
|
||||
state::LauncherPaths,
|
||||
util::{
|
||||
errors::{wrap, AnyError, WrappedError},
|
||||
machine,
|
||||
},
|
||||
};
|
||||
|
||||
pub const SERVER_FOLDER_NAME: &str = "server";
|
||||
|
||||
pub struct ServerPaths {
|
||||
// Directory into which the server is downloaded
|
||||
pub server_dir: PathBuf,
|
||||
// Executable path, within the server_id
|
||||
pub executable: PathBuf,
|
||||
// File where logs for the server should be written.
|
||||
pub logfile: PathBuf,
|
||||
// File where the process ID for the server should be written.
|
||||
pub pidfile: PathBuf,
|
||||
}
|
||||
|
||||
impl ServerPaths {
|
||||
// Queries the system to determine the process ID of the running server.
|
||||
// Returns the process ID, if the server is running.
|
||||
pub fn get_running_pid(&self) -> Option<u32> {
|
||||
if let Some(pid) = self.read_pid() {
|
||||
return match machine::process_at_path_exists(pid, &self.executable) {
|
||||
true => Some(pid),
|
||||
false => None,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(pid) = machine::find_running_process(&self.executable) {
|
||||
// attempt to backfill process ID:
|
||||
self.write_pid(pid).ok();
|
||||
return Some(pid);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Delete the server directory
|
||||
pub fn delete(&self) -> Result<(), WrappedError> {
|
||||
remove_dir_all(&self.server_dir).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error deleting server dir {}", self.server_dir.display()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// VS Code Server pid
|
||||
pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> {
|
||||
write(&self.pidfile, format!("{}", pid)).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error writing process id into {}", self.pidfile.display()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn read_pid(&self) -> Option<u32> {
|
||||
read_to_string(&self.pidfile)
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<u32>().ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct InstalledServer {
|
||||
pub quality: options::Quality,
|
||||
pub commit: String,
|
||||
pub headless: bool,
|
||||
}
|
||||
|
||||
impl InstalledServer {
|
||||
/// Gets path information about where a specific server should be stored.
|
||||
pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths {
|
||||
let server_dir = self.get_install_folder(p);
|
||||
ServerPaths {
|
||||
executable: server_dir
|
||||
.join(SERVER_FOLDER_NAME)
|
||||
.join("bin")
|
||||
.join(self.quality.server_entrypoint()),
|
||||
logfile: server_dir.join("log.txt"),
|
||||
pidfile: server_dir.join("pid.txt"),
|
||||
server_dir,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_install_folder(&self, p: &LauncherPaths) -> PathBuf {
|
||||
p.server_cache.path().join(if !self.headless {
|
||||
format!("{}-web", get_server_folder_name(self.quality, &self.commit))
|
||||
} else {
|
||||
get_server_folder_name(self.quality, &self.commit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Prunes servers not currently running, and returns the deleted servers.
|
||||
pub fn prune_stopped_servers(launcher_paths: &LauncherPaths) -> Result<Vec<ServerPaths>, AnyError> {
|
||||
get_all_servers(launcher_paths)
|
||||
.into_iter()
|
||||
.map(|s| s.server_paths(launcher_paths))
|
||||
.filter(|s| s.get_running_pid().is_none())
|
||||
.map(|s| s.delete().map(|_| s))
|
||||
.collect::<Result<_, _>>()
|
||||
.map_err(AnyError::from)
|
||||
}
|
||||
|
||||
// Gets a list of all servers which look like they might be running.
|
||||
pub fn get_all_servers(lp: &LauncherPaths) -> Vec<InstalledServer> {
|
||||
let mut servers: Vec<InstalledServer> = vec![];
|
||||
if let Ok(children) = read_dir(lp.server_cache.path()) {
|
||||
for child in children.flatten() {
|
||||
let fname = child.file_name();
|
||||
let fname = fname.to_string_lossy();
|
||||
let (quality, commit) = match fname.split_once('-') {
|
||||
Some(r) => r,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let quality = match options::Quality::try_from(quality) {
|
||||
Ok(q) => q,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
servers.push(InstalledServer {
|
||||
quality,
|
||||
commit: commit.to_string(),
|
||||
headless: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
servers
|
||||
}
|
||||
|
||||
pub fn get_server_folder_name(quality: Quality, commit: &str) -> String {
|
||||
format!("{}-{}", quality, commit)
|
||||
}
|
||||
131
cli/src/tunnels/port_forwarder.rs
Normal file
131
cli/src/tunnels/port_forwarder.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::{
|
||||
constants::CONTROL_PORT,
|
||||
util::errors::{AnyError, CannotForwardControlPort, ServerHasClosed},
|
||||
};
|
||||
|
||||
use super::dev_tunnels::ActiveTunnel;
|
||||
|
||||
pub enum PortForwardingRec {
|
||||
Forward(u16, oneshot::Sender<Result<String, AnyError>>),
|
||||
Unforward(u16, oneshot::Sender<Result<(), AnyError>>),
|
||||
}
|
||||
|
||||
/// Provides a port forwarding service for connected clients. Clients can make
|
||||
/// requests on it, which are (and *must be*) processed by calling the `.process()`
|
||||
/// method on the forwarder.
|
||||
pub struct PortForwardingProcessor {
|
||||
tx: mpsc::Sender<PortForwardingRec>,
|
||||
rx: mpsc::Receiver<PortForwardingRec>,
|
||||
forwarded: HashSet<u16>,
|
||||
}
|
||||
|
||||
impl PortForwardingProcessor {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = mpsc::channel(8);
|
||||
Self {
|
||||
tx,
|
||||
rx,
|
||||
forwarded: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a handle that can be passed off to consumers of port forwarding.
|
||||
pub fn handle(&self) -> PortForwarding {
|
||||
PortForwarding {
|
||||
tx: self.tx.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Receives port forwarding requests. Consumers MUST call `process()`
|
||||
/// with the received requests.
|
||||
pub async fn recv(&mut self) -> Option<PortForwardingRec> {
|
||||
self.rx.recv().await
|
||||
}
|
||||
|
||||
/// Processes the incoming forwarding request.
|
||||
pub async fn process(&mut self, req: PortForwardingRec, tunnel: &mut ActiveTunnel) {
|
||||
match req {
|
||||
PortForwardingRec::Forward(port, tx) => {
|
||||
tx.send(self.process_forward(port, tunnel).await).ok();
|
||||
}
|
||||
PortForwardingRec::Unforward(port, tx) => {
|
||||
tx.send(self.process_unforward(port, tunnel).await).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_unforward(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tunnel: &mut ActiveTunnel,
|
||||
) -> Result<(), AnyError> {
|
||||
if port == CONTROL_PORT {
|
||||
return Err(CannotForwardControlPort().into());
|
||||
}
|
||||
|
||||
tunnel.remove_port(port).await?;
|
||||
self.forwarded.remove(&port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_forward(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tunnel: &mut ActiveTunnel,
|
||||
) -> Result<String, AnyError> {
|
||||
if port == CONTROL_PORT {
|
||||
return Err(CannotForwardControlPort().into());
|
||||
}
|
||||
|
||||
if !self.forwarded.contains(&port) {
|
||||
tunnel.add_port_tcp(port).await?;
|
||||
self.forwarded.insert(port);
|
||||
}
|
||||
|
||||
tunnel.get_port_uri(port).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PortForwarding {
|
||||
tx: mpsc::Sender<PortForwardingRec>,
|
||||
}
|
||||
|
||||
impl PortForwarding {
|
||||
pub async fn forward(&self, port: u16) -> Result<String, AnyError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let req = PortForwardingRec::Forward(port, tx);
|
||||
|
||||
if self.tx.send(req).await.is_err() {
|
||||
return Err(ServerHasClosed().into());
|
||||
}
|
||||
|
||||
match rx.await {
|
||||
Ok(r) => r,
|
||||
Err(_) => Err(ServerHasClosed().into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unforward(&self, port: u16) -> Result<(), AnyError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let req = PortForwardingRec::Unforward(port, tx);
|
||||
|
||||
if self.tx.send(req).await.is_err() {
|
||||
return Err(ServerHasClosed().into());
|
||||
}
|
||||
|
||||
match rx.await {
|
||||
Ok(r) => r,
|
||||
Err(_) => Err(ServerHasClosed().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
249
cli/src/tunnels/protocol.rs
Normal file
249
cli/src/tunnels/protocol.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
constants::{PROTOCOL_VERSION, VSCODE_CLI_VERSION},
|
||||
options::Quality,
|
||||
update_service::Platform,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum ClientRequestMethod<'a> {
|
||||
servermsg(RefServerMessageParams<'a>),
|
||||
serverlog(ServerLog<'a>),
|
||||
makehttpreq(HttpRequestParams<'a>),
|
||||
version(VersionResponse),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct HttpBodyParams {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: Vec<u8>,
|
||||
pub complete: bool,
|
||||
pub req_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct HttpRequestParams<'a> {
|
||||
pub url: &'a str,
|
||||
pub method: &'static str,
|
||||
pub req_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct HttpHeadersParams {
|
||||
pub status_code: u16,
|
||||
pub headers: Vec<(String, String)>,
|
||||
pub req_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ForwardParams {
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct UnforwardParams {
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ForwardResult {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ServeParams {
|
||||
pub socket_id: u16,
|
||||
pub commit_id: Option<String>,
|
||||
pub quality: Quality,
|
||||
pub extensions: Vec<String>,
|
||||
/// Optional preferred connection token.
|
||||
#[serde(default)]
|
||||
pub connection_token: Option<String>,
|
||||
#[serde(default)]
|
||||
pub use_local_download: bool,
|
||||
/// If true, the client and server should gzip servermsg's sent in either direction.
|
||||
#[serde(default)]
|
||||
pub compress: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct EmptyObject {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct UpdateParams {
|
||||
pub do_update: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ServerMessageParams {
|
||||
pub i: u16,
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct RefServerMessageParams<'a> {
|
||||
pub i: u16,
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub body: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct UpdateResult {
|
||||
pub up_to_date: bool,
|
||||
pub did_update: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ToClientRequest<'a> {
|
||||
pub id: Option<u32>,
|
||||
#[serde(flatten)]
|
||||
pub params: ClientRequestMethod<'a>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct ServerLog<'a> {
|
||||
pub line: &'a str,
|
||||
pub level: u8,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetHostnameResponse {
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetEnvResponse {
|
||||
pub env: HashMap<String, String>,
|
||||
pub os_platform: &'static str,
|
||||
pub os_release: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FsStatRequest {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct FsStatResponse {
|
||||
pub exists: bool,
|
||||
pub size: Option<u64>,
|
||||
#[serde(rename = "type")]
|
||||
pub kind: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct CallServerHttpParams {
|
||||
pub path: String,
|
||||
pub method: String,
|
||||
pub headers: HashMap<String, String>,
|
||||
pub body: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct CallServerHttpResult {
|
||||
pub status: u16,
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub body: Vec<u8>,
|
||||
pub headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct VersionResponse {
|
||||
pub version: &'static str,
|
||||
pub protocol_version: u32,
|
||||
}
|
||||
|
||||
impl Default for VersionResponse {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
version: VSCODE_CLI_VERSION.unwrap_or("dev"),
|
||||
protocol_version: PROTOCOL_VERSION,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SpawnParams {
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub cwd: Option<String>,
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AcquireCliParams {
|
||||
pub platform: Platform,
|
||||
pub quality: Quality,
|
||||
pub commit_id: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub spawn: SpawnParams,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SpawnResult {
|
||||
pub message: String,
|
||||
pub exit_code: i32,
|
||||
}
|
||||
|
||||
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
|
||||
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueResponse {
|
||||
pub challenge: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct ChallengeVerifyParams {
|
||||
pub response: String,
|
||||
}
|
||||
|
||||
pub mod singleton {
|
||||
use crate::log;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const METHOD_RESTART: &str = "restart";
|
||||
pub const METHOD_SHUTDOWN: &str = "shutdown";
|
||||
pub const METHOD_STATUS: &str = "status";
|
||||
pub const METHOD_LOG: &str = "log";
|
||||
pub const METHOD_LOG_REPLY_DONE: &str = "log_done";
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LogMessage<'a> {
|
||||
pub level: Option<log::Level>,
|
||||
pub prefix: &'a str,
|
||||
pub message: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LogMessageOwned {
|
||||
pub level: Option<log::Level>,
|
||||
pub prefix: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Status {
|
||||
pub tunnel: TunnelState,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct LogReplayFinished {}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub enum TunnelState {
|
||||
Disconnected,
|
||||
Connected { name: String },
|
||||
}
|
||||
}
|
||||
62
cli/src/tunnels/server_bridge.rs
Normal file
62
cli/src/tunnels/server_bridge.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
|
||||
use crate::{
|
||||
async_pipe::{get_socket_rw_stream, socket_stream_split, AsyncPipeWriteHalf},
|
||||
util::errors::AnyError,
|
||||
};
|
||||
use std::path::Path;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
pub struct ServerBridge {
|
||||
write: AsyncPipeWriteHalf,
|
||||
decoder: ClientMessageDecoder,
|
||||
}
|
||||
|
||||
const BUFFER_SIZE: usize = 65536;
|
||||
|
||||
impl ServerBridge {
|
||||
pub async fn new(
|
||||
path: &Path,
|
||||
mut target: ServerMessageSink,
|
||||
decoder: ClientMessageDecoder,
|
||||
) -> Result<Self, AnyError> {
|
||||
let stream = get_socket_rw_stream(path).await?;
|
||||
let (mut read, write) = socket_stream_split(stream);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut read_buf = vec![0; BUFFER_SIZE];
|
||||
loop {
|
||||
match read.read(&mut read_buf).await {
|
||||
Err(_) => return,
|
||||
Ok(0) => {
|
||||
return; // EOF
|
||||
}
|
||||
Ok(s) => {
|
||||
let send = target.server_message(&read_buf[..s]).await;
|
||||
if send.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ServerBridge { write, decoder })
|
||||
}
|
||||
|
||||
pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
|
||||
let dec = self.decoder.decode(&b)?;
|
||||
if !dec.is_empty() {
|
||||
self.write.write_all(dec).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn close(mut self) -> std::io::Result<()> {
|
||||
self.write.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
145
cli/src/tunnels/server_multiplexer.rs
Normal file
145
cli/src/tunnels/server_multiplexer.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::future::join_all;
|
||||
|
||||
use crate::log;
|
||||
|
||||
use super::server_bridge::ServerBridge;
|
||||
|
||||
type Inner = Arc<std::sync::Mutex<Option<Vec<ServerBridgeRec>>>>;
|
||||
|
||||
struct ServerBridgeRec {
|
||||
id: u16,
|
||||
// bridge is removed when there's a write loop currently active
|
||||
bridge: Option<ServerBridge>,
|
||||
write_queue: Vec<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// The ServerMultiplexer manages multiple server bridges and allows writing
|
||||
/// to them in a thread-safe way. It is copy, sync, and clone.
|
||||
#[derive(Clone)]
|
||||
pub struct ServerMultiplexer {
|
||||
inner: Inner,
|
||||
}
|
||||
|
||||
impl ServerMultiplexer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(std::sync::Mutex::new(Some(Vec::new()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new bridge to the multiplexer.
|
||||
pub fn register(&self, id: u16, bridge: ServerBridge) {
|
||||
let bridge_rec = ServerBridgeRec {
|
||||
id,
|
||||
bridge: Some(bridge),
|
||||
write_queue: vec![],
|
||||
};
|
||||
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
match &mut *lock {
|
||||
Some(server_bridges) => (*server_bridges).push(bridge_rec),
|
||||
None => *lock = Some(vec![bridge_rec]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a server bridge by ID.
|
||||
pub fn remove(&self, id: u16) {
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
if let Some(bridges) = &mut *lock {
|
||||
bridges.retain(|sb| sb.id != id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming server message. This is synchronous and uses a 'write loop'
|
||||
/// to ensure message order is preserved exactly, which is necessary for compression.
|
||||
/// Returns false if there was no server with the given bridge_id.
|
||||
pub fn write_message(&self, log: &log::Logger, bridge_id: u16, message: Vec<u8>) -> bool {
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
|
||||
let bridges = match &mut *lock {
|
||||
Some(sb) => sb,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let record = match bridges.iter_mut().find(|b| b.id == bridge_id) {
|
||||
Some(sb) => sb,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
record.write_queue.push(message);
|
||||
if let Some(bridge) = record.bridge.take() {
|
||||
let bridges_lock = self.inner.clone();
|
||||
let log = log.clone();
|
||||
tokio::spawn(write_loop(log, record.id, bridge, bridges_lock));
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Disposes all running server bridges.
|
||||
pub async fn dispose(&self) {
|
||||
let bridges = {
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
lock.take()
|
||||
};
|
||||
|
||||
let bridges = match bridges {
|
||||
Some(b) => b,
|
||||
None => return,
|
||||
};
|
||||
|
||||
join_all(
|
||||
bridges
|
||||
.into_iter()
|
||||
.filter_map(|b| b.bridge)
|
||||
.map(|b| b.close()),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Write loop started by `handle_server_message`. It takes the ServerBridge, and
|
||||
/// runs until there's no more items in the 'write queue'. At that point, if the
|
||||
/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll
|
||||
/// return the ServerBridge so that the next handle_server_message call starts
|
||||
/// the loop again. Otherwise, it'll close the bridge.
|
||||
async fn write_loop(log: log::Logger, id: u16, mut bridge: ServerBridge, bridges_lock: Inner) {
|
||||
let mut items_vec = vec![];
|
||||
loop {
|
||||
{
|
||||
let mut lock = bridges_lock.lock().unwrap();
|
||||
let server_bridges = match &mut *lock {
|
||||
Some(sb) => sb,
|
||||
None => break,
|
||||
};
|
||||
|
||||
let bridge_rec = match server_bridges.iter_mut().find(|b| id == b.id) {
|
||||
Some(b) => b,
|
||||
None => break,
|
||||
};
|
||||
|
||||
if bridge_rec.write_queue.is_empty() {
|
||||
bridge_rec.bridge = Some(bridge);
|
||||
return;
|
||||
}
|
||||
|
||||
std::mem::swap(&mut bridge_rec.write_queue, &mut items_vec);
|
||||
}
|
||||
|
||||
for item in items_vec.drain(..) {
|
||||
if let Err(e) = bridge.write(item).await {
|
||||
warning!(log, "Error writing to server: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bridge.close().await.ok(); // got here from `break` above, meaning our record got cleared. Close the bridge if so
|
||||
}
|
||||
92
cli/src/tunnels/service.rs
Normal file
92
cli/src/tunnels/service.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::log;
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::util::errors::{wrap, AnyError};
|
||||
use crate::util::io::{tailf, TailEvent};
|
||||
|
||||
pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log";
|
||||
|
||||
#[async_trait]
|
||||
pub trait ServiceContainer: Send {
|
||||
async fn run_service(
|
||||
&mut self,
|
||||
log: log::Logger,
|
||||
launcher_paths: LauncherPaths,
|
||||
) -> Result<(), AnyError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ServiceManager {
|
||||
/// Registers the current executable as a service to run with the given set
|
||||
/// of arguments.
|
||||
async fn register(&self, exe: PathBuf, args: &[&str]) -> Result<(), AnyError>;
|
||||
|
||||
/// Runs the service using the given handle. The executable *must not* take
|
||||
/// any action which may fail prior to calling this to ensure service
|
||||
/// states may update.
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: LauncherPaths,
|
||||
handle: impl 'static + ServiceContainer,
|
||||
) -> Result<(), AnyError>;
|
||||
|
||||
/// Show logs from the running service to standard out.
|
||||
async fn show_logs(&self) -> Result<(), AnyError>;
|
||||
|
||||
/// Unregisters the current executable as a service.
|
||||
async fn unregister(&self) -> Result<(), AnyError>;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub type ServiceManagerImpl = super::service_windows::WindowsService;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub type ServiceManagerImpl = super::service_linux::SystemdService;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub type ServiceManagerImpl = super::service_macos::LaunchdService;
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
#[allow(unused_variables)]
|
||||
pub fn create_service_manager(log: log::Logger, paths: &LauncherPaths) -> ServiceManagerImpl {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
super::service_macos::LaunchdService::new(log, paths)
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
super::service_windows::WindowsService::new(log, paths)
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
super::service_linux::SystemdService::new(log, paths.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // unused on Linux
|
||||
pub(crate) async fn tail_log_file(log_file: &Path) -> Result<(), AnyError> {
|
||||
if !log_file.exists() {
|
||||
println!("The tunnel service has not started yet.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let file = std::fs::File::open(log_file).map_err(|e| wrap(e, "error opening log file"))?;
|
||||
let mut rx = tailf(file, 20);
|
||||
while let Some(line) = rx.recv().await {
|
||||
match line {
|
||||
TailEvent::Line(l) => print!("{}", l),
|
||||
TailEvent::Reset => println!("== Tunnel service restarted =="),
|
||||
TailEvent::Err(e) => return Err(wrap(e, "error reading log file").into()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
240
cli/src/tunnels/service_linux.rs
Normal file
240
cli/src/tunnels/service_linux.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{self, Write},
|
||||
path::PathBuf,
|
||||
process::Command,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zbus::{dbus_proxy, zvariant, Connection};
|
||||
|
||||
use crate::{
|
||||
constants::{APPLICATION_NAME, PRODUCT_NAME_LONG},
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
util::errors::{wrap, AnyError},
|
||||
};
|
||||
|
||||
use super::ServiceManager;
|
||||
|
||||
pub struct SystemdService {
|
||||
log: log::Logger,
|
||||
service_file: PathBuf,
|
||||
}
|
||||
|
||||
impl SystemdService {
|
||||
pub fn new(log: log::Logger, paths: LauncherPaths) -> Self {
|
||||
Self {
|
||||
log,
|
||||
service_file: paths.root().join(SystemdService::service_name_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SystemdService {
|
||||
async fn connect() -> Result<Connection, AnyError> {
|
||||
let connection = Connection::session()
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error creating dbus session"))?;
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
async fn proxy(connection: &Connection) -> Result<SystemdManagerDbusProxy<'_>, AnyError> {
|
||||
let proxy = SystemdManagerDbusProxy::new(connection)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
"error connecting to systemd, you may need to re-run with sudo:",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(proxy)
|
||||
}
|
||||
|
||||
fn service_path_string(&self) -> String {
|
||||
self.service_file.as_os_str().to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
fn service_name_string() -> String {
|
||||
format!("{}-tunnel.service", APPLICATION_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ServiceManager for SystemdService {
|
||||
async fn register(
|
||||
&self,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
let connection = SystemdService::connect().await?;
|
||||
let proxy = SystemdService::proxy(&connection).await?;
|
||||
|
||||
write_systemd_service_file(&self.service_file, exe, args)
|
||||
.map_err(|e| wrap(e, "error creating service file"))?;
|
||||
|
||||
proxy
|
||||
.link_unit_files(
|
||||
vec![self.service_path_string()],
|
||||
/* 'runtime only'= */ false,
|
||||
/* replace existing = */ true,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error registering service"))?;
|
||||
|
||||
info!(self.log, "Successfully registered service...");
|
||||
|
||||
// note: enablement is implicit in recent systemd version, but required for older systems
|
||||
// https://github.com/microsoft/vscode/issues/167489#issuecomment-1331222826
|
||||
proxy
|
||||
.enable_unit_files(
|
||||
vec![SystemdService::service_name_string()],
|
||||
/* 'runtime only'= */ false,
|
||||
/* replace existing = */ true,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error enabling unit files for service"))?;
|
||||
|
||||
info!(self.log, "Successfully enabled unit files...");
|
||||
|
||||
proxy
|
||||
.start_unit(SystemdService::service_name_string(), "replace".to_string())
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error starting service"))?;
|
||||
|
||||
info!(self.log, "Tunnel service successfully started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: crate::state::LauncherPaths,
|
||||
mut handle: impl 'static + super::ServiceContainer,
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
handle.run_service(self.log, launcher_paths).await
|
||||
}
|
||||
|
||||
async fn show_logs(&self) -> Result<(), AnyError> {
|
||||
// show the systemctl status header...
|
||||
Command::new("systemctl")
|
||||
.args([
|
||||
"--user",
|
||||
"status",
|
||||
"-n",
|
||||
"0",
|
||||
&SystemdService::service_name_string(),
|
||||
])
|
||||
.status()
|
||||
.map(|s| s.code().unwrap_or(1))
|
||||
.map_err(|e| wrap(e, "error running systemctl"))?;
|
||||
|
||||
// then follow log files
|
||||
Command::new("journalctl")
|
||||
.args(["--user", "-f", "-u", &SystemdService::service_name_string()])
|
||||
.status()
|
||||
.map(|s| s.code().unwrap_or(1))
|
||||
.map_err(|e| wrap(e, "error running journalctl"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister(&self) -> Result<(), crate::util::errors::AnyError> {
|
||||
let connection = SystemdService::connect().await?;
|
||||
let proxy = SystemdService::proxy(&connection).await?;
|
||||
|
||||
proxy
|
||||
.stop_unit(SystemdService::service_name_string(), "replace".to_string())
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error unregistering service"))?;
|
||||
|
||||
info!(self.log, "Successfully stopped service...");
|
||||
|
||||
proxy
|
||||
.disable_unit_files(
|
||||
vec![SystemdService::service_name_string()],
|
||||
/* 'runtime only'= */ false,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error unregistering service"))?;
|
||||
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn write_systemd_service_file(
|
||||
path: &PathBuf,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> io::Result<()> {
|
||||
let mut f = File::create(path)?;
|
||||
write!(
|
||||
&mut f,
|
||||
"[Unit]\n\
|
||||
Description={} Tunnel\n\
|
||||
After=network.target\n\
|
||||
StartLimitIntervalSec=0\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
Restart=always\n\
|
||||
RestartSec=10\n\
|
||||
ExecStart={} \"{}\"\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n\
|
||||
",
|
||||
PRODUCT_NAME_LONG,
|
||||
exe.into_os_string().to_string_lossy(),
|
||||
args.join("\" \"")
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Minimal implementation of systemd types for the services we need. The full
|
||||
/// definition can be found on any systemd machine with the command:
|
||||
///
|
||||
/// gdbus introspect --system --dest org.freedesktop.systemd1 --object-path /org/freedesktop/systemd1
|
||||
///
|
||||
/// See docs here: https://www.freedesktop.org/software/systemd/man/org.freedesktop.systemd1.html
|
||||
#[dbus_proxy(
|
||||
interface = "org.freedesktop.systemd1.Manager",
|
||||
gen_blocking = false,
|
||||
default_service = "org.freedesktop.systemd1",
|
||||
default_path = "/org/freedesktop/systemd1"
|
||||
)]
|
||||
trait SystemdManagerDbus {
|
||||
#[dbus_proxy(name = "EnableUnitFiles")]
|
||||
fn enable_unit_files(
|
||||
&self,
|
||||
files: Vec<String>,
|
||||
runtime: bool,
|
||||
force: bool,
|
||||
) -> zbus::Result<(bool, Vec<(String, String, String)>)>;
|
||||
|
||||
fn link_unit_files(
|
||||
&self,
|
||||
files: Vec<String>,
|
||||
runtime: bool,
|
||||
force: bool,
|
||||
) -> zbus::Result<Vec<(String, String, String)>>;
|
||||
|
||||
fn disable_unit_files(
|
||||
&self,
|
||||
files: Vec<String>,
|
||||
runtime: bool,
|
||||
) -> zbus::Result<Vec<(String, String, String)>>;
|
||||
|
||||
#[dbus_proxy(name = "StartUnit")]
|
||||
fn start_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
|
||||
|
||||
#[dbus_proxy(name = "StopUnit")]
|
||||
fn stop_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
|
||||
}
|
||||
163
cli/src/tunnels/service_macos.rs
Normal file
163
cli/src/tunnels/service_macos.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::{remove_file, File},
|
||||
io::{self, Write},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
constants::APPLICATION_NAME,
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
util::{
|
||||
command::capture_command_and_check_status,
|
||||
errors::{wrap, AnyError, CodeError, MissingHomeDirectory},
|
||||
},
|
||||
};
|
||||
|
||||
use super::{service::tail_log_file, ServiceManager};
|
||||
|
||||
pub struct LaunchdService {
|
||||
log: log::Logger,
|
||||
log_file: PathBuf,
|
||||
}
|
||||
|
||||
impl LaunchdService {
|
||||
pub fn new(log: log::Logger, paths: &LauncherPaths) -> Self {
|
||||
Self {
|
||||
log,
|
||||
log_file: paths.service_log_file(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ServiceManager for LaunchdService {
|
||||
async fn register(
|
||||
&self,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
let service_file = get_service_file_path()?;
|
||||
write_service_file(&service_file, &self.log_file, exe, args)
|
||||
.map_err(|e| wrap(e, "error creating service file"))?;
|
||||
|
||||
info!(self.log, "Successfully registered service...");
|
||||
|
||||
capture_command_and_check_status(
|
||||
"launchctl",
|
||||
&["load", service_file.as_os_str().to_string_lossy().as_ref()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
capture_command_and_check_status("launchctl", &["start", &get_service_label()]).await?;
|
||||
|
||||
info!(self.log, "Tunnel service successfully started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn show_logs(&self) -> Result<(), AnyError> {
|
||||
tail_log_file(&self.log_file).await
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: crate::state::LauncherPaths,
|
||||
mut handle: impl 'static + super::ServiceContainer,
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
handle.run_service(self.log, launcher_paths).await
|
||||
}
|
||||
|
||||
async fn unregister(&self) -> Result<(), crate::util::errors::AnyError> {
|
||||
let service_file = get_service_file_path()?;
|
||||
|
||||
match capture_command_and_check_status("launchctl", &["stop", &get_service_label()]).await {
|
||||
Ok(_) => {}
|
||||
// status 3 == "no such process"
|
||||
Err(CodeError::CommandFailed { code, .. }) if code == 3 => {}
|
||||
Err(e) => return Err(wrap(e, "error stopping service").into()),
|
||||
};
|
||||
|
||||
info!(self.log, "Successfully stopped service...");
|
||||
|
||||
capture_command_and_check_status(
|
||||
"launchctl",
|
||||
&[
|
||||
"unload",
|
||||
service_file.as_os_str().to_string_lossy().as_ref(),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
if let Ok(f) = get_service_file_path() {
|
||||
remove_file(f).ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_service_label() -> String {
|
||||
format!("com.visualstudio.{}.tunnel", APPLICATION_NAME)
|
||||
}
|
||||
|
||||
fn get_service_file_path() -> Result<PathBuf, MissingHomeDirectory> {
|
||||
match dirs::home_dir() {
|
||||
Some(mut d) => {
|
||||
d.push(format!("{}.plist", get_service_label()));
|
||||
Ok(d)
|
||||
}
|
||||
None => Err(MissingHomeDirectory()),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_service_file(
|
||||
path: &PathBuf,
|
||||
log_file: &Path,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> io::Result<()> {
|
||||
let mut f = File::create(path)?;
|
||||
let log_file = log_file.as_os_str().to_string_lossy();
|
||||
// todo: we may be able to skip file logging and use the ASL instead
|
||||
// if/when we no longer need to support older macOS versions.
|
||||
write!(
|
||||
&mut f,
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
|
||||
<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n\
|
||||
<plist version=\"1.0\">\n\
|
||||
<dict>\n\
|
||||
<key>Label</key>\n\
|
||||
<string>{}</string>\n\
|
||||
<key>LimitLoadToSessionType</key>\n\
|
||||
<string>Aqua</string>\n\
|
||||
<key>ProgramArguments</key>\n\
|
||||
<array>\n\
|
||||
<string>{}</string>\n\
|
||||
<string>{}</string>\n\
|
||||
</array>\n\
|
||||
<key>KeepAlive</key>\n\
|
||||
<true/>\n\
|
||||
<key>StandardErrorPath</key>\n\
|
||||
<string>{}</string>\n\
|
||||
<key>StandardOutPath</key>\n\
|
||||
<string>{}</string>\n\
|
||||
</dict>\n\
|
||||
</plist>",
|
||||
get_service_label(),
|
||||
exe.into_os_string().to_string_lossy(),
|
||||
args.join("</string><string>"),
|
||||
log_file,
|
||||
log_file
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
118
cli/src/tunnels/service_windows.rs
Normal file
118
cli/src/tunnels/service_windows.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use async_trait::async_trait;
|
||||
use shell_escape::windows::escape as shell_escape;
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
process::{Command, Stdio},
|
||||
};
|
||||
use winreg::{enums::HKEY_CURRENT_USER, RegKey};
|
||||
|
||||
use crate::{
|
||||
constants::TUNNEL_ACTIVITY_NAME,
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
tunnels::{protocol, singleton_client::do_single_rpc_call},
|
||||
util::errors::{wrap, wrapdbg, AnyError},
|
||||
};
|
||||
|
||||
use super::service::{tail_log_file, ServiceContainer, ServiceManager as CliServiceManager};
|
||||
|
||||
pub struct WindowsService {
|
||||
log: log::Logger,
|
||||
tunnel_lock: PathBuf,
|
||||
log_file: PathBuf,
|
||||
}
|
||||
|
||||
impl WindowsService {
|
||||
pub fn new(log: log::Logger, paths: &LauncherPaths) -> Self {
|
||||
Self {
|
||||
log,
|
||||
tunnel_lock: paths.tunnel_lockfile(),
|
||||
log_file: paths.service_log_file(),
|
||||
}
|
||||
}
|
||||
|
||||
fn open_key() -> Result<RegKey, AnyError> {
|
||||
RegKey::predef(HKEY_CURRENT_USER)
|
||||
.create_subkey(r"Software\Microsoft\Windows\CurrentVersion\Run")
|
||||
.map_err(|e| wrap(e, "error opening run registry key").into())
|
||||
.map(|(key, _)| key)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CliServiceManager for WindowsService {
|
||||
async fn register(&self, exe: std::path::PathBuf, args: &[&str]) -> Result<(), AnyError> {
|
||||
let key = WindowsService::open_key()?;
|
||||
|
||||
let mut reg_str = String::new();
|
||||
let mut cmd = Command::new(&exe);
|
||||
reg_str.push_str(shell_escape(exe.to_string_lossy()).as_ref());
|
||||
|
||||
let mut add_arg = |arg: &str| {
|
||||
reg_str.push(' ');
|
||||
reg_str.push_str(shell_escape((*arg).into()).as_ref());
|
||||
cmd.arg(arg);
|
||||
};
|
||||
|
||||
for arg in args {
|
||||
add_arg(arg);
|
||||
}
|
||||
|
||||
add_arg("--log-to-file");
|
||||
add_arg(self.log_file.to_string_lossy().as_ref());
|
||||
|
||||
key.set_value(TUNNEL_ACTIVITY_NAME, ®_str)
|
||||
.map_err(|e| AnyError::from(wrapdbg(e, "error setting registry key")))?;
|
||||
|
||||
info!(self.log, "Successfully registered service...");
|
||||
|
||||
cmd.stderr(Stdio::null());
|
||||
cmd.stdout(Stdio::null());
|
||||
cmd.stdin(Stdio::null());
|
||||
cmd.spawn()
|
||||
.map_err(|e| wrapdbg(e, "error starting service"))?;
|
||||
|
||||
info!(self.log, "Tunnel service successfully started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn show_logs(&self) -> Result<(), AnyError> {
|
||||
tail_log_file(&self.log_file).await
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: LauncherPaths,
|
||||
mut handle: impl 'static + ServiceContainer,
|
||||
) -> Result<(), AnyError> {
|
||||
handle.run_service(self.log, launcher_paths).await
|
||||
}
|
||||
|
||||
async fn unregister(&self) -> Result<(), AnyError> {
|
||||
let key = WindowsService::open_key()?;
|
||||
key.delete_value(TUNNEL_ACTIVITY_NAME)
|
||||
.map_err(|e| AnyError::from(wrap(e, "error deleting registry key")))?;
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
let r = do_single_rpc_call::<_, ()>(
|
||||
&self.tunnel_lock,
|
||||
self.log.clone(),
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
protocol::EmptyObject {},
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
warning!(self.log, "The tunnel service has been unregistered, but we couldn't find a running tunnel process. You may need to restart or log out and back in to fully stop the tunnel.");
|
||||
} else {
|
||||
info!(self.log, "Successfully shut down running tunnel.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
90
cli/src/tunnels/shutdown_signal.rs
Normal file
90
cli/src/tunnels/shutdown_signal.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use std::{fmt, path::PathBuf};
|
||||
use sysinfo::Pid;
|
||||
|
||||
use crate::util::{
|
||||
machine::{wait_until_exe_deleted, wait_until_process_exits},
|
||||
sync::{new_barrier, Barrier, Receivable},
|
||||
};
|
||||
|
||||
/// Describes the signal to manully stop the server
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ShutdownSignal {
|
||||
CtrlC,
|
||||
ParentProcessKilled(Pid),
|
||||
ExeUninstalled,
|
||||
ServiceStopped,
|
||||
RpcShutdownRequested,
|
||||
RpcRestartRequested,
|
||||
}
|
||||
|
||||
impl fmt::Display for ShutdownSignal {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
ShutdownSignal::CtrlC => write!(f, "Ctrl-C received"),
|
||||
ShutdownSignal::ParentProcessKilled(p) => {
|
||||
write!(f, "Parent process {} no longer exists", p)
|
||||
}
|
||||
ShutdownSignal::ExeUninstalled => {
|
||||
write!(f, "Executable no longer exists")
|
||||
}
|
||||
ShutdownSignal::ServiceStopped => write!(f, "Service stopped"),
|
||||
ShutdownSignal::RpcShutdownRequested => write!(f, "RPC client requested shutdown"),
|
||||
ShutdownSignal::RpcRestartRequested => {
|
||||
write!(f, "RPC client requested a tunnel restart")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ShutdownRequest {
|
||||
CtrlC,
|
||||
ParentProcessKilled(Pid),
|
||||
ExeUninstalled(PathBuf),
|
||||
Derived(Box<dyn Receivable<ShutdownSignal> + Send>),
|
||||
}
|
||||
|
||||
impl ShutdownRequest {
|
||||
async fn wait(self) -> Option<ShutdownSignal> {
|
||||
match self {
|
||||
ShutdownRequest::CtrlC => {
|
||||
let ctrl_c = tokio::signal::ctrl_c();
|
||||
ctrl_c.await.ok();
|
||||
Some(ShutdownSignal::CtrlC)
|
||||
}
|
||||
ShutdownRequest::ParentProcessKilled(pid) => {
|
||||
wait_until_process_exits(pid, 2000).await;
|
||||
Some(ShutdownSignal::ParentProcessKilled(pid))
|
||||
}
|
||||
ShutdownRequest::ExeUninstalled(exe_path) => {
|
||||
wait_until_exe_deleted(&exe_path, 2000).await;
|
||||
Some(ShutdownSignal::ExeUninstalled)
|
||||
}
|
||||
ShutdownRequest::Derived(mut rx) => rx.recv_msg().await,
|
||||
}
|
||||
}
|
||||
/// Creates a receiver channel sent to once any of the signals are received.
|
||||
/// Note: does not handle ServiceStopped
|
||||
pub fn create_rx(
|
||||
signals: impl IntoIterator<Item = ShutdownRequest>,
|
||||
) -> Barrier<ShutdownSignal> {
|
||||
let (barrier, opener) = new_barrier();
|
||||
let futures = signals
|
||||
.into_iter()
|
||||
.map(|s| s.wait())
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(s) = futures.filter_map(futures::future::ready).next().await {
|
||||
opener.open(s);
|
||||
}
|
||||
});
|
||||
|
||||
barrier
|
||||
}
|
||||
}
|
||||
190
cli/src/tunnels/singleton_client.rs
Normal file
190
cli/src/tunnels/singleton_client.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
use const_format::concatcp;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
async_pipe::{socket_stream_split, AsyncPipe},
|
||||
constants::IS_INTERACTIVE_CLI,
|
||||
json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer},
|
||||
log,
|
||||
rpc::RpcCaller,
|
||||
singleton::connect_as_client,
|
||||
tunnels::{code_server::print_listening, protocol::EmptyObject},
|
||||
util::{errors::CodeError, sync::Barrier},
|
||||
};
|
||||
|
||||
use super::{
|
||||
protocol,
|
||||
shutdown_signal::{ShutdownRequest, ShutdownSignal},
|
||||
};
|
||||
|
||||
pub struct SingletonClientArgs {
|
||||
pub log: log::Logger,
|
||||
pub stream: AsyncPipe,
|
||||
pub shutdown: Barrier<ShutdownSignal>,
|
||||
}
|
||||
|
||||
struct SingletonServerContext {
|
||||
log: log::Logger,
|
||||
exit_entirely: Arc<AtomicBool>,
|
||||
caller: RpcCaller<JsonRpcSerializer>,
|
||||
}
|
||||
|
||||
const CONTROL_INSTRUCTIONS_COMMON: &str =
|
||||
"Connected to an existing tunnel process running on this machine.";
|
||||
|
||||
const CONTROL_INSTRUCTIONS_INTERACTIVE: &str = concatcp!(
|
||||
CONTROL_INSTRUCTIONS_COMMON,
|
||||
" You can press:
|
||||
|
||||
- \"x\" + Enter to stop the tunnel and exit
|
||||
- \"r\" + Enter to restart the tunnel
|
||||
- Ctrl+C to detach
|
||||
"
|
||||
);
|
||||
|
||||
/// Serves a client singleton. Returns true if the process should exit after
|
||||
/// this returns, instead of trying to start a tunnel.
|
||||
pub async fn start_singleton_client(args: SingletonClientArgs) -> bool {
|
||||
let mut rpc = new_json_rpc();
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let exit_entirely = Arc::new(AtomicBool::new(false));
|
||||
|
||||
debug!(
|
||||
args.log,
|
||||
"An existing tunnel is running on this machine, connecting to it..."
|
||||
);
|
||||
|
||||
if *IS_INTERACTIVE_CLI {
|
||||
let stdin_handle = rpc.get_caller(msg_tx.clone());
|
||||
thread::spawn(move || {
|
||||
let mut input = String::new();
|
||||
loop {
|
||||
input.truncate(0);
|
||||
match std::io::stdin().read_line(&mut input) {
|
||||
Err(_) | Ok(0) => return, // EOF or not a tty
|
||||
_ => {}
|
||||
};
|
||||
|
||||
match input.chars().next().map(|c| c.to_ascii_lowercase()) {
|
||||
Some('x') => {
|
||||
stdin_handle.notify(protocol::singleton::METHOD_SHUTDOWN, EmptyObject {});
|
||||
return;
|
||||
}
|
||||
Some('r') => {
|
||||
stdin_handle.notify(protocol::singleton::METHOD_RESTART, EmptyObject {});
|
||||
}
|
||||
Some(_) | None => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
let mut rpc = rpc.methods(SingletonServerContext {
|
||||
log: args.log.clone(),
|
||||
exit_entirely: exit_entirely.clone(),
|
||||
caller,
|
||||
});
|
||||
|
||||
rpc.register_sync(protocol::singleton::METHOD_SHUTDOWN, |_: EmptyObject, c| {
|
||||
c.exit_entirely.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
rpc.register_async(
|
||||
protocol::singleton::METHOD_LOG_REPLY_DONE,
|
||||
|_: EmptyObject, c| async move {
|
||||
c.log.result(if *IS_INTERACTIVE_CLI {
|
||||
CONTROL_INSTRUCTIONS_INTERACTIVE
|
||||
} else {
|
||||
CONTROL_INSTRUCTIONS_COMMON
|
||||
});
|
||||
|
||||
let res = c.caller.call::<_, _, protocol::singleton::Status>(
|
||||
protocol::singleton::METHOD_STATUS,
|
||||
protocol::EmptyObject {},
|
||||
);
|
||||
|
||||
// we want to ensure the "listening" string always gets printed for
|
||||
// consumers (i.e. VS Code). Ask for it. If the tunnel is not currently
|
||||
// connected though, it will be soon, and that'll be in the log replays.
|
||||
if let Ok(Ok(s)) = res.await {
|
||||
if let protocol::singleton::TunnelState::Connected { name } = s.tunnel {
|
||||
print_listening(&c.log, &name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_LOG,
|
||||
|log: protocol::singleton::LogMessageOwned, c| {
|
||||
match log.level {
|
||||
Some(level) => c.log.emit(level, &format!("{}{}", log.prefix, log.message)),
|
||||
None => c.log.result(format!("{}{}", log.prefix, log.message)),
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
let (read, write) = socket_stream_split(args.stream);
|
||||
let _ = start_json_rpc(rpc.build(args.log), read, write, msg_rx, args.shutdown).await;
|
||||
|
||||
exit_entirely.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub async fn do_single_rpc_call<
|
||||
P: serde::Serialize + 'static,
|
||||
R: serde::de::DeserializeOwned + Send + 'static,
|
||||
>(
|
||||
lock_file: &Path,
|
||||
log: log::Logger,
|
||||
method: &'static str,
|
||||
params: P,
|
||||
) -> Result<R, CodeError> {
|
||||
let client = match connect_as_client(lock_file).await {
|
||||
Ok(p) => p,
|
||||
Err(CodeError::SingletonLockfileOpenFailed(_))
|
||||
| Err(CodeError::SingletonLockedProcessExited(_)) => {
|
||||
return Err(CodeError::NoRunningTunnel);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let mut rpc = new_json_rpc();
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
let (read, write) = socket_stream_split(client);
|
||||
|
||||
let rpc = tokio::spawn(async move {
|
||||
start_json_rpc(
|
||||
rpc.methods(()).build(log),
|
||||
read,
|
||||
write,
|
||||
msg_rx,
|
||||
ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let r = caller.call(method, params).await.unwrap();
|
||||
rpc.abort();
|
||||
r.map_err(CodeError::TunnelRpcCallFailed)
|
||||
}
|
||||
263
cli/src/tunnels/singleton_server.rs
Normal file
263
cli/src/tunnels/singleton_server.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use super::{
|
||||
code_server::CodeServerArgs,
|
||||
control_server::ServerTermination,
|
||||
dev_tunnels::ActiveTunnel,
|
||||
protocol,
|
||||
shutdown_signal::{ShutdownRequest, ShutdownSignal},
|
||||
};
|
||||
use crate::{
|
||||
async_pipe::socket_stream_split,
|
||||
json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer},
|
||||
log,
|
||||
rpc::{RpcCaller, RpcDispatcher},
|
||||
singleton::SingletonServer,
|
||||
state::LauncherPaths,
|
||||
tunnels::code_server::print_listening,
|
||||
update_service::Platform,
|
||||
util::{
|
||||
errors::{AnyError, CodeError},
|
||||
ring_buffer::RingBuffer,
|
||||
sync::{Barrier, ConcatReceivable},
|
||||
},
|
||||
};
|
||||
use futures::future::Either;
|
||||
use tokio::{
|
||||
pin,
|
||||
sync::{broadcast, mpsc},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
pub struct SingletonServerArgs<'a> {
|
||||
pub server: &'a mut RpcServer,
|
||||
pub log: log::Logger,
|
||||
pub tunnel: ActiveTunnel,
|
||||
pub paths: &'a LauncherPaths,
|
||||
pub code_server_args: &'a CodeServerArgs,
|
||||
pub platform: Platform,
|
||||
pub shutdown: Barrier<ShutdownSignal>,
|
||||
pub log_broadcast: &'a BroadcastLogSink,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SingletonServerContext {
|
||||
log: log::Logger,
|
||||
shutdown_tx: broadcast::Sender<ShutdownSignal>,
|
||||
broadcast_tx: broadcast::Sender<Vec<u8>>,
|
||||
current_name: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
pub struct RpcServer {
|
||||
fut: JoinHandle<Result<(), CodeError>>,
|
||||
shutdown_broadcast: broadcast::Sender<ShutdownSignal>,
|
||||
current_name: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
pub fn make_singleton_server(
|
||||
log_broadcast: BroadcastLogSink,
|
||||
log: log::Logger,
|
||||
server: SingletonServer,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> RpcServer {
|
||||
let (shutdown_broadcast, _) = broadcast::channel(4);
|
||||
let rpc = new_json_rpc();
|
||||
|
||||
let current_name = Arc::new(Mutex::new(None));
|
||||
let mut rpc = rpc.methods(SingletonServerContext {
|
||||
log: log.clone(),
|
||||
shutdown_tx: shutdown_broadcast.clone(),
|
||||
broadcast_tx: log_broadcast.get_brocaster(),
|
||||
current_name: current_name.clone(),
|
||||
});
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_RESTART,
|
||||
|_: protocol::EmptyObject, ctx| {
|
||||
info!(ctx.log, "restarting tunnel after client request");
|
||||
let _ = ctx.shutdown_tx.send(ShutdownSignal::RpcRestartRequested);
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_STATUS,
|
||||
|_: protocol::EmptyObject, c| {
|
||||
Ok(protocol::singleton::Status {
|
||||
tunnel: match c.current_name.lock().unwrap().clone() {
|
||||
Some(name) => protocol::singleton::TunnelState::Connected { name },
|
||||
None => protocol::singleton::TunnelState::Disconnected,
|
||||
},
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
|_: protocol::EmptyObject, ctx| {
|
||||
info!(
|
||||
ctx.log,
|
||||
"closing tunnel and all clients after a shutdown request"
|
||||
);
|
||||
let _ = ctx.broadcast_tx.send(RpcCaller::serialize_notify(
|
||||
&JsonRpcSerializer {},
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
protocol::EmptyObject {},
|
||||
));
|
||||
let _ = ctx.shutdown_tx.send(ShutdownSignal::RpcShutdownRequested);
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
// we tokio spawn instead of keeping a future, since we want it to progress
|
||||
// even outside of the start_singleton_server loop (i.e. while the tunnel restarts)
|
||||
let fut = tokio::spawn(async move {
|
||||
serve_singleton_rpc(log_broadcast, server, rpc.build(log), shutdown_rx).await
|
||||
});
|
||||
RpcServer {
|
||||
shutdown_broadcast,
|
||||
current_name,
|
||||
fut,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_singleton_server<'a>(
|
||||
args: SingletonServerArgs<'_>,
|
||||
) -> Result<ServerTermination, AnyError> {
|
||||
let shutdown_rx = ShutdownRequest::create_rx([
|
||||
ShutdownRequest::Derived(Box::new(args.server.shutdown_broadcast.subscribe())),
|
||||
ShutdownRequest::Derived(Box::new(args.shutdown.clone())),
|
||||
]);
|
||||
|
||||
{
|
||||
print_listening(&args.log, &args.tunnel.name);
|
||||
let mut name = args.server.current_name.lock().unwrap();
|
||||
*name = Some(args.tunnel.name.clone())
|
||||
}
|
||||
|
||||
let serve_fut = super::serve(
|
||||
&args.log,
|
||||
args.tunnel,
|
||||
args.paths,
|
||||
args.code_server_args,
|
||||
args.platform,
|
||||
shutdown_rx,
|
||||
);
|
||||
|
||||
pin!(serve_fut);
|
||||
|
||||
match futures::future::select(Pin::new(&mut args.server.fut), &mut serve_fut).await {
|
||||
Either::Left((rpc_result, fut)) => {
|
||||
// the rpc server will only end as a result of a graceful shutdown, or
|
||||
// with an error. Return the result of the eventual shutdown of the
|
||||
// control server.
|
||||
rpc_result.unwrap()?;
|
||||
fut.await
|
||||
}
|
||||
Either::Right((ctrl_result, _)) => ctrl_result,
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_singleton_rpc<C: Clone + Send + Sync + 'static>(
|
||||
log_broadcast: BroadcastLogSink,
|
||||
mut server: SingletonServer,
|
||||
dispatcher: RpcDispatcher<JsonRpcSerializer, C>,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> Result<(), CodeError> {
|
||||
let mut own_shutdown = shutdown_rx.clone();
|
||||
let shutdown_fut = own_shutdown.wait();
|
||||
pin!(shutdown_fut);
|
||||
|
||||
loop {
|
||||
let cnx = tokio::select! {
|
||||
c = server.accept() => c?,
|
||||
_ = &mut shutdown_fut => return Ok(()),
|
||||
};
|
||||
|
||||
let (read, write) = socket_stream_split(cnx);
|
||||
let dispatcher = dispatcher.clone();
|
||||
let msg_rx = log_broadcast.replay_and_subscribe();
|
||||
let shutdown_rx = shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = start_json_rpc(dispatcher.clone(), read, write, msg_rx, shutdown_rx).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Log sink that can broadcast and replay log events. Used for transmitting
|
||||
/// logs from the singleton to all clients. This should be created and injected
|
||||
/// into other services, like the tunnel, before `start_singleton_server`
|
||||
/// is called.
|
||||
#[derive(Clone)]
|
||||
pub struct BroadcastLogSink {
|
||||
recent: Arc<Mutex<RingBuffer<Vec<u8>>>>,
|
||||
tx: broadcast::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl Default for BroadcastLogSink {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl BroadcastLogSink {
|
||||
pub fn new() -> Self {
|
||||
let (tx, _) = broadcast::channel(64);
|
||||
Self {
|
||||
tx,
|
||||
recent: Arc::new(Mutex::new(RingBuffer::new(50))),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_brocaster(&self) -> broadcast::Sender<Vec<u8>> {
|
||||
self.tx.clone()
|
||||
}
|
||||
|
||||
fn replay_and_subscribe(
|
||||
&self,
|
||||
) -> ConcatReceivable<Vec<u8>, mpsc::UnboundedReceiver<Vec<u8>>, broadcast::Receiver<Vec<u8>>> {
|
||||
let (log_replay_tx, log_replay_rx) = mpsc::unbounded_channel();
|
||||
|
||||
for log in self.recent.lock().unwrap().iter() {
|
||||
let _ = log_replay_tx.send(log.clone());
|
||||
}
|
||||
|
||||
let _ = log_replay_tx.send(RpcCaller::serialize_notify(
|
||||
&JsonRpcSerializer {},
|
||||
protocol::singleton::METHOD_LOG_REPLY_DONE,
|
||||
protocol::EmptyObject {},
|
||||
));
|
||||
|
||||
ConcatReceivable::new(log_replay_rx, self.tx.subscribe())
|
||||
}
|
||||
}
|
||||
|
||||
impl log::LogSink for BroadcastLogSink {
|
||||
fn write_log(&self, level: log::Level, prefix: &str, message: &str) {
|
||||
let s = JsonRpcSerializer {};
|
||||
let serialized = RpcCaller::serialize_notify(
|
||||
&s,
|
||||
protocol::singleton::METHOD_LOG,
|
||||
protocol::singleton::LogMessage {
|
||||
level: Some(level),
|
||||
prefix,
|
||||
message,
|
||||
},
|
||||
);
|
||||
|
||||
let _ = self.tx.send(serialized.clone());
|
||||
self.recent.lock().unwrap().push(serialized);
|
||||
}
|
||||
|
||||
fn write_result(&self, message: &str) {
|
||||
self.write_log(log::Level::Info, "", message);
|
||||
}
|
||||
}
|
||||
290
cli/src/tunnels/socket_signal.rs
Normal file
290
cli/src/tunnels/socket_signal.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::msgpack_rpc::MsgPackCaller;
|
||||
|
||||
use super::{
|
||||
protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest},
|
||||
server_multiplexer::ServerMultiplexer,
|
||||
};
|
||||
|
||||
pub struct CloseReason(pub String);
|
||||
|
||||
pub enum SocketSignal {
|
||||
/// Signals bytes to send to the socket.
|
||||
Send(Vec<u8>),
|
||||
/// Closes the socket (e.g. as a result of an error)
|
||||
CloseWith(CloseReason),
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for SocketSignal {
|
||||
fn from(v: Vec<u8>) -> Self {
|
||||
SocketSignal::Send(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl SocketSignal {
|
||||
pub fn from_message<T>(msg: &T) -> Self
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// todo@connor4312: cleanup once everything is moved to rpc standard interfaces
|
||||
#[allow(dead_code)]
|
||||
pub enum ServerMessageDestination {
|
||||
Channel(mpsc::Sender<SocketSignal>),
|
||||
Rpc(MsgPackCaller),
|
||||
}
|
||||
|
||||
/// Struct that handling sending or closing a connected server socket.
|
||||
pub struct ServerMessageSink {
|
||||
id: u16,
|
||||
tx: Option<ServerMessageDestination>,
|
||||
multiplexer: ServerMultiplexer,
|
||||
flate: Option<FlateStream<CompressFlateAlgorithm>>,
|
||||
}
|
||||
|
||||
impl ServerMessageSink {
|
||||
pub fn new_plain(
|
||||
multiplexer: ServerMultiplexer,
|
||||
id: u16,
|
||||
tx: ServerMessageDestination,
|
||||
) -> Self {
|
||||
Self {
|
||||
tx: Some(tx),
|
||||
id,
|
||||
multiplexer,
|
||||
flate: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_compressed(
|
||||
multiplexer: ServerMultiplexer,
|
||||
id: u16,
|
||||
tx: ServerMessageDestination,
|
||||
) -> Self {
|
||||
Self {
|
||||
tx: Some(tx),
|
||||
id,
|
||||
multiplexer,
|
||||
flate: Some(FlateStream::new(CompressFlateAlgorithm(
|
||||
flate2::Compress::new(flate2::Compression::new(2), false),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn server_message(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
|
||||
let id = self.id;
|
||||
let mut tx = self.tx.take().unwrap();
|
||||
let body = self.get_server_msg_content(body);
|
||||
let msg = RefServerMessageParams { i: id, body };
|
||||
|
||||
let r = match &mut tx {
|
||||
ServerMessageDestination::Channel(tx) => {
|
||||
tx.send(SocketSignal::from_message(&ToClientRequest {
|
||||
id: None,
|
||||
params: ClientRequestMethod::servermsg(msg),
|
||||
}))
|
||||
.await
|
||||
}
|
||||
ServerMessageDestination::Rpc(caller) => {
|
||||
caller.notify("servermsg", msg);
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
|
||||
self.tx = Some(tx);
|
||||
r
|
||||
}
|
||||
|
||||
pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] {
|
||||
if let Some(flate) = &mut self.flate {
|
||||
if let Ok(compressed) = flate.process(body) {
|
||||
return compressed;
|
||||
}
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerMessageSink {
|
||||
fn drop(&mut self) {
|
||||
self.multiplexer.remove(self.id);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientMessageDecoder {
|
||||
dec: Option<FlateStream<DecompressFlateAlgorithm>>,
|
||||
}
|
||||
|
||||
impl ClientMessageDecoder {
|
||||
pub fn new_plain() -> Self {
|
||||
ClientMessageDecoder { dec: None }
|
||||
}
|
||||
|
||||
pub fn new_compressed() -> Self {
|
||||
ClientMessageDecoder {
|
||||
dec: Some(FlateStream::new(DecompressFlateAlgorithm(
|
||||
flate2::Decompress::new(false),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> {
|
||||
match &mut self.dec {
|
||||
Some(d) => d.process(message),
|
||||
None => Ok(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait FlateAlgorithm {
|
||||
fn total_in(&self) -> u64;
|
||||
fn total_out(&self) -> u64;
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error>;
|
||||
}
|
||||
|
||||
struct DecompressFlateAlgorithm(flate2::Decompress);
|
||||
|
||||
impl FlateAlgorithm for DecompressFlateAlgorithm {
|
||||
fn total_in(&self) -> u64 {
|
||||
self.0.total_in()
|
||||
}
|
||||
|
||||
fn total_out(&self) -> u64 {
|
||||
self.0.total_out()
|
||||
}
|
||||
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error> {
|
||||
self.0
|
||||
.decompress(contents, output, flate2::FlushDecompress::None)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
||||
struct CompressFlateAlgorithm(flate2::Compress);
|
||||
|
||||
impl FlateAlgorithm for CompressFlateAlgorithm {
|
||||
fn total_in(&self) -> u64 {
|
||||
self.0.total_in()
|
||||
}
|
||||
|
||||
fn total_out(&self) -> u64 {
|
||||
self.0.total_out()
|
||||
}
|
||||
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error> {
|
||||
self.0
|
||||
.compress(contents, output, flate2::FlushCompress::Sync)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
||||
struct FlateStream<A>
|
||||
where
|
||||
A: FlateAlgorithm,
|
||||
{
|
||||
flate: A,
|
||||
output: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<A> FlateStream<A>
|
||||
where
|
||||
A: FlateAlgorithm,
|
||||
{
|
||||
pub fn new(alg: A) -> Self {
|
||||
Self {
|
||||
flate: alg,
|
||||
output: vec![0; 4096],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> {
|
||||
let mut out_offset = 0;
|
||||
let mut in_offset = 0;
|
||||
loop {
|
||||
let in_before = self.flate.total_in();
|
||||
let out_before = self.flate.total_out();
|
||||
|
||||
match self
|
||||
.flate
|
||||
.process(&contents[in_offset..], &mut self.output[out_offset..])
|
||||
{
|
||||
Ok(flate2::Status::Ok | flate2::Status::BufError) => {
|
||||
let processed_len = in_offset + (self.flate.total_in() - in_before) as usize;
|
||||
let output_len = out_offset + (self.flate.total_out() - out_before) as usize;
|
||||
if processed_len < contents.len() {
|
||||
// If we filled the output buffer but there's more data to compress,
|
||||
// extend the output buffer and keep compressing.
|
||||
out_offset = output_len;
|
||||
in_offset = processed_len;
|
||||
if output_len == self.output.len() {
|
||||
self.output.resize(self.output.len() * 2, 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok(&self.output[..output_len]);
|
||||
}
|
||||
Ok(flate2::Status::StreamEnd) => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"unexpected stream end",
|
||||
))
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// Note this useful idiom: importing names from outer (for mod tests) scope.
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_round_trips_compression() {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
let mut sink = ServerMessageSink::new_compressed(
|
||||
ServerMultiplexer::new(),
|
||||
0,
|
||||
ServerMessageDestination::Channel(tx),
|
||||
);
|
||||
let mut decompress = ClientMessageDecoder::new_compressed();
|
||||
|
||||
// 3000 and 30000 test resizing the buffer
|
||||
for msg_len in [3, 30, 300, 3000, 30000] {
|
||||
let vals = (0..msg_len).map(|v| v as u8).collect::<Vec<u8>>();
|
||||
let compressed = sink.get_server_msg_content(&vals);
|
||||
assert_ne!(compressed, vals);
|
||||
let decompressed = decompress.decode(compressed).unwrap();
|
||||
assert_eq!(decompressed.len(), vals.len());
|
||||
assert_eq!(decompressed, vals);
|
||||
}
|
||||
}
|
||||
}
|
||||
319
cli/src/update_service.rs
Normal file
319
cli/src/update_service.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{ffi::OsStr, fmt, path::Path};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
constants::VSCODE_CLI_UPDATE_ENDPOINT,
|
||||
debug, log, options, spanf,
|
||||
util::{
|
||||
errors::{AnyError, CodeError, UpdatesNotConfigured, WrappedError},
|
||||
http::{BoxedHttp, SimpleResponse},
|
||||
io::ReportCopyProgress,
|
||||
tar, zipper,
|
||||
},
|
||||
};
|
||||
|
||||
/// Implementation of the VS Code Update service for use in the CLI.
|
||||
pub struct UpdateService {
|
||||
client: BoxedHttp,
|
||||
log: log::Logger,
|
||||
}
|
||||
|
||||
/// Describes a specific release, can be created manually or returned from the update service.
|
||||
#[derive(Clone, Eq, PartialEq)]
|
||||
pub struct Release {
|
||||
pub name: String,
|
||||
pub platform: Platform,
|
||||
pub target: TargetKind,
|
||||
pub quality: options::Quality,
|
||||
pub commit: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Release {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{} (commit {})", self.name, self.commit)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UpdateServerVersion {
|
||||
pub version: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
fn quality_download_segment(quality: options::Quality) -> &'static str {
|
||||
match quality {
|
||||
options::Quality::Stable => "stable",
|
||||
options::Quality::Insiders => "insider",
|
||||
options::Quality::Exploration => "exploration",
|
||||
}
|
||||
}
|
||||
|
||||
impl UpdateService {
|
||||
pub fn new(log: log::Logger, http: BoxedHttp) -> Self {
|
||||
UpdateService { client: http, log }
|
||||
}
|
||||
|
||||
pub async fn get_release_by_semver_version(
|
||||
&self,
|
||||
platform: Platform,
|
||||
target: TargetKind,
|
||||
quality: options::Quality,
|
||||
version: &str,
|
||||
) -> Result<Release, AnyError> {
|
||||
let update_endpoint =
|
||||
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
|
||||
let download_segment = target
|
||||
.download_segment(platform)
|
||||
.ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?;
|
||||
let download_url = format!(
|
||||
"{}/api/versions/{}/{}/{}",
|
||||
update_endpoint,
|
||||
version,
|
||||
download_segment,
|
||||
quality_download_segment(quality),
|
||||
);
|
||||
|
||||
let mut response = spanf!(
|
||||
self.log,
|
||||
self.log.span("server.version.resolve"),
|
||||
self.client.make_request("GET", download_url)
|
||||
)?;
|
||||
|
||||
if !response.status_code.is_success() {
|
||||
return Err(response.into_err().await.into());
|
||||
}
|
||||
|
||||
let res = response.json::<UpdateServerVersion>().await?;
|
||||
debug!(self.log, "Resolved version {} to {}", version, res.version);
|
||||
|
||||
Ok(Release {
|
||||
target,
|
||||
platform,
|
||||
quality,
|
||||
name: res.name,
|
||||
commit: res.version,
|
||||
})
|
||||
}
|
||||
|
||||
/// Gets the latest commit for the target of the given quality.
|
||||
pub async fn get_latest_commit(
|
||||
&self,
|
||||
platform: Platform,
|
||||
target: TargetKind,
|
||||
quality: options::Quality,
|
||||
) -> Result<Release, AnyError> {
|
||||
let update_endpoint =
|
||||
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
|
||||
let download_segment = target
|
||||
.download_segment(platform)
|
||||
.ok_or_else(|| CodeError::UnsupportedPlatform(platform.to_string()))?;
|
||||
let download_url = format!(
|
||||
"{}/api/latest/{}/{}",
|
||||
update_endpoint,
|
||||
download_segment,
|
||||
quality_download_segment(quality),
|
||||
);
|
||||
|
||||
let mut response = spanf!(
|
||||
self.log,
|
||||
self.log.span("server.version.resolve"),
|
||||
self.client.make_request("GET", download_url)
|
||||
)?;
|
||||
|
||||
if !response.status_code.is_success() {
|
||||
return Err(response.into_err().await.into());
|
||||
}
|
||||
|
||||
let res = response.json::<UpdateServerVersion>().await?;
|
||||
debug!(self.log, "Resolved quality {} to {}", quality, res.version);
|
||||
|
||||
Ok(Release {
|
||||
target,
|
||||
platform,
|
||||
quality,
|
||||
name: res.name,
|
||||
commit: res.version,
|
||||
})
|
||||
}
|
||||
|
||||
/// Gets the download stream for the release.
|
||||
pub async fn get_download_stream(&self, release: &Release) -> Result<SimpleResponse, AnyError> {
|
||||
let update_endpoint =
|
||||
VSCODE_CLI_UPDATE_ENDPOINT.ok_or_else(UpdatesNotConfigured::no_url)?;
|
||||
let download_segment = release
|
||||
.target
|
||||
.download_segment(release.platform)
|
||||
.ok_or_else(|| CodeError::UnsupportedPlatform(release.platform.to_string()))?;
|
||||
|
||||
let download_url = format!(
|
||||
"{}/commit:{}/{}/{}",
|
||||
update_endpoint,
|
||||
release.commit,
|
||||
download_segment,
|
||||
quality_download_segment(release.quality),
|
||||
);
|
||||
|
||||
let response = self.client.make_request("GET", download_url).await?;
|
||||
if !response.status_code.is_success() {
|
||||
return Err(response.into_err().await.into());
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unzip_downloaded_release<T>(
|
||||
compressed_file: &Path,
|
||||
target_dir: &Path,
|
||||
reporter: T,
|
||||
) -> Result<(), WrappedError>
|
||||
where
|
||||
T: ReportCopyProgress,
|
||||
{
|
||||
if compressed_file.extension() == Some(OsStr::new("zip")) {
|
||||
zipper::unzip_file(compressed_file, target_dir, reporter)
|
||||
} else {
|
||||
tar::decompress_tarball(compressed_file, target_dir, reporter)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq, Copy, Clone)]
|
||||
pub enum TargetKind {
|
||||
Server,
|
||||
Archive,
|
||||
Web,
|
||||
Cli,
|
||||
}
|
||||
|
||||
impl TargetKind {
|
||||
fn download_segment(&self, platform: Platform) -> Option<String> {
|
||||
match *self {
|
||||
TargetKind::Server => Some(platform.headless()),
|
||||
TargetKind::Archive => platform.archive(),
|
||||
TargetKind::Web => Some(platform.web()),
|
||||
TargetKind::Cli => Some(platform.cli()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Platform {
|
||||
LinuxAlpineX64,
|
||||
LinuxAlpineARM64,
|
||||
LinuxX64,
|
||||
LinuxARM64,
|
||||
LinuxARM32,
|
||||
DarwinX64,
|
||||
DarwinARM64,
|
||||
WindowsX64,
|
||||
WindowsX86,
|
||||
WindowsARM64,
|
||||
}
|
||||
|
||||
impl Platform {
|
||||
pub fn archive(&self) -> Option<String> {
|
||||
match self {
|
||||
Platform::LinuxX64 => Some("linux-x64".to_owned()),
|
||||
Platform::LinuxARM64 => Some("linux-arm64".to_owned()),
|
||||
Platform::LinuxARM32 => Some("linux-armhf".to_owned()),
|
||||
Platform::DarwinX64 => Some("darwin".to_owned()),
|
||||
Platform::DarwinARM64 => Some("darwin-arm64".to_owned()),
|
||||
Platform::WindowsX64 => Some("win32-x64-archive".to_owned()),
|
||||
Platform::WindowsX86 => Some("win32-archive".to_owned()),
|
||||
Platform::WindowsARM64 => Some("win32-arm64-archive".to_owned()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub fn headless(&self) -> String {
|
||||
match self {
|
||||
Platform::LinuxAlpineARM64 => "server-alpine-arm64",
|
||||
Platform::LinuxAlpineX64 => "server-linux-alpine",
|
||||
Platform::LinuxX64 => "server-linux-x64",
|
||||
Platform::LinuxARM64 => "server-linux-arm64",
|
||||
Platform::LinuxARM32 => "server-linux-armhf",
|
||||
Platform::DarwinX64 => "server-darwin",
|
||||
Platform::DarwinARM64 => "server-darwin-arm64",
|
||||
Platform::WindowsX64 => "server-win32-x64",
|
||||
Platform::WindowsX86 => "server-win32",
|
||||
Platform::WindowsARM64 => "server-win32-x64", // we don't publish an arm64 server build yet
|
||||
}
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
pub fn cli(&self) -> String {
|
||||
match self {
|
||||
Platform::LinuxAlpineARM64 => "cli-alpine-arm64",
|
||||
Platform::LinuxAlpineX64 => "cli-alpine-x64",
|
||||
Platform::LinuxX64 => "cli-linux-x64",
|
||||
Platform::LinuxARM64 => "cli-linux-arm64",
|
||||
Platform::LinuxARM32 => "cli-linux-armhf",
|
||||
Platform::DarwinX64 => "cli-darwin-x64",
|
||||
Platform::DarwinARM64 => "cli-darwin-arm64",
|
||||
Platform::WindowsARM64 => "cli-win32-arm64",
|
||||
Platform::WindowsX64 => "cli-win32-x64",
|
||||
Platform::WindowsX86 => "cli-win32",
|
||||
}
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
pub fn web(&self) -> String {
|
||||
format!("{}-web", self.headless())
|
||||
}
|
||||
|
||||
pub fn env_default() -> Option<Platform> {
|
||||
if cfg!(all(
|
||||
target_os = "linux",
|
||||
target_arch = "x86_64",
|
||||
target_env = "musl"
|
||||
)) {
|
||||
Some(Platform::LinuxAlpineX64)
|
||||
} else if cfg!(all(
|
||||
target_os = "linux",
|
||||
target_arch = "aarch64",
|
||||
target_env = "musl"
|
||||
)) {
|
||||
Some(Platform::LinuxAlpineARM64)
|
||||
} else if cfg!(all(target_os = "linux", target_arch = "x86_64")) {
|
||||
Some(Platform::LinuxX64)
|
||||
} else if cfg!(all(target_os = "linux", target_arch = "arm")) {
|
||||
Some(Platform::LinuxARM32)
|
||||
} else if cfg!(all(target_os = "linux", target_arch = "aarch64")) {
|
||||
Some(Platform::LinuxARM64)
|
||||
} else if cfg!(all(target_os = "macos", target_arch = "x86_64")) {
|
||||
Some(Platform::DarwinX64)
|
||||
} else if cfg!(all(target_os = "macos", target_arch = "aarch64")) {
|
||||
Some(Platform::DarwinARM64)
|
||||
} else if cfg!(all(target_os = "windows", target_arch = "x86_64")) {
|
||||
Some(Platform::WindowsX64)
|
||||
} else if cfg!(all(target_os = "windows", target_arch = "x86")) {
|
||||
Some(Platform::WindowsX86)
|
||||
} else if cfg!(all(target_os = "windows", target_arch = "aarch64")) {
|
||||
Some(Platform::WindowsARM64)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Platform {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Platform::LinuxAlpineARM64 => "LinuxAlpineARM64",
|
||||
Platform::LinuxAlpineX64 => "LinuxAlpineX64",
|
||||
Platform::LinuxX64 => "LinuxX64",
|
||||
Platform::LinuxARM64 => "LinuxARM64",
|
||||
Platform::LinuxARM32 => "LinuxARM32",
|
||||
Platform::DarwinX64 => "DarwinX64",
|
||||
Platform::DarwinARM64 => "DarwinARM64",
|
||||
Platform::WindowsX64 => "WindowsX64",
|
||||
Platform::WindowsX86 => "WindowsX86",
|
||||
Platform::WindowsARM64 => "WindowsARM64",
|
||||
})
|
||||
}
|
||||
}
|
||||
22
cli/src/util.rs
Normal file
22
cli/src/util.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
mod is_integrated;
|
||||
|
||||
pub mod command;
|
||||
pub mod errors;
|
||||
pub mod http;
|
||||
pub mod input;
|
||||
pub mod io;
|
||||
pub mod machine;
|
||||
pub mod prereqs;
|
||||
pub mod ring_buffer;
|
||||
pub mod sync;
|
||||
pub use is_integrated::*;
|
||||
pub mod app_lock;
|
||||
pub mod file_lock;
|
||||
pub mod os;
|
||||
pub mod tar;
|
||||
pub mod zipper;
|
||||
61
cli/src/util/app_lock.rs
Normal file
61
cli/src/util/app_lock.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(windows)]
|
||||
use std::{io, ptr};
|
||||
|
||||
#[cfg(windows)]
|
||||
use winapi::{
|
||||
shared::winerror::ERROR_ALREADY_EXISTS,
|
||||
um::{handleapi::CloseHandle, synchapi::CreateMutexA, winnt::HANDLE},
|
||||
};
|
||||
|
||||
use super::errors::CodeError;
|
||||
|
||||
pub struct AppMutex {
|
||||
#[cfg(windows)]
|
||||
handle: HANDLE,
|
||||
}
|
||||
|
||||
#[cfg(windows)] // handle is thread-safe, mark it so with this
|
||||
unsafe impl Send for AppMutex {}
|
||||
|
||||
impl AppMutex {
|
||||
#[cfg(unix)]
|
||||
pub fn new(_name: &str) -> Result<Self, CodeError> {
|
||||
Ok(Self {})
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
pub fn new(name: &str) -> Result<Self, CodeError> {
|
||||
use std::ffi::CString;
|
||||
|
||||
let cname = CString::new(name).unwrap();
|
||||
let handle = unsafe { CreateMutexA(ptr::null_mut(), 0, cname.as_ptr() as _) };
|
||||
|
||||
if !handle.is_null() {
|
||||
return Ok(Self { handle });
|
||||
}
|
||||
|
||||
let err = io::Error::last_os_error();
|
||||
let raw = err.raw_os_error();
|
||||
// docs report it should return ERROR_IO_PENDING, but in my testing it actually
|
||||
// returns ERROR_LOCK_VIOLATION. Or maybe winapi is wrong?
|
||||
if raw == Some(ERROR_ALREADY_EXISTS as i32) {
|
||||
return Err(CodeError::AppAlreadyLocked(name.to_string()));
|
||||
}
|
||||
|
||||
Err(CodeError::AppLockFailed(err))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for AppMutex {
|
||||
fn drop(&mut self) {
|
||||
#[cfg(windows)]
|
||||
unsafe {
|
||||
CloseHandle(self.handle)
|
||||
};
|
||||
}
|
||||
}
|
||||
119
cli/src/util/command.rs
Normal file
119
cli/src/util/command.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::errors::CodeError;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
ffi::OsStr,
|
||||
process::{Output, Stdio},
|
||||
};
|
||||
use tokio::process::Command;
|
||||
|
||||
pub async fn capture_command_and_check_status(
|
||||
command_str: impl AsRef<OsStr>,
|
||||
args: &[impl AsRef<OsStr>],
|
||||
) -> Result<std::process::Output, CodeError> {
|
||||
let output = capture_command(&command_str, args).await?;
|
||||
|
||||
check_output_status(output, || {
|
||||
format!(
|
||||
"{} {}",
|
||||
command_str.as_ref().to_string_lossy(),
|
||||
args.iter()
|
||||
.map(|a| a.as_ref().to_string_lossy())
|
||||
.collect::<Vec<Cow<'_, str>>>()
|
||||
.join(" ")
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn check_output_status(
|
||||
output: Output,
|
||||
cmd_str: impl FnOnce() -> String,
|
||||
) -> Result<std::process::Output, CodeError> {
|
||||
if !output.status.success() {
|
||||
return Err(CodeError::CommandFailed {
|
||||
command: cmd_str(),
|
||||
code: output.status.code().unwrap_or(-1),
|
||||
output: String::from_utf8_lossy(if output.stderr.is_empty() {
|
||||
&output.stdout
|
||||
} else {
|
||||
&output.stderr
|
||||
})
|
||||
.into(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub async fn capture_command<A, I, S>(
|
||||
command_str: A,
|
||||
args: I,
|
||||
) -> Result<std::process::Output, CodeError>
|
||||
where
|
||||
A: AsRef<OsStr>,
|
||||
I: IntoIterator<Item = S>,
|
||||
S: AsRef<OsStr>,
|
||||
{
|
||||
Command::new(&command_str)
|
||||
.args(args)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CodeError::CommandFailed {
|
||||
command: command_str.as_ref().to_string_lossy().to_string(),
|
||||
code: -1,
|
||||
output: e.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Kills and processes and all of its children.
|
||||
#[cfg(target_os = "windows")]
|
||||
pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> {
|
||||
capture_command("taskkill", &["/t", "/pid", &process_id.to_string()]).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Kills and processes and all of its children.
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
pub async fn kill_tree(process_id: u32) -> Result<(), CodeError> {
|
||||
use futures::future::join_all;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
|
||||
async fn kill_single_pid(process_id_str: String) {
|
||||
capture_command("kill", &[&process_id_str]).await.ok();
|
||||
}
|
||||
|
||||
// Rusty version of https://github.com/microsoft/vscode-js-debug/blob/main/src/targets/node/terminateProcess.sh
|
||||
|
||||
let parent_id = process_id.to_string();
|
||||
let mut prgrep_cmd = Command::new("pgrep")
|
||||
.arg("-P")
|
||||
.arg(&parent_id)
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CodeError::CommandFailed {
|
||||
command: format!("pgrep -P {}", parent_id),
|
||||
code: -1,
|
||||
output: e.to_string(),
|
||||
})?;
|
||||
|
||||
let mut kill_futures = vec![tokio::spawn(
|
||||
async move { kill_single_pid(parent_id).await },
|
||||
)];
|
||||
|
||||
if let Some(stdout) = prgrep_cmd.stdout.take() {
|
||||
let mut reader = BufReader::new(stdout).lines();
|
||||
while let Some(line) = reader.next_line().await.unwrap_or(None) {
|
||||
kill_futures.push(tokio::spawn(async move { kill_single_pid(line).await }))
|
||||
}
|
||||
}
|
||||
|
||||
join_all(kill_futures).await;
|
||||
prgrep_cmd.kill().await.ok();
|
||||
Ok(())
|
||||
}
|
||||
532
cli/src/util/errors.rs
Normal file
532
cli/src/util/errors.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::{
|
||||
constants::{APPLICATION_NAME, CONTROL_PORT, DOCUMENTATION_URL, QUALITYLESS_PRODUCT_NAME},
|
||||
rpc::ResponseError,
|
||||
};
|
||||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
// Wraps another error with additional info.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WrappedError {
|
||||
message: String,
|
||||
original: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WrappedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}: {}", self.message, self.original)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WrappedError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl WrappedError {
|
||||
// fn new(original: Box<dyn std::error::Error>, message: String) -> WrappedError {
|
||||
// WrappedError { message, original }
|
||||
// }
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for WrappedError {
|
||||
fn from(e: reqwest::Error) -> WrappedError {
|
||||
WrappedError {
|
||||
message: format!(
|
||||
"error requesting {}",
|
||||
e.url().map_or("<unknown>", |u| u.as_str())
|
||||
),
|
||||
original: format!("{}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrapdbg<T, S>(original: T, message: S) -> WrappedError
|
||||
where
|
||||
T: std::fmt::Debug,
|
||||
S: Into<String>,
|
||||
{
|
||||
WrappedError {
|
||||
message: message.into(),
|
||||
original: format!("{:?}", original),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrap<T, S>(original: T, message: S) -> WrappedError
|
||||
where
|
||||
T: Display,
|
||||
S: Into<String>,
|
||||
{
|
||||
WrappedError {
|
||||
message: message.into(),
|
||||
original: format!("{}", original),
|
||||
}
|
||||
}
|
||||
|
||||
// Error generated by an unsuccessful HTTP response
|
||||
#[derive(Debug)]
|
||||
pub struct StatusError {
|
||||
pub url: String,
|
||||
pub status_code: u16,
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for StatusError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"error requesting {}: {} {}",
|
||||
self.url, self.status_code, self.body
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl StatusError {
|
||||
pub async fn from_res(res: reqwest::Response) -> Result<StatusError, AnyError> {
|
||||
let status_code = res.status().as_u16();
|
||||
let url = res.url().to_string();
|
||||
let body = res.text().await.map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!(
|
||||
"failed to read response body on {} code from {}",
|
||||
status_code, url
|
||||
),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(StatusError {
|
||||
url,
|
||||
status_code,
|
||||
body,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// When the user has not consented to the licensing terms in using the Launcher
|
||||
#[derive(Debug)]
|
||||
pub struct MissingLegalConsent(pub String);
|
||||
|
||||
impl std::fmt::Display for MissingLegalConsent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
// When the provided connection token doesn't match the one used to set up the original VS Code Server
|
||||
// This is most likely due to a new user joining.
|
||||
#[derive(Debug)]
|
||||
pub struct MismatchConnectionToken(pub String);
|
||||
|
||||
impl std::fmt::Display for MismatchConnectionToken {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
// When the VS Code server has an unrecognized extension (rather than zip or gz)
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidServerExtensionError(pub String);
|
||||
|
||||
impl std::fmt::Display for InvalidServerExtensionError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "invalid server extension '{}'", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
// When the tunnel fails to open
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DevTunnelError(pub String);
|
||||
|
||||
impl std::fmt::Display for DevTunnelError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "could not open tunnel: {}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DevTunnelError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// When the server was downloaded, but the entrypoint scripts don't exist.
|
||||
#[derive(Debug)]
|
||||
pub struct MissingEntrypointError();
|
||||
|
||||
impl std::fmt::Display for MissingEntrypointError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Missing entrypoints in server download. Most likely this is a corrupted download. Please retry")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SetupError(pub String);
|
||||
|
||||
impl std::fmt::Display for SetupError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}\n\nMore info at {}/remote/linux",
|
||||
DOCUMENTATION_URL.unwrap_or("<docs>"),
|
||||
self.0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoHomeForLauncherError();
|
||||
|
||||
impl std::fmt::Display for NoHomeForLauncherError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"No $HOME variable was found in your environment. Either set it, or specify a `--data-dir` manually when invoking the launcher.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidTunnelName(pub String);
|
||||
|
||||
impl std::fmt::Display for InvalidTunnelName {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", &self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TunnelCreationFailed(pub String, pub String);
|
||||
|
||||
impl std::fmt::Display for TunnelCreationFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Could not create tunnel with name: {}\nReason: {}",
|
||||
&self.0, &self.1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TunnelHostFailed(pub String);
|
||||
|
||||
impl std::fmt::Display for TunnelHostFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", &self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExtensionInstallFailed(pub String);
|
||||
|
||||
impl std::fmt::Display for ExtensionInstallFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Extension install failed: {}", &self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MismatchedLaunchModeError();
|
||||
|
||||
impl std::fmt::Display for MismatchedLaunchModeError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "A server is already running, but it was not launched in the same listening mode (port vs. socket) as this request")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoAttachedServerError();
|
||||
|
||||
impl std::fmt::Display for NoAttachedServerError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "No server is running")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RefreshTokenNotAvailableError();
|
||||
|
||||
impl std::fmt::Display for RefreshTokenNotAvailableError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Refresh token not available, authentication is required")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NoInstallInUserProvidedPath(pub String);
|
||||
|
||||
impl std::fmt::Display for NoInstallInUserProvidedPath {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"No {} installation could be found in {}. You can run `{} --use-quality=stable` to switch to the latest stable version of {}.",
|
||||
QUALITYLESS_PRODUCT_NAME,
|
||||
self.0,
|
||||
APPLICATION_NAME,
|
||||
QUALITYLESS_PRODUCT_NAME
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidRequestedVersion();
|
||||
|
||||
impl std::fmt::Display for InvalidRequestedVersion {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"The reqested version is invalid, expected one of 'stable', 'insiders', version number (x.y.z), or absolute path.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UserCancelledInstallation();
|
||||
|
||||
impl std::fmt::Display for UserCancelledInstallation {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Installation aborted.")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CannotForwardControlPort();
|
||||
|
||||
impl std::fmt::Display for CannotForwardControlPort {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Cannot forward or unforward port {}.", CONTROL_PORT)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ServerHasClosed();
|
||||
|
||||
impl std::fmt::Display for ServerHasClosed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Request cancelled because the server has closed")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpdatesNotConfigured(pub String);
|
||||
|
||||
impl UpdatesNotConfigured {
|
||||
pub fn no_url() -> Self {
|
||||
UpdatesNotConfigured("no service url".to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UpdatesNotConfigured {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Update service is not configured: {}", self.0)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct ServiceAlreadyRegistered();
|
||||
|
||||
impl std::fmt::Display for ServiceAlreadyRegistered {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Already registered the service. Run `{} tunnel service uninstall` to unregister it first", APPLICATION_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WindowsNeedsElevation(pub String);
|
||||
|
||||
impl std::fmt::Display for WindowsNeedsElevation {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
writeln!(f, "{}", self.0)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f, "You may need to run this command as an administrator:")?;
|
||||
writeln!(f, " 1. Open the start menu and search for Powershell")?;
|
||||
writeln!(f, " 2. Right click and 'Run as administrator'")?;
|
||||
if let Ok(exe) = std::env::current_exe() {
|
||||
writeln!(
|
||||
f,
|
||||
" 3. Run &'{}' '{}'",
|
||||
exe.display(),
|
||||
std::env::args().skip(1).collect::<Vec<_>>().join("' '")
|
||||
)
|
||||
} else {
|
||||
writeln!(f, " 3. Run the same command again",)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InvalidRpcDataError(pub String);
|
||||
|
||||
impl std::fmt::Display for InvalidRpcDataError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "parse error: {}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CorruptDownload(pub String);
|
||||
|
||||
impl std::fmt::Display for CorruptDownload {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Error updating the {} CLI: {}",
|
||||
QUALITYLESS_PRODUCT_NAME, self.0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MissingHomeDirectory();
|
||||
|
||||
impl std::fmt::Display for MissingHomeDirectory {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "Could not find your home directory. Please ensure this command is running in the context of an normal user.")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OAuthError {
|
||||
pub error: String,
|
||||
pub error_description: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for OAuthError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Error getting authorization: {} {}",
|
||||
self.error,
|
||||
self.error_description.as_deref().unwrap_or("")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Makes an "AnyError" enum that contains any of the given errors, in the form
|
||||
// `enum AnyError { FooError(FooError) }` (when given `makeAnyError!(FooError)`).
|
||||
// Useful to easily deal with application error types without making tons of "From"
|
||||
// clauses.
|
||||
macro_rules! makeAnyError {
|
||||
($($e:ident),*) => {
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::enum_variant_names)]
|
||||
pub enum AnyError {
|
||||
$($e($e),)*
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AnyError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match *self {
|
||||
$(AnyError::$e(ref e) => e.fmt(f),)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AnyError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
$(impl From<$e> for AnyError {
|
||||
fn from(e: $e) -> AnyError {
|
||||
AnyError::$e(e)
|
||||
}
|
||||
})*
|
||||
};
|
||||
}
|
||||
|
||||
/// Internal errors in the VS Code CLI.
|
||||
/// Note: other error should be migrated to this type gradually
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CodeError {
|
||||
#[error("could not connect to socket/pipe: {0:?}")]
|
||||
AsyncPipeFailed(std::io::Error),
|
||||
#[error("could not listen on socket/pipe: {0:?}")]
|
||||
AsyncPipeListenerFailed(std::io::Error),
|
||||
#[error("could not create singleton lock file: {0:?}")]
|
||||
SingletonLockfileOpenFailed(std::io::Error),
|
||||
#[error("could not read singleton lock file: {0:?}")]
|
||||
SingletonLockfileReadFailed(rmp_serde::decode::Error),
|
||||
#[error("the process holding the singleton lock file (pid={0}) exited")]
|
||||
SingletonLockedProcessExited(u32),
|
||||
#[error("no tunnel process is currently running")]
|
||||
NoRunningTunnel,
|
||||
#[error("rpc call failed: {0:?}")]
|
||||
TunnelRpcCallFailed(ResponseError),
|
||||
#[cfg(windows)]
|
||||
#[error("the windows app lock {0} already exists")]
|
||||
AppAlreadyLocked(String),
|
||||
#[cfg(windows)]
|
||||
#[error("could not get windows app lock: {0:?}")]
|
||||
AppLockFailed(std::io::Error),
|
||||
#[error("failed to run command \"{command}\" (code {code}): {output}")]
|
||||
CommandFailed {
|
||||
command: String,
|
||||
code: i32,
|
||||
output: String,
|
||||
},
|
||||
|
||||
#[error("platform not currently supported: {0}")]
|
||||
UnsupportedPlatform(String),
|
||||
#[error("This machine not meet {name}'s prerequisites, expected either...: {bullets}")]
|
||||
PrerequisitesFailed { name: &'static str, bullets: String },
|
||||
#[error("failed to spawn process: {0:?}")]
|
||||
ProcessSpawnFailed(std::io::Error),
|
||||
#[error("failed to handshake spawned process: {0:?}")]
|
||||
ProcessSpawnHandshakeFailed(std::io::Error),
|
||||
#[error("download appears corrupted, please retry ({0})")]
|
||||
CorruptDownload(&'static str),
|
||||
#[error("port forwarding is not available in this context")]
|
||||
PortForwardingNotAvailable,
|
||||
#[error("'auth' call required")]
|
||||
ServerAuthRequired,
|
||||
#[error("challenge not yet issued")]
|
||||
AuthChallengeNotIssued,
|
||||
#[error("unauthorized client refused")]
|
||||
AuthMismatch,
|
||||
}
|
||||
|
||||
makeAnyError!(
|
||||
MissingLegalConsent,
|
||||
MismatchConnectionToken,
|
||||
DevTunnelError,
|
||||
StatusError,
|
||||
WrappedError,
|
||||
InvalidServerExtensionError,
|
||||
MissingEntrypointError,
|
||||
SetupError,
|
||||
NoHomeForLauncherError,
|
||||
TunnelCreationFailed,
|
||||
TunnelHostFailed,
|
||||
InvalidTunnelName,
|
||||
ExtensionInstallFailed,
|
||||
MismatchedLaunchModeError,
|
||||
NoAttachedServerError,
|
||||
RefreshTokenNotAvailableError,
|
||||
NoInstallInUserProvidedPath,
|
||||
UserCancelledInstallation,
|
||||
InvalidRequestedVersion,
|
||||
CannotForwardControlPort,
|
||||
ServerHasClosed,
|
||||
ServiceAlreadyRegistered,
|
||||
WindowsNeedsElevation,
|
||||
UpdatesNotConfigured,
|
||||
CorruptDownload,
|
||||
MissingHomeDirectory,
|
||||
OAuthError,
|
||||
InvalidRpcDataError,
|
||||
CodeError
|
||||
);
|
||||
|
||||
impl From<reqwest::Error> for AnyError {
|
||||
fn from(e: reqwest::Error) -> AnyError {
|
||||
AnyError::WrappedError(WrappedError::from(e))
|
||||
}
|
||||
}
|
||||
125
cli/src/util/file_lock.rs
Normal file
125
cli/src/util/file_lock.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use crate::util::errors::CodeError;
|
||||
use std::{fs::File, io};
|
||||
|
||||
pub struct FileLock {
|
||||
file: File,
|
||||
#[cfg(windows)]
|
||||
overlapped: winapi::um::minwinbase::OVERLAPPED,
|
||||
}
|
||||
|
||||
#[cfg(windows)] // overlapped is thread-safe, mark it so with this
|
||||
unsafe impl Send for FileLock {}
|
||||
|
||||
pub enum Lock {
|
||||
Acquired(FileLock),
|
||||
AlreadyLocked(File),
|
||||
}
|
||||
|
||||
/// Number of locked bytes in the file. On Windows, locking prevents reads,
|
||||
/// but consumers of the lock may still want to read what the locking file
|
||||
/// as written. Thus, only PREFIX_LOCKED_BYTES are locked, and any globally-
|
||||
/// readable content should be written after the prefix.
|
||||
#[cfg(windows)]
|
||||
pub const PREFIX_LOCKED_BYTES: usize = 1;
|
||||
|
||||
#[cfg(unix)]
|
||||
pub const PREFIX_LOCKED_BYTES: usize = 0;
|
||||
|
||||
impl FileLock {
|
||||
#[cfg(windows)]
|
||||
pub fn acquire(file: File) -> Result<Lock, CodeError> {
|
||||
use std::os::windows::prelude::AsRawHandle;
|
||||
use winapi::{
|
||||
shared::winerror::{ERROR_IO_PENDING, ERROR_LOCK_VIOLATION},
|
||||
um::{
|
||||
fileapi::LockFileEx,
|
||||
minwinbase::{LOCKFILE_EXCLUSIVE_LOCK, LOCKFILE_FAIL_IMMEDIATELY},
|
||||
},
|
||||
};
|
||||
|
||||
let handle = file.as_raw_handle();
|
||||
let (overlapped, ok) = unsafe {
|
||||
let mut overlapped = std::mem::zeroed();
|
||||
let ok = LockFileEx(
|
||||
handle,
|
||||
LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY,
|
||||
0,
|
||||
PREFIX_LOCKED_BYTES as u32,
|
||||
0,
|
||||
&mut overlapped,
|
||||
);
|
||||
|
||||
(overlapped, ok)
|
||||
};
|
||||
|
||||
if ok != 0 {
|
||||
return Ok(Lock::Acquired(Self { file, overlapped }));
|
||||
}
|
||||
|
||||
let err = io::Error::last_os_error();
|
||||
let raw = err.raw_os_error();
|
||||
// docs report it should return ERROR_IO_PENDING, but in my testing it actually
|
||||
// returns ERROR_LOCK_VIOLATION. Or maybe winapi is wrong?
|
||||
if raw == Some(ERROR_IO_PENDING as i32) || raw == Some(ERROR_LOCK_VIOLATION as i32) {
|
||||
return Ok(Lock::AlreadyLocked(file));
|
||||
}
|
||||
|
||||
Err(CodeError::SingletonLockfileOpenFailed(err))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn acquire(file: File) -> Result<Lock, CodeError> {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
let fd = file.as_raw_fd();
|
||||
let res = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) };
|
||||
if res == 0 {
|
||||
return Ok(Lock::Acquired(Self { file }));
|
||||
}
|
||||
|
||||
let err = io::Error::last_os_error();
|
||||
if err.kind() == io::ErrorKind::WouldBlock {
|
||||
return Ok(Lock::AlreadyLocked(file));
|
||||
}
|
||||
|
||||
Err(CodeError::SingletonLockfileOpenFailed(err))
|
||||
}
|
||||
|
||||
pub fn file(&self) -> &File {
|
||||
&self.file
|
||||
}
|
||||
|
||||
pub fn file_mut(&mut self) -> &mut File {
|
||||
&mut self.file
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for FileLock {
|
||||
#[cfg(windows)]
|
||||
fn drop(&mut self) {
|
||||
use std::os::windows::prelude::AsRawHandle;
|
||||
use winapi::um::fileapi::UnlockFileEx;
|
||||
|
||||
unsafe {
|
||||
UnlockFileEx(
|
||||
self.file.as_raw_handle(),
|
||||
0,
|
||||
u32::MAX,
|
||||
u32::MAX,
|
||||
&mut self.overlapped,
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn drop(&mut self) {
|
||||
use std::os::unix::io::AsRawFd;
|
||||
|
||||
unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) };
|
||||
}
|
||||
}
|
||||
376
cli/src/util/http.rs
Normal file
376
cli/src/util/http.rs
Normal file
@@ -0,0 +1,376 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::{
|
||||
constants::get_default_user_agent,
|
||||
log,
|
||||
util::errors::{self, WrappedError},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use core::panic;
|
||||
use futures::stream::TryStreamExt;
|
||||
use hyper::{
|
||||
header::{HeaderName, CONTENT_LENGTH},
|
||||
http::HeaderValue,
|
||||
HeaderMap, StatusCode,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{io, pin::Pin, str::FromStr, sync::Arc, task::Poll};
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{AsyncRead, AsyncReadExt},
|
||||
sync::mpsc,
|
||||
};
|
||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||
|
||||
use super::{
|
||||
errors::{wrap, AnyError, StatusError},
|
||||
io::{copy_async_progress, ReadBuffer, ReportCopyProgress},
|
||||
};
|
||||
|
||||
pub async fn download_into_file<T>(
|
||||
filename: &std::path::Path,
|
||||
progress: T,
|
||||
mut res: SimpleResponse,
|
||||
) -> Result<fs::File, WrappedError>
|
||||
where
|
||||
T: ReportCopyProgress,
|
||||
{
|
||||
let mut file = fs::File::create(filename)
|
||||
.await
|
||||
.map_err(|e| errors::wrap(e, "failed to create file"))?;
|
||||
|
||||
let content_length = res
|
||||
.headers
|
||||
.get(CONTENT_LENGTH)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
copy_async_progress(progress, &mut res.read, &mut file, content_length)
|
||||
.await
|
||||
.map_err(|e| errors::wrap(e, "failed to download file"))?;
|
||||
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
pub struct SimpleResponse {
|
||||
pub status_code: StatusCode,
|
||||
pub headers: HeaderMap,
|
||||
pub read: Pin<Box<dyn Send + AsyncRead + 'static>>,
|
||||
pub url: Option<url::Url>,
|
||||
}
|
||||
|
||||
impl SimpleResponse {
|
||||
pub fn url_path_basename(&self) -> Option<String> {
|
||||
self.url.as_ref().and_then(|u| {
|
||||
u.path_segments()
|
||||
.and_then(|s| s.last().map(|s| s.to_owned()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleResponse {
|
||||
pub fn generic_error(url: &str) -> Self {
|
||||
let (_, rx) = mpsc::unbounded_channel();
|
||||
SimpleResponse {
|
||||
url: url::Url::parse(url).ok(),
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
headers: HeaderMap::new(),
|
||||
read: Box::pin(DelegatedReader::new(rx)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the response into a StatusError
|
||||
pub async fn into_err(mut self) -> StatusError {
|
||||
let mut body = String::new();
|
||||
self.read.read_to_string(&mut body).await.ok();
|
||||
|
||||
StatusError {
|
||||
url: self
|
||||
.url
|
||||
.map(|u| u.to_string())
|
||||
.unwrap_or_else(|| "<invalid url>".to_owned()),
|
||||
status_code: self.status_code.as_u16(),
|
||||
body,
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserializes the response body as JSON
|
||||
pub async fn json<T: DeserializeOwned>(&mut self) -> Result<T, AnyError> {
|
||||
let mut buf = vec![];
|
||||
|
||||
// ideally serde would deserialize a stream, but it does not appear that
|
||||
// is supported. reqwest itself reads and decodes separately like we do here:
|
||||
self.read
|
||||
.read_to_end(&mut buf)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error reading response"))?;
|
||||
|
||||
let t = serde_json::from_slice(&buf)
|
||||
.map_err(|e| wrap(e, format!("error decoding json from {:?}", self.url)))?;
|
||||
|
||||
Ok(t)
|
||||
}
|
||||
}
|
||||
|
||||
/// *Very* simple HTTP implementation. In most cases, this will just delegate to
|
||||
/// the request library on the server (i.e. `reqwest`) but it can also be used
|
||||
/// to make update/download requests on the client rather than the server,
|
||||
/// similar to SSH's `remote.SSH.localServerDownload` setting.
|
||||
#[async_trait]
|
||||
pub trait SimpleHttp {
|
||||
async fn make_request(
|
||||
&self,
|
||||
method: &'static str,
|
||||
url: String,
|
||||
) -> Result<SimpleResponse, AnyError>;
|
||||
}
|
||||
|
||||
pub type BoxedHttp = Arc<dyn SimpleHttp + Send + Sync + 'static>;
|
||||
|
||||
// Implementation of SimpleHttp that uses a reqwest client.
|
||||
#[derive(Clone)]
|
||||
pub struct ReqwestSimpleHttp {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl ReqwestSimpleHttp {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
client: reqwest::ClientBuilder::new()
|
||||
.user_agent(get_default_user_agent())
|
||||
.build()
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_client(client: reqwest::Client) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReqwestSimpleHttp {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleHttp for ReqwestSimpleHttp {
|
||||
async fn make_request(
|
||||
&self,
|
||||
method: &'static str,
|
||||
url: String,
|
||||
) -> Result<SimpleResponse, AnyError> {
|
||||
let res = self
|
||||
.client
|
||||
.request(reqwest::Method::try_from(method).unwrap(), &url)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(SimpleResponse {
|
||||
status_code: res.status(),
|
||||
headers: res.headers().clone(),
|
||||
url: Some(res.url().clone()),
|
||||
read: Box::pin(
|
||||
res.bytes_stream()
|
||||
.map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
|
||||
.into_async_read()
|
||||
.compat(),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum DelegatedHttpEvent {
|
||||
InitResponse {
|
||||
status_code: u16,
|
||||
headers: Vec<(String, String)>,
|
||||
},
|
||||
Body(Vec<u8>),
|
||||
End,
|
||||
}
|
||||
|
||||
// Handle for a delegated request that allows manually issuing and response.
|
||||
pub struct DelegatedHttpRequest {
|
||||
pub method: &'static str,
|
||||
pub url: String,
|
||||
ch: mpsc::UnboundedSender<DelegatedHttpEvent>,
|
||||
}
|
||||
|
||||
impl DelegatedHttpRequest {
|
||||
pub fn initial_response(&self, status_code: u16, headers: Vec<(String, String)>) {
|
||||
self.ch
|
||||
.send(DelegatedHttpEvent::InitResponse {
|
||||
status_code,
|
||||
headers,
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
|
||||
pub fn body(&self, chunk: Vec<u8>) {
|
||||
self.ch.send(DelegatedHttpEvent::Body(chunk)).ok();
|
||||
}
|
||||
|
||||
pub fn end(self) {}
|
||||
}
|
||||
|
||||
impl Drop for DelegatedHttpRequest {
|
||||
fn drop(&mut self) {
|
||||
self.ch.send(DelegatedHttpEvent::End).ok();
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of SimpleHttp that allows manually controlling responses.
|
||||
#[derive(Clone)]
|
||||
pub struct DelegatedSimpleHttp {
|
||||
start_request: mpsc::Sender<DelegatedHttpRequest>,
|
||||
log: log::Logger,
|
||||
}
|
||||
|
||||
impl DelegatedSimpleHttp {
|
||||
pub fn new(log: log::Logger) -> (Self, mpsc::Receiver<DelegatedHttpRequest>) {
|
||||
let (tx, rx) = mpsc::channel(4);
|
||||
(
|
||||
DelegatedSimpleHttp {
|
||||
log,
|
||||
start_request: tx,
|
||||
},
|
||||
rx,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleHttp for DelegatedSimpleHttp {
|
||||
async fn make_request(
|
||||
&self,
|
||||
method: &'static str,
|
||||
url: String,
|
||||
) -> Result<SimpleResponse, AnyError> {
|
||||
trace!(self.log, "making delegated request to {}", url);
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
let sent = self
|
||||
.start_request
|
||||
.send(DelegatedHttpRequest {
|
||||
method,
|
||||
url: url.clone(),
|
||||
ch: tx,
|
||||
})
|
||||
.await;
|
||||
|
||||
if sent.is_err() {
|
||||
return Ok(SimpleResponse::generic_error(&url)); // sender shut down
|
||||
}
|
||||
|
||||
match rx.recv().await {
|
||||
Some(DelegatedHttpEvent::InitResponse {
|
||||
status_code,
|
||||
headers,
|
||||
}) => {
|
||||
trace!(
|
||||
self.log,
|
||||
"delegated request to {} resulted in status = {}",
|
||||
url,
|
||||
status_code
|
||||
);
|
||||
let mut headers_map = HeaderMap::with_capacity(headers.len());
|
||||
for (k, v) in &headers {
|
||||
if let (Ok(key), Ok(value)) = (
|
||||
HeaderName::from_str(&k.to_lowercase()),
|
||||
HeaderValue::from_str(v),
|
||||
) {
|
||||
headers_map.insert(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(SimpleResponse {
|
||||
url: url::Url::parse(&url).ok(),
|
||||
status_code: StatusCode::from_u16(status_code)
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
headers: headers_map,
|
||||
read: Box::pin(DelegatedReader::new(rx)),
|
||||
})
|
||||
}
|
||||
Some(DelegatedHttpEvent::End) => Ok(SimpleResponse::generic_error(&url)),
|
||||
Some(_) => panic!("expected initresponse as first message from delegated http"),
|
||||
None => Ok(SimpleResponse::generic_error(&url)), // sender shut down
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct DelegatedReader {
|
||||
receiver: mpsc::UnboundedReceiver<DelegatedHttpEvent>,
|
||||
readbuf: ReadBuffer,
|
||||
}
|
||||
|
||||
impl DelegatedReader {
|
||||
pub fn new(rx: mpsc::UnboundedReceiver<DelegatedHttpEvent>) -> Self {
|
||||
DelegatedReader {
|
||||
readbuf: ReadBuffer::default(),
|
||||
receiver: rx,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for DelegatedReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
if let Some((v, s)) = self.readbuf.take_data() {
|
||||
return self.readbuf.put_data(buf, v, s);
|
||||
}
|
||||
|
||||
match self.receiver.poll_recv(cx) {
|
||||
Poll::Ready(Some(DelegatedHttpEvent::Body(msg))) => self.readbuf.put_data(buf, msg, 0),
|
||||
Poll::Ready(Some(_)) => Poll::Ready(Ok(())), // EOF
|
||||
Poll::Ready(None) => {
|
||||
Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof, "EOF")))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple http implementation that falls back to delegated http if
|
||||
/// making a direct reqwest fails.
|
||||
pub struct FallbackSimpleHttp {
|
||||
native: ReqwestSimpleHttp,
|
||||
delegated: DelegatedSimpleHttp,
|
||||
}
|
||||
|
||||
impl FallbackSimpleHttp {
|
||||
pub fn new(native: ReqwestSimpleHttp, delegated: DelegatedSimpleHttp) -> Self {
|
||||
FallbackSimpleHttp { native, delegated }
|
||||
}
|
||||
|
||||
pub fn native(&self) -> ReqwestSimpleHttp {
|
||||
self.native.clone()
|
||||
}
|
||||
|
||||
pub fn delegated(&self) -> DelegatedSimpleHttp {
|
||||
self.delegated.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleHttp for FallbackSimpleHttp {
|
||||
async fn make_request(
|
||||
&self,
|
||||
method: &'static str,
|
||||
url: String,
|
||||
) -> Result<SimpleResponse, AnyError> {
|
||||
let r1 = self.native.make_request(method, url.clone()).await;
|
||||
if let Ok(res) = r1 {
|
||||
if !res.status_code.is_server_error() {
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
|
||||
self.delegated.make_request(method, url).await
|
||||
}
|
||||
}
|
||||
69
cli/src/util/input.rs
Normal file
69
cli/src/util/input.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::util::errors::wrap;
|
||||
use dialoguer::{theme::ColorfulTheme, Confirm, Input, Select};
|
||||
use indicatif::ProgressBar;
|
||||
use std::fmt::Display;
|
||||
|
||||
use super::{errors::WrappedError, io::ReportCopyProgress};
|
||||
|
||||
/// Wrapper around indicatif::ProgressBar that implements ReportCopyProgress.
|
||||
pub struct ProgressBarReporter {
|
||||
bar: ProgressBar,
|
||||
has_set_total: bool,
|
||||
}
|
||||
|
||||
impl From<ProgressBar> for ProgressBarReporter {
|
||||
fn from(bar: ProgressBar) -> Self {
|
||||
ProgressBarReporter {
|
||||
bar,
|
||||
has_set_total: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportCopyProgress for ProgressBarReporter {
|
||||
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64) {
|
||||
if !self.has_set_total {
|
||||
self.bar.set_length(total_bytes);
|
||||
}
|
||||
|
||||
if bytes_so_far == total_bytes {
|
||||
self.bar.finish_and_clear();
|
||||
} else {
|
||||
self.bar.set_position(bytes_so_far);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompt_yn(text: &str) -> Result<bool, WrappedError> {
|
||||
Confirm::with_theme(&ColorfulTheme::default())
|
||||
.with_prompt(text)
|
||||
.default(true)
|
||||
.interact()
|
||||
.map_err(|e| wrap(e, "Failed to read confirm input"))
|
||||
}
|
||||
|
||||
pub fn prompt_options<T>(text: impl Into<String>, options: &[T]) -> Result<T, WrappedError>
|
||||
where
|
||||
T: Display + Copy,
|
||||
{
|
||||
let chosen = Select::with_theme(&ColorfulTheme::default())
|
||||
.with_prompt(text)
|
||||
.items(options)
|
||||
.default(0)
|
||||
.interact()
|
||||
.map_err(|e| wrap(e, "Failed to read select input"))?;
|
||||
|
||||
Ok(options[chosen])
|
||||
}
|
||||
|
||||
pub fn prompt_placeholder(question: &str, placeholder: &str) -> Result<String, WrappedError> {
|
||||
Input::with_theme(&ColorfulTheme::default())
|
||||
.with_prompt(question)
|
||||
.default(placeholder.to_string())
|
||||
.interact_text()
|
||||
.map_err(|e| wrap(e, "Failed to read confirm input"))
|
||||
}
|
||||
355
cli/src/util/io.rs
Normal file
355
cli/src/util/io.rs
Normal file
@@ -0,0 +1,355 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{self, BufRead, Seek},
|
||||
task::Poll,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
sync::mpsc,
|
||||
time::sleep,
|
||||
};
|
||||
|
||||
use super::ring_buffer::RingBuffer;
|
||||
|
||||
pub trait ReportCopyProgress {
|
||||
fn report_progress(&mut self, bytes_so_far: u64, total_bytes: u64);
|
||||
}
|
||||
|
||||
/// Type that doesn't emit anything for download progress.
|
||||
pub struct SilentCopyProgress();
|
||||
|
||||
impl ReportCopyProgress for SilentCopyProgress {
|
||||
fn report_progress(&mut self, _bytes_so_far: u64, _total_bytes: u64) {}
|
||||
}
|
||||
|
||||
/// Copies from the reader to the writer, reporting progress to the provided
|
||||
/// reporter every so often.
|
||||
pub async fn copy_async_progress<T, R, W>(
|
||||
mut reporter: T,
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
total_bytes: u64,
|
||||
) -> io::Result<u64>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
T: ReportCopyProgress,
|
||||
{
|
||||
let mut buf = vec![0; 8 * 1024];
|
||||
let mut bytes_so_far = 0;
|
||||
let mut bytes_last_reported = 0;
|
||||
let report_granularity = std::cmp::min(total_bytes / 10, 2 * 1024 * 1024);
|
||||
|
||||
reporter.report_progress(0, total_bytes);
|
||||
|
||||
loop {
|
||||
let read_buf = match reader.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => &buf[..n],
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
writer.write_all(read_buf).await?;
|
||||
|
||||
bytes_so_far += read_buf.len() as u64;
|
||||
if bytes_so_far - bytes_last_reported > report_granularity {
|
||||
bytes_last_reported = bytes_so_far;
|
||||
reporter.report_progress(bytes_so_far, total_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
reporter.report_progress(bytes_so_far, total_bytes);
|
||||
|
||||
Ok(bytes_so_far)
|
||||
}
|
||||
|
||||
/// Helper used when converting Future interfaces to poll-based interfaces.
|
||||
/// Stores excess data that can be reused on future polls.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ReadBuffer(Option<(Vec<u8>, usize)>);
|
||||
|
||||
impl ReadBuffer {
|
||||
/// Removes any data stored in the read buffer
|
||||
pub fn take_data(&mut self) -> Option<(Vec<u8>, usize)> {
|
||||
self.0.take()
|
||||
}
|
||||
|
||||
/// Writes as many bytes as possible to the readbuf, stashing any extra.
|
||||
pub fn put_data(
|
||||
&mut self,
|
||||
target: &mut tokio::io::ReadBuf<'_>,
|
||||
bytes: Vec<u8>,
|
||||
start: usize,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
if bytes.is_empty() {
|
||||
self.0 = None;
|
||||
// should not return Ok(), since if nothing is written to the target
|
||||
// it signals EOF. Instead wait for more data from the source.
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
if target.remaining() >= bytes.len() - start {
|
||||
target.put_slice(&bytes[start..]);
|
||||
self.0 = None;
|
||||
} else {
|
||||
let end = start + target.remaining();
|
||||
target.put_slice(&bytes[start..end]);
|
||||
self.0 = Some((bytes, end));
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TailEvent {
|
||||
/// A new line was read from the file. The line includes its trailing newline character.
|
||||
Line(String),
|
||||
/// The file appears to have been rewritten (size shrunk)
|
||||
Reset,
|
||||
/// An error was encountered with the file.
|
||||
Err(io::Error),
|
||||
}
|
||||
|
||||
/// Simple, naive implementation of `tail -f -n <n> <path>`. Uses polling, so
|
||||
/// it's not the fastest, but simple and working for easy cases.
|
||||
pub fn tailf(file: File, n: usize) -> mpsc::UnboundedReceiver<TailEvent> {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let mut last_len = match file.metadata() {
|
||||
Ok(m) => m.len(),
|
||||
Err(e) => {
|
||||
tx.send(TailEvent::Err(e)).ok();
|
||||
return rx;
|
||||
}
|
||||
};
|
||||
|
||||
let mut reader = io::BufReader::new(file);
|
||||
let mut pos = 0;
|
||||
|
||||
// Read the initial "n" lines back from the request. initial_lines
|
||||
// is a small ring buffer.
|
||||
let mut initial_lines = RingBuffer::new(n);
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
let bytes_read = match reader.read_line(&mut line) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tx.send(TailEvent::Err(e)).ok();
|
||||
return rx;
|
||||
}
|
||||
};
|
||||
|
||||
if !line.ends_with('\n') {
|
||||
// EOF
|
||||
break;
|
||||
}
|
||||
|
||||
pos += bytes_read as u64;
|
||||
initial_lines.push(line);
|
||||
}
|
||||
|
||||
for line in initial_lines.into_iter() {
|
||||
tx.send(TailEvent::Line(line)).ok();
|
||||
}
|
||||
|
||||
// now spawn the poll process to keep reading new lines
|
||||
tokio::spawn(async move {
|
||||
let poll_interval = Duration::from_millis(500);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = sleep(poll_interval) => {},
|
||||
_ = tx.closed() => return
|
||||
}
|
||||
|
||||
match reader.get_ref().metadata() {
|
||||
Err(e) => {
|
||||
tx.send(TailEvent::Err(e)).ok();
|
||||
return;
|
||||
}
|
||||
Ok(m) => {
|
||||
if m.len() == last_len {
|
||||
continue;
|
||||
}
|
||||
|
||||
if m.len() < last_len {
|
||||
tx.send(TailEvent::Reset).ok();
|
||||
pos = 0;
|
||||
}
|
||||
|
||||
last_len = m.len();
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = reader.seek(io::SeekFrom::Start(pos)) {
|
||||
tx.send(TailEvent::Err(e)).ok();
|
||||
return;
|
||||
}
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
let n = match reader.read_line(&mut line) {
|
||||
Ok(0) => break,
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
tx.send(TailEvent::Err(e)).ok();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if n == 0 || !line.ends_with('\n') {
|
||||
break;
|
||||
}
|
||||
|
||||
pos += n as u64;
|
||||
if tx.send(TailEvent::Line(line)).is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::Rng;
|
||||
use std::{fs::OpenOptions, io::Write};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tailf_empty() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file_path = dir.path().join("tmp");
|
||||
|
||||
let read_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.read(true)
|
||||
.create(true)
|
||||
.open(&file_path)
|
||||
.unwrap();
|
||||
|
||||
let mut rx = tailf(read_file, 32);
|
||||
assert!(rx.try_recv().is_err());
|
||||
|
||||
let mut append_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.append(true)
|
||||
.open(&file_path)
|
||||
.unwrap();
|
||||
writeln!(&mut append_file, "some line").unwrap();
|
||||
|
||||
let recv = rx.recv().await;
|
||||
if let Some(TailEvent::Line(l)) = recv {
|
||||
assert_eq!("some line\n".to_string(), l);
|
||||
} else {
|
||||
unreachable!("expect a line event, got {:?}", recv)
|
||||
}
|
||||
|
||||
write!(&mut append_file, "partial ").unwrap();
|
||||
writeln!(&mut append_file, "line").unwrap();
|
||||
|
||||
let recv = rx.recv().await;
|
||||
if let Some(TailEvent::Line(l)) = recv {
|
||||
assert_eq!("partial line\n".to_string(), l);
|
||||
} else {
|
||||
unreachable!("expect a line event, got {:?}", recv)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tailf_resets() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file_path = dir.path().join("tmp");
|
||||
|
||||
let mut read_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.read(true)
|
||||
.create(true)
|
||||
.open(&file_path)
|
||||
.unwrap();
|
||||
|
||||
writeln!(&mut read_file, "some existing content").unwrap();
|
||||
let mut rx = tailf(read_file, 0);
|
||||
assert!(rx.try_recv().is_err());
|
||||
|
||||
let mut append_file = File::create(&file_path).unwrap(); // truncates
|
||||
writeln!(&mut append_file, "some line").unwrap();
|
||||
|
||||
let recv = rx.recv().await;
|
||||
if let Some(TailEvent::Reset) = recv {
|
||||
// ok
|
||||
} else {
|
||||
unreachable!("expect a reset event, got {:?}", recv)
|
||||
}
|
||||
|
||||
let recv = rx.recv().await;
|
||||
if let Some(TailEvent::Line(l)) = recv {
|
||||
assert_eq!("some line\n".to_string(), l);
|
||||
} else {
|
||||
unreachable!("expect a line event, got {:?}", recv)
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tailf_with_data() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let file_path = dir.path().join("tmp");
|
||||
|
||||
let mut read_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.read(true)
|
||||
.create(true)
|
||||
.open(&file_path)
|
||||
.unwrap();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut written = vec![];
|
||||
let base_line = "Elit ipsum cillum ex cillum. Adipisicing consequat cupidatat do proident ut in sunt Lorem ipsum tempor. Eiusmod ipsum Lorem labore exercitation sunt pariatur excepteur fugiat cillum velit cillum enim. Nisi Lorem cupidatat ad enim velit officia eiusmod esse tempor aliquip. Deserunt pariatur tempor in duis culpa esse sit nulla irure ullamco ipsum voluptate non laboris. Occaecat officia nulla officia mollit do aliquip reprehenderit ad incididunt.";
|
||||
for i in 0..100 {
|
||||
let line = format!("{}: {}", i, &base_line[..rng.gen_range(0..base_line.len())]);
|
||||
writeln!(&mut read_file, "{}", line).unwrap();
|
||||
written.push(line);
|
||||
}
|
||||
write!(&mut read_file, "partial line").unwrap();
|
||||
read_file.seek(io::SeekFrom::Start(0)).unwrap();
|
||||
|
||||
let last_n = 32;
|
||||
let mut rx = tailf(read_file, last_n);
|
||||
for i in 0..last_n {
|
||||
let recv = rx.try_recv().unwrap();
|
||||
if let TailEvent::Line(l) = recv {
|
||||
let mut expected = written[written.len() - last_n + i].to_string();
|
||||
expected.push('\n');
|
||||
assert_eq!(expected, l);
|
||||
} else {
|
||||
unreachable!("expect a line event, got {:?}", recv)
|
||||
}
|
||||
}
|
||||
|
||||
assert!(rx.try_recv().is_err());
|
||||
|
||||
let mut append_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.append(true)
|
||||
.open(&file_path)
|
||||
.unwrap();
|
||||
writeln!(append_file, " is now complete").unwrap();
|
||||
|
||||
let recv = rx.recv().await;
|
||||
if let Some(TailEvent::Line(l)) = recv {
|
||||
assert_eq!("partial line is now complete\n".to_string(), l);
|
||||
} else {
|
||||
unreachable!("expect a line event, got {:?}", recv)
|
||||
}
|
||||
}
|
||||
}
|
||||
30
cli/src/util/is_integrated.rs
Normal file
30
cli/src/util/is_integrated.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{env, io};
|
||||
|
||||
/// Gets whether the current CLI seems like it's running in integrated mode,
|
||||
/// by looking at the location of the exe and known VS Code files.
|
||||
pub fn is_integrated_cli() -> io::Result<bool> {
|
||||
let exe = env::current_exe()?;
|
||||
|
||||
let parent = match exe.parent() {
|
||||
Some(parent) if parent.file_name().and_then(|n| n.to_str()) == Some("bin") => parent,
|
||||
_ => return Ok(false),
|
||||
};
|
||||
|
||||
let parent = match parent.parent() {
|
||||
Some(p) => p,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let expected_file = if cfg!(target_os = "macos") {
|
||||
"node_modules.asar"
|
||||
} else {
|
||||
"resources.pak"
|
||||
};
|
||||
|
||||
Ok(parent.join(expected_file).exists())
|
||||
}
|
||||
61
cli/src/util/machine.rs
Normal file
61
cli/src/util/machine.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{path::Path, time::Duration};
|
||||
use sysinfo::{Pid, PidExt, ProcessExt, System, SystemExt};
|
||||
|
||||
pub fn process_at_path_exists(pid: u32, name: &Path) -> bool {
|
||||
let mut sys = System::new();
|
||||
let pid = Pid::from_u32(pid);
|
||||
if !sys.refresh_process(pid) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let name_str = format!("{}", name.display());
|
||||
if let Some(process) = sys.process(pid) {
|
||||
for cmd in process.cmd() {
|
||||
if cmd.contains(&name_str) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
pub fn process_exists(pid: u32) -> bool {
|
||||
let mut sys = System::new();
|
||||
sys.refresh_process(Pid::from_u32(pid))
|
||||
}
|
||||
|
||||
pub async fn wait_until_process_exits(pid: Pid, poll_ms: u64) {
|
||||
let mut s = System::new();
|
||||
let duration = Duration::from_millis(poll_ms);
|
||||
while s.refresh_process(pid) {
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_running_process(name: &Path) -> Option<u32> {
|
||||
let mut sys = System::new();
|
||||
sys.refresh_processes();
|
||||
|
||||
let name_str = format!("{}", name.display());
|
||||
|
||||
for (pid, process) in sys.processes() {
|
||||
for cmd in process.cmd() {
|
||||
if cmd.contains(&name_str) {
|
||||
return Some(pid.as_u32());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn wait_until_exe_deleted(current_exe: &Path, poll_ms: u64) {
|
||||
let duration = Duration::from_millis(poll_ms);
|
||||
while current_exe.exists() {
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
}
|
||||
39
cli/src/util/os.rs
Normal file
39
cli/src/util/os.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(windows)]
|
||||
pub fn os_release() -> Result<String, std::io::Error> {
|
||||
// The windows API *had* nice GetVersionEx/A APIs, but these were deprecated
|
||||
// in Winodws 8 and there's no newer win API to get version numbers. So
|
||||
// instead read the registry.
|
||||
|
||||
use winreg::{enums::HKEY_LOCAL_MACHINE, RegKey};
|
||||
|
||||
let key = RegKey::predef(HKEY_LOCAL_MACHINE)
|
||||
.open_subkey(r"SOFTWARE\Microsoft\Windows NT\CurrentVersion")?;
|
||||
|
||||
let major: u32 = key.get_value("CurrentMajorVersionNumber")?;
|
||||
let minor: u32 = key.get_value("CurrentMinorVersionNumber")?;
|
||||
let build: String = key.get_value("CurrentBuild")?;
|
||||
|
||||
Ok(format!("{}.{}.{}", major, minor, build))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn os_release() -> Result<String, std::io::Error> {
|
||||
use std::{ffi::CStr, mem};
|
||||
|
||||
unsafe {
|
||||
let mut ret = mem::MaybeUninit::zeroed();
|
||||
|
||||
if libc::uname(ret.as_mut_ptr()) != 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
let ret = ret.assume_init();
|
||||
let c_str: &CStr = CStr::from_ptr(ret.release.as_ptr());
|
||||
Ok(c_str.to_string_lossy().into_owned())
|
||||
}
|
||||
}
|
||||
349
cli/src/util/prereqs.rs
Normal file
349
cli/src/util/prereqs.rs
Normal file
@@ -0,0 +1,349 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use super::command::capture_command;
|
||||
use crate::constants::QUALITYLESS_SERVER_NAME;
|
||||
use crate::update_service::Platform;
|
||||
use lazy_static::lazy_static;
|
||||
use regex::bytes::Regex as BinRegex;
|
||||
use regex::Regex;
|
||||
use tokio::fs;
|
||||
|
||||
use super::errors::CodeError;
|
||||
|
||||
lazy_static! {
|
||||
static ref LDCONFIG_STDC_RE: Regex = Regex::new(r"libstdc\+\+.* => (.+)").unwrap();
|
||||
static ref LDD_VERSION_RE: BinRegex = BinRegex::new(r"^ldd.*(.+)\.(.+)\s").unwrap();
|
||||
static ref GENERIC_VERSION_RE: Regex = Regex::new(r"^([0-9]+)\.([0-9]+)$").unwrap();
|
||||
static ref LIBSTD_CXX_VERSION_RE: BinRegex =
|
||||
BinRegex::new(r"GLIBCXX_([0-9]+)\.([0-9]+)(?:\.([0-9]+))?").unwrap();
|
||||
static ref MIN_CXX_VERSION: SimpleSemver = SimpleSemver::new(3, 4, 18);
|
||||
static ref MIN_LDD_VERSION: SimpleSemver = SimpleSemver::new(2, 17, 0);
|
||||
}
|
||||
|
||||
const NIXOS_TEST_PATH: &str = "/etc/NIXOS";
|
||||
|
||||
pub struct PreReqChecker {}
|
||||
|
||||
impl Default for PreReqChecker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PreReqChecker {
|
||||
pub fn new() -> PreReqChecker {
|
||||
PreReqChecker {}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
pub async fn verify(&self) -> Result<Platform, CodeError> {
|
||||
Platform::env_default().ok_or_else(|| {
|
||||
CodeError::UnsupportedPlatform(format!(
|
||||
"{} {}",
|
||||
std::env::consts::OS,
|
||||
std::env::consts::ARCH
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub async fn verify(&self) -> Result<Platform, CodeError> {
|
||||
let (is_nixos, gnu_a, gnu_b, or_musl) = tokio::join!(
|
||||
check_is_nixos(),
|
||||
check_glibc_version(),
|
||||
check_glibcxx_version(),
|
||||
check_musl_interpreter()
|
||||
);
|
||||
|
||||
if (gnu_a.is_ok() && gnu_b.is_ok()) || is_nixos {
|
||||
return Ok(if cfg!(target_arch = "x86_64") {
|
||||
Platform::LinuxX64
|
||||
} else if cfg!(target_arch = "arm") {
|
||||
Platform::LinuxARM32
|
||||
} else {
|
||||
Platform::LinuxARM64
|
||||
});
|
||||
}
|
||||
|
||||
if or_musl.is_ok() {
|
||||
return Ok(if cfg!(target_arch = "x86_64") {
|
||||
Platform::LinuxAlpineX64
|
||||
} else {
|
||||
Platform::LinuxAlpineARM64
|
||||
});
|
||||
}
|
||||
|
||||
let mut errors: Vec<String> = vec![];
|
||||
if let Err(e) = gnu_a {
|
||||
errors.push(e);
|
||||
} else if let Err(e) = gnu_b {
|
||||
errors.push(e);
|
||||
}
|
||||
|
||||
if let Err(e) = or_musl {
|
||||
errors.push(e);
|
||||
}
|
||||
|
||||
let bullets = errors
|
||||
.iter()
|
||||
.map(|e| format!(" - {}", e))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
Err(CodeError::PrerequisitesFailed {
|
||||
bullets,
|
||||
name: QUALITYLESS_SERVER_NAME,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn check_musl_interpreter() -> Result<(), String> {
|
||||
const MUSL_PATH: &str = if cfg!(target_arch = "aarch64") {
|
||||
"/lib/ld-musl-aarch64.so.1"
|
||||
} else {
|
||||
"/lib/ld-musl-x86_64.so.1"
|
||||
};
|
||||
|
||||
if fs::metadata(MUSL_PATH).await.is_err() {
|
||||
return Err(format!(
|
||||
"find {}, which is required to run the {} in musl environments",
|
||||
MUSL_PATH, QUALITYLESS_SERVER_NAME
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn check_glibc_version() -> Result<(), String> {
|
||||
#[cfg(target_env = "gnu")]
|
||||
let version = {
|
||||
let v = unsafe { libc::gnu_get_libc_version() };
|
||||
let v = unsafe { std::ffi::CStr::from_ptr(v) };
|
||||
let v = v.to_str().unwrap();
|
||||
extract_generic_version(v)
|
||||
};
|
||||
#[cfg(not(target_env = "gnu"))]
|
||||
let version = {
|
||||
capture_command("ldd", ["--version"])
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|o| extract_ldd_version(&o.stdout))
|
||||
};
|
||||
|
||||
if let Some(v) = version {
|
||||
return if v >= *MIN_LDD_VERSION {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"find GLIBC >= 2.17 (but found {} instead) for GNU environments",
|
||||
v
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check for nixos to avoid mandating glibc versions. See:
|
||||
/// https://github.com/microsoft/vscode-remote-release/issues/7129
|
||||
#[allow(dead_code)]
|
||||
async fn check_is_nixos() -> bool {
|
||||
fs::metadata(NIXOS_TEST_PATH).await.is_ok()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn check_glibcxx_version() -> Result<(), String> {
|
||||
let mut libstdc_path: Option<String> = None;
|
||||
|
||||
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
|
||||
const DEFAULT_LIB_PATH: &str = "/usr/lib64/libstdc++.so.6";
|
||||
#[cfg(any(target_arch = "x86", target_arch = "arm"))]
|
||||
const DEFAULT_LIB_PATH: &str = "/usr/lib/libstdc++.so.6";
|
||||
const LDCONFIG_PATH: &str = "/sbin/ldconfig";
|
||||
|
||||
if fs::metadata(DEFAULT_LIB_PATH).await.is_ok() {
|
||||
libstdc_path = Some(DEFAULT_LIB_PATH.to_owned());
|
||||
} else if fs::metadata(LDCONFIG_PATH).await.is_ok() {
|
||||
libstdc_path = capture_command(LDCONFIG_PATH, ["-p"])
|
||||
.await
|
||||
.ok()
|
||||
.and_then(|o| extract_libstd_from_ldconfig(&o.stdout));
|
||||
}
|
||||
|
||||
match libstdc_path {
|
||||
Some(path) => match fs::read(&path).await {
|
||||
Ok(contents) => check_for_sufficient_glibcxx_versions(contents),
|
||||
Err(e) => Err(format!(
|
||||
"validate GLIBCXX version for GNU environments, but could not: {}",
|
||||
e
|
||||
)),
|
||||
},
|
||||
None => Err("find libstdc++.so or ldconfig for GNU environments".to_owned()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn check_for_sufficient_glibcxx_versions(contents: Vec<u8>) -> Result<(), String> {
|
||||
let all_versions: Vec<SimpleSemver> = LIBSTD_CXX_VERSION_RE
|
||||
.captures_iter(&contents)
|
||||
.map(|m| SimpleSemver {
|
||||
major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())),
|
||||
minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())),
|
||||
patch: m.get(3).map_or(0, |s| u32_from_bytes(s.as_bytes())),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !all_versions.iter().any(|v| &*MIN_CXX_VERSION >= v) {
|
||||
return Err(format!(
|
||||
"find GLIBCXX >= 3.4.18 (but found {} instead) for GNU environments",
|
||||
all_versions
|
||||
.iter()
|
||||
.map(String::from)
|
||||
.collect::<Vec<String>>()
|
||||
.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn extract_ldd_version(output: &[u8]) -> Option<SimpleSemver> {
|
||||
LDD_VERSION_RE.captures(output).map(|m| SimpleSemver {
|
||||
major: m.get(1).map_or(0, |s| u32_from_bytes(s.as_bytes())),
|
||||
minor: m.get(2).map_or(0, |s| u32_from_bytes(s.as_bytes())),
|
||||
patch: 0,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn extract_generic_version(output: &str) -> Option<SimpleSemver> {
|
||||
GENERIC_VERSION_RE.captures(output).map(|m| SimpleSemver {
|
||||
major: m.get(1).map_or(0, |s| s.as_str().parse().unwrap()),
|
||||
minor: m.get(2).map_or(0, |s| s.as_str().parse().unwrap()),
|
||||
patch: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_libstd_from_ldconfig(output: &[u8]) -> Option<String> {
|
||||
String::from_utf8_lossy(output)
|
||||
.lines()
|
||||
.find_map(|l| LDCONFIG_STDC_RE.captures(l))
|
||||
.and_then(|cap| cap.get(1))
|
||||
.map(|cap| cap.as_str().to_owned())
|
||||
}
|
||||
|
||||
fn u32_from_bytes(b: &[u8]) -> u32 {
|
||||
String::from_utf8_lossy(b).parse::<u32>().unwrap_or(0)
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, PartialEq, Eq)]
|
||||
struct SimpleSemver {
|
||||
major: u32,
|
||||
minor: u32,
|
||||
patch: u32,
|
||||
}
|
||||
|
||||
impl PartialOrd for SimpleSemver {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for SimpleSemver {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let major = self.major.cmp(&other.major);
|
||||
if major != Ordering::Equal {
|
||||
return major;
|
||||
}
|
||||
|
||||
let minor = self.minor.cmp(&other.minor);
|
||||
if minor != Ordering::Equal {
|
||||
return minor;
|
||||
}
|
||||
|
||||
self.patch.cmp(&other.patch)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&SimpleSemver> for String {
|
||||
fn from(s: &SimpleSemver) -> Self {
|
||||
format!("v{}.{}.{}", s.major, s.minor, s.patch)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for SimpleSemver {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{}", String::from(self))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl SimpleSemver {
|
||||
fn new(major: u32, minor: u32, patch: u32) -> SimpleSemver {
|
||||
SimpleSemver {
|
||||
major,
|
||||
minor,
|
||||
patch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_libstd_from_ldconfig() {
|
||||
let actual = "
|
||||
libstoken.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstoken.so.1
|
||||
libstemmer.so.0d (libc6,x86-64) => /lib/x86_64-linux-gnu/libstemmer.so.0d
|
||||
libstdc++.so.6 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstdc++.so.6
|
||||
libstartup-notification-1.so.0 (libc6,x86-64) => /lib/x86_64-linux-gnu/libstartup-notification-1.so.0
|
||||
libssl3.so (libc6,x86-64) => /lib/x86_64-linux-gnu/libssl3.so
|
||||
".to_owned().into_bytes();
|
||||
|
||||
assert_eq!(
|
||||
extract_libstd_from_ldconfig(&actual),
|
||||
Some("/lib/x86_64-linux-gnu/libstdc++.so.6".to_owned()),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
extract_libstd_from_ldconfig(&"nothing here!".to_owned().into_bytes()),
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gte() {
|
||||
assert!(SimpleSemver::new(1, 2, 3) >= SimpleSemver::new(1, 2, 3));
|
||||
assert!(SimpleSemver::new(1, 2, 3) >= SimpleSemver::new(0, 10, 10));
|
||||
assert!(SimpleSemver::new(1, 2, 3) >= SimpleSemver::new(1, 1, 10));
|
||||
|
||||
assert!(SimpleSemver::new(1, 2, 3) < SimpleSemver::new(1, 2, 10));
|
||||
assert!(SimpleSemver::new(1, 2, 3) < SimpleSemver::new(1, 3, 1));
|
||||
assert!(SimpleSemver::new(1, 2, 3) < SimpleSemver::new(2, 2, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_for_sufficient_glibcxx_versions() {
|
||||
let actual = "ldd (Ubuntu GLIBC 2.31-0ubuntu9.7) 2.31
|
||||
Copyright (C) 2020 Free Software Foundation, Inc.
|
||||
This is free software; see the source for copying conditions. There is NO
|
||||
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
Written by Roland McGrath and Ulrich Drepper."
|
||||
.to_owned()
|
||||
.into_bytes();
|
||||
|
||||
assert_eq!(
|
||||
extract_ldd_version(&actual),
|
||||
Some(SimpleSemver::new(2, 31, 0)),
|
||||
);
|
||||
}
|
||||
}
|
||||
142
cli/src/util/ring_buffer.rs
Normal file
142
cli/src/util/ring_buffer.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
pub struct RingBuffer<T> {
|
||||
data: Vec<T>,
|
||||
i: usize,
|
||||
}
|
||||
|
||||
impl<T> RingBuffer<T> {
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
data: Vec::with_capacity(capacity),
|
||||
i: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.data.capacity()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.data.len() == self.data.capacity()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
pub fn push(&mut self, value: T) {
|
||||
if self.data.len() == self.data.capacity() {
|
||||
self.data[self.i] = value;
|
||||
} else {
|
||||
self.data.push(value);
|
||||
}
|
||||
|
||||
self.i = (self.i + 1) % self.data.capacity();
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> RingBufferIter<'_, T> {
|
||||
RingBufferIter {
|
||||
index: 0,
|
||||
buffer: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Default> IntoIterator for RingBuffer<T> {
|
||||
type Item = T;
|
||||
type IntoIter = OwnedRingBufferIter<T>;
|
||||
|
||||
fn into_iter(self) -> OwnedRingBufferIter<T>
|
||||
where
|
||||
T: Default,
|
||||
{
|
||||
OwnedRingBufferIter {
|
||||
index: 0,
|
||||
buffer: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OwnedRingBufferIter<T: Default> {
|
||||
buffer: RingBuffer<T>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<T: Default> Iterator for OwnedRingBufferIter<T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index == self.buffer.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ii = (self.index + self.buffer.i) % self.buffer.len();
|
||||
let item = std::mem::take(&mut self.buffer.data[ii]);
|
||||
self.index += 1;
|
||||
Some(item)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingBufferIter<'a, T> {
|
||||
buffer: &'a RingBuffer<T>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for RingBufferIter<'a, T> {
|
||||
type Item = &'a T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index == self.buffer.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ii = (self.index + self.buffer.i) % self.buffer.len();
|
||||
let item = &self.buffer.data[ii];
|
||||
self.index += 1;
|
||||
Some(item)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_inserts() {
|
||||
let mut rb = RingBuffer::new(3);
|
||||
assert_eq!(rb.capacity(), 3);
|
||||
assert!(!rb.is_full());
|
||||
assert_eq!(rb.len(), 0);
|
||||
assert_eq!(rb.iter().copied().count(), 0);
|
||||
|
||||
rb.push(1);
|
||||
assert!(!rb.is_full());
|
||||
assert_eq!(rb.len(), 1);
|
||||
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![1]);
|
||||
|
||||
rb.push(2);
|
||||
assert!(!rb.is_full());
|
||||
assert_eq!(rb.len(), 2);
|
||||
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![1, 2]);
|
||||
|
||||
rb.push(3);
|
||||
assert!(rb.is_full());
|
||||
assert_eq!(rb.len(), 3);
|
||||
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![1, 2, 3]);
|
||||
|
||||
rb.push(4);
|
||||
assert!(rb.is_full());
|
||||
assert_eq!(rb.len(), 3);
|
||||
assert_eq!(rb.iter().copied().collect::<Vec<i32>>(), vec![2, 3, 4]);
|
||||
|
||||
assert_eq!(rb.into_iter().collect::<Vec<i32>>(), vec![2, 3, 4]);
|
||||
}
|
||||
}
|
||||
221
cli/src/util/sync.rs
Normal file
221
cli/src/util/sync.rs
Normal file
@@ -0,0 +1,221 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use async_trait::async_trait;
|
||||
use std::{marker::PhantomData, sync::Arc};
|
||||
use tokio::sync::{
|
||||
broadcast, mpsc,
|
||||
watch::{self, error::RecvError},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Barrier<T>(watch::Receiver<Option<T>>)
|
||||
where
|
||||
T: Clone;
|
||||
|
||||
impl<T> Barrier<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// Waits for the barrier to be closed, returning a value if one was sent.
|
||||
pub async fn wait(&mut self) -> Result<T, RecvError> {
|
||||
loop {
|
||||
self.0.changed().await?;
|
||||
|
||||
if let Some(v) = self.0.borrow().clone() {
|
||||
return Ok(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets whether the barrier is currently open
|
||||
pub fn is_open(&self) -> bool {
|
||||
self.0.borrow().is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Clone + Send + Sync> Receivable<T> for Barrier<T> {
|
||||
async fn recv_msg(&mut self) -> Option<T> {
|
||||
self.wait().await.ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BarrierOpener<T: Clone>(Arc<watch::Sender<Option<T>>>);
|
||||
|
||||
impl<T: Clone> BarrierOpener<T> {
|
||||
/// Opens the barrier.
|
||||
pub fn open(&self, value: T) {
|
||||
self.0.send_if_modified(|v| {
|
||||
if v.is_none() {
|
||||
*v = Some(value);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// The Barrier is something that can be opened once from one side,
|
||||
/// and is thereafter permanently closed. It can contain a value.
|
||||
pub fn new_barrier<T>() -> (Barrier<T>, BarrierOpener<T>)
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
let (closed_tx, closed_rx) = watch::channel(None);
|
||||
(Barrier(closed_rx), BarrierOpener(Arc::new(closed_tx)))
|
||||
}
|
||||
|
||||
/// Type that can receive messages in an async way.
|
||||
#[async_trait]
|
||||
pub trait Receivable<T> {
|
||||
async fn recv_msg(&mut self) -> Option<T>;
|
||||
}
|
||||
|
||||
// todo: ideally we would use an Arc in the broadcast::Receiver to avoid having
|
||||
// to clone bytes everywhere, requires updating rpc consumers as well.
|
||||
#[async_trait]
|
||||
impl<T: Clone + Send> Receivable<T> for broadcast::Receiver<T> {
|
||||
async fn recv_msg(&mut self) -> Option<T> {
|
||||
loop {
|
||||
match self.recv().await {
|
||||
Ok(v) => return Some(v),
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(broadcast::error::RecvError::Closed) => return None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Send> Receivable<T> for mpsc::UnboundedReceiver<T> {
|
||||
async fn recv_msg(&mut self) -> Option<T> {
|
||||
self.recv().await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Send> Receivable<T> for () {
|
||||
async fn recv_msg(&mut self) -> Option<T> {
|
||||
futures::future::pending().await
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConcatReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
|
||||
left: Option<A>,
|
||||
right: B,
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Send, A: Receivable<T>, B: Receivable<T>> ConcatReceivable<T, A, B> {
|
||||
pub fn new(left: A, right: B) -> Self {
|
||||
Self {
|
||||
left: Some(left),
|
||||
right,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
|
||||
for ConcatReceivable<T, A, B>
|
||||
{
|
||||
async fn recv_msg(&mut self) -> Option<T> {
|
||||
if let Some(left) = &mut self.left {
|
||||
match left.recv_msg().await {
|
||||
Some(v) => return Some(v),
|
||||
None => {
|
||||
self.left = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return self.right.recv_msg().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MergedReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
|
||||
left: Option<A>,
|
||||
right: Option<B>,
|
||||
_marker: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Send, A: Receivable<T>, B: Receivable<T>> MergedReceivable<T, A, B> {
|
||||
pub fn new(left: A, right: B) -> Self {
|
||||
Self {
|
||||
left: Some(left),
|
||||
right: Some(right),
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
|
||||
for MergedReceivable<T, A, B>
|
||||
{
|
||||
async fn recv_msg(&mut self) -> Option<T> {
|
||||
loop {
|
||||
match (&mut self.left, &mut self.right) {
|
||||
(Some(left), Some(right)) => {
|
||||
tokio::select! {
|
||||
left = left.recv_msg() => match left {
|
||||
Some(v) => return Some(v),
|
||||
None => { self.left = None; continue; },
|
||||
},
|
||||
right = right.recv_msg() => match right {
|
||||
Some(v) => return Some(v),
|
||||
None => { self.right = None; continue; },
|
||||
},
|
||||
}
|
||||
}
|
||||
(Some(a), None) => break a.recv_msg().await,
|
||||
(None, Some(b)) => break b.recv_msg().await,
|
||||
(None, None) => break None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_barrier_close_after_spawn() {
|
||||
let (mut barrier, opener) = new_barrier::<u32>();
|
||||
let (tx, rx) = tokio::sync::oneshot::channel::<u32>();
|
||||
|
||||
tokio::spawn(async move {
|
||||
tx.send(barrier.wait().await.unwrap()).unwrap();
|
||||
});
|
||||
|
||||
opener.open(42);
|
||||
|
||||
assert!(rx.await.unwrap() == 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_barrier_close_before_spawn() {
|
||||
let (barrier, opener) = new_barrier::<u32>();
|
||||
let (tx1, rx1) = tokio::sync::oneshot::channel::<u32>();
|
||||
let (tx2, rx2) = tokio::sync::oneshot::channel::<u32>();
|
||||
|
||||
opener.open(42);
|
||||
let mut b1 = barrier.clone();
|
||||
tokio::spawn(async move {
|
||||
tx1.send(b1.wait().await.unwrap()).unwrap();
|
||||
});
|
||||
let mut b2 = barrier.clone();
|
||||
tokio::spawn(async move {
|
||||
tx2.send(b2.wait().await.unwrap()).unwrap();
|
||||
});
|
||||
|
||||
assert!(rx1.await.unwrap() == 42);
|
||||
assert!(rx2.await.unwrap() == 42);
|
||||
}
|
||||
}
|
||||
105
cli/src/util/tar.rs
Normal file
105
cli/src/util/tar.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::util::errors::{wrap, WrappedError};
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use std::fs;
|
||||
use std::io::{Seek, SeekFrom};
|
||||
use std::path::{Path, PathBuf};
|
||||
use tar::Archive;
|
||||
|
||||
use super::io::ReportCopyProgress;
|
||||
|
||||
fn should_skip_first_segment(file: &fs::File) -> Result<bool, WrappedError> {
|
||||
// unfortunately, we need to re-read the archive here since you cannot reuse
|
||||
// `.entries()`. But this will generally only look at one or two files, so this
|
||||
// should be acceptably speedy... If not, we could hardcode behavior for
|
||||
// different types of archives.
|
||||
|
||||
let tar = GzDecoder::new(file);
|
||||
let mut archive = Archive::new(tar);
|
||||
let mut entries = archive
|
||||
.entries()
|
||||
.map_err(|e| wrap(e, "error opening archive"))?;
|
||||
|
||||
let first_name = {
|
||||
let file = entries
|
||||
.next()
|
||||
.expect("expected not to have an empty archive")
|
||||
.map_err(|e| wrap(e, "error reading entry file"))?;
|
||||
|
||||
let path = file.path().expect("expected to have path");
|
||||
|
||||
path.iter()
|
||||
.next()
|
||||
.expect("expected to have non-empty name")
|
||||
.to_owned()
|
||||
};
|
||||
|
||||
let mut had_multiple = false;
|
||||
for file in entries.flatten() {
|
||||
had_multiple = true;
|
||||
if let Ok(name) = file.path() {
|
||||
if name.iter().next() != Some(&first_name) {
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(had_multiple) // prefix removal is invalid if there's only a single file
|
||||
}
|
||||
|
||||
pub fn decompress_tarball<T>(
|
||||
path: &Path,
|
||||
parent_path: &Path,
|
||||
mut reporter: T,
|
||||
) -> Result<(), WrappedError>
|
||||
where
|
||||
T: ReportCopyProgress,
|
||||
{
|
||||
let mut tar_gz = fs::File::open(path)
|
||||
.map_err(|e| wrap(e, format!("error opening file {}", path.display())))?;
|
||||
let skip_first = should_skip_first_segment(&tar_gz)?;
|
||||
|
||||
// reset since skip logic read the tar already:
|
||||
tar_gz
|
||||
.seek(SeekFrom::Start(0))
|
||||
.map_err(|e| wrap(e, "error resetting seek position"))?;
|
||||
|
||||
let tar = GzDecoder::new(tar_gz);
|
||||
let mut archive = Archive::new(tar);
|
||||
|
||||
let results = archive
|
||||
.entries()
|
||||
.map_err(|e| wrap(e, format!("error opening archive {}", path.display())))?
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|mut entry| {
|
||||
let entry_path = entry
|
||||
.path()
|
||||
.map_err(|e| wrap(e, "error reading entry path"))?;
|
||||
|
||||
let path = parent_path.join(if skip_first {
|
||||
entry_path.iter().skip(1).collect::<PathBuf>()
|
||||
} else {
|
||||
entry_path.into_owned()
|
||||
});
|
||||
|
||||
if let Some(p) = path.parent() {
|
||||
fs::create_dir_all(p)
|
||||
.map_err(|e| wrap(e, format!("could not create dir for {}", p.display())))?;
|
||||
}
|
||||
|
||||
entry
|
||||
.unpack(&path)
|
||||
.map_err(|e| wrap(e, format!("error unpacking {}", path.display())))?;
|
||||
Ok(path)
|
||||
})
|
||||
.collect::<Result<Vec<PathBuf>, WrappedError>>()?;
|
||||
|
||||
// Tarballs don't have a way to get the number of entries ahead of time
|
||||
reporter.report_progress(results.len() as u64, results.len() as u64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
150
cli/src/util/zipper.rs
Normal file
150
cli/src/util/zipper.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::errors::{wrap, WrappedError};
|
||||
use super::io::ReportCopyProgress;
|
||||
use std::fs::{self, File};
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use zip::read::ZipFile;
|
||||
use zip::{self, ZipArchive};
|
||||
|
||||
// Borrowed and modified from https://github.com/zip-rs/zip/blob/master/examples/extract.rs
|
||||
|
||||
/// Returns whether all files in the archive start with the same path segment.
|
||||
/// If so, it's an indication we should skip that segment when extracting.
|
||||
fn should_skip_first_segment(archive: &mut ZipArchive<File>) -> bool {
|
||||
let first_name = {
|
||||
let file = archive
|
||||
.by_index_raw(0)
|
||||
.expect("expected not to have an empty archive");
|
||||
|
||||
let path = file
|
||||
.enclosed_name()
|
||||
.expect("expected to have path")
|
||||
.iter()
|
||||
.next()
|
||||
.expect("expected to have non-empty name");
|
||||
|
||||
path.to_owned()
|
||||
};
|
||||
|
||||
for i in 1..archive.len() {
|
||||
if let Ok(file) = archive.by_index_raw(i) {
|
||||
if let Some(name) = file.enclosed_name() {
|
||||
if name.iter().next() != Some(&first_name) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
archive.len() > 1 // prefix removal is invalid if there's only a single file
|
||||
}
|
||||
|
||||
pub fn unzip_file<T>(path: &Path, parent_path: &Path, mut reporter: T) -> Result<(), WrappedError>
|
||||
where
|
||||
T: ReportCopyProgress,
|
||||
{
|
||||
let file = fs::File::open(path)
|
||||
.map_err(|e| wrap(e, format!("unable to open file {}", path.display())))?;
|
||||
|
||||
let mut archive = zip::ZipArchive::new(file)
|
||||
.map_err(|e| wrap(e, format!("failed to open zip archive {}", path.display())))?;
|
||||
|
||||
let skip_segments_no = usize::from(should_skip_first_segment(&mut archive));
|
||||
for i in 0..archive.len() {
|
||||
reporter.report_progress(i as u64, archive.len() as u64);
|
||||
let mut file = archive
|
||||
.by_index(i)
|
||||
.map_err(|e| wrap(e, format!("could not open zip entry {}", i)))?;
|
||||
|
||||
let outpath: PathBuf = match file.enclosed_name() {
|
||||
Some(path) => {
|
||||
let mut full_path = PathBuf::from(parent_path);
|
||||
full_path.push(PathBuf::from_iter(path.iter().skip(skip_segments_no)));
|
||||
full_path
|
||||
}
|
||||
None => continue,
|
||||
};
|
||||
|
||||
if file.is_dir() || file.name().ends_with('/') {
|
||||
fs::create_dir_all(&outpath)
|
||||
.map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?;
|
||||
apply_permissions(&file, &outpath)?;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(p) = outpath.parent() {
|
||||
fs::create_dir_all(p)
|
||||
.map_err(|e| wrap(e, format!("could not create dir for {}", outpath.display())))?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use libc::S_IFLNK;
|
||||
use std::io::Read;
|
||||
use std::os::unix::ffi::OsStringExt;
|
||||
|
||||
if matches!(file.unix_mode(), Some(mode) if mode & (S_IFLNK as u32) == (S_IFLNK as u32))
|
||||
{
|
||||
let mut link_to = Vec::new();
|
||||
file.read_to_end(&mut link_to).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("could not read symlink linkpath {}", outpath.display()),
|
||||
)
|
||||
})?;
|
||||
|
||||
let link_path = PathBuf::from(std::ffi::OsString::from_vec(link_to));
|
||||
std::os::unix::fs::symlink(link_path, &outpath).map_err(|e| {
|
||||
wrap(e, format!("could not create symlink {}", outpath.display()))
|
||||
})?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let mut outfile = fs::File::create(&outpath).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!(
|
||||
"unable to open file to write {} (from {:?})",
|
||||
outpath.display(),
|
||||
file.enclosed_name().map(|p| p.to_string_lossy()),
|
||||
),
|
||||
)
|
||||
})?;
|
||||
|
||||
io::copy(&mut file, &mut outfile)
|
||||
.map_err(|e| wrap(e, format!("error copying file {}", outpath.display())))?;
|
||||
|
||||
apply_permissions(&file, &outpath)?;
|
||||
}
|
||||
|
||||
reporter.report_progress(archive.len() as u64, archive.len() as u64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn apply_permissions(file: &ZipFile, outpath: &Path) -> Result<(), WrappedError> {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
if let Some(mode) = file.unix_mode() {
|
||||
fs::set_permissions(outpath, fs::Permissions::from_mode(mode)).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error setting permissions on {}", outpath.display()),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn apply_permissions(_file: &ZipFile, _outpath: &Path) -> Result<(), WrappedError> {
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user