Fixing infinite exception loop bug

Fixing issue where submitting a malformed JSON RPC request results in the
message reader entering into an infinite loop of throwing and catching
exceptions without reading anything from the input stream.

At the same time, this change also fixes a potential memory leak where the
message read buffer is never reinstantiated or shrunk. This issue is fixed
by shifting buffer contents after a message was read successfully, or if
an error occurs during parsing.
This commit is contained in:
Benjamin Russell
2016-08-12 15:37:16 -07:00
parent a2983539a7
commit fe79f6e85c
2 changed files with 91 additions and 68 deletions

View File

@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.EditorServices.Utility;
@@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
this.SynchronizationContext = SynchronizationContext.Current; this.SynchronizationContext = SynchronizationContext.Current;
// Run the message loop // Run the message loop
bool isRunning = true; while (!cancellationToken.IsCancellationRequested)
while (isRunning && !cancellationToken.IsCancellationRequested)
{ {
Message newMessage = null; Message newMessage;
try try
{ {
@@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
} }
catch (MessageParseException e) catch (MessageParseException e)
{ {
// TODO: Write an error response string message = String.Format("Exception occurred while parsing message: {0}", e.Message);
Logger.Write(LogLevel.Error, message);
Logger.Write( await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
LogLevel.Error, {
"Could not parse a message that was received:\r\n\r\n" + Message = message
e.ToString()); });
// Continue the loop // Continue the loop
continue; continue;
@@ -227,8 +227,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
} }
catch (Exception e) catch (Exception e)
{ {
var b = e.Message; // Log the error and send an error event to the client
newMessage = null; string message = String.Format("Exception occurred while receiving message: {0}", e.Message);
Logger.Write(LogLevel.Error, message);
await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
{
Message = message
});
// Continue the loop
continue;
} }
// The message could be null if there was an error parsing the // The message could be null if there was an error parsing the
@@ -236,9 +244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
if (newMessage != null) if (newMessage != null)
{ {
// Process the message // Process the message
await this.DispatchMessage( await this.DispatchMessage(newMessage, this.MessageWriter);
newMessage,
this.MessageWriter);
} }
} }
} }

View File

@@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
private const int CR = 0x0D; private const int CR = 0x0D;
private const int LF = 0x0A; private const int LF = 0x0A;
private static string[] NewLineDelimiters = new string[] { Environment.NewLine }; private static readonly string[] NewLineDelimiters = { Environment.NewLine };
private Stream inputStream; private readonly Stream inputStream;
private IMessageSerializer messageSerializer; private readonly IMessageSerializer messageSerializer;
private Encoding messageEncoding; private readonly Encoding messageEncoding;
private ReadState readState; private ReadState readState;
private bool needsMoreData = true; private bool needsMoreData = true;
private int readOffset; private int readOffset;
private int bufferEndOffset; private int bufferEndOffset;
private byte[] messageBuffer = new byte[DefaultBufferSize]; private byte[] messageBuffer;
private int expectedContentLength; private int expectedContentLength;
private Dictionary<string, string> messageHeaders; private Dictionary<string, string> messageHeaders;
enum ReadState private enum ReadState
{ {
Headers, Headers,
Content Content
@@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
this.needsMoreData = false; this.needsMoreData = false;
// Do we need to look for message headers? // Do we need to look for message headers?
if (this.readState == ReadState.Headers && if (this.readState == ReadState.Headers &&
!this.TryReadMessageHeaders()) !this.TryReadMessageHeaders())
{ {
// If we don't have enough data to read headers yet, keep reading // If we don't have enough data to read headers yet, keep reading
@@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
} }
// Do we need to look for message content? // Do we need to look for message content?
if (this.readState == ReadState.Content && if (this.readState == ReadState.Content &&
!this.TryReadMessageContent(out messageContent)) !this.TryReadMessageContent(out messageContent))
{ {
// If we don't have enough data yet to construct the content, keep reading // If we don't have enough data yet to construct the content, keep reading
@@ -106,6 +106,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
break; break;
} }
// Now that we have a message, reset the buffer's state
ShiftBufferBytesAndShrink(readOffset);
// Get the JObject for the JSON content // Get the JObject for the JSON content
JObject messageObject = JObject.Parse(messageContent); JObject messageObject = JObject.Parse(messageContent);
@@ -162,8 +165,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
{ {
int scanOffset = this.readOffset; int scanOffset = this.readOffset;
// Scan for the final double-newline that marks the // Scan for the final double-newline that marks the end of the header lines
// end of the header lines
while (scanOffset + 3 < this.bufferEndOffset && while (scanOffset + 3 < this.bufferEndOffset &&
(this.messageBuffer[scanOffset] != CR || (this.messageBuffer[scanOffset] != CR ||
this.messageBuffer[scanOffset + 1] != LF || this.messageBuffer[scanOffset + 1] != LF ||
@@ -173,45 +175,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
scanOffset++; scanOffset++;
} }
// No header or body separator found (e.g CRLFCRLF) // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF)
if (scanOffset + 3 >= this.bufferEndOffset) if (scanOffset + 3 >= this.bufferEndOffset)
{ {
return false; return false;
} }
this.messageHeaders = new Dictionary<string, string>(); // Convert the header block into a array of lines
var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset)
.Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
var headers = try
Encoding.ASCII
.GetString(this.messageBuffer, this.readOffset, scanOffset)
.Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
// Read each header and store it in the dictionary
foreach (var header in headers)
{ {
int currentLength = header.IndexOf(':'); // Read each header and store it in the dictionary
if (currentLength == -1) this.messageHeaders = new Dictionary<string, string>();
foreach (var header in headers)
{ {
throw new ArgumentException("Message header must separate key and value using :"); int currentLength = header.IndexOf(':');
if (currentLength == -1)
{
throw new ArgumentException("Message header must separate key and value using :");
}
var key = header.Substring(0, currentLength);
var value = header.Substring(currentLength + 1).Trim();
this.messageHeaders[key] = value;
} }
var key = header.Substring(0, currentLength); // Parse out the content length as an int
var value = header.Substring(currentLength + 1).Trim(); string contentLengthString;
this.messageHeaders[key] = value; if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
} {
throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
}
// Make sure a Content-Length header was present, otherwise it // Parse the content length to an integer
// is a fatal error if (!int.TryParse(contentLengthString, out this.expectedContentLength))
string contentLengthString = null; {
if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
{ }
throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
} }
catch (Exception)
// Parse the content length to an integer
if (!int.TryParse(contentLengthString, out this.expectedContentLength))
{ {
throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); // The content length was invalid or missing. Trash the buffer we've read
ShiftBufferBytesAndShrink(scanOffset + 4);
throw;
} }
// Skip past the headers plus the newline characters // Skip past the headers plus the newline characters
@@ -234,31 +242,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
} }
// Convert the message contents to a string using the specified encoding // Convert the message contents to a string using the specified encoding
messageContent = messageContent = this.messageEncoding.GetString(
this.messageEncoding.GetString( this.messageBuffer,
this.messageBuffer, this.readOffset,
this.readOffset, this.expectedContentLength);
this.expectedContentLength);
// Move the remaining bytes to the front of the buffer for the next message readOffset += expectedContentLength;
var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset);
Buffer.BlockCopy(
this.messageBuffer,
this.expectedContentLength + this.readOffset,
this.messageBuffer,
0,
remainingByteCount);
// Reset the offsets for the next read // Done reading content, now look for headers for the next message
this.readOffset = 0;
this.bufferEndOffset = remainingByteCount;
// Done reading content, now look for headers
this.readState = ReadState.Headers; this.readState = ReadState.Headers;
return true; return true;
} }
private void ShiftBufferBytesAndShrink(int bytesToRemove)
{
// Create a new buffer that is shrunken by the number of bytes to remove
// Note: by using Max, we can guarantee a buffer of at least default buffer size
byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)];
// If we need to do shifting, do the shifting
if (bytesToRemove <= messageBuffer.Length)
{
// Copy the existing buffer starting at the offset to remove
Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove);
}
// Make the new buffer the message buffer
messageBuffer = newBuffer;
// Reset the read offset and the end offset
readOffset = 0;
bufferEndOffset -= bytesToRemove;
}
#endregion #endregion
} }
} }