diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs new file mode 100644 index 00000000..d6e65801 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts +{ + /// + /// Parameters to be used for reporting hosting-level errors, such as protocol violations + /// + public class HostingErrorParams + { + /// + /// The message of the error + /// + public string Message { get; set; } + } + + public class HostingErrorEvent + { + public static readonly + EventType Type = + EventType.Create("hosting/error"); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index a18fa806..c4cf5365 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.EditorServices.Utility; @@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.SynchronizationContext = SynchronizationContext.Current; // Run the message loop - bool isRunning = true; - while (isRunning && !cancellationToken.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - Message newMessage = null; + Message newMessage; try { @@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (MessageParseException e) { - // TODO: Write an error response - - Logger.Write( - LogLevel.Error, - "Could not parse a message that was received:\r\n\r\n" + - e.ToString()); + string message = string.Format("Exception occurred while parsing message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); // Continue the loop continue; @@ -227,18 +227,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (Exception e) { - var b = e.Message; - newMessage = null; + // Log the error and send an error event to the client + string message = string.Format("Exception occurred while receiving message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); + + // Continue the loop + continue; } // The message could be null if there was an error parsing the // previous message. In this case, do not try to dispatch it. if (newMessage != null) { + // Verbose logging + string logMessage = string.Format("Received message of type[{0}] and method[{1}]", + newMessage.MessageType, newMessage.Method); + Logger.Write(LogLevel.Verbose, logMessage); + // Process the message - await this.DispatchMessage( - newMessage, - this.MessageWriter); + await this.DispatchMessage(newMessage, this.MessageWriter); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index f3857710..17d4b5e0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol private const int CR = 0x0D; private const int LF = 0x0A; - private static string[] NewLineDelimiters = new string[] { Environment.NewLine }; + private static readonly string[] NewLineDelimiters = { Environment.NewLine }; - private Stream inputStream; - private IMessageSerializer messageSerializer; - private Encoding messageEncoding; + private readonly Stream inputStream; + private readonly IMessageSerializer messageSerializer; + private readonly Encoding messageEncoding; private ReadState readState; private bool needsMoreData = true; private int readOffset; private int bufferEndOffset; - private byte[] messageBuffer = new byte[DefaultBufferSize]; + private byte[] messageBuffer; private int expectedContentLength; private Dictionary messageHeaders; - enum ReadState + private enum ReadState { Headers, Content @@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.needsMoreData = false; // Do we need to look for message headers? - if (this.readState == ReadState.Headers && + if (this.readState == ReadState.Headers && !this.TryReadMessageHeaders()) { // If we don't have enough data to read headers yet, keep reading @@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Do we need to look for message content? - if (this.readState == ReadState.Content && + if (this.readState == ReadState.Content && !this.TryReadMessageContent(out messageContent)) { // If we don't have enough data yet to construct the content, keep reading @@ -106,16 +106,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol break; } + // Now that we have a message, reset the buffer's state + ShiftBufferBytesAndShrink(readOffset); + // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); - // Load the message - Logger.Write( - LogLevel.Verbose, - string.Format( - "READ MESSAGE:\r\n\r\n{0}", - messageObject.ToString(Formatting.Indented))); - // Return the parsed message return this.messageSerializer.DeserializeMessage(messageObject); } @@ -162,8 +158,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { int scanOffset = this.readOffset; - // Scan for the final double-newline that marks the - // end of the header lines + // Scan for the final double-newline that marks the end of the header lines while (scanOffset + 3 < this.bufferEndOffset && (this.messageBuffer[scanOffset] != CR || this.messageBuffer[scanOffset + 1] != LF || @@ -173,45 +168,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol scanOffset++; } - // No header or body separator found (e.g CRLFCRLF) + // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF) if (scanOffset + 3 >= this.bufferEndOffset) { return false; } - this.messageHeaders = new Dictionary(); + // Convert the header block into a array of lines + var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset) + .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - var headers = - Encoding.ASCII - .GetString(this.messageBuffer, this.readOffset, scanOffset) - .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - - // Read each header and store it in the dictionary - foreach (var header in headers) + try { - int currentLength = header.IndexOf(':'); - if (currentLength == -1) + // Read each header and store it in the dictionary + this.messageHeaders = new Dictionary(); + foreach (var header in headers) { - throw new ArgumentException("Message header must separate key and value using :"); + int currentLength = header.IndexOf(':'); + if (currentLength == -1) + { + throw new ArgumentException("Message header must separate key and value using :"); + } + + var key = header.Substring(0, currentLength); + var value = header.Substring(currentLength + 1).Trim(); + this.messageHeaders[key] = value; } - var key = header.Substring(0, currentLength); - var value = header.Substring(currentLength + 1).Trim(); - this.messageHeaders[key] = value; - } + // Parse out the content length as an int + string contentLengthString; + if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) + { + throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + } - // Make sure a Content-Length header was present, otherwise it - // is a fatal error - string contentLengthString = null; - if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) - { - throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + // Parse the content length to an integer + if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + { + throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + } } - - // Parse the content length to an integer - if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + catch (Exception) { - throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + // The content length was invalid or missing. Trash the buffer we've read + ShiftBufferBytesAndShrink(scanOffset + 4); + throw; } // Skip past the headers plus the newline characters @@ -234,31 +235,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Convert the message contents to a string using the specified encoding - messageContent = - this.messageEncoding.GetString( - this.messageBuffer, - this.readOffset, - this.expectedContentLength); + messageContent = this.messageEncoding.GetString( + this.messageBuffer, + this.readOffset, + this.expectedContentLength); - // Move the remaining bytes to the front of the buffer for the next message - var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset); - Buffer.BlockCopy( - this.messageBuffer, - this.expectedContentLength + this.readOffset, - this.messageBuffer, - 0, - remainingByteCount); + readOffset += expectedContentLength; - // Reset the offsets for the next read - this.readOffset = 0; - this.bufferEndOffset = remainingByteCount; - - // Done reading content, now look for headers + // Done reading content, now look for headers for the next message this.readState = ReadState.Headers; return true; } + private void ShiftBufferBytesAndShrink(int bytesToRemove) + { + // Create a new buffer that is shrunken by the number of bytes to remove + // Note: by using Max, we can guarantee a buffer of at least default buffer size + byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)]; + + // If we need to do shifting, do the shifting + if (bytesToRemove <= messageBuffer.Length) + { + // Copy the existing buffer starting at the offset to remove + Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove); + } + + // Make the new buffer the message buffer + messageBuffer = newBuffer; + + // Reset the read offset and the end offset + readOffset = 0; + bufferEndOffset -= bytesToRemove; + } + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 6cdbd745..f380f1bb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -337,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { Logger.Write( LogLevel.Error, - String.Format( + string.Format( "Exception while cancelling analysis task:\n\n{0}", e.ToString())); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs index 701fa6f5..f47cacb9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -182,7 +182,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace foreach (var textChange in textChangeParams.ContentChanges) { string fileUri = textChangeParams.Uri ?? textChangeParams.TextDocument.Uri; - msg.AppendLine(String.Format(" File: {0}", fileUri)); + msg.AppendLine(string.Format(" File: {0}", fileUri)); ScriptFile changedFile = Workspace.GetFile(fileUri); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs deleted file mode 100644 index 54fbf01f..00000000 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs +++ /dev/null @@ -1,178 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.IO; -using System.Text; -using System.Threading.Tasks; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -using Xunit; - -namespace Microsoft.SqlTools.ServiceLayer.Test.Message -{ - public class MessageReaderWriterTests - { - const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}"; - const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}"; - readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); - - private IMessageSerializer messageSerializer; - - public MessageReaderWriterTests() - { - this.messageSerializer = new V8MessageSerializer(); - } - - [Fact] - public async Task WritesMessage() - { - MemoryStream outputStream = new MemoryStream(); - - MessageWriter messageWriter = - new MessageWriter( - outputStream, - this.messageSerializer); - - // Write the message and then roll back the stream to be read - // TODO: This will need to be redone! - await messageWriter.WriteMessage(HostingMessage.Event("testEvent", null)); - outputStream.Seek(0, SeekOrigin.Begin); - - string expectedHeaderString = - string.Format( - Constants.ContentLengthFormatString, - ExpectedMessageByteCount); - - byte[] buffer = new byte[128]; - await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); - - Assert.Equal( - expectedHeaderString, - Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); - - // Read the message - await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount); - - Assert.Equal( - TestEventString, - Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount)); - - outputStream.Dispose(); - } - - [Fact] - public void ReadsMessage() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Write a message to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - inputStream.Write( - this.GetMessageBytes(TestEventString), - 0, - messageBuffer.Length); - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - [Fact] - public void ReadsManyBufferedMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message to use for writing to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - - // How many messages of this size should we write to overflow the buffer? - int overflowMessageCount = - (int)Math.Ceiling( - (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); - - // Write the necessary number of messages to the stream - for (int i = 0; i < overflowMessageCount; i++) - { - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - } - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - // Read the written messages from the stream - for (int i = 0; i < overflowMessageCount; i++) - { - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - } - - inputStream.Dispose(); - } - - [Fact] - public void ReaderResizesBufferForLargeMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message with content so large that the buffer will need - // to be resized to fit it all. - byte[] messageBuffer = - this.GetMessageBytes( - string.Format( - TestEventFormatString, - new String('X', (int)(MessageReader.DefaultBufferSize * 3)))); - - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - private byte[] GetMessageBytes(string messageString, Encoding encoding = null) - { - if (encoding == null) - { - encoding = Encoding.UTF8; - } - - byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); - byte[] headerBytes = - Encoding.ASCII.GetBytes( - string.Format( - Constants.ContentLengthFormatString, - messageBytes.Length)); - - // Copy the bytes into a single buffer - byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; - Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); - Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); - - return finalBytes; - } - } -} - diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs new file mode 100644 index 00000000..b575fe0c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs @@ -0,0 +1,17 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class Common + { + public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}"; + public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}"; + public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs new file mode 100644 index 00000000..0a12dc3e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs @@ -0,0 +1,241 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageReaderTests + { + + private readonly IMessageSerializer messageSerializer; + + public MessageReaderTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public void ReadsMessage() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer); + + // Write a message to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length); + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + [Fact] + public void ReadsManyBufferedMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message to use for writing to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + + // How many messages of this size should we write to overflow the buffer? + int overflowMessageCount = + (int)Math.Ceiling( + (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); + + // Write the necessary number of messages to the stream + for (int i = 0; i < overflowMessageCount; i++) + { + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + } + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Read the written messages from the stream + for (int i = 0; i < overflowMessageCount; i++) + { + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + } + + inputStream.Dispose(); + } + + [Fact] + public void ReadMalformedMissingHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthNonIntegerTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a non-integer content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMissingContentLengthHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a a message without a content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthTooShortTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... Pass in an event that has an incorrect content length + // ... And pass in an event that is correct + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... The first read should fail with an exception while deserializing + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... The second read should fail with an exception while reading headers + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedThenValidTest() + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + // ... Then pass it information that is valid and attempt to read a message from it + using (MemoryStream inputStream = new MemoryStream()) + { + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = GetMessageBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading the first one + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... A test event should be successfully read from the second one + Message messageResult = messageReader.ReadMessage().Result; + Assert.NotNull(messageResult); + Assert.Equal("testEvent", messageResult.Method); + } + } + + [Fact] + public void ReaderResizesBufferForLargeMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message with content so large that the buffer will need + // to be resized to fit it all. + byte[] messageBuffer = this.GetMessageBytes( + string.Format( + Common.TestEventFormatString, + new String('X', (int) (MessageReader.DefaultBufferSize*3)))); + + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + private byte[] GetMessageBytes(string messageString, Encoding encoding = null) + { + if (encoding == null) + { + encoding = Encoding.UTF8; + } + + byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); + byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length)); + + // Copy the bytes into a single buffer + byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; + Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); + Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); + + return finalBytes; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs new file mode 100644 index 00000000..3c007a85 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageWriterTests + { + private readonly IMessageSerializer messageSerializer; + + public MessageWriterTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public async Task WritesMessage() + { + MemoryStream outputStream = new MemoryStream(); + MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer); + + // Write the message and then roll back the stream to be read + // TODO: This will need to be redone! + await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null)); + outputStream.Seek(0, SeekOrigin.Begin); + + string expectedHeaderString = string.Format(Constants.ContentLengthFormatString, + Common.ExpectedMessageByteCount); + + byte[] buffer = new byte[128]; + await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); + + Assert.Equal( + expectedHeaderString, + Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); + + // Read the message + await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount); + + Assert.Equal(Common.TestEventString, + Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount)); + + outputStream.Dispose(); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs similarity index 94% rename from test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs index 0ba08056..89238098 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -namespace Microsoft.SqlTools.ServiceLayer.Test.Message +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging { #region Request Types diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 49b329a0..f887b50f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -37,7 +37,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Dictionary rowDictionary = new Dictionary(); for (int column = 0; column < columns; column++) { - rowDictionary.Add(String.Format("column{0}", column), String.Format("val{0}{1}", column, row)); + rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row)); } output[row] = rowDictionary; }