mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-27 17:24:26 -05:00
Merge pull request #17 from Microsoft/bug/exceptionLoop
Fix: Infinite Exception Loop, Plaintext Password Logging, Ever-expanding Message Buffer
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts
|
||||
{
|
||||
/// <summary>
|
||||
/// Parameters to be used for reporting hosting-level errors, such as protocol violations
|
||||
/// </summary>
|
||||
public class HostingErrorParams
|
||||
{
|
||||
/// <summary>
|
||||
/// The message of the error
|
||||
/// </summary>
|
||||
public string Message { get; set; }
|
||||
}
|
||||
|
||||
public class HostingErrorEvent
|
||||
{
|
||||
public static readonly
|
||||
EventType<HostingErrorParams> Type =
|
||||
EventType<HostingErrorParams>.Create("hosting/error");
|
||||
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
|
||||
using Microsoft.SqlTools.EditorServices.Utility;
|
||||
@@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
this.SynchronizationContext = SynchronizationContext.Current;
|
||||
|
||||
// Run the message loop
|
||||
bool isRunning = true;
|
||||
while (isRunning && !cancellationToken.IsCancellationRequested)
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
Message newMessage = null;
|
||||
Message newMessage;
|
||||
|
||||
try
|
||||
{
|
||||
@@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
}
|
||||
catch (MessageParseException e)
|
||||
{
|
||||
// TODO: Write an error response
|
||||
|
||||
Logger.Write(
|
||||
LogLevel.Error,
|
||||
"Could not parse a message that was received:\r\n\r\n" +
|
||||
e.ToString());
|
||||
string message = string.Format("Exception occurred while parsing message: {0}", e.Message);
|
||||
Logger.Write(LogLevel.Error, message);
|
||||
await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
|
||||
{
|
||||
Message = message
|
||||
});
|
||||
|
||||
// Continue the loop
|
||||
continue;
|
||||
@@ -227,18 +227,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
var b = e.Message;
|
||||
newMessage = null;
|
||||
// Log the error and send an error event to the client
|
||||
string message = string.Format("Exception occurred while receiving message: {0}", e.Message);
|
||||
Logger.Write(LogLevel.Error, message);
|
||||
await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
|
||||
{
|
||||
Message = message
|
||||
});
|
||||
|
||||
// Continue the loop
|
||||
continue;
|
||||
}
|
||||
|
||||
// The message could be null if there was an error parsing the
|
||||
// previous message. In this case, do not try to dispatch it.
|
||||
if (newMessage != null)
|
||||
{
|
||||
// Verbose logging
|
||||
string logMessage = string.Format("Received message of type[{0}] and method[{1}]",
|
||||
newMessage.MessageType, newMessage.Method);
|
||||
Logger.Write(LogLevel.Verbose, logMessage);
|
||||
|
||||
// Process the message
|
||||
await this.DispatchMessage(
|
||||
newMessage,
|
||||
this.MessageWriter);
|
||||
await this.DispatchMessage(newMessage, this.MessageWriter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
|
||||
private const int CR = 0x0D;
|
||||
private const int LF = 0x0A;
|
||||
private static string[] NewLineDelimiters = new string[] { Environment.NewLine };
|
||||
private static readonly string[] NewLineDelimiters = { Environment.NewLine };
|
||||
|
||||
private Stream inputStream;
|
||||
private IMessageSerializer messageSerializer;
|
||||
private Encoding messageEncoding;
|
||||
private readonly Stream inputStream;
|
||||
private readonly IMessageSerializer messageSerializer;
|
||||
private readonly Encoding messageEncoding;
|
||||
|
||||
private ReadState readState;
|
||||
private bool needsMoreData = true;
|
||||
private int readOffset;
|
||||
private int bufferEndOffset;
|
||||
private byte[] messageBuffer = new byte[DefaultBufferSize];
|
||||
private byte[] messageBuffer;
|
||||
|
||||
private int expectedContentLength;
|
||||
private Dictionary<string, string> messageHeaders;
|
||||
|
||||
enum ReadState
|
||||
private enum ReadState
|
||||
{
|
||||
Headers,
|
||||
Content
|
||||
@@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
this.needsMoreData = false;
|
||||
|
||||
// Do we need to look for message headers?
|
||||
if (this.readState == ReadState.Headers &&
|
||||
if (this.readState == ReadState.Headers &&
|
||||
!this.TryReadMessageHeaders())
|
||||
{
|
||||
// If we don't have enough data to read headers yet, keep reading
|
||||
@@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
}
|
||||
|
||||
// Do we need to look for message content?
|
||||
if (this.readState == ReadState.Content &&
|
||||
if (this.readState == ReadState.Content &&
|
||||
!this.TryReadMessageContent(out messageContent))
|
||||
{
|
||||
// If we don't have enough data yet to construct the content, keep reading
|
||||
@@ -106,16 +106,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
break;
|
||||
}
|
||||
|
||||
// Now that we have a message, reset the buffer's state
|
||||
ShiftBufferBytesAndShrink(readOffset);
|
||||
|
||||
// Get the JObject for the JSON content
|
||||
JObject messageObject = JObject.Parse(messageContent);
|
||||
|
||||
// Load the message
|
||||
Logger.Write(
|
||||
LogLevel.Verbose,
|
||||
string.Format(
|
||||
"READ MESSAGE:\r\n\r\n{0}",
|
||||
messageObject.ToString(Formatting.Indented)));
|
||||
|
||||
// Return the parsed message
|
||||
return this.messageSerializer.DeserializeMessage(messageObject);
|
||||
}
|
||||
@@ -162,8 +158,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
{
|
||||
int scanOffset = this.readOffset;
|
||||
|
||||
// Scan for the final double-newline that marks the
|
||||
// end of the header lines
|
||||
// Scan for the final double-newline that marks the end of the header lines
|
||||
while (scanOffset + 3 < this.bufferEndOffset &&
|
||||
(this.messageBuffer[scanOffset] != CR ||
|
||||
this.messageBuffer[scanOffset + 1] != LF ||
|
||||
@@ -173,45 +168,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
scanOffset++;
|
||||
}
|
||||
|
||||
// No header or body separator found (e.g CRLFCRLF)
|
||||
// Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF)
|
||||
if (scanOffset + 3 >= this.bufferEndOffset)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
this.messageHeaders = new Dictionary<string, string>();
|
||||
// Convert the header block into a array of lines
|
||||
var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset)
|
||||
.Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
|
||||
|
||||
var headers =
|
||||
Encoding.ASCII
|
||||
.GetString(this.messageBuffer, this.readOffset, scanOffset)
|
||||
.Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
|
||||
|
||||
// Read each header and store it in the dictionary
|
||||
foreach (var header in headers)
|
||||
try
|
||||
{
|
||||
int currentLength = header.IndexOf(':');
|
||||
if (currentLength == -1)
|
||||
// Read each header and store it in the dictionary
|
||||
this.messageHeaders = new Dictionary<string, string>();
|
||||
foreach (var header in headers)
|
||||
{
|
||||
throw new ArgumentException("Message header must separate key and value using :");
|
||||
int currentLength = header.IndexOf(':');
|
||||
if (currentLength == -1)
|
||||
{
|
||||
throw new ArgumentException("Message header must separate key and value using :");
|
||||
}
|
||||
|
||||
var key = header.Substring(0, currentLength);
|
||||
var value = header.Substring(currentLength + 1).Trim();
|
||||
this.messageHeaders[key] = value;
|
||||
}
|
||||
|
||||
var key = header.Substring(0, currentLength);
|
||||
var value = header.Substring(currentLength + 1).Trim();
|
||||
this.messageHeaders[key] = value;
|
||||
}
|
||||
// Parse out the content length as an int
|
||||
string contentLengthString;
|
||||
if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
|
||||
{
|
||||
throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
|
||||
}
|
||||
|
||||
// Make sure a Content-Length header was present, otherwise it
|
||||
// is a fatal error
|
||||
string contentLengthString = null;
|
||||
if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
|
||||
{
|
||||
throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
|
||||
// Parse the content length to an integer
|
||||
if (!int.TryParse(contentLengthString, out this.expectedContentLength))
|
||||
{
|
||||
throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the content length to an integer
|
||||
if (!int.TryParse(contentLengthString, out this.expectedContentLength))
|
||||
catch (Exception)
|
||||
{
|
||||
throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
|
||||
// The content length was invalid or missing. Trash the buffer we've read
|
||||
ShiftBufferBytesAndShrink(scanOffset + 4);
|
||||
throw;
|
||||
}
|
||||
|
||||
// Skip past the headers plus the newline characters
|
||||
@@ -234,31 +235,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
|
||||
}
|
||||
|
||||
// Convert the message contents to a string using the specified encoding
|
||||
messageContent =
|
||||
this.messageEncoding.GetString(
|
||||
this.messageBuffer,
|
||||
this.readOffset,
|
||||
this.expectedContentLength);
|
||||
messageContent = this.messageEncoding.GetString(
|
||||
this.messageBuffer,
|
||||
this.readOffset,
|
||||
this.expectedContentLength);
|
||||
|
||||
// Move the remaining bytes to the front of the buffer for the next message
|
||||
var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset);
|
||||
Buffer.BlockCopy(
|
||||
this.messageBuffer,
|
||||
this.expectedContentLength + this.readOffset,
|
||||
this.messageBuffer,
|
||||
0,
|
||||
remainingByteCount);
|
||||
readOffset += expectedContentLength;
|
||||
|
||||
// Reset the offsets for the next read
|
||||
this.readOffset = 0;
|
||||
this.bufferEndOffset = remainingByteCount;
|
||||
|
||||
// Done reading content, now look for headers
|
||||
// Done reading content, now look for headers for the next message
|
||||
this.readState = ReadState.Headers;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private void ShiftBufferBytesAndShrink(int bytesToRemove)
|
||||
{
|
||||
// Create a new buffer that is shrunken by the number of bytes to remove
|
||||
// Note: by using Max, we can guarantee a buffer of at least default buffer size
|
||||
byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)];
|
||||
|
||||
// If we need to do shifting, do the shifting
|
||||
if (bytesToRemove <= messageBuffer.Length)
|
||||
{
|
||||
// Copy the existing buffer starting at the offset to remove
|
||||
Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove);
|
||||
}
|
||||
|
||||
// Make the new buffer the message buffer
|
||||
messageBuffer = newBuffer;
|
||||
|
||||
// Reset the read offset and the end offset
|
||||
readOffset = 0;
|
||||
bufferEndOffset -= bytesToRemove;
|
||||
}
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
|
||||
@@ -337,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
|
||||
{
|
||||
Logger.Write(
|
||||
LogLevel.Error,
|
||||
String.Format(
|
||||
string.Format(
|
||||
"Exception while cancelling analysis task:\n\n{0}",
|
||||
e.ToString()));
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace
|
||||
foreach (var textChange in textChangeParams.ContentChanges)
|
||||
{
|
||||
string fileUri = textChangeParams.Uri ?? textChangeParams.TextDocument.Uri;
|
||||
msg.AppendLine(String.Format(" File: {0}", fileUri));
|
||||
msg.AppendLine(string.Format(" File: {0}", fileUri));
|
||||
|
||||
ScriptFile changedFile = Workspace.GetFile(fileUri);
|
||||
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
|
||||
using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Message
|
||||
{
|
||||
public class MessageReaderWriterTests
|
||||
{
|
||||
const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}";
|
||||
const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}";
|
||||
readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString);
|
||||
|
||||
private IMessageSerializer messageSerializer;
|
||||
|
||||
public MessageReaderWriterTests()
|
||||
{
|
||||
this.messageSerializer = new V8MessageSerializer();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritesMessage()
|
||||
{
|
||||
MemoryStream outputStream = new MemoryStream();
|
||||
|
||||
MessageWriter messageWriter =
|
||||
new MessageWriter(
|
||||
outputStream,
|
||||
this.messageSerializer);
|
||||
|
||||
// Write the message and then roll back the stream to be read
|
||||
// TODO: This will need to be redone!
|
||||
await messageWriter.WriteMessage(HostingMessage.Event("testEvent", null));
|
||||
outputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
string expectedHeaderString =
|
||||
string.Format(
|
||||
Constants.ContentLengthFormatString,
|
||||
ExpectedMessageByteCount);
|
||||
|
||||
byte[] buffer = new byte[128];
|
||||
await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length);
|
||||
|
||||
Assert.Equal(
|
||||
expectedHeaderString,
|
||||
Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length));
|
||||
|
||||
// Read the message
|
||||
await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount);
|
||||
|
||||
Assert.Equal(
|
||||
TestEventString,
|
||||
Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount));
|
||||
|
||||
outputStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadsMessage()
|
||||
{
|
||||
MemoryStream inputStream = new MemoryStream();
|
||||
MessageReader messageReader =
|
||||
new MessageReader(
|
||||
inputStream,
|
||||
this.messageSerializer);
|
||||
|
||||
// Write a message to the stream
|
||||
byte[] messageBuffer = this.GetMessageBytes(TestEventString);
|
||||
inputStream.Write(
|
||||
this.GetMessageBytes(TestEventString),
|
||||
0,
|
||||
messageBuffer.Length);
|
||||
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
HostingMessage messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
|
||||
inputStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadsManyBufferedMessages()
|
||||
{
|
||||
MemoryStream inputStream = new MemoryStream();
|
||||
MessageReader messageReader =
|
||||
new MessageReader(
|
||||
inputStream,
|
||||
this.messageSerializer);
|
||||
|
||||
// Get a message to use for writing to the stream
|
||||
byte[] messageBuffer = this.GetMessageBytes(TestEventString);
|
||||
|
||||
// How many messages of this size should we write to overflow the buffer?
|
||||
int overflowMessageCount =
|
||||
(int)Math.Ceiling(
|
||||
(MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length);
|
||||
|
||||
// Write the necessary number of messages to the stream
|
||||
for (int i = 0; i < overflowMessageCount; i++)
|
||||
{
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
}
|
||||
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Read the written messages from the stream
|
||||
for (int i = 0; i < overflowMessageCount; i++)
|
||||
{
|
||||
HostingMessage messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
}
|
||||
|
||||
inputStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReaderResizesBufferForLargeMessages()
|
||||
{
|
||||
MemoryStream inputStream = new MemoryStream();
|
||||
MessageReader messageReader =
|
||||
new MessageReader(
|
||||
inputStream,
|
||||
this.messageSerializer);
|
||||
|
||||
// Get a message with content so large that the buffer will need
|
||||
// to be resized to fit it all.
|
||||
byte[] messageBuffer =
|
||||
this.GetMessageBytes(
|
||||
string.Format(
|
||||
TestEventFormatString,
|
||||
new String('X', (int)(MessageReader.DefaultBufferSize * 3))));
|
||||
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
HostingMessage messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
|
||||
inputStream.Dispose();
|
||||
}
|
||||
|
||||
private byte[] GetMessageBytes(string messageString, Encoding encoding = null)
|
||||
{
|
||||
if (encoding == null)
|
||||
{
|
||||
encoding = Encoding.UTF8;
|
||||
}
|
||||
|
||||
byte[] messageBytes = Encoding.UTF8.GetBytes(messageString);
|
||||
byte[] headerBytes =
|
||||
Encoding.ASCII.GetBytes(
|
||||
string.Format(
|
||||
Constants.ContentLengthFormatString,
|
||||
messageBytes.Length));
|
||||
|
||||
// Copy the bytes into a single buffer
|
||||
byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length];
|
||||
Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length);
|
||||
Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length);
|
||||
|
||||
return finalBytes;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
|
||||
{
|
||||
public class Common
|
||||
{
|
||||
public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}";
|
||||
public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}";
|
||||
public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Text;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers;
|
||||
using Newtonsoft.Json;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
|
||||
{
|
||||
public class MessageReaderTests
|
||||
{
|
||||
|
||||
private readonly IMessageSerializer messageSerializer;
|
||||
|
||||
public MessageReaderTests()
|
||||
{
|
||||
this.messageSerializer = new V8MessageSerializer();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadsMessage()
|
||||
{
|
||||
MemoryStream inputStream = new MemoryStream();
|
||||
MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer);
|
||||
|
||||
// Write a message to the stream
|
||||
byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString);
|
||||
inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length);
|
||||
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
Message messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
|
||||
inputStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadsManyBufferedMessages()
|
||||
{
|
||||
MemoryStream inputStream = new MemoryStream();
|
||||
MessageReader messageReader =
|
||||
new MessageReader(
|
||||
inputStream,
|
||||
this.messageSerializer);
|
||||
|
||||
// Get a message to use for writing to the stream
|
||||
byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString);
|
||||
|
||||
// How many messages of this size should we write to overflow the buffer?
|
||||
int overflowMessageCount =
|
||||
(int)Math.Ceiling(
|
||||
(MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length);
|
||||
|
||||
// Write the necessary number of messages to the stream
|
||||
for (int i = 0; i < overflowMessageCount; i++)
|
||||
{
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
}
|
||||
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Read the written messages from the stream
|
||||
for (int i = 0; i < overflowMessageCount; i++)
|
||||
{
|
||||
Message messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
}
|
||||
|
||||
inputStream.Dispose();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMalformedMissingHeaderTest()
|
||||
{
|
||||
using (MemoryStream inputStream = new MemoryStream())
|
||||
{
|
||||
// If:
|
||||
// ... I create a new stream and pass it information that is malformed
|
||||
// ... and attempt to read a message from it
|
||||
MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
|
||||
byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n");
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Then:
|
||||
// ... An exception should be thrown while reading
|
||||
Assert.ThrowsAsync<ArgumentException>(() => messageReader.ReadMessage()).Wait();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMalformedContentLengthNonIntegerTest()
|
||||
{
|
||||
using (MemoryStream inputStream = new MemoryStream())
|
||||
{
|
||||
// If:
|
||||
// ... I create a new stream and pass it a non-integer content-length header
|
||||
// ... and attempt to read a message from it
|
||||
MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
|
||||
byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n");
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Then:
|
||||
// ... An exception should be thrown while reading
|
||||
Assert.ThrowsAsync<MessageParseException>(() => messageReader.ReadMessage()).Wait();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMissingContentLengthHeaderTest()
|
||||
{
|
||||
using (MemoryStream inputStream = new MemoryStream())
|
||||
{
|
||||
// If:
|
||||
// ... I create a new stream and pass it a a message without a content-length header
|
||||
// ... and attempt to read a message from it
|
||||
MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
|
||||
byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n");
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Then:
|
||||
// ... An exception should be thrown while reading
|
||||
Assert.ThrowsAsync<MessageParseException>(() => messageReader.ReadMessage()).Wait();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMalformedContentLengthTooShortTest()
|
||||
{
|
||||
using (MemoryStream inputStream = new MemoryStream())
|
||||
{
|
||||
// If:
|
||||
// ... Pass in an event that has an incorrect content length
|
||||
// ... And pass in an event that is correct
|
||||
MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
|
||||
byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n");
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString);
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n");
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Then:
|
||||
// ... The first read should fail with an exception while deserializing
|
||||
Assert.ThrowsAsync<JsonReaderException>(() => messageReader.ReadMessage()).Wait();
|
||||
|
||||
// ... The second read should fail with an exception while reading headers
|
||||
Assert.ThrowsAsync<MessageParseException>(() => messageReader.ReadMessage()).Wait();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReadMalformedThenValidTest()
|
||||
{
|
||||
// If:
|
||||
// ... I create a new stream and pass it information that is malformed
|
||||
// ... and attempt to read a message from it
|
||||
// ... Then pass it information that is valid and attempt to read a message from it
|
||||
using (MemoryStream inputStream = new MemoryStream())
|
||||
{
|
||||
MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
|
||||
byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n");
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
messageBuffer = GetMessageBytes(Common.TestEventString);
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Then:
|
||||
// ... An exception should be thrown while reading the first one
|
||||
Assert.ThrowsAsync<ArgumentException>(() => messageReader.ReadMessage()).Wait();
|
||||
|
||||
// ... A test event should be successfully read from the second one
|
||||
Message messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.NotNull(messageResult);
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ReaderResizesBufferForLargeMessages()
|
||||
{
|
||||
MemoryStream inputStream = new MemoryStream();
|
||||
MessageReader messageReader =
|
||||
new MessageReader(
|
||||
inputStream,
|
||||
this.messageSerializer);
|
||||
|
||||
// Get a message with content so large that the buffer will need
|
||||
// to be resized to fit it all.
|
||||
byte[] messageBuffer = this.GetMessageBytes(
|
||||
string.Format(
|
||||
Common.TestEventFormatString,
|
||||
new String('X', (int) (MessageReader.DefaultBufferSize*3))));
|
||||
|
||||
inputStream.Write(messageBuffer, 0, messageBuffer.Length);
|
||||
inputStream.Flush();
|
||||
inputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
Message messageResult = messageReader.ReadMessage().Result;
|
||||
Assert.Equal("testEvent", messageResult.Method);
|
||||
|
||||
inputStream.Dispose();
|
||||
}
|
||||
|
||||
private byte[] GetMessageBytes(string messageString, Encoding encoding = null)
|
||||
{
|
||||
if (encoding == null)
|
||||
{
|
||||
encoding = Encoding.UTF8;
|
||||
}
|
||||
|
||||
byte[] messageBytes = Encoding.UTF8.GetBytes(messageString);
|
||||
byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length));
|
||||
|
||||
// Copy the bytes into a single buffer
|
||||
byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length];
|
||||
Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length);
|
||||
Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length);
|
||||
|
||||
return finalBytes;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System.IO;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
|
||||
{
|
||||
public class MessageWriterTests
|
||||
{
|
||||
private readonly IMessageSerializer messageSerializer;
|
||||
|
||||
public MessageWriterTests()
|
||||
{
|
||||
this.messageSerializer = new V8MessageSerializer();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritesMessage()
|
||||
{
|
||||
MemoryStream outputStream = new MemoryStream();
|
||||
MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer);
|
||||
|
||||
// Write the message and then roll back the stream to be read
|
||||
// TODO: This will need to be redone!
|
||||
await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null));
|
||||
outputStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
string expectedHeaderString = string.Format(Constants.ContentLengthFormatString,
|
||||
Common.ExpectedMessageByteCount);
|
||||
|
||||
byte[] buffer = new byte[128];
|
||||
await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length);
|
||||
|
||||
Assert.Equal(
|
||||
expectedHeaderString,
|
||||
Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length));
|
||||
|
||||
// Read the message
|
||||
await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount);
|
||||
|
||||
Assert.Equal(Common.TestEventString,
|
||||
Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount));
|
||||
|
||||
outputStream.Dispose();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Message
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
|
||||
{
|
||||
#region Request Types
|
||||
|
||||
@@ -37,7 +37,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
|
||||
Dictionary<string, string> rowDictionary = new Dictionary<string, string>();
|
||||
for (int column = 0; column < columns; column++)
|
||||
{
|
||||
rowDictionary.Add(String.Format("column{0}", column), String.Format("val{0}{1}", column, row));
|
||||
rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row));
|
||||
}
|
||||
output[row] = rowDictionary;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user