From fe79f6e85c60be70726589e8e23c6962a3000032 Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 12 Aug 2016 15:37:16 -0700 Subject: [PATCH 01/12] Fixing infinite exception loop bug Fixing issue where submitting a malformed JSON RPC request results in the message reader entering into an infinite loop of throwing and catching exceptions without reading anything from the input stream. At the same time, this change also fixes a potential memory leak where the message read buffer is never reinstantiated or shrunk. This issue is fixed by shifting buffer contents after a message was read successfully, or if an error occurs during parsing. --- .../Hosting/Protocol/MessageDispatcher.cs | 34 +++-- .../Hosting/Protocol/MessageReader.cs | 125 ++++++++++-------- 2 files changed, 91 insertions(+), 68 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index a18fa806..7cf1f2dd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.EditorServices.Utility; @@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.SynchronizationContext = SynchronizationContext.Current; // Run the message loop - bool isRunning = true; - while (isRunning && !cancellationToken.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - Message newMessage = null; + Message newMessage; try { @@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (MessageParseException e) { - // TODO: Write an error response - - Logger.Write( - LogLevel.Error, - "Could not parse a message that was received:\r\n\r\n" + - e.ToString()); + string message = String.Format("Exception occurred while parsing message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); // Continue the loop continue; @@ -227,8 +227,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (Exception e) { - var b = e.Message; - newMessage = null; + // Log the error and send an error event to the client + string message = String.Format("Exception occurred while receiving message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); + + // Continue the loop + continue; } // The message could be null if there was an error parsing the @@ -236,9 +244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol if (newMessage != null) { // Process the message - await this.DispatchMessage( - newMessage, - this.MessageWriter); + await this.DispatchMessage(newMessage, this.MessageWriter); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index f3857710..4351361f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol private const int CR = 0x0D; private const int LF = 0x0A; - private static string[] NewLineDelimiters = new string[] { Environment.NewLine }; + private static readonly string[] NewLineDelimiters = { Environment.NewLine }; - private Stream inputStream; - private IMessageSerializer messageSerializer; - private Encoding messageEncoding; + private readonly Stream inputStream; + private readonly IMessageSerializer messageSerializer; + private readonly Encoding messageEncoding; private ReadState readState; private bool needsMoreData = true; private int readOffset; private int bufferEndOffset; - private byte[] messageBuffer = new byte[DefaultBufferSize]; + private byte[] messageBuffer; private int expectedContentLength; private Dictionary messageHeaders; - enum ReadState + private enum ReadState { Headers, Content @@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.needsMoreData = false; // Do we need to look for message headers? - if (this.readState == ReadState.Headers && + if (this.readState == ReadState.Headers && !this.TryReadMessageHeaders()) { // If we don't have enough data to read headers yet, keep reading @@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Do we need to look for message content? - if (this.readState == ReadState.Content && + if (this.readState == ReadState.Content && !this.TryReadMessageContent(out messageContent)) { // If we don't have enough data yet to construct the content, keep reading @@ -106,6 +106,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol break; } + // Now that we have a message, reset the buffer's state + ShiftBufferBytesAndShrink(readOffset); + // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); @@ -162,8 +165,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { int scanOffset = this.readOffset; - // Scan for the final double-newline that marks the - // end of the header lines + // Scan for the final double-newline that marks the end of the header lines while (scanOffset + 3 < this.bufferEndOffset && (this.messageBuffer[scanOffset] != CR || this.messageBuffer[scanOffset + 1] != LF || @@ -173,45 +175,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol scanOffset++; } - // No header or body separator found (e.g CRLFCRLF) + // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF) if (scanOffset + 3 >= this.bufferEndOffset) { return false; } - this.messageHeaders = new Dictionary(); + // Convert the header block into a array of lines + var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset) + .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - var headers = - Encoding.ASCII - .GetString(this.messageBuffer, this.readOffset, scanOffset) - .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - - // Read each header and store it in the dictionary - foreach (var header in headers) + try { - int currentLength = header.IndexOf(':'); - if (currentLength == -1) + // Read each header and store it in the dictionary + this.messageHeaders = new Dictionary(); + foreach (var header in headers) { - throw new ArgumentException("Message header must separate key and value using :"); + int currentLength = header.IndexOf(':'); + if (currentLength == -1) + { + throw new ArgumentException("Message header must separate key and value using :"); + } + + var key = header.Substring(0, currentLength); + var value = header.Substring(currentLength + 1).Trim(); + this.messageHeaders[key] = value; } - var key = header.Substring(0, currentLength); - var value = header.Substring(currentLength + 1).Trim(); - this.messageHeaders[key] = value; - } + // Parse out the content length as an int + string contentLengthString; + if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) + { + throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + } - // Make sure a Content-Length header was present, otherwise it - // is a fatal error - string contentLengthString = null; - if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) - { - throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + // Parse the content length to an integer + if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + { + throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + } } - - // Parse the content length to an integer - if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + catch (Exception) { - throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + // The content length was invalid or missing. Trash the buffer we've read + ShiftBufferBytesAndShrink(scanOffset + 4); + throw; } // Skip past the headers plus the newline characters @@ -234,31 +242,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Convert the message contents to a string using the specified encoding - messageContent = - this.messageEncoding.GetString( - this.messageBuffer, - this.readOffset, - this.expectedContentLength); + messageContent = this.messageEncoding.GetString( + this.messageBuffer, + this.readOffset, + this.expectedContentLength); - // Move the remaining bytes to the front of the buffer for the next message - var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset); - Buffer.BlockCopy( - this.messageBuffer, - this.expectedContentLength + this.readOffset, - this.messageBuffer, - 0, - remainingByteCount); + readOffset += expectedContentLength; - // Reset the offsets for the next read - this.readOffset = 0; - this.bufferEndOffset = remainingByteCount; - - // Done reading content, now look for headers + // Done reading content, now look for headers for the next message this.readState = ReadState.Headers; return true; } + private void ShiftBufferBytesAndShrink(int bytesToRemove) + { + // Create a new buffer that is shrunken by the number of bytes to remove + // Note: by using Max, we can guarantee a buffer of at least default buffer size + byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)]; + + // If we need to do shifting, do the shifting + if (bytesToRemove <= messageBuffer.Length) + { + // Copy the existing buffer starting at the offset to remove + Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove); + } + + // Make the new buffer the message buffer + messageBuffer = newBuffer; + + // Reset the read offset and the end offset + readOffset = 0; + bufferEndOffset -= bytesToRemove; + } + #endregion } } From ba144bd5d0c3bc35e5b92aeed076f9e3eb6cdeef Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 12 Aug 2016 17:37:07 -0700 Subject: [PATCH 02/12] Unit tests for the message reader --- .../Message/MessageReaderWriterTests.cs | 178 ------------- .../Messaging/Common.cs | 17 ++ .../Messaging/MessageReaderTests.cs | 241 ++++++++++++++++++ .../Messaging/MessageWriterTests.cs | 55 ++++ .../TestMessageTypes.cs | 0 5 files changed, 313 insertions(+), 178 deletions(-) delete mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs rename test/Microsoft.SqlTools.ServiceLayer.Test/{Message => Messaging}/TestMessageTypes.cs (100%) diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs deleted file mode 100644 index 54fbf01f..00000000 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs +++ /dev/null @@ -1,178 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.IO; -using System.Text; -using System.Threading.Tasks; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -using Xunit; - -namespace Microsoft.SqlTools.ServiceLayer.Test.Message -{ - public class MessageReaderWriterTests - { - const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}"; - const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}"; - readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); - - private IMessageSerializer messageSerializer; - - public MessageReaderWriterTests() - { - this.messageSerializer = new V8MessageSerializer(); - } - - [Fact] - public async Task WritesMessage() - { - MemoryStream outputStream = new MemoryStream(); - - MessageWriter messageWriter = - new MessageWriter( - outputStream, - this.messageSerializer); - - // Write the message and then roll back the stream to be read - // TODO: This will need to be redone! - await messageWriter.WriteMessage(HostingMessage.Event("testEvent", null)); - outputStream.Seek(0, SeekOrigin.Begin); - - string expectedHeaderString = - string.Format( - Constants.ContentLengthFormatString, - ExpectedMessageByteCount); - - byte[] buffer = new byte[128]; - await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); - - Assert.Equal( - expectedHeaderString, - Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); - - // Read the message - await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount); - - Assert.Equal( - TestEventString, - Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount)); - - outputStream.Dispose(); - } - - [Fact] - public void ReadsMessage() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Write a message to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - inputStream.Write( - this.GetMessageBytes(TestEventString), - 0, - messageBuffer.Length); - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - [Fact] - public void ReadsManyBufferedMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message to use for writing to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - - // How many messages of this size should we write to overflow the buffer? - int overflowMessageCount = - (int)Math.Ceiling( - (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); - - // Write the necessary number of messages to the stream - for (int i = 0; i < overflowMessageCount; i++) - { - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - } - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - // Read the written messages from the stream - for (int i = 0; i < overflowMessageCount; i++) - { - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - } - - inputStream.Dispose(); - } - - [Fact] - public void ReaderResizesBufferForLargeMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message with content so large that the buffer will need - // to be resized to fit it all. - byte[] messageBuffer = - this.GetMessageBytes( - string.Format( - TestEventFormatString, - new String('X', (int)(MessageReader.DefaultBufferSize * 3)))); - - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - private byte[] GetMessageBytes(string messageString, Encoding encoding = null) - { - if (encoding == null) - { - encoding = Encoding.UTF8; - } - - byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); - byte[] headerBytes = - Encoding.ASCII.GetBytes( - string.Format( - Constants.ContentLengthFormatString, - messageBytes.Length)); - - // Copy the bytes into a single buffer - byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; - Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); - Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); - - return finalBytes; - } - } -} - diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs new file mode 100644 index 00000000..b575fe0c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs @@ -0,0 +1,17 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class Common + { + public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}"; + public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}"; + public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs new file mode 100644 index 00000000..0a12dc3e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs @@ -0,0 +1,241 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageReaderTests + { + + private readonly IMessageSerializer messageSerializer; + + public MessageReaderTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public void ReadsMessage() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer); + + // Write a message to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length); + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + [Fact] + public void ReadsManyBufferedMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message to use for writing to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + + // How many messages of this size should we write to overflow the buffer? + int overflowMessageCount = + (int)Math.Ceiling( + (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); + + // Write the necessary number of messages to the stream + for (int i = 0; i < overflowMessageCount; i++) + { + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + } + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Read the written messages from the stream + for (int i = 0; i < overflowMessageCount; i++) + { + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + } + + inputStream.Dispose(); + } + + [Fact] + public void ReadMalformedMissingHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthNonIntegerTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a non-integer content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMissingContentLengthHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a a message without a content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthTooShortTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... Pass in an event that has an incorrect content length + // ... And pass in an event that is correct + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... The first read should fail with an exception while deserializing + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... The second read should fail with an exception while reading headers + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedThenValidTest() + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + // ... Then pass it information that is valid and attempt to read a message from it + using (MemoryStream inputStream = new MemoryStream()) + { + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = GetMessageBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading the first one + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... A test event should be successfully read from the second one + Message messageResult = messageReader.ReadMessage().Result; + Assert.NotNull(messageResult); + Assert.Equal("testEvent", messageResult.Method); + } + } + + [Fact] + public void ReaderResizesBufferForLargeMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message with content so large that the buffer will need + // to be resized to fit it all. + byte[] messageBuffer = this.GetMessageBytes( + string.Format( + Common.TestEventFormatString, + new String('X', (int) (MessageReader.DefaultBufferSize*3)))); + + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + private byte[] GetMessageBytes(string messageString, Encoding encoding = null) + { + if (encoding == null) + { + encoding = Encoding.UTF8; + } + + byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); + byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length)); + + // Copy the bytes into a single buffer + byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; + Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); + Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); + + return finalBytes; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs new file mode 100644 index 00000000..3c007a85 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageWriterTests + { + private readonly IMessageSerializer messageSerializer; + + public MessageWriterTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public async Task WritesMessage() + { + MemoryStream outputStream = new MemoryStream(); + MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer); + + // Write the message and then roll back the stream to be read + // TODO: This will need to be redone! + await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null)); + outputStream.Seek(0, SeekOrigin.Begin); + + string expectedHeaderString = string.Format(Constants.ContentLengthFormatString, + Common.ExpectedMessageByteCount); + + byte[] buffer = new byte[128]; + await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); + + Assert.Equal( + expectedHeaderString, + Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); + + // Read the message + await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount); + + Assert.Equal(Common.TestEventString, + Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount)); + + outputStream.Dispose(); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs similarity index 100% rename from test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs From 1acc8c91228eeb2dbf295787b5840b2de5344229 Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 12 Aug 2016 17:38:41 -0700 Subject: [PATCH 03/12] Fixing up error logging Fixing issue where plaintext passwords could be written to logs Fixing up todo issues where error events needed to be passed back from the service layer when the hosting component fails. --- .../Hosting/Contracts/HostingErrorEvent.cs | 28 +++++++++++++++++++ .../Hosting/Protocol/MessageDispatcher.cs | 5 ++++ .../Hosting/Protocol/MessageReader.cs | 7 ----- .../Messaging/TestMessageTypes.cs | 2 +- 4 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs new file mode 100644 index 00000000..4d221acc --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts +{ + /// + /// Parameters to be used for reporting hosting-level errors, such as protocol violations + /// + public class HostingErrorParams + { + /// + /// + /// + public string Message { get; set; } + } + + public class HostingErrorEvent + { + public static readonly + EventType Type = + EventType.Create("hostingError"); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index 7cf1f2dd..d7e5812d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -243,6 +243,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol // previous message. In this case, do not try to dispatch it. if (newMessage != null) { + // Verbose logging + string logMessage = String.Format("Received message of type[{0}] and method[{1}]", + newMessage.MessageType, newMessage.Method); + Logger.Write(LogLevel.Verbose, logMessage); + // Process the message await this.DispatchMessage(newMessage, this.MessageWriter); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index 4351361f..17d4b5e0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -112,13 +112,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); - // Load the message - Logger.Write( - LogLevel.Verbose, - string.Format( - "READ MESSAGE:\r\n\r\n{0}", - messageObject.ToString(Formatting.Indented))); - // Return the parsed message return this.messageSerializer.DeserializeMessage(messageObject); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs index 0ba08056..89238098 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -namespace Microsoft.SqlTools.ServiceLayer.Test.Message +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging { #region Request Types From 062c40368d7a133db4a90cf12e28f0a260f84074 Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Mon, 15 Aug 2016 15:23:07 -0700 Subject: [PATCH 04/12] Adding support for query cancellation Query cancellation support is added via CancellationToken mechanisms that were implemented previously. This change adds a new request type "query/cancel" that will issue the cancellation token. Unit tests were also added. --- .../Contracts/QueryCancelRequest.cs | 36 +++++ .../QueryExecution/Query.cs | 15 +++ .../QueryExecution/QueryExecutionService.cs | 46 +++++++ .../QueryExecution/CancelTests.cs | 124 ++++++++++++++++++ 4 files changed, 221 insertions(+) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs new file mode 100644 index 00000000..3eb87f4f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query cancellation request + /// + public class QueryCancelParams + { + public string OwnerUri { get; set; } + } + + /// + /// Parameters to return as the result of a query dispose request + /// + public class QueryCancelResult + { + /// + /// Any error messages that occurred during disposing the result set. Optional, can be set + /// to null if there were no errors. + /// + public string Messages { get; set; } + } + + public class QueryCancelRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/cancel"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 887bbbaf..ec25fd6c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -233,6 +233,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }; } + /// + /// Cancels the query by issuing the cancellation token + /// + public void Cancel() + { + // Make sure that the query hasn't completed execution + if (HasExecuted) + { + throw new InvalidOperationException("The query has already completed, it cannot be cancelled."); + } + + // Issue the cancellation token for the query + cancellationSource.Cancel(); + } + /// /// Delegate handler for storing messages that are returned from the server /// NOTE: Only messages that are below a certain severity will be returned via this diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 389be092..942beef3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -73,6 +73,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution serviceHost.SetRequestHandler(QueryExecuteRequest.Type, HandleExecuteRequest); serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest); serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest); + serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest); // Register handler for shutdown event serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => @@ -178,6 +179,51 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + public async Task HandleCancelRequest(QueryCancelParams cancelParams, + RequestContext requestContext) + { + try + { + // Attempt to find the query for the owner uri + Query result; + if (!ActiveQueries.TryGetValue(cancelParams.OwnerUri, out result)) + { + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Failed to cancel query, ID not found." + }); + return; + } + + // Cancel the query + result.Cancel(); + + // Attempt to dispose the query + if (!ActiveQueries.TryRemove(cancelParams.OwnerUri, out result)) + { + // It really shouldn't be possible to get to this scenario, but we'll cover it anyhow + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Query successfully cancelled, failed to dispose query. ID not found." + }); + return; + } + + await requestContext.SendResult(new QueryCancelResult()); + } + catch (InvalidOperationException e) + { + await requestContext.SendResult(new QueryCancelResult + { + Messages = e.Message + }); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + #endregion #region Private Helpers diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs new file mode 100644 index 00000000..dbc344f8 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -0,0 +1,124 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class CancelTests + { + [Fact] + public void CancelInProgressQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and execute it + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri }; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a successful event (no messages) + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.Null(result.Messages); + + // ... The query should have been disposed as well + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void CancelExecutedQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and wait for execution + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri}; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + + // ... The query should not have been disposed + Assert.NotEmpty(queryService.ActiveQueries); + } + + [Fact] + public void CancelNonExistantTest() + { + // If: + // ... I request to cancel a query that doesn't exist + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + } + + #region Mocking + + private static Mock> GetQueryCancelResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private static void VerifyQueryCancelCallCount(Mock> mock, + Times sendResultCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} From b7f88084c0dc6e6f3a111e0dec9641cae2d45217 Mon Sep 17 00:00:00 2001 From: Mitchell Sternke Date: Tue, 16 Aug 2016 12:21:42 -0700 Subject: [PATCH 05/12] Added initial tests for the connection manager's intellisense cache --- .../LanguageServices/AutoCompleteService.cs | 37 ++- .../LanguageServer/LanguageServiceTests.cs | 211 ++++++++++++++++-- 2 files changed, 232 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs index 14148778..bae71d86 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs @@ -235,13 +235,36 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } } + + private ConnectionService connectionService = null; + + /// + /// Internal for testing purposes only + /// + internal ConnectionService ConnectionServiceInstance + { + get + { + if(connectionService == null) + { + connectionService = ConnectionService.Instance; + } + return connectionService; + } + + set + { + connectionService = value; + } + } + public void InitializeService(ServiceHost serviceHost) { // Register a callback for when a connection is created - ConnectionService.Instance.RegisterOnConnectionTask(UpdateAutoCompleteCache); + ConnectionServiceInstance.RegisterOnConnectionTask(UpdateAutoCompleteCache); // Register a callback for when a connection is closed - ConnectionService.Instance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); + ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); } private async Task UpdateAutoCompleteCache(ConnectionInfo connectionInfo) @@ -252,6 +275,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } + /// + /// Intellisense cache count access for testing. + /// + public int GetCacheCount() + { + return caches.Count; + } + /// /// Remove a reference to an autocomplete cache from a URI. If /// it is the last URI connected to a particular connection, @@ -312,7 +343,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // that are not backed by a SQL connection ConnectionInfo info; IntellisenseCache cache; - if (ConnectionService.Instance.TryFindConnection(textDocumentPosition.Uri, out info) + if (ConnectionServiceInstance.TryFindConnection(textDocumentPosition.Uri, out info) && caches.TryGetValue((ConnectionSummary)info.ConnectionDetails, out cache)) { return cache.GetAutoCompleteItems(textDocumentPosition).ToArray(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs index 80ea3ec9..ddf82059 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs @@ -3,11 +3,18 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System.Collections.Generic; +using System.Data; +using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Microsoft.SqlTools.Test.Utility; +using Moq; +using Moq.Protected; using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices @@ -19,6 +26,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices { #region "Diagnostics tests" + /// + /// Verify that the latest SqlParser (2016 as of this writing) is used by default + /// + [Fact] + public void LatestSqlParserIsUsedByDefault() + { + // This should only parse correctly on SQL server 2016 or newer + const string sql2016Text = + @"CREATE SECURITY POLICY [FederatedSecurityPolicy]" + "\r\n" + + @"ADD FILTER PREDICATE [rls].[fn_securitypredicate]([CustomerId])" + "\r\n" + + @"ON [dbo].[Customer];"; + + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sql2016Text); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify that no errors are detected + Assert.Equal(0, fileMarkers.Length); + } + /// /// Verify that the SQL parser correctly detects errors in text /// @@ -108,24 +138,179 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices #region "Autocomplete Tests" /// - /// Verify that the SQL parser correctly detects errors in text + /// Creates a mock db command that returns a predefined result set + /// + public static DbCommand CreateTestCommand(Dictionary[][] data) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + commandMockSetup.Returns(new TestDbDataReader(data)); + + return commandMock.Object; + } + + /// + /// Creates a mock db connection that returns predefined data when queried for a result set + /// + public DbConnection CreateMockDbConnection(Dictionary[][] data) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data)); + + return connectionMock.Object; + } + + /// + /// Verify that the autocomplete service returns tables for the current connection as suggestions /// [Fact] - public async Task AutocompleteTest() + public void TablesAreReturnedAsAutocompleteSuggestions() { - // TODO Re-enable this test once we have a way to hook up the right auto-complete and connection services. - // Probably need a service provider channel so that we can mock service access. Otherwise everything accesses - // static instances and cannot be properly tested. - - //var autocompleteService = TestObjects.GetAutoCompleteService(); - //var connectionService = TestObjects.GetTestConnectionService(); + // Result set for the query of database tables + Dictionary[] data = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } } + }; - //ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); - //var connectionResult = connectionService.Connect(connectionRequest); + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data})); + + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + autocompleteService.ConnectionFactory = mockFactory.Object; - //var sqlConnection = connectionService.ActiveConnections[connectionResult.ConnectionId]; - //await autocompleteService.UpdateAutoCompleteCache(sqlConnection); - await Task.Run(() => { return; }); + // Open a connection + // The cache should get updated as part of this + ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); + var connectionResult = connectionService.Connect(connectionRequest); + Assert.NotEmpty(connectionResult.ConnectionId); + + // Check that there is one cache created in the auto complete service + Assert.Equal(1, autocompleteService.GetCacheCount()); + + // Check that we get table suggestions for an autocomplete request + TextDocumentPosition position = new TextDocumentPosition(); + position.Uri = connectionRequest.OwnerUri; + position.Position = new Position(); + position.Position.Line = 1; + position.Position.Character = 1; + var items = autocompleteService.GetCompletionItems(position); + Assert.Equal(2, items.Length); + Assert.Equal("master", items[0].Label); + Assert.Equal("model", items[1].Label); + } + + /// + /// Verify that only one intellisense cache is created for two documents using + /// the autocomplete service when they share a common connection. + /// + [Fact] + public void OnlyOneCacheIsCreatedForTwoDocumentsWithSameConnection() + { + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + // Open two connections + ConnectParams connectionRequest1 = TestObjects.GetTestConnectionParams(); + connectionRequest1.OwnerUri = "file:///my/first/file.sql"; + ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams(); + connectionRequest2.OwnerUri = "file:///my/second/file.sql"; + var connectionResult1 = connectionService.Connect(connectionRequest1); + Assert.NotEmpty(connectionResult1.ConnectionId); + var connectionResult2 = connectionService.Connect(connectionRequest2); + Assert.NotEmpty(connectionResult2.ConnectionId); + + // Verify that only one intellisense cache is created to service both URI's + Assert.Equal(1, autocompleteService.GetCacheCount()); + } + + /// + /// Verify that two different intellisense caches and corresponding autocomplete + /// suggestions are provided for two documents with different connections. + /// + [Fact] + public void TwoCachesAreCreatedForTwoDocumentsWithDifferentConnections() + { + // Result set for the query of database tables + Dictionary[] data1 = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } } + }; + + Dictionary[] data2 = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "my_table" } }, + new Dictionary { {"name", "my_other_table" } } + }; + + var mockFactory = new Mock(); + mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data1})) + .Returns(CreateMockDbConnection(new[] {data2})); + + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + autocompleteService.ConnectionFactory = mockFactory.Object; + + // Open connections + // The cache should get updated as part of this + ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); + connectionRequest.OwnerUri = "file:///my/first/sql/file.sql"; + var connectionResult = connectionService.Connect(connectionRequest); + Assert.NotEmpty(connectionResult.ConnectionId); + + // Check that there is one cache created in the auto complete service + Assert.Equal(1, autocompleteService.GetCacheCount()); + + // Open second connection + ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams(); + connectionRequest2.OwnerUri = "file:///my/second/sql/file.sql"; + connectionRequest2.Connection.DatabaseName = "my_other_db"; + var connectionResult2 = connectionService.Connect(connectionRequest2); + Assert.NotEmpty(connectionResult2.ConnectionId); + + // Check that there are now two caches in the auto complete service + Assert.Equal(2, autocompleteService.GetCacheCount()); + + // Check that we get 2 different table suggestions for autocomplete requests + TextDocumentPosition position = new TextDocumentPosition(); + position.Uri = connectionRequest.OwnerUri; + position.Position = new Position(); + position.Position.Line = 1; + position.Position.Character = 1; + + var items = autocompleteService.GetCompletionItems(position); + Assert.Equal(2, items.Length); + Assert.Equal("master", items[0].Label); + Assert.Equal("model", items[1].Label); + + TextDocumentPosition position2 = new TextDocumentPosition(); + position2.Uri = connectionRequest2.OwnerUri; + position2.Position = new Position(); + position2.Position.Line = 1; + position2.Position.Character = 1; + + var items2 = autocompleteService.GetCompletionItems(position2); + Assert.Equal(3, items2.Length); + Assert.Equal("master", items2[0].Label); + Assert.Equal("my_table", items2[1].Label); + Assert.Equal("my_other_table", items2[2].Label); } #endregion From 9fa183ea6dddf48c117c08758a5c8e48988f4017 Mon Sep 17 00:00:00 2001 From: benrr101 Date: Tue, 16 Aug 2016 12:28:52 -0700 Subject: [PATCH 06/12] Fixing String.Format to string.Format --- .../Hosting/Protocol/MessageDispatcher.cs | 6 +++--- .../LanguageServices/LanguageService.cs | 2 +- .../Workspace/WorkspaceService.cs | 2 +- .../QueryExecution/Common.cs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index d7e5812d..c4cf5365 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -210,7 +210,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (MessageParseException e) { - string message = String.Format("Exception occurred while parsing message: {0}", e.Message); + string message = string.Format("Exception occurred while parsing message: {0}", e.Message); Logger.Write(LogLevel.Error, message); await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams { @@ -228,7 +228,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol catch (Exception e) { // Log the error and send an error event to the client - string message = String.Format("Exception occurred while receiving message: {0}", e.Message); + string message = string.Format("Exception occurred while receiving message: {0}", e.Message); Logger.Write(LogLevel.Error, message); await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams { @@ -244,7 +244,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol if (newMessage != null) { // Verbose logging - string logMessage = String.Format("Received message of type[{0}] and method[{1}]", + string logMessage = string.Format("Received message of type[{0}] and method[{1}]", newMessage.MessageType, newMessage.Method); Logger.Write(LogLevel.Verbose, logMessage); diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 6cdbd745..f380f1bb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -337,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { Logger.Write( LogLevel.Error, - String.Format( + string.Format( "Exception while cancelling analysis task:\n\n{0}", e.ToString())); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs index 701fa6f5..f47cacb9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -182,7 +182,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace foreach (var textChange in textChangeParams.ContentChanges) { string fileUri = textChangeParams.Uri ?? textChangeParams.TextDocument.Uri; - msg.AppendLine(String.Format(" File: {0}", fileUri)); + msg.AppendLine(string.Format(" File: {0}", fileUri)); ScriptFile changedFile = Workspace.GetFile(fileUri); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 9bc8053b..292e3264 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -36,7 +36,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Dictionary rowDictionary = new Dictionary(); for (int column = 0; column < columns; column++) { - rowDictionary.Add(String.Format("column{0}", column), String.Format("val{0}{1}", column, row)); + rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row)); } output[row] = rowDictionary; } From 71d852ba076eb1ea44d797db28ddd2f50f139d9c Mon Sep 17 00:00:00 2001 From: benrr101 Date: Tue, 16 Aug 2016 12:35:39 -0700 Subject: [PATCH 07/12] Renaming the hosting error method --- .../Hosting/Contracts/HostingErrorEvent.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs index 4d221acc..d6e65801 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -13,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts public class HostingErrorParams { /// - /// + /// The message of the error /// public string Message { get; set; } } @@ -22,7 +22,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts { public static readonly EventType Type = - EventType.Create("hostingError"); + EventType.Create("hosting/error"); } } From a6cc14d31f3c173bfe8190a43130ad7e858119e2 Mon Sep 17 00:00:00 2001 From: Mitchell Sternke Date: Tue, 16 Aug 2016 13:48:36 -0700 Subject: [PATCH 08/12] Cleaned up connection management code --- .../Connection/ConnectionInfo.cs | 38 +++++ .../Connection/ConnectionService.cs | 26 --- .../Contracts/ConnectMessagesExtensions.cs | 19 +++ .../Contracts/ConnectionSummaryComparer.cs | 53 ++++++ .../LanguageServices/AutoCompleteService.cs | 155 ------------------ .../LanguageServices/IntellisenseCache.cs | 119 ++++++++++++++ 6 files changed, 229 insertions(+), 181 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs new file mode 100644 index 00000000..7490b33c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -0,0 +1,38 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Information pertaining to a unique connection instance. + /// + public class ConnectionInfo + { + public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details) + { + Factory = factory; + OwnerUri = ownerUri; + ConnectionDetails = details; + ConnectionId = Guid.NewGuid(); + } + + /// + /// Unique Id, helpful to identify a connection info object + /// + public Guid ConnectionId { get; private set; } + + public string OwnerUri { get; private set; } + + public ISqlConnectionFactory Factory {get; private set;} + + public ConnectionDetails ConnectionDetails { get; private set; } + + public DbConnection SqlConnection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 8f430e29..8905ab92 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -5,42 +5,16 @@ using System; using System.Collections.Generic; -using System.Data.Common; using System.Data.SqlClient; using System.Threading.Tasks; using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Workspace; namespace Microsoft.SqlTools.ServiceLayer.Connection { - public class ConnectionInfo - { - public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details) - { - Factory = factory; - OwnerUri = ownerUri; - ConnectionDetails = details; - ConnectionId = Guid.NewGuid(); - } - - /// - /// Unique Id, helpful to identify a connection info object - /// - public Guid ConnectionId { get; private set; } - - public string OwnerUri { get; private set; } - - public ISqlConnectionFactory Factory {get; private set;} - - public ConnectionDetails ConnectionDetails { get; private set; } - - public DbConnection SqlConnection { get; set; } - } - /// /// Main class for the Connection Management services /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs index b9e73e09..d15adae6 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs @@ -27,4 +27,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts ); } } + + /// + /// Extension methods to ConnectionSummary + /// + public static class ConnectionSummaryExtensions + { + /// + /// Create a copy of a ConnectionSummary object + /// + public static ConnectionSummary Clone(this ConnectionSummary summary) + { + return new ConnectionSummary() + { + ServerName = summary.ServerName, + DatabaseName = summary.DatabaseName, + UserName = summary.UserName + }; + } + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs new file mode 100644 index 00000000..dfeb0ab4 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + + /// + /// Treats connections as the same if their server, db and usernames all match + /// + public class ConnectionSummaryComparer : IEqualityComparer + { + public bool Equals(ConnectionSummary x, ConnectionSummary y) + { + if(x == y) { return true; } + else if(x != null) + { + if(y == null) { return false; } + + // Compare server, db, username. Note: server is case-insensitive in the driver + return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0 + && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0 + && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0; + } + return false; + } + + public int GetHashCode(ConnectionSummary obj) + { + int hashcode = 31; + if(obj != null) + { + if(obj.ServerName != null) + { + hashcode ^= obj.ServerName.GetHashCode(); + } + if (obj.DatabaseName != null) + { + hashcode ^= obj.DatabaseName.GetHashCode(); + } + if (obj.UserName != null) + { + hashcode ^= obj.UserName.GetHashCode(); + } + } + return hashcode; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs index 14148778..616fc2c6 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs @@ -5,8 +5,6 @@ using System; using System.Collections.Generic; -using System.Data; -using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; @@ -16,159 +14,6 @@ using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { - internal class IntellisenseCache - { - // connection used to query for intellisense info - private DbConnection connection; - - // number of documents (URI's) that are using the cache for the same database - // the autocomplete service uses this to remove unreferenced caches - public int ReferenceCount { get; set; } - - public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails) - { - ReferenceCount = 0; - DatabaseInfo = CopySummary(connectionDetails); - - // TODO error handling on this. Intellisense should catch or else the service should handle - connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails)); - connection.Open(); - } - - /// - /// Used to identify a database for which this cache is used - /// - public ConnectionSummary DatabaseInfo - { - get; - private set; - } - /// - /// Gets the current autocomplete candidate list - /// - public IEnumerable AutoCompleteList { get; private set; } - - public async Task UpdateCache() - { - DbCommand command = connection.CreateCommand(); - command.CommandText = "SELECT name FROM sys.tables"; - command.CommandTimeout = 15; - command.CommandType = CommandType.Text; - var reader = await command.ExecuteReaderAsync(); - - List results = new List(); - while (await reader.ReadAsync()) - { - results.Add(reader[0].ToString()); - } - - AutoCompleteList = results; - await Task.FromResult(0); - } - - public List GetAutoCompleteItems(TextDocumentPosition textDocumentPosition) - { - List completions = new List(); - - int i = 0; - - // Take a reference to the list at a point in time in case we update and replace the list - var suggestions = AutoCompleteList; - // the completion list will be null is user not connected to server - if (this.AutoCompleteList != null) - { - - foreach (var autoCompleteItem in suggestions) - { - // convert the completion item candidates into CompletionItems - completions.Add(new CompletionItem() - { - Label = autoCompleteItem, - Kind = CompletionItemKind.Keyword, - Detail = autoCompleteItem + " details", - Documentation = autoCompleteItem + " documentation", - TextEdit = new TextEdit - { - NewText = autoCompleteItem, - Range = new Range - { - Start = new Position - { - Line = textDocumentPosition.Position.Line, - Character = textDocumentPosition.Position.Character - }, - End = new Position - { - Line = textDocumentPosition.Position.Line, - Character = textDocumentPosition.Position.Character + 5 - } - } - } - }); - - // only show 50 items - if (++i == 50) - { - break; - } - } - } - - return completions; - } - - private static ConnectionSummary CopySummary(ConnectionSummary summary) - { - return new ConnectionSummary() - { - ServerName = summary.ServerName, - DatabaseName = summary.DatabaseName, - UserName = summary.UserName - }; - } - } - - /// - /// Treats connections as the same if their server, db and usernames all match - /// - public class ConnectionSummaryComparer : IEqualityComparer - { - public bool Equals(ConnectionSummary x, ConnectionSummary y) - { - if(x == y) { return true; } - else if(x != null) - { - if(y == null) { return false; } - - // Compare server, db, username. Note: server is case-insensitive in the driver - return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0 - && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0 - && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0; - } - return false; - } - - public int GetHashCode(ConnectionSummary obj) - { - int hashcode = 31; - if(obj != null) - { - if(obj.ServerName != null) - { - hashcode ^= obj.ServerName.GetHashCode(); - } - if (obj.DatabaseName != null) - { - hashcode ^= obj.DatabaseName.GetHashCode(); - } - if (obj.UserName != null) - { - hashcode ^= obj.UserName.GetHashCode(); - } - } - return hashcode; - } - } /// /// Main class for Autocomplete functionality /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs new file mode 100644 index 00000000..ca4c155a --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs @@ -0,0 +1,119 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + + internal class IntellisenseCache + { + // connection used to query for intellisense info + private DbConnection connection; + + // number of documents (URI's) that are using the cache for the same database + // the autocomplete service uses this to remove unreferenced caches + public int ReferenceCount { get; set; } + + public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails) + { + ReferenceCount = 0; + DatabaseInfo = connectionDetails.Clone(); + + // TODO error handling on this. Intellisense should catch or else the service should handle + connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails)); + connection.Open(); + } + + /// + /// Used to identify a database for which this cache is used + /// + public ConnectionSummary DatabaseInfo + { + get; + private set; + } + /// + /// Gets the current autocomplete candidate list + /// + public IEnumerable AutoCompleteList { get; private set; } + + public async Task UpdateCache() + { + DbCommand command = connection.CreateCommand(); + command.CommandText = "SELECT name FROM sys.tables"; + command.CommandTimeout = 15; + command.CommandType = CommandType.Text; + var reader = await command.ExecuteReaderAsync(); + + List results = new List(); + while (await reader.ReadAsync()) + { + results.Add(reader[0].ToString()); + } + + AutoCompleteList = results; + await Task.FromResult(0); + } + + public List GetAutoCompleteItems(TextDocumentPosition textDocumentPosition) + { + List completions = new List(); + + int i = 0; + + // Take a reference to the list at a point in time in case we update and replace the list + var suggestions = AutoCompleteList; + // the completion list will be null is user not connected to server + if (this.AutoCompleteList != null) + { + + foreach (var autoCompleteItem in suggestions) + { + // convert the completion item candidates into CompletionItems + completions.Add(new CompletionItem() + { + Label = autoCompleteItem, + Kind = CompletionItemKind.Keyword, + Detail = autoCompleteItem + " details", + Documentation = autoCompleteItem + " documentation", + TextEdit = new TextEdit + { + NewText = autoCompleteItem, + Range = new Range + { + Start = new Position + { + Line = textDocumentPosition.Position.Line, + Character = textDocumentPosition.Position.Character + }, + End = new Position + { + Line = textDocumentPosition.Position.Line, + Character = textDocumentPosition.Position.Character + 5 + } + } + } + }); + + // only show 50 items + if (++i == 50) + { + break; + } + } + } + + return completions; + } + } +} From 6ffdf644baf6c1a6b4b4fce744506d660f526bc0 Mon Sep 17 00:00:00 2001 From: Mitchell Sternke Date: Tue, 16 Aug 2016 15:27:26 -0700 Subject: [PATCH 09/12] Addressing code review feedback --- .../Connection/ConnectionInfo.cs | 15 ++++ .../Connection/Contracts/ConnectMessages.cs | 89 ------------------- .../Connection/Contracts/ConnectParams.cs | 26 ++++++ ...tensions.cs => ConnectParamsExtensions.cs} | 19 ---- .../Connection/Contracts/ConnectResponse.cs | 23 +++++ .../ConnectionChangedNotification.cs | 19 ++++ ...Messages.cs => ConnectionChangedParams.cs} | 13 --- .../Connection/Contracts/ConnectionDetails.cs | 21 +++++ .../Connection/Contracts/ConnectionRequest.cs | 19 ++++ .../Connection/Contracts/ConnectionSummary.cs | 28 ++++++ .../Contracts/ConnectionSummaryExtensions.cs | 26 ++++++ ...connectMessages.cs => DisconnectParams.cs} | 12 --- .../Connection/Contracts/DisconnectRequest.cs | 19 ++++ .../LanguageServices/IntellisenseCache.cs | 11 ++- 14 files changed, 203 insertions(+), 137 deletions(-) delete mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs rename src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/{ConnectMessagesExtensions.cs => ConnectParamsExtensions.cs} (64%) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs rename src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/{ConnectionChangedMessages.cs => ConnectionChangedParams.cs} (68%) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs rename src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/{DisconnectMessages.cs => DisconnectParams.cs} (63%) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs index 7490b33c..31d0026d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -14,6 +14,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// public class ConnectionInfo { + /// + /// Constructor + /// public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details) { Factory = factory; @@ -27,12 +30,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// public Guid ConnectionId { get; private set; } + /// + /// URI identifying the owner/user of the connection. Could be a file, service, resource, etc. + /// public string OwnerUri { get; private set; } + /// + /// Factory used for creating the SQL connection associated with the connection info. + /// public ISqlConnectionFactory Factory {get; private set;} + /// + /// Properties used for creating/opening the SQL connection. + /// public ConnectionDetails ConnectionDetails { get; private set; } + /// + /// The connection to the SQL database that commands will be run against. + /// public DbConnection SqlConnection { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs deleted file mode 100644 index 543b18f5..00000000 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs +++ /dev/null @@ -1,89 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; - -namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts -{ - /// - /// Parameters for the Connect Request. - /// - public class ConnectParams - { - /// - /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace - /// or a virtual file representing an object in a database. - /// - public string OwnerUri { get; set; } - /// - /// Contains the required parameters to initialize a connection to a database. - /// A connection will identified by its server name, database name and user name. - /// This may be changed in the future to support multiple connections with different - /// connection properties to the same database. - /// - public ConnectionDetails Connection { get; set; } - } - - /// - /// Message format for the connection result response - /// - public class ConnectResponse - { - /// - /// A GUID representing a unique connection ID - /// - public string ConnectionId { get; set; } - - /// - /// Gets or sets any connection error messages - /// - public string Messages { get; set; } - } - - /// - /// Provides high level information about a connection. - /// - public class ConnectionSummary - { - /// - /// Gets or sets the connection server name - /// - public string ServerName { get; set; } - - /// - /// Gets or sets the connection database name - /// - public string DatabaseName { get; set; } - - /// - /// Gets or sets the connection user name - /// - public string UserName { get; set; } - } - - /// - /// Message format for the initial connection request - /// - public class ConnectionDetails : ConnectionSummary - { - /// - /// Gets or sets the connection password - /// - /// - public string Password { get; set; } - - // TODO Handle full set of properties - } - - /// - /// Connect request mapping entry - /// - public class ConnectionRequest - { - public static readonly - RequestType Type = - RequestType.Create("connection/connect"); - } -} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs new file mode 100644 index 00000000..31dad8c5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the Connect Request. + /// + public class ConnectParams + { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + /// + /// Contains the required parameters to initialize a connection to a database. + /// A connection will identified by its server name, database name and user name. + /// This may be changed in the future to support multiple connections with different + /// connection properties to the same database. + /// + public ConnectionDetails Connection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs similarity index 64% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs index d15adae6..b9e73e09 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs @@ -27,23 +27,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts ); } } - - /// - /// Extension methods to ConnectionSummary - /// - public static class ConnectionSummaryExtensions - { - /// - /// Create a copy of a ConnectionSummary object - /// - public static ConnectionSummary Clone(this ConnectionSummary summary) - { - return new ConnectionSummary() - { - ServerName = summary.ServerName, - DatabaseName = summary.DatabaseName, - UserName = summary.UserName - }; - } - } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs new file mode 100644 index 00000000..c325c64f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs @@ -0,0 +1,23 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the connection result response + /// + public class ConnectResponse + { + /// + /// A GUID representing a unique connection ID + /// + public string ConnectionId { get; set; } + + /// + /// Gets or sets any connection error messages + /// + public string Messages { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs new file mode 100644 index 00000000..c0daee6d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// ConnectionChanged notification mapping entry + /// + public class ConnectionChangedNotification + { + public static readonly + EventType Type = + EventType.Create("connection/connectionchanged"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs similarity index 68% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs index 94454bc5..3db86f34 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs @@ -3,8 +3,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; - namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { /// @@ -22,15 +20,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// public ConnectionSummary Connection { get; set; } } - - /// - /// ConnectionChanged notification mapping entry - /// - public class ConnectionChangedNotification - { - public static readonly - EventType Type = - EventType.Create("connection/connectionchanged"); - } - } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs new file mode 100644 index 00000000..0acac867 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -0,0 +1,21 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the initial connection request + /// + public class ConnectionDetails : ConnectionSummary + { + /// + /// Gets or sets the connection password + /// + /// + public string Password { get; set; } + + // TODO Handle full set of properties + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs new file mode 100644 index 00000000..50251e12 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Connect request mapping entry + /// + public class ConnectionRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/connect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs new file mode 100644 index 00000000..11549e85 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Provides high level information about a connection. + /// + public class ConnectionSummary + { + /// + /// Gets or sets the connection server name + /// + public string ServerName { get; set; } + + /// + /// Gets or sets the connection database name + /// + public string DatabaseName { get; set; } + + /// + /// Gets or sets the connection user name + /// + public string UserName { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs new file mode 100644 index 00000000..02bc7623 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Extension methods to ConnectionSummary + /// + public static class ConnectionSummaryExtensions + { + /// + /// Create a copy of a ConnectionSummary object + /// + public static ConnectionSummary Clone(this ConnectionSummary summary) + { + return new ConnectionSummary() + { + ServerName = summary.ServerName, + DatabaseName = summary.DatabaseName, + UserName = summary.UserName + }; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs similarity index 63% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs index c078b308..91bc7faf 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs @@ -3,8 +3,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; - namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { /// @@ -18,14 +16,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// public string OwnerUri { get; set; } } - - /// - /// Disconnect request mapping entry - /// - public class DisconnectRequest - { - public static readonly - RequestType Type = - RequestType.Create("connection/disconnect"); - } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs new file mode 100644 index 00000000..cbf67ef2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Disconnect request mapping entry + /// + public class DisconnectRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/disconnect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs index ca4c155a..eea72771 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs @@ -14,14 +14,17 @@ using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { - internal class IntellisenseCache { - // connection used to query for intellisense info + /// + /// connection used to query for intellisense info + /// private DbConnection connection; - // number of documents (URI's) that are using the cache for the same database - // the autocomplete service uses this to remove unreferenced caches + /// + /// Number of documents (URI's) that are using the cache for the same database. + /// The autocomplete service uses this to remove unreferenced caches. + /// public int ReferenceCount { get; set; } public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails) From c80a90331df771801907d3fa5eb35d30dc26217a Mon Sep 17 00:00:00 2001 From: Mitchell Sternke Date: Tue, 16 Aug 2016 12:21:42 -0700 Subject: [PATCH 10/12] Added initial tests for the connection manager's intellisense cache --- .../LanguageServices/AutoCompleteService.cs | 37 ++- .../LanguageServer/LanguageServiceTests.cs | 211 ++++++++++++++++-- 2 files changed, 232 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs index 616fc2c6..f8feae6a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs @@ -80,13 +80,36 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } } + + private ConnectionService connectionService = null; + + /// + /// Internal for testing purposes only + /// + internal ConnectionService ConnectionServiceInstance + { + get + { + if(connectionService == null) + { + connectionService = ConnectionService.Instance; + } + return connectionService; + } + + set + { + connectionService = value; + } + } + public void InitializeService(ServiceHost serviceHost) { // Register a callback for when a connection is created - ConnectionService.Instance.RegisterOnConnectionTask(UpdateAutoCompleteCache); + ConnectionServiceInstance.RegisterOnConnectionTask(UpdateAutoCompleteCache); // Register a callback for when a connection is closed - ConnectionService.Instance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); + ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); } private async Task UpdateAutoCompleteCache(ConnectionInfo connectionInfo) @@ -97,6 +120,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } + /// + /// Intellisense cache count access for testing. + /// + public int GetCacheCount() + { + return caches.Count; + } + /// /// Remove a reference to an autocomplete cache from a URI. If /// it is the last URI connected to a particular connection, @@ -157,7 +188,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // that are not backed by a SQL connection ConnectionInfo info; IntellisenseCache cache; - if (ConnectionService.Instance.TryFindConnection(textDocumentPosition.Uri, out info) + if (ConnectionServiceInstance.TryFindConnection(textDocumentPosition.Uri, out info) && caches.TryGetValue((ConnectionSummary)info.ConnectionDetails, out cache)) { return cache.GetAutoCompleteItems(textDocumentPosition).ToArray(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs index 80ea3ec9..ddf82059 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs @@ -3,11 +3,18 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System.Collections.Generic; +using System.Data; +using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Microsoft.SqlTools.Test.Utility; +using Moq; +using Moq.Protected; using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices @@ -19,6 +26,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices { #region "Diagnostics tests" + /// + /// Verify that the latest SqlParser (2016 as of this writing) is used by default + /// + [Fact] + public void LatestSqlParserIsUsedByDefault() + { + // This should only parse correctly on SQL server 2016 or newer + const string sql2016Text = + @"CREATE SECURITY POLICY [FederatedSecurityPolicy]" + "\r\n" + + @"ADD FILTER PREDICATE [rls].[fn_securitypredicate]([CustomerId])" + "\r\n" + + @"ON [dbo].[Customer];"; + + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sql2016Text); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify that no errors are detected + Assert.Equal(0, fileMarkers.Length); + } + /// /// Verify that the SQL parser correctly detects errors in text /// @@ -108,24 +138,179 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices #region "Autocomplete Tests" /// - /// Verify that the SQL parser correctly detects errors in text + /// Creates a mock db command that returns a predefined result set + /// + public static DbCommand CreateTestCommand(Dictionary[][] data) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + commandMockSetup.Returns(new TestDbDataReader(data)); + + return commandMock.Object; + } + + /// + /// Creates a mock db connection that returns predefined data when queried for a result set + /// + public DbConnection CreateMockDbConnection(Dictionary[][] data) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data)); + + return connectionMock.Object; + } + + /// + /// Verify that the autocomplete service returns tables for the current connection as suggestions /// [Fact] - public async Task AutocompleteTest() + public void TablesAreReturnedAsAutocompleteSuggestions() { - // TODO Re-enable this test once we have a way to hook up the right auto-complete and connection services. - // Probably need a service provider channel so that we can mock service access. Otherwise everything accesses - // static instances and cannot be properly tested. - - //var autocompleteService = TestObjects.GetAutoCompleteService(); - //var connectionService = TestObjects.GetTestConnectionService(); + // Result set for the query of database tables + Dictionary[] data = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } } + }; - //ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); - //var connectionResult = connectionService.Connect(connectionRequest); + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data})); + + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + autocompleteService.ConnectionFactory = mockFactory.Object; - //var sqlConnection = connectionService.ActiveConnections[connectionResult.ConnectionId]; - //await autocompleteService.UpdateAutoCompleteCache(sqlConnection); - await Task.Run(() => { return; }); + // Open a connection + // The cache should get updated as part of this + ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); + var connectionResult = connectionService.Connect(connectionRequest); + Assert.NotEmpty(connectionResult.ConnectionId); + + // Check that there is one cache created in the auto complete service + Assert.Equal(1, autocompleteService.GetCacheCount()); + + // Check that we get table suggestions for an autocomplete request + TextDocumentPosition position = new TextDocumentPosition(); + position.Uri = connectionRequest.OwnerUri; + position.Position = new Position(); + position.Position.Line = 1; + position.Position.Character = 1; + var items = autocompleteService.GetCompletionItems(position); + Assert.Equal(2, items.Length); + Assert.Equal("master", items[0].Label); + Assert.Equal("model", items[1].Label); + } + + /// + /// Verify that only one intellisense cache is created for two documents using + /// the autocomplete service when they share a common connection. + /// + [Fact] + public void OnlyOneCacheIsCreatedForTwoDocumentsWithSameConnection() + { + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + // Open two connections + ConnectParams connectionRequest1 = TestObjects.GetTestConnectionParams(); + connectionRequest1.OwnerUri = "file:///my/first/file.sql"; + ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams(); + connectionRequest2.OwnerUri = "file:///my/second/file.sql"; + var connectionResult1 = connectionService.Connect(connectionRequest1); + Assert.NotEmpty(connectionResult1.ConnectionId); + var connectionResult2 = connectionService.Connect(connectionRequest2); + Assert.NotEmpty(connectionResult2.ConnectionId); + + // Verify that only one intellisense cache is created to service both URI's + Assert.Equal(1, autocompleteService.GetCacheCount()); + } + + /// + /// Verify that two different intellisense caches and corresponding autocomplete + /// suggestions are provided for two documents with different connections. + /// + [Fact] + public void TwoCachesAreCreatedForTwoDocumentsWithDifferentConnections() + { + // Result set for the query of database tables + Dictionary[] data1 = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } } + }; + + Dictionary[] data2 = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "my_table" } }, + new Dictionary { {"name", "my_other_table" } } + }; + + var mockFactory = new Mock(); + mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data1})) + .Returns(CreateMockDbConnection(new[] {data2})); + + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + autocompleteService.ConnectionFactory = mockFactory.Object; + + // Open connections + // The cache should get updated as part of this + ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); + connectionRequest.OwnerUri = "file:///my/first/sql/file.sql"; + var connectionResult = connectionService.Connect(connectionRequest); + Assert.NotEmpty(connectionResult.ConnectionId); + + // Check that there is one cache created in the auto complete service + Assert.Equal(1, autocompleteService.GetCacheCount()); + + // Open second connection + ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams(); + connectionRequest2.OwnerUri = "file:///my/second/sql/file.sql"; + connectionRequest2.Connection.DatabaseName = "my_other_db"; + var connectionResult2 = connectionService.Connect(connectionRequest2); + Assert.NotEmpty(connectionResult2.ConnectionId); + + // Check that there are now two caches in the auto complete service + Assert.Equal(2, autocompleteService.GetCacheCount()); + + // Check that we get 2 different table suggestions for autocomplete requests + TextDocumentPosition position = new TextDocumentPosition(); + position.Uri = connectionRequest.OwnerUri; + position.Position = new Position(); + position.Position.Line = 1; + position.Position.Character = 1; + + var items = autocompleteService.GetCompletionItems(position); + Assert.Equal(2, items.Length); + Assert.Equal("master", items[0].Label); + Assert.Equal("model", items[1].Label); + + TextDocumentPosition position2 = new TextDocumentPosition(); + position2.Uri = connectionRequest2.OwnerUri; + position2.Position = new Position(); + position2.Position.Line = 1; + position2.Position.Character = 1; + + var items2 = autocompleteService.GetCompletionItems(position2); + Assert.Equal(3, items2.Length); + Assert.Equal("master", items2[0].Label); + Assert.Equal("my_table", items2[1].Label); + Assert.Equal("my_other_table", items2[2].Label); } #endregion From e33df61dc33a5bb68b45ddc596bbd0e6448d4024 Mon Sep 17 00:00:00 2001 From: Mitchell Sternke Date: Tue, 16 Aug 2016 15:53:00 -0700 Subject: [PATCH 11/12] Addressing code review feedback --- .../LanguageServices/AutoCompleteService.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs index f8feae6a..a390eae2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs @@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// /// Intellisense cache count access for testing. /// - public int GetCacheCount() + internal int GetCacheCount() { return caches.Count; } From 709123eaafcd02322bff84bbed895fe6e4a2dac0 Mon Sep 17 00:00:00 2001 From: benrr101 Date: Tue, 16 Aug 2016 16:14:27 -0700 Subject: [PATCH 12/12] Final iteration, fixing a couple mistakes for query exceptions --- src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs | 2 -- .../QueryExecution/QueryExecutionService.cs | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index ec25fd6c..d9a886d4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -177,12 +177,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { HasError = true; UnwrapDbException(dbe); - conn?.Dispose(); } catch (Exception) { HasError = true; - conn?.Dispose(); throw; } finally diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 942beef3..5b26eafc 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -213,6 +213,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } catch (InvalidOperationException e) { + // If this exception occurred, we most likely were trying to cancel a completed query await requestContext.SendResult(new QueryCancelResult { Messages = e.Message