diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/ChannelBase.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/ChannelBase.cs
index de7886ce..d625d10e 100644
--- a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/ChannelBase.cs
+++ b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/ChannelBase.cs
@@ -3,6 +3,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using System.IO;
using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol.Serializers;
@@ -33,7 +34,9 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
/// Starts the channel and initializes the MessageDispatcher.
///
/// The type of message protocol used by the channel.
- public void Start(MessageProtocolType messageProtocolType)
+ /// Optional stream to use for the input stream
+ /// Optional stream to use for the output stream
+ public void Start(MessageProtocolType messageProtocolType, Stream? inputStream = null, Stream? outputStream = null)
{
IMessageSerializer messageSerializer = null;
if (messageProtocolType == MessageProtocolType.LanguageServer)
@@ -45,7 +48,7 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
messageSerializer = new V8MessageSerializer();
}
- this.Initialize(messageSerializer);
+ this.Initialize(messageSerializer, inputStream, outputStream);
}
///
@@ -70,7 +73,9 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
/// assignment of the MessageReader and MessageWriter properties.
///
/// The IMessageSerializer to use for message serialization.
- protected abstract void Initialize(IMessageSerializer messageSerializer);
+ /// Optional stream to use for the input stream
+ /// Optional stream to use for the output stream
+ protected abstract void Initialize(IMessageSerializer messageSerializer, Stream? inputStream = null, Stream? outputStream = null);
///
/// A method to be implemented by subclasses to handle shutdown
diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioClientChannel.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioClientChannel.cs
index aa4339fc..e917d678 100644
--- a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioClientChannel.cs
+++ b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioClientChannel.cs
@@ -50,7 +50,7 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
}
}
- protected override void Initialize(IMessageSerializer messageSerializer)
+ protected override void Initialize(IMessageSerializer messageSerializer, Stream? inputStream = null, Stream? outputStream = null)
{
this.serviceProcess = new Process
{
@@ -73,8 +73,8 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
this.ProcessId = this.serviceProcess.Id;
// Open the standard input/output streams
- this.inputStream = this.serviceProcess.StandardOutput.BaseStream;
- this.outputStream = this.serviceProcess.StandardInput.BaseStream;
+ this.inputStream = inputStream ?? this.serviceProcess.StandardOutput.BaseStream;
+ this.outputStream = outputStream ?? this.serviceProcess.StandardInput.BaseStream;
// Set up the message reader and writer
this.MessageReader =
diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioServerChannel.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioServerChannel.cs
index de669e04..5ae744ad 100644
--- a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioServerChannel.cs
+++ b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Channel/StdioServerChannel.cs
@@ -20,7 +20,7 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
private Stream inputStream;
private Stream outputStream;
- protected override void Initialize(IMessageSerializer messageSerializer)
+ protected override void Initialize(IMessageSerializer messageSerializer, Stream? inputStream = null, Stream? outputStream = null)
{
#if !NanoServer
// Ensure that the console is using UTF-8 encoding
@@ -29,8 +29,8 @@ namespace Microsoft.SqlTools.Hosting.Protocol.Channel
#endif
// Open the standard input/output streams
- this.inputStream = System.Console.OpenStandardInput();
- this.outputStream = System.Console.OpenStandardOutput();
+ this.inputStream = inputStream ?? System.Console.OpenStandardInput();
+ this.outputStream = outputStream ?? System.Console.OpenStandardOutput();
// Set up the reader and writer
this.MessageReader =
diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/ProtocolEndpoint.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/ProtocolEndpoint.cs
index f043e429..4337a4a8 100644
--- a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/ProtocolEndpoint.cs
+++ b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/ProtocolEndpoint.cs
@@ -6,6 +6,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
+using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.Hosting.Protocol.Channel;
@@ -68,12 +69,12 @@ namespace Microsoft.SqlTools.Hosting.Protocol
///
/// Initializes
///
- public void Initialize()
+ public void Initialize(Stream inputStream = null, Stream outputStream = null)
{
if (!this.isInitialized)
{
// Start the provided protocol channel
- this.protocolChannel.Start(this.messageProtocolType);
+ this.protocolChannel.Start(this.messageProtocolType, inputStream, outputStream);
// Start the message dispatcher
this.MessageDispatcher = new MessageDispatcher(this.protocolChannel);
diff --git a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
index a9c06729..f017bf6f 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
@@ -40,6 +40,7 @@ using Microsoft.SqlTools.ServiceLayer.ModelManagement;
using Microsoft.SqlTools.ServiceLayer.TableDesigner;
using Microsoft.SqlTools.ServiceLayer.AzureBlob;
using Microsoft.SqlTools.ServiceLayer.ExecutionPlan;
+using System.IO;
namespace Microsoft.SqlTools.ServiceLayer
{
@@ -52,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer
private static object lockObject = new object();
private static bool isLoaded;
- internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext)
+ internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, Stream? inputStream = null, Stream? outputStream = null)
{
ServiceHost serviceHost = ServiceHost.Instance;
lock (lockObject)
@@ -60,7 +61,7 @@ namespace Microsoft.SqlTools.ServiceLayer
if (!isLoaded)
{
// Grab the instance of the service host
- serviceHost.Initialize();
+ serviceHost.Initialize(inputStream, outputStream);
InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext);
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs
index be22cbd1..33cbe8be 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs
@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
+using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
@@ -188,8 +189,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
var hostDetails = new HostDetails(hostName, hostProfileId, hostVersion);
SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails);
- // Grab the instance of the service host
- ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext);
+ // Initialize the ServiceHost, using a MemoryStream for the output stream so that we don't fill up the logs
+ // with a bunch of outgoing messages (which aren't used for anything during tests)
+ ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, new MemoryStream());
}
}