From fe79f6e85c60be70726589e8e23c6962a3000032 Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 12 Aug 2016 15:37:16 -0700 Subject: [PATCH 1/5] Fixing infinite exception loop bug Fixing issue where submitting a malformed JSON RPC request results in the message reader entering into an infinite loop of throwing and catching exceptions without reading anything from the input stream. At the same time, this change also fixes a potential memory leak where the message read buffer is never reinstantiated or shrunk. This issue is fixed by shifting buffer contents after a message was read successfully, or if an error occurs during parsing. --- .../Hosting/Protocol/MessageDispatcher.cs | 34 +++-- .../Hosting/Protocol/MessageReader.cs | 125 ++++++++++-------- 2 files changed, 91 insertions(+), 68 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index a18fa806..7cf1f2dd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.EditorServices.Utility; @@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.SynchronizationContext = SynchronizationContext.Current; // Run the message loop - bool isRunning = true; - while (isRunning && !cancellationToken.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - Message newMessage = null; + Message newMessage; try { @@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (MessageParseException e) { - // TODO: Write an error response - - Logger.Write( - LogLevel.Error, - "Could not parse a message that was received:\r\n\r\n" + - e.ToString()); + string message = String.Format("Exception occurred while parsing message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); // Continue the loop continue; @@ -227,8 +227,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (Exception e) { - var b = e.Message; - newMessage = null; + // Log the error and send an error event to the client + string message = String.Format("Exception occurred while receiving message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); + + // Continue the loop + continue; } // The message could be null if there was an error parsing the @@ -236,9 +244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol if (newMessage != null) { // Process the message - await this.DispatchMessage( - newMessage, - this.MessageWriter); + await this.DispatchMessage(newMessage, this.MessageWriter); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index f3857710..4351361f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol private const int CR = 0x0D; private const int LF = 0x0A; - private static string[] NewLineDelimiters = new string[] { Environment.NewLine }; + private static readonly string[] NewLineDelimiters = { Environment.NewLine }; - private Stream inputStream; - private IMessageSerializer messageSerializer; - private Encoding messageEncoding; + private readonly Stream inputStream; + private readonly IMessageSerializer messageSerializer; + private readonly Encoding messageEncoding; private ReadState readState; private bool needsMoreData = true; private int readOffset; private int bufferEndOffset; - private byte[] messageBuffer = new byte[DefaultBufferSize]; + private byte[] messageBuffer; private int expectedContentLength; private Dictionary messageHeaders; - enum ReadState + private enum ReadState { Headers, Content @@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.needsMoreData = false; // Do we need to look for message headers? - if (this.readState == ReadState.Headers && + if (this.readState == ReadState.Headers && !this.TryReadMessageHeaders()) { // If we don't have enough data to read headers yet, keep reading @@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Do we need to look for message content? - if (this.readState == ReadState.Content && + if (this.readState == ReadState.Content && !this.TryReadMessageContent(out messageContent)) { // If we don't have enough data yet to construct the content, keep reading @@ -106,6 +106,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol break; } + // Now that we have a message, reset the buffer's state + ShiftBufferBytesAndShrink(readOffset); + // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); @@ -162,8 +165,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { int scanOffset = this.readOffset; - // Scan for the final double-newline that marks the - // end of the header lines + // Scan for the final double-newline that marks the end of the header lines while (scanOffset + 3 < this.bufferEndOffset && (this.messageBuffer[scanOffset] != CR || this.messageBuffer[scanOffset + 1] != LF || @@ -173,45 +175,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol scanOffset++; } - // No header or body separator found (e.g CRLFCRLF) + // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF) if (scanOffset + 3 >= this.bufferEndOffset) { return false; } - this.messageHeaders = new Dictionary(); + // Convert the header block into a array of lines + var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset) + .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - var headers = - Encoding.ASCII - .GetString(this.messageBuffer, this.readOffset, scanOffset) - .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - - // Read each header and store it in the dictionary - foreach (var header in headers) + try { - int currentLength = header.IndexOf(':'); - if (currentLength == -1) + // Read each header and store it in the dictionary + this.messageHeaders = new Dictionary(); + foreach (var header in headers) { - throw new ArgumentException("Message header must separate key and value using :"); + int currentLength = header.IndexOf(':'); + if (currentLength == -1) + { + throw new ArgumentException("Message header must separate key and value using :"); + } + + var key = header.Substring(0, currentLength); + var value = header.Substring(currentLength + 1).Trim(); + this.messageHeaders[key] = value; } - var key = header.Substring(0, currentLength); - var value = header.Substring(currentLength + 1).Trim(); - this.messageHeaders[key] = value; - } + // Parse out the content length as an int + string contentLengthString; + if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) + { + throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + } - // Make sure a Content-Length header was present, otherwise it - // is a fatal error - string contentLengthString = null; - if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) - { - throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + // Parse the content length to an integer + if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + { + throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + } } - - // Parse the content length to an integer - if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + catch (Exception) { - throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + // The content length was invalid or missing. Trash the buffer we've read + ShiftBufferBytesAndShrink(scanOffset + 4); + throw; } // Skip past the headers plus the newline characters @@ -234,31 +242,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Convert the message contents to a string using the specified encoding - messageContent = - this.messageEncoding.GetString( - this.messageBuffer, - this.readOffset, - this.expectedContentLength); + messageContent = this.messageEncoding.GetString( + this.messageBuffer, + this.readOffset, + this.expectedContentLength); - // Move the remaining bytes to the front of the buffer for the next message - var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset); - Buffer.BlockCopy( - this.messageBuffer, - this.expectedContentLength + this.readOffset, - this.messageBuffer, - 0, - remainingByteCount); + readOffset += expectedContentLength; - // Reset the offsets for the next read - this.readOffset = 0; - this.bufferEndOffset = remainingByteCount; - - // Done reading content, now look for headers + // Done reading content, now look for headers for the next message this.readState = ReadState.Headers; return true; } + private void ShiftBufferBytesAndShrink(int bytesToRemove) + { + // Create a new buffer that is shrunken by the number of bytes to remove + // Note: by using Max, we can guarantee a buffer of at least default buffer size + byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)]; + + // If we need to do shifting, do the shifting + if (bytesToRemove <= messageBuffer.Length) + { + // Copy the existing buffer starting at the offset to remove + Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove); + } + + // Make the new buffer the message buffer + messageBuffer = newBuffer; + + // Reset the read offset and the end offset + readOffset = 0; + bufferEndOffset -= bytesToRemove; + } + #endregion } } From ba144bd5d0c3bc35e5b92aeed076f9e3eb6cdeef Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 12 Aug 2016 17:37:07 -0700 Subject: [PATCH 2/5] Unit tests for the message reader --- .../Message/MessageReaderWriterTests.cs | 178 ------------- .../Messaging/Common.cs | 17 ++ .../Messaging/MessageReaderTests.cs | 241 ++++++++++++++++++ .../Messaging/MessageWriterTests.cs | 55 ++++ .../TestMessageTypes.cs | 0 5 files changed, 313 insertions(+), 178 deletions(-) delete mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs rename test/Microsoft.SqlTools.ServiceLayer.Test/{Message => Messaging}/TestMessageTypes.cs (100%) diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs deleted file mode 100644 index 54fbf01f..00000000 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs +++ /dev/null @@ -1,178 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.IO; -using System.Text; -using System.Threading.Tasks; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -using Xunit; - -namespace Microsoft.SqlTools.ServiceLayer.Test.Message -{ - public class MessageReaderWriterTests - { - const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}"; - const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}"; - readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); - - private IMessageSerializer messageSerializer; - - public MessageReaderWriterTests() - { - this.messageSerializer = new V8MessageSerializer(); - } - - [Fact] - public async Task WritesMessage() - { - MemoryStream outputStream = new MemoryStream(); - - MessageWriter messageWriter = - new MessageWriter( - outputStream, - this.messageSerializer); - - // Write the message and then roll back the stream to be read - // TODO: This will need to be redone! - await messageWriter.WriteMessage(HostingMessage.Event("testEvent", null)); - outputStream.Seek(0, SeekOrigin.Begin); - - string expectedHeaderString = - string.Format( - Constants.ContentLengthFormatString, - ExpectedMessageByteCount); - - byte[] buffer = new byte[128]; - await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); - - Assert.Equal( - expectedHeaderString, - Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); - - // Read the message - await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount); - - Assert.Equal( - TestEventString, - Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount)); - - outputStream.Dispose(); - } - - [Fact] - public void ReadsMessage() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Write a message to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - inputStream.Write( - this.GetMessageBytes(TestEventString), - 0, - messageBuffer.Length); - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - [Fact] - public void ReadsManyBufferedMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message to use for writing to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - - // How many messages of this size should we write to overflow the buffer? - int overflowMessageCount = - (int)Math.Ceiling( - (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); - - // Write the necessary number of messages to the stream - for (int i = 0; i < overflowMessageCount; i++) - { - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - } - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - // Read the written messages from the stream - for (int i = 0; i < overflowMessageCount; i++) - { - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - } - - inputStream.Dispose(); - } - - [Fact] - public void ReaderResizesBufferForLargeMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message with content so large that the buffer will need - // to be resized to fit it all. - byte[] messageBuffer = - this.GetMessageBytes( - string.Format( - TestEventFormatString, - new String('X', (int)(MessageReader.DefaultBufferSize * 3)))); - - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - private byte[] GetMessageBytes(string messageString, Encoding encoding = null) - { - if (encoding == null) - { - encoding = Encoding.UTF8; - } - - byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); - byte[] headerBytes = - Encoding.ASCII.GetBytes( - string.Format( - Constants.ContentLengthFormatString, - messageBytes.Length)); - - // Copy the bytes into a single buffer - byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; - Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); - Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); - - return finalBytes; - } - } -} - diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs new file mode 100644 index 00000000..b575fe0c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs @@ -0,0 +1,17 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class Common + { + public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}"; + public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}"; + public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs new file mode 100644 index 00000000..0a12dc3e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs @@ -0,0 +1,241 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageReaderTests + { + + private readonly IMessageSerializer messageSerializer; + + public MessageReaderTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public void ReadsMessage() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer); + + // Write a message to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length); + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + [Fact] + public void ReadsManyBufferedMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message to use for writing to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + + // How many messages of this size should we write to overflow the buffer? + int overflowMessageCount = + (int)Math.Ceiling( + (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); + + // Write the necessary number of messages to the stream + for (int i = 0; i < overflowMessageCount; i++) + { + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + } + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Read the written messages from the stream + for (int i = 0; i < overflowMessageCount; i++) + { + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + } + + inputStream.Dispose(); + } + + [Fact] + public void ReadMalformedMissingHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthNonIntegerTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a non-integer content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMissingContentLengthHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a a message without a content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthTooShortTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... Pass in an event that has an incorrect content length + // ... And pass in an event that is correct + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... The first read should fail with an exception while deserializing + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... The second read should fail with an exception while reading headers + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedThenValidTest() + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + // ... Then pass it information that is valid and attempt to read a message from it + using (MemoryStream inputStream = new MemoryStream()) + { + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = GetMessageBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading the first one + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... A test event should be successfully read from the second one + Message messageResult = messageReader.ReadMessage().Result; + Assert.NotNull(messageResult); + Assert.Equal("testEvent", messageResult.Method); + } + } + + [Fact] + public void ReaderResizesBufferForLargeMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message with content so large that the buffer will need + // to be resized to fit it all. + byte[] messageBuffer = this.GetMessageBytes( + string.Format( + Common.TestEventFormatString, + new String('X', (int) (MessageReader.DefaultBufferSize*3)))); + + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + private byte[] GetMessageBytes(string messageString, Encoding encoding = null) + { + if (encoding == null) + { + encoding = Encoding.UTF8; + } + + byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); + byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length)); + + // Copy the bytes into a single buffer + byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; + Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); + Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); + + return finalBytes; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs new file mode 100644 index 00000000..3c007a85 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageWriterTests + { + private readonly IMessageSerializer messageSerializer; + + public MessageWriterTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public async Task WritesMessage() + { + MemoryStream outputStream = new MemoryStream(); + MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer); + + // Write the message and then roll back the stream to be read + // TODO: This will need to be redone! + await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null)); + outputStream.Seek(0, SeekOrigin.Begin); + + string expectedHeaderString = string.Format(Constants.ContentLengthFormatString, + Common.ExpectedMessageByteCount); + + byte[] buffer = new byte[128]; + await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); + + Assert.Equal( + expectedHeaderString, + Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); + + // Read the message + await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount); + + Assert.Equal(Common.TestEventString, + Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount)); + + outputStream.Dispose(); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs similarity index 100% rename from test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs From 1acc8c91228eeb2dbf295787b5840b2de5344229 Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 12 Aug 2016 17:38:41 -0700 Subject: [PATCH 3/5] Fixing up error logging Fixing issue where plaintext passwords could be written to logs Fixing up todo issues where error events needed to be passed back from the service layer when the hosting component fails. --- .../Hosting/Contracts/HostingErrorEvent.cs | 28 +++++++++++++++++++ .../Hosting/Protocol/MessageDispatcher.cs | 5 ++++ .../Hosting/Protocol/MessageReader.cs | 7 ----- .../Messaging/TestMessageTypes.cs | 2 +- 4 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs new file mode 100644 index 00000000..4d221acc --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts +{ + /// + /// Parameters to be used for reporting hosting-level errors, such as protocol violations + /// + public class HostingErrorParams + { + /// + /// + /// + public string Message { get; set; } + } + + public class HostingErrorEvent + { + public static readonly + EventType Type = + EventType.Create("hostingError"); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index 7cf1f2dd..d7e5812d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -243,6 +243,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol // previous message. In this case, do not try to dispatch it. if (newMessage != null) { + // Verbose logging + string logMessage = String.Format("Received message of type[{0}] and method[{1}]", + newMessage.MessageType, newMessage.Method); + Logger.Write(LogLevel.Verbose, logMessage); + // Process the message await this.DispatchMessage(newMessage, this.MessageWriter); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index 4351361f..17d4b5e0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -112,13 +112,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); - // Load the message - Logger.Write( - LogLevel.Verbose, - string.Format( - "READ MESSAGE:\r\n\r\n{0}", - messageObject.ToString(Formatting.Indented))); - // Return the parsed message return this.messageSerializer.DeserializeMessage(messageObject); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs index 0ba08056..89238098 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -namespace Microsoft.SqlTools.ServiceLayer.Test.Message +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging { #region Request Types From 9fa183ea6dddf48c117c08758a5c8e48988f4017 Mon Sep 17 00:00:00 2001 From: benrr101 Date: Tue, 16 Aug 2016 12:28:52 -0700 Subject: [PATCH 4/5] Fixing String.Format to string.Format --- .../Hosting/Protocol/MessageDispatcher.cs | 6 +++--- .../LanguageServices/LanguageService.cs | 2 +- .../Workspace/WorkspaceService.cs | 2 +- .../QueryExecution/Common.cs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index d7e5812d..c4cf5365 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -210,7 +210,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (MessageParseException e) { - string message = String.Format("Exception occurred while parsing message: {0}", e.Message); + string message = string.Format("Exception occurred while parsing message: {0}", e.Message); Logger.Write(LogLevel.Error, message); await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams { @@ -228,7 +228,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol catch (Exception e) { // Log the error and send an error event to the client - string message = String.Format("Exception occurred while receiving message: {0}", e.Message); + string message = string.Format("Exception occurred while receiving message: {0}", e.Message); Logger.Write(LogLevel.Error, message); await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams { @@ -244,7 +244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol if (newMessage != null) { // Verbose logging - string logMessage = String.Format("Received message of type[{0}] and method[{1}]", + string logMessage = string.Format("Received message of type[{0}] and method[{1}]", newMessage.MessageType, newMessage.Method); Logger.Write(LogLevel.Verbose, logMessage); diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 6cdbd745..f380f1bb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -337,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { Logger.Write( LogLevel.Error, - String.Format( + string.Format( "Exception while cancelling analysis task:\n\n{0}", e.ToString())); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs index 701fa6f5..f47cacb9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -182,7 +182,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace foreach (var textChange in textChangeParams.ContentChanges) { string fileUri = textChangeParams.Uri ?? textChangeParams.TextDocument.Uri; - msg.AppendLine(String.Format(" File: {0}", fileUri)); + msg.AppendLine(string.Format(" File: {0}", fileUri)); ScriptFile changedFile = Workspace.GetFile(fileUri); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 9bc8053b..292e3264 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -36,7 +36,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Dictionary rowDictionary = new Dictionary(); for (int column = 0; column < columns; column++) { - rowDictionary.Add(String.Format("column{0}", column), String.Format("val{0}{1}", column, row)); + rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row)); } output[row] = rowDictionary; } From 71d852ba076eb1ea44d797db28ddd2f50f139d9c Mon Sep 17 00:00:00 2001 From: benrr101 Date: Tue, 16 Aug 2016 12:35:39 -0700 Subject: [PATCH 5/5] Renaming the hosting error method --- .../Hosting/Contracts/HostingErrorEvent.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs index 4d221acc..d6e65801 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -13,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts public class HostingErrorParams { /// - /// + /// The message of the error /// public string Message { get; set; } } @@ -22,7 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts { public static readonly EventType Type = - EventType.Create("hostingError"); + EventType.Create("hosting/error"); } }