Fixing infinite exception loop bug

Fixing issue where submitting a malformed JSON RPC request results in the
message reader entering into an infinite loop of throwing and catching
exceptions without reading anything from the input stream.

At the same time, this change also fixes a potential memory leak where the
message read buffer is never reinstantiated or shrunk. This issue is fixed
by shifting buffer contents after a message was read successfully, or if
an error occurs during parsing.
This commit is contained in:
Benjamin Russell
2016-08-12 15:37:16 -07:00
parent a2983539a7
commit fe79f6e85c
2 changed files with 91 additions and 68 deletions

View File

@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.EditorServices.Utility;
@@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
this.SynchronizationContext = SynchronizationContext.Current;
// Run the message loop
bool isRunning = true;
while (isRunning && !cancellationToken.IsCancellationRequested)
while (!cancellationToken.IsCancellationRequested)
{
Message newMessage = null;
Message newMessage;
try
{
@@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
catch (MessageParseException e)
{
// TODO: Write an error response
Logger.Write(
LogLevel.Error,
"Could not parse a message that was received:\r\n\r\n" +
e.ToString());
string message = String.Format("Exception occurred while parsing message: {0}", e.Message);
Logger.Write(LogLevel.Error, message);
await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
{
Message = message
});
// Continue the loop
continue;
@@ -227,8 +227,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
catch (Exception e)
{
var b = e.Message;
newMessage = null;
// Log the error and send an error event to the client
string message = String.Format("Exception occurred while receiving message: {0}", e.Message);
Logger.Write(LogLevel.Error, message);
await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
{
Message = message
});
// Continue the loop
continue;
}
// The message could be null if there was an error parsing the
@@ -236,9 +244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
if (newMessage != null)
{
// Process the message
await this.DispatchMessage(
newMessage,
this.MessageWriter);
await this.DispatchMessage(newMessage, this.MessageWriter);
}
}
}

View File

@@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
private const int CR = 0x0D;
private const int LF = 0x0A;
private static string[] NewLineDelimiters = new string[] { Environment.NewLine };
private static readonly string[] NewLineDelimiters = { Environment.NewLine };
private Stream inputStream;
private IMessageSerializer messageSerializer;
private Encoding messageEncoding;
private readonly Stream inputStream;
private readonly IMessageSerializer messageSerializer;
private readonly Encoding messageEncoding;
private ReadState readState;
private bool needsMoreData = true;
private int readOffset;
private int bufferEndOffset;
private byte[] messageBuffer = new byte[DefaultBufferSize];
private byte[] messageBuffer;
private int expectedContentLength;
private Dictionary<string, string> messageHeaders;
enum ReadState
private enum ReadState
{
Headers,
Content
@@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
this.needsMoreData = false;
// Do we need to look for message headers?
if (this.readState == ReadState.Headers &&
if (this.readState == ReadState.Headers &&
!this.TryReadMessageHeaders())
{
// If we don't have enough data to read headers yet, keep reading
@@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
// Do we need to look for message content?
if (this.readState == ReadState.Content &&
if (this.readState == ReadState.Content &&
!this.TryReadMessageContent(out messageContent))
{
// If we don't have enough data yet to construct the content, keep reading
@@ -106,6 +106,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
break;
}
// Now that we have a message, reset the buffer's state
ShiftBufferBytesAndShrink(readOffset);
// Get the JObject for the JSON content
JObject messageObject = JObject.Parse(messageContent);
@@ -162,8 +165,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
{
int scanOffset = this.readOffset;
// Scan for the final double-newline that marks the
// end of the header lines
// Scan for the final double-newline that marks the end of the header lines
while (scanOffset + 3 < this.bufferEndOffset &&
(this.messageBuffer[scanOffset] != CR ||
this.messageBuffer[scanOffset + 1] != LF ||
@@ -173,45 +175,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
scanOffset++;
}
// No header or body separator found (e.g CRLFCRLF)
// Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF)
if (scanOffset + 3 >= this.bufferEndOffset)
{
return false;
}
this.messageHeaders = new Dictionary<string, string>();
// Convert the header block into a array of lines
var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset)
.Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
var headers =
Encoding.ASCII
.GetString(this.messageBuffer, this.readOffset, scanOffset)
.Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
// Read each header and store it in the dictionary
foreach (var header in headers)
try
{
int currentLength = header.IndexOf(':');
if (currentLength == -1)
// Read each header and store it in the dictionary
this.messageHeaders = new Dictionary<string, string>();
foreach (var header in headers)
{
throw new ArgumentException("Message header must separate key and value using :");
int currentLength = header.IndexOf(':');
if (currentLength == -1)
{
throw new ArgumentException("Message header must separate key and value using :");
}
var key = header.Substring(0, currentLength);
var value = header.Substring(currentLength + 1).Trim();
this.messageHeaders[key] = value;
}
var key = header.Substring(0, currentLength);
var value = header.Substring(currentLength + 1).Trim();
this.messageHeaders[key] = value;
}
// Parse out the content length as an int
string contentLengthString;
if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
{
throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
}
// Make sure a Content-Length header was present, otherwise it
// is a fatal error
string contentLengthString = null;
if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
{
throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
// Parse the content length to an integer
if (!int.TryParse(contentLengthString, out this.expectedContentLength))
{
throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
}
}
// Parse the content length to an integer
if (!int.TryParse(contentLengthString, out this.expectedContentLength))
catch (Exception)
{
throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
// The content length was invalid or missing. Trash the buffer we've read
ShiftBufferBytesAndShrink(scanOffset + 4);
throw;
}
// Skip past the headers plus the newline characters
@@ -234,31 +242,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
// Convert the message contents to a string using the specified encoding
messageContent =
this.messageEncoding.GetString(
this.messageBuffer,
this.readOffset,
this.expectedContentLength);
messageContent = this.messageEncoding.GetString(
this.messageBuffer,
this.readOffset,
this.expectedContentLength);
// Move the remaining bytes to the front of the buffer for the next message
var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset);
Buffer.BlockCopy(
this.messageBuffer,
this.expectedContentLength + this.readOffset,
this.messageBuffer,
0,
remainingByteCount);
readOffset += expectedContentLength;
// Reset the offsets for the next read
this.readOffset = 0;
this.bufferEndOffset = remainingByteCount;
// Done reading content, now look for headers
// Done reading content, now look for headers for the next message
this.readState = ReadState.Headers;
return true;
}
private void ShiftBufferBytesAndShrink(int bytesToRemove)
{
// Create a new buffer that is shrunken by the number of bytes to remove
// Note: by using Max, we can guarantee a buffer of at least default buffer size
byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)];
// If we need to do shifting, do the shifting
if (bytesToRemove <= messageBuffer.Length)
{
// Copy the existing buffer starting at the offset to remove
Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove);
}
// Make the new buffer the message buffer
messageBuffer = newBuffer;
// Reset the read offset and the end offset
readOffset = 0;
bufferEndOffset -= bytesToRemove;
}
#endregion
}
}