diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs new file mode 100644 index 00000000..31d0026d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Information pertaining to a unique connection instance. + /// + public class ConnectionInfo + { + /// + /// Constructor + /// + public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details) + { + Factory = factory; + OwnerUri = ownerUri; + ConnectionDetails = details; + ConnectionId = Guid.NewGuid(); + } + + /// + /// Unique Id, helpful to identify a connection info object + /// + public Guid ConnectionId { get; private set; } + + /// + /// URI identifying the owner/user of the connection. Could be a file, service, resource, etc. + /// + public string OwnerUri { get; private set; } + + /// + /// Factory used for creating the SQL connection associated with the connection info. + /// + public ISqlConnectionFactory Factory {get; private set;} + + /// + /// Properties used for creating/opening the SQL connection. + /// + public ConnectionDetails ConnectionDetails { get; private set; } + + /// + /// The connection to the SQL database that commands will be run against. + /// + public DbConnection SqlConnection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 8f430e29..8905ab92 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -5,42 +5,16 @@ using System; using System.Collections.Generic; -using System.Data.Common; using System.Data.SqlClient; using System.Threading.Tasks; using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Workspace; namespace Microsoft.SqlTools.ServiceLayer.Connection { - public class ConnectionInfo - { - public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details) - { - Factory = factory; - OwnerUri = ownerUri; - ConnectionDetails = details; - ConnectionId = Guid.NewGuid(); - } - - /// - /// Unique Id, helpful to identify a connection info object - /// - public Guid ConnectionId { get; private set; } - - public string OwnerUri { get; private set; } - - public ISqlConnectionFactory Factory {get; private set;} - - public ConnectionDetails ConnectionDetails { get; private set; } - - public DbConnection SqlConnection { get; set; } - } - /// /// Main class for the Connection Management services /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs deleted file mode 100644 index 543b18f5..00000000 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs +++ /dev/null @@ -1,89 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; - -namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts -{ - /// - /// Parameters for the Connect Request. - /// - public class ConnectParams - { - /// - /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace - /// or a virtual file representing an object in a database. - /// - public string OwnerUri { get; set; } - /// - /// Contains the required parameters to initialize a connection to a database. - /// A connection will identified by its server name, database name and user name. - /// This may be changed in the future to support multiple connections with different - /// connection properties to the same database. - /// - public ConnectionDetails Connection { get; set; } - } - - /// - /// Message format for the connection result response - /// - public class ConnectResponse - { - /// - /// A GUID representing a unique connection ID - /// - public string ConnectionId { get; set; } - - /// - /// Gets or sets any connection error messages - /// - public string Messages { get; set; } - } - - /// - /// Provides high level information about a connection. - /// - public class ConnectionSummary - { - /// - /// Gets or sets the connection server name - /// - public string ServerName { get; set; } - - /// - /// Gets or sets the connection database name - /// - public string DatabaseName { get; set; } - - /// - /// Gets or sets the connection user name - /// - public string UserName { get; set; } - } - - /// - /// Message format for the initial connection request - /// - public class ConnectionDetails : ConnectionSummary - { - /// - /// Gets or sets the connection password - /// - /// - public string Password { get; set; } - - // TODO Handle full set of properties - } - - /// - /// Connect request mapping entry - /// - public class ConnectionRequest - { - public static readonly - RequestType Type = - RequestType.Create("connection/connect"); - } -} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs new file mode 100644 index 00000000..31dad8c5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the Connect Request. + /// + public class ConnectParams + { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + /// + /// Contains the required parameters to initialize a connection to a database. + /// A connection will identified by its server name, database name and user name. + /// This may be changed in the future to support multiple connections with different + /// connection properties to the same database. + /// + public ConnectionDetails Connection { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs similarity index 100% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs new file mode 100644 index 00000000..c325c64f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs @@ -0,0 +1,23 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the connection result response + /// + public class ConnectResponse + { + /// + /// A GUID representing a unique connection ID + /// + public string ConnectionId { get; set; } + + /// + /// Gets or sets any connection error messages + /// + public string Messages { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs new file mode 100644 index 00000000..c0daee6d --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// ConnectionChanged notification mapping entry + /// + public class ConnectionChangedNotification + { + public static readonly + EventType Type = + EventType.Create("connection/connectionchanged"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs similarity index 68% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs index 94454bc5..3db86f34 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs @@ -3,8 +3,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; - namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { /// @@ -22,15 +20,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// public ConnectionSummary Connection { get; set; } } - - /// - /// ConnectionChanged notification mapping entry - /// - public class ConnectionChangedNotification - { - public static readonly - EventType Type = - EventType.Create("connection/connectionchanged"); - } - } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs new file mode 100644 index 00000000..0acac867 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -0,0 +1,21 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Message format for the initial connection request + /// + public class ConnectionDetails : ConnectionSummary + { + /// + /// Gets or sets the connection password + /// + /// + public string Password { get; set; } + + // TODO Handle full set of properties + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs new file mode 100644 index 00000000..50251e12 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Connect request mapping entry + /// + public class ConnectionRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/connect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs new file mode 100644 index 00000000..11549e85 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Provides high level information about a connection. + /// + public class ConnectionSummary + { + /// + /// Gets or sets the connection server name + /// + public string ServerName { get; set; } + + /// + /// Gets or sets the connection database name + /// + public string DatabaseName { get; set; } + + /// + /// Gets or sets the connection user name + /// + public string UserName { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs new file mode 100644 index 00000000..dfeb0ab4 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + + /// + /// Treats connections as the same if their server, db and usernames all match + /// + public class ConnectionSummaryComparer : IEqualityComparer + { + public bool Equals(ConnectionSummary x, ConnectionSummary y) + { + if(x == y) { return true; } + else if(x != null) + { + if(y == null) { return false; } + + // Compare server, db, username. Note: server is case-insensitive in the driver + return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0 + && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0 + && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0; + } + return false; + } + + public int GetHashCode(ConnectionSummary obj) + { + int hashcode = 31; + if(obj != null) + { + if(obj.ServerName != null) + { + hashcode ^= obj.ServerName.GetHashCode(); + } + if (obj.DatabaseName != null) + { + hashcode ^= obj.DatabaseName.GetHashCode(); + } + if (obj.UserName != null) + { + hashcode ^= obj.UserName.GetHashCode(); + } + } + return hashcode; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs new file mode 100644 index 00000000..02bc7623 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Extension methods to ConnectionSummary + /// + public static class ConnectionSummaryExtensions + { + /// + /// Create a copy of a ConnectionSummary object + /// + public static ConnectionSummary Clone(this ConnectionSummary summary) + { + return new ConnectionSummary() + { + ServerName = summary.ServerName, + DatabaseName = summary.DatabaseName, + UserName = summary.UserName + }; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs similarity index 63% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs index c078b308..91bc7faf 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs @@ -3,8 +3,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; - namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { /// @@ -18,14 +16,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// public string OwnerUri { get; set; } } - - /// - /// Disconnect request mapping entry - /// - public class DisconnectRequest - { - public static readonly - RequestType Type = - RequestType.Create("connection/disconnect"); - } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs new file mode 100644 index 00000000..cbf67ef2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Disconnect request mapping entry + /// + public class DisconnectRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/disconnect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs new file mode 100644 index 00000000..d6e65801 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts +{ + /// + /// Parameters to be used for reporting hosting-level errors, such as protocol violations + /// + public class HostingErrorParams + { + /// + /// The message of the error + /// + public string Message { get; set; } + } + + public class HostingErrorEvent + { + public static readonly + EventType Type = + EventType.Create("hosting/error"); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs index a18fa806..c4cf5365 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.EditorServices.Utility; @@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.SynchronizationContext = SynchronizationContext.Current; // Run the message loop - bool isRunning = true; - while (isRunning && !cancellationToken.IsCancellationRequested) + while (!cancellationToken.IsCancellationRequested) { - Message newMessage = null; + Message newMessage; try { @@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (MessageParseException e) { - // TODO: Write an error response - - Logger.Write( - LogLevel.Error, - "Could not parse a message that was received:\r\n\r\n" + - e.ToString()); + string message = string.Format("Exception occurred while parsing message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); // Continue the loop continue; @@ -227,18 +227,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } catch (Exception e) { - var b = e.Message; - newMessage = null; + // Log the error and send an error event to the client + string message = string.Format("Exception occurred while receiving message: {0}", e.Message); + Logger.Write(LogLevel.Error, message); + await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams + { + Message = message + }); + + // Continue the loop + continue; } // The message could be null if there was an error parsing the // previous message. In this case, do not try to dispatch it. if (newMessage != null) { + // Verbose logging + string logMessage = string.Format("Received message of type[{0}] and method[{1}]", + newMessage.MessageType, newMessage.Method); + Logger.Write(LogLevel.Verbose, logMessage); + // Process the message - await this.DispatchMessage( - newMessage, - this.MessageWriter); + await this.DispatchMessage(newMessage, this.MessageWriter); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs index f3857710..17d4b5e0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs @@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol private const int CR = 0x0D; private const int LF = 0x0A; - private static string[] NewLineDelimiters = new string[] { Environment.NewLine }; + private static readonly string[] NewLineDelimiters = { Environment.NewLine }; - private Stream inputStream; - private IMessageSerializer messageSerializer; - private Encoding messageEncoding; + private readonly Stream inputStream; + private readonly IMessageSerializer messageSerializer; + private readonly Encoding messageEncoding; private ReadState readState; private bool needsMoreData = true; private int readOffset; private int bufferEndOffset; - private byte[] messageBuffer = new byte[DefaultBufferSize]; + private byte[] messageBuffer; private int expectedContentLength; private Dictionary messageHeaders; - enum ReadState + private enum ReadState { Headers, Content @@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.needsMoreData = false; // Do we need to look for message headers? - if (this.readState == ReadState.Headers && + if (this.readState == ReadState.Headers && !this.TryReadMessageHeaders()) { // If we don't have enough data to read headers yet, keep reading @@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Do we need to look for message content? - if (this.readState == ReadState.Content && + if (this.readState == ReadState.Content && !this.TryReadMessageContent(out messageContent)) { // If we don't have enough data yet to construct the content, keep reading @@ -106,16 +106,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol break; } + // Now that we have a message, reset the buffer's state + ShiftBufferBytesAndShrink(readOffset); + // Get the JObject for the JSON content JObject messageObject = JObject.Parse(messageContent); - // Load the message - Logger.Write( - LogLevel.Verbose, - string.Format( - "READ MESSAGE:\r\n\r\n{0}", - messageObject.ToString(Formatting.Indented))); - // Return the parsed message return this.messageSerializer.DeserializeMessage(messageObject); } @@ -162,8 +158,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { int scanOffset = this.readOffset; - // Scan for the final double-newline that marks the - // end of the header lines + // Scan for the final double-newline that marks the end of the header lines while (scanOffset + 3 < this.bufferEndOffset && (this.messageBuffer[scanOffset] != CR || this.messageBuffer[scanOffset + 1] != LF || @@ -173,45 +168,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol scanOffset++; } - // No header or body separator found (e.g CRLFCRLF) + // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF) if (scanOffset + 3 >= this.bufferEndOffset) { return false; } - this.messageHeaders = new Dictionary(); + // Convert the header block into a array of lines + var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset) + .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - var headers = - Encoding.ASCII - .GetString(this.messageBuffer, this.readOffset, scanOffset) - .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries); - - // Read each header and store it in the dictionary - foreach (var header in headers) + try { - int currentLength = header.IndexOf(':'); - if (currentLength == -1) + // Read each header and store it in the dictionary + this.messageHeaders = new Dictionary(); + foreach (var header in headers) { - throw new ArgumentException("Message header must separate key and value using :"); + int currentLength = header.IndexOf(':'); + if (currentLength == -1) + { + throw new ArgumentException("Message header must separate key and value using :"); + } + + var key = header.Substring(0, currentLength); + var value = header.Substring(currentLength + 1).Trim(); + this.messageHeaders[key] = value; } - var key = header.Substring(0, currentLength); - var value = header.Substring(currentLength + 1).Trim(); - this.messageHeaders[key] = value; - } + // Parse out the content length as an int + string contentLengthString; + if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) + { + throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + } - // Make sure a Content-Length header was present, otherwise it - // is a fatal error - string contentLengthString = null; - if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString)) - { - throw new MessageParseException("", "Fatal error: Content-Length header must be provided."); + // Parse the content length to an integer + if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + { + throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + } } - - // Parse the content length to an integer - if (!int.TryParse(contentLengthString, out this.expectedContentLength)) + catch (Exception) { - throw new MessageParseException("", "Fatal error: Content-Length value is not an integer."); + // The content length was invalid or missing. Trash the buffer we've read + ShiftBufferBytesAndShrink(scanOffset + 4); + throw; } // Skip past the headers plus the newline characters @@ -234,31 +235,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol } // Convert the message contents to a string using the specified encoding - messageContent = - this.messageEncoding.GetString( - this.messageBuffer, - this.readOffset, - this.expectedContentLength); + messageContent = this.messageEncoding.GetString( + this.messageBuffer, + this.readOffset, + this.expectedContentLength); - // Move the remaining bytes to the front of the buffer for the next message - var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset); - Buffer.BlockCopy( - this.messageBuffer, - this.expectedContentLength + this.readOffset, - this.messageBuffer, - 0, - remainingByteCount); + readOffset += expectedContentLength; - // Reset the offsets for the next read - this.readOffset = 0; - this.bufferEndOffset = remainingByteCount; - - // Done reading content, now look for headers + // Done reading content, now look for headers for the next message this.readState = ReadState.Headers; return true; } + private void ShiftBufferBytesAndShrink(int bytesToRemove) + { + // Create a new buffer that is shrunken by the number of bytes to remove + // Note: by using Max, we can guarantee a buffer of at least default buffer size + byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)]; + + // If we need to do shifting, do the shifting + if (bytesToRemove <= messageBuffer.Length) + { + // Copy the existing buffer starting at the offset to remove + Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove); + } + + // Make the new buffer the message buffer + messageBuffer = newBuffer; + + // Reset the read offset and the end offset + readOffset = 0; + bufferEndOffset -= bytesToRemove; + } + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs index 14148778..a390eae2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs @@ -5,8 +5,6 @@ using System; using System.Collections.Generic; -using System.Data; -using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; @@ -16,159 +14,6 @@ using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { - internal class IntellisenseCache - { - // connection used to query for intellisense info - private DbConnection connection; - - // number of documents (URI's) that are using the cache for the same database - // the autocomplete service uses this to remove unreferenced caches - public int ReferenceCount { get; set; } - - public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails) - { - ReferenceCount = 0; - DatabaseInfo = CopySummary(connectionDetails); - - // TODO error handling on this. Intellisense should catch or else the service should handle - connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails)); - connection.Open(); - } - - /// - /// Used to identify a database for which this cache is used - /// - public ConnectionSummary DatabaseInfo - { - get; - private set; - } - /// - /// Gets the current autocomplete candidate list - /// - public IEnumerable AutoCompleteList { get; private set; } - - public async Task UpdateCache() - { - DbCommand command = connection.CreateCommand(); - command.CommandText = "SELECT name FROM sys.tables"; - command.CommandTimeout = 15; - command.CommandType = CommandType.Text; - var reader = await command.ExecuteReaderAsync(); - - List results = new List(); - while (await reader.ReadAsync()) - { - results.Add(reader[0].ToString()); - } - - AutoCompleteList = results; - await Task.FromResult(0); - } - - public List GetAutoCompleteItems(TextDocumentPosition textDocumentPosition) - { - List completions = new List(); - - int i = 0; - - // Take a reference to the list at a point in time in case we update and replace the list - var suggestions = AutoCompleteList; - // the completion list will be null is user not connected to server - if (this.AutoCompleteList != null) - { - - foreach (var autoCompleteItem in suggestions) - { - // convert the completion item candidates into CompletionItems - completions.Add(new CompletionItem() - { - Label = autoCompleteItem, - Kind = CompletionItemKind.Keyword, - Detail = autoCompleteItem + " details", - Documentation = autoCompleteItem + " documentation", - TextEdit = new TextEdit - { - NewText = autoCompleteItem, - Range = new Range - { - Start = new Position - { - Line = textDocumentPosition.Position.Line, - Character = textDocumentPosition.Position.Character - }, - End = new Position - { - Line = textDocumentPosition.Position.Line, - Character = textDocumentPosition.Position.Character + 5 - } - } - } - }); - - // only show 50 items - if (++i == 50) - { - break; - } - } - } - - return completions; - } - - private static ConnectionSummary CopySummary(ConnectionSummary summary) - { - return new ConnectionSummary() - { - ServerName = summary.ServerName, - DatabaseName = summary.DatabaseName, - UserName = summary.UserName - }; - } - } - - /// - /// Treats connections as the same if their server, db and usernames all match - /// - public class ConnectionSummaryComparer : IEqualityComparer - { - public bool Equals(ConnectionSummary x, ConnectionSummary y) - { - if(x == y) { return true; } - else if(x != null) - { - if(y == null) { return false; } - - // Compare server, db, username. Note: server is case-insensitive in the driver - return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0 - && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0 - && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0; - } - return false; - } - - public int GetHashCode(ConnectionSummary obj) - { - int hashcode = 31; - if(obj != null) - { - if(obj.ServerName != null) - { - hashcode ^= obj.ServerName.GetHashCode(); - } - if (obj.DatabaseName != null) - { - hashcode ^= obj.DatabaseName.GetHashCode(); - } - if (obj.UserName != null) - { - hashcode ^= obj.UserName.GetHashCode(); - } - } - return hashcode; - } - } /// /// Main class for Autocomplete functionality /// @@ -235,13 +80,36 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } } + + private ConnectionService connectionService = null; + + /// + /// Internal for testing purposes only + /// + internal ConnectionService ConnectionServiceInstance + { + get + { + if(connectionService == null) + { + connectionService = ConnectionService.Instance; + } + return connectionService; + } + + set + { + connectionService = value; + } + } + public void InitializeService(ServiceHost serviceHost) { // Register a callback for when a connection is created - ConnectionService.Instance.RegisterOnConnectionTask(UpdateAutoCompleteCache); + ConnectionServiceInstance.RegisterOnConnectionTask(UpdateAutoCompleteCache); // Register a callback for when a connection is closed - ConnectionService.Instance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); + ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference); } private async Task UpdateAutoCompleteCache(ConnectionInfo connectionInfo) @@ -252,6 +120,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } + /// + /// Intellisense cache count access for testing. + /// + internal int GetCacheCount() + { + return caches.Count; + } + /// /// Remove a reference to an autocomplete cache from a URI. If /// it is the last URI connected to a particular connection, @@ -312,7 +188,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices // that are not backed by a SQL connection ConnectionInfo info; IntellisenseCache cache; - if (ConnectionService.Instance.TryFindConnection(textDocumentPosition.Uri, out info) + if (ConnectionServiceInstance.TryFindConnection(textDocumentPosition.Uri, out info) && caches.TryGetValue((ConnectionSummary)info.ConnectionDetails, out cache)) { return cache.GetAutoCompleteItems(textDocumentPosition).ToArray(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs new file mode 100644 index 00000000..eea72771 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs @@ -0,0 +1,122 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.LanguageServices +{ + internal class IntellisenseCache + { + /// + /// connection used to query for intellisense info + /// + private DbConnection connection; + + /// + /// Number of documents (URI's) that are using the cache for the same database. + /// The autocomplete service uses this to remove unreferenced caches. + /// + public int ReferenceCount { get; set; } + + public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails) + { + ReferenceCount = 0; + DatabaseInfo = connectionDetails.Clone(); + + // TODO error handling on this. Intellisense should catch or else the service should handle + connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails)); + connection.Open(); + } + + /// + /// Used to identify a database for which this cache is used + /// + public ConnectionSummary DatabaseInfo + { + get; + private set; + } + /// + /// Gets the current autocomplete candidate list + /// + public IEnumerable AutoCompleteList { get; private set; } + + public async Task UpdateCache() + { + DbCommand command = connection.CreateCommand(); + command.CommandText = "SELECT name FROM sys.tables"; + command.CommandTimeout = 15; + command.CommandType = CommandType.Text; + var reader = await command.ExecuteReaderAsync(); + + List results = new List(); + while (await reader.ReadAsync()) + { + results.Add(reader[0].ToString()); + } + + AutoCompleteList = results; + await Task.FromResult(0); + } + + public List GetAutoCompleteItems(TextDocumentPosition textDocumentPosition) + { + List completions = new List(); + + int i = 0; + + // Take a reference to the list at a point in time in case we update and replace the list + var suggestions = AutoCompleteList; + // the completion list will be null is user not connected to server + if (this.AutoCompleteList != null) + { + + foreach (var autoCompleteItem in suggestions) + { + // convert the completion item candidates into CompletionItems + completions.Add(new CompletionItem() + { + Label = autoCompleteItem, + Kind = CompletionItemKind.Keyword, + Detail = autoCompleteItem + " details", + Documentation = autoCompleteItem + " documentation", + TextEdit = new TextEdit + { + NewText = autoCompleteItem, + Range = new Range + { + Start = new Position + { + Line = textDocumentPosition.Position.Line, + Character = textDocumentPosition.Position.Character + }, + End = new Position + { + Line = textDocumentPosition.Position.Line, + Character = textDocumentPosition.Position.Character + 5 + } + } + } + }); + + // only show 50 items + if (++i == 50) + { + break; + } + } + } + + return completions; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 6cdbd745..f380f1bb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -337,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { Logger.Write( LogLevel.Error, - String.Format( + string.Format( "Exception while cancelling analysis task:\n\n{0}", e.ToString())); diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs new file mode 100644 index 00000000..3eb87f4f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query cancellation request + /// + public class QueryCancelParams + { + public string OwnerUri { get; set; } + } + + /// + /// Parameters to return as the result of a query dispose request + /// + public class QueryCancelResult + { + /// + /// Any error messages that occurred during disposing the result set. Optional, can be set + /// to null if there were no errors. + /// + public string Messages { get; set; } + } + + public class QueryCancelRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/cancel"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 887bbbaf..d9a886d4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -177,12 +177,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { HasError = true; UnwrapDbException(dbe); - conn?.Dispose(); } catch (Exception) { HasError = true; - conn?.Dispose(); throw; } finally @@ -233,6 +231,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }; } + /// + /// Cancels the query by issuing the cancellation token + /// + public void Cancel() + { + // Make sure that the query hasn't completed execution + if (HasExecuted) + { + throw new InvalidOperationException("The query has already completed, it cannot be cancelled."); + } + + // Issue the cancellation token for the query + cancellationSource.Cancel(); + } + /// /// Delegate handler for storing messages that are returned from the server /// NOTE: Only messages that are below a certain severity will be returned via this diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 389be092..5b26eafc 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -73,6 +73,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution serviceHost.SetRequestHandler(QueryExecuteRequest.Type, HandleExecuteRequest); serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest); serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest); + serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest); // Register handler for shutdown event serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => @@ -178,6 +179,52 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + public async Task HandleCancelRequest(QueryCancelParams cancelParams, + RequestContext requestContext) + { + try + { + // Attempt to find the query for the owner uri + Query result; + if (!ActiveQueries.TryGetValue(cancelParams.OwnerUri, out result)) + { + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Failed to cancel query, ID not found." + }); + return; + } + + // Cancel the query + result.Cancel(); + + // Attempt to dispose the query + if (!ActiveQueries.TryRemove(cancelParams.OwnerUri, out result)) + { + // It really shouldn't be possible to get to this scenario, but we'll cover it anyhow + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Query successfully cancelled, failed to dispose query. ID not found." + }); + return; + } + + await requestContext.SendResult(new QueryCancelResult()); + } + catch (InvalidOperationException e) + { + // If this exception occurred, we most likely were trying to cancel a completed query + await requestContext.SendResult(new QueryCancelResult + { + Messages = e.Message + }); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + #endregion #region Private Helpers diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs index 701fa6f5..f47cacb9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -182,7 +182,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace foreach (var textChange in textChangeParams.ContentChanges) { string fileUri = textChangeParams.Uri ?? textChangeParams.TextDocument.Uri; - msg.AppendLine(String.Format(" File: {0}", fileUri)); + msg.AppendLine(string.Format(" File: {0}", fileUri)); ScriptFile changedFile = Workspace.GetFile(fileUri); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs index 80ea3ec9..ddf82059 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs @@ -3,11 +3,18 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System.Collections.Generic; +using System.Data; +using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Microsoft.SqlTools.Test.Utility; +using Moq; +using Moq.Protected; using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices @@ -19,6 +26,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices { #region "Diagnostics tests" + /// + /// Verify that the latest SqlParser (2016 as of this writing) is used by default + /// + [Fact] + public void LatestSqlParserIsUsedByDefault() + { + // This should only parse correctly on SQL server 2016 or newer + const string sql2016Text = + @"CREATE SECURITY POLICY [FederatedSecurityPolicy]" + "\r\n" + + @"ADD FILTER PREDICATE [rls].[fn_securitypredicate]([CustomerId])" + "\r\n" + + @"ON [dbo].[Customer];"; + + LanguageService service = TestObjects.GetTestLanguageService(); + + // parse + var scriptFile = new ScriptFile(); + scriptFile.SetFileContents(sql2016Text); + ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile); + + // verify that no errors are detected + Assert.Equal(0, fileMarkers.Length); + } + /// /// Verify that the SQL parser correctly detects errors in text /// @@ -108,24 +138,179 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices #region "Autocomplete Tests" /// - /// Verify that the SQL parser correctly detects errors in text + /// Creates a mock db command that returns a predefined result set + /// + public static DbCommand CreateTestCommand(Dictionary[][] data) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + commandMockSetup.Returns(new TestDbDataReader(data)); + + return commandMock.Object; + } + + /// + /// Creates a mock db connection that returns predefined data when queried for a result set + /// + public DbConnection CreateMockDbConnection(Dictionary[][] data) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data)); + + return connectionMock.Object; + } + + /// + /// Verify that the autocomplete service returns tables for the current connection as suggestions /// [Fact] - public async Task AutocompleteTest() + public void TablesAreReturnedAsAutocompleteSuggestions() { - // TODO Re-enable this test once we have a way to hook up the right auto-complete and connection services. - // Probably need a service provider channel so that we can mock service access. Otherwise everything accesses - // static instances and cannot be properly tested. - - //var autocompleteService = TestObjects.GetAutoCompleteService(); - //var connectionService = TestObjects.GetTestConnectionService(); + // Result set for the query of database tables + Dictionary[] data = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } } + }; - //ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); - //var connectionResult = connectionService.Connect(connectionRequest); + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data})); + + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + autocompleteService.ConnectionFactory = mockFactory.Object; - //var sqlConnection = connectionService.ActiveConnections[connectionResult.ConnectionId]; - //await autocompleteService.UpdateAutoCompleteCache(sqlConnection); - await Task.Run(() => { return; }); + // Open a connection + // The cache should get updated as part of this + ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); + var connectionResult = connectionService.Connect(connectionRequest); + Assert.NotEmpty(connectionResult.ConnectionId); + + // Check that there is one cache created in the auto complete service + Assert.Equal(1, autocompleteService.GetCacheCount()); + + // Check that we get table suggestions for an autocomplete request + TextDocumentPosition position = new TextDocumentPosition(); + position.Uri = connectionRequest.OwnerUri; + position.Position = new Position(); + position.Position.Line = 1; + position.Position.Character = 1; + var items = autocompleteService.GetCompletionItems(position); + Assert.Equal(2, items.Length); + Assert.Equal("master", items[0].Label); + Assert.Equal("model", items[1].Label); + } + + /// + /// Verify that only one intellisense cache is created for two documents using + /// the autocomplete service when they share a common connection. + /// + [Fact] + public void OnlyOneCacheIsCreatedForTwoDocumentsWithSameConnection() + { + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + // Open two connections + ConnectParams connectionRequest1 = TestObjects.GetTestConnectionParams(); + connectionRequest1.OwnerUri = "file:///my/first/file.sql"; + ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams(); + connectionRequest2.OwnerUri = "file:///my/second/file.sql"; + var connectionResult1 = connectionService.Connect(connectionRequest1); + Assert.NotEmpty(connectionResult1.ConnectionId); + var connectionResult2 = connectionService.Connect(connectionRequest2); + Assert.NotEmpty(connectionResult2.ConnectionId); + + // Verify that only one intellisense cache is created to service both URI's + Assert.Equal(1, autocompleteService.GetCacheCount()); + } + + /// + /// Verify that two different intellisense caches and corresponding autocomplete + /// suggestions are provided for two documents with different connections. + /// + [Fact] + public void TwoCachesAreCreatedForTwoDocumentsWithDifferentConnections() + { + // Result set for the query of database tables + Dictionary[] data1 = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "model" } } + }; + + Dictionary[] data2 = + { + new Dictionary { {"name", "master" } }, + new Dictionary { {"name", "my_table" } }, + new Dictionary { {"name", "my_other_table" } } + }; + + var mockFactory = new Mock(); + mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateMockDbConnection(new[] {data1})) + .Returns(CreateMockDbConnection(new[] {data2})); + + var connectionService = TestObjects.GetTestConnectionService(); + var autocompleteService = new AutoCompleteService(); + autocompleteService.ConnectionServiceInstance = connectionService; + autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance); + + autocompleteService.ConnectionFactory = mockFactory.Object; + + // Open connections + // The cache should get updated as part of this + ConnectParams connectionRequest = TestObjects.GetTestConnectionParams(); + connectionRequest.OwnerUri = "file:///my/first/sql/file.sql"; + var connectionResult = connectionService.Connect(connectionRequest); + Assert.NotEmpty(connectionResult.ConnectionId); + + // Check that there is one cache created in the auto complete service + Assert.Equal(1, autocompleteService.GetCacheCount()); + + // Open second connection + ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams(); + connectionRequest2.OwnerUri = "file:///my/second/sql/file.sql"; + connectionRequest2.Connection.DatabaseName = "my_other_db"; + var connectionResult2 = connectionService.Connect(connectionRequest2); + Assert.NotEmpty(connectionResult2.ConnectionId); + + // Check that there are now two caches in the auto complete service + Assert.Equal(2, autocompleteService.GetCacheCount()); + + // Check that we get 2 different table suggestions for autocomplete requests + TextDocumentPosition position = new TextDocumentPosition(); + position.Uri = connectionRequest.OwnerUri; + position.Position = new Position(); + position.Position.Line = 1; + position.Position.Character = 1; + + var items = autocompleteService.GetCompletionItems(position); + Assert.Equal(2, items.Length); + Assert.Equal("master", items[0].Label); + Assert.Equal("model", items[1].Label); + + TextDocumentPosition position2 = new TextDocumentPosition(); + position2.Uri = connectionRequest2.OwnerUri; + position2.Position = new Position(); + position2.Position.Line = 1; + position2.Position.Character = 1; + + var items2 = autocompleteService.GetCompletionItems(position2); + Assert.Equal(3, items2.Length); + Assert.Equal("master", items2[0].Label); + Assert.Equal("my_table", items2[1].Label); + Assert.Equal("my_other_table", items2[2].Label); } #endregion diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs deleted file mode 100644 index 54fbf01f..00000000 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs +++ /dev/null @@ -1,178 +0,0 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.IO; -using System.Text; -using System.Threading.Tasks; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; -using Xunit; - -namespace Microsoft.SqlTools.ServiceLayer.Test.Message -{ - public class MessageReaderWriterTests - { - const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}"; - const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}"; - readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); - - private IMessageSerializer messageSerializer; - - public MessageReaderWriterTests() - { - this.messageSerializer = new V8MessageSerializer(); - } - - [Fact] - public async Task WritesMessage() - { - MemoryStream outputStream = new MemoryStream(); - - MessageWriter messageWriter = - new MessageWriter( - outputStream, - this.messageSerializer); - - // Write the message and then roll back the stream to be read - // TODO: This will need to be redone! - await messageWriter.WriteMessage(HostingMessage.Event("testEvent", null)); - outputStream.Seek(0, SeekOrigin.Begin); - - string expectedHeaderString = - string.Format( - Constants.ContentLengthFormatString, - ExpectedMessageByteCount); - - byte[] buffer = new byte[128]; - await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); - - Assert.Equal( - expectedHeaderString, - Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); - - // Read the message - await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount); - - Assert.Equal( - TestEventString, - Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount)); - - outputStream.Dispose(); - } - - [Fact] - public void ReadsMessage() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Write a message to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - inputStream.Write( - this.GetMessageBytes(TestEventString), - 0, - messageBuffer.Length); - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - [Fact] - public void ReadsManyBufferedMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message to use for writing to the stream - byte[] messageBuffer = this.GetMessageBytes(TestEventString); - - // How many messages of this size should we write to overflow the buffer? - int overflowMessageCount = - (int)Math.Ceiling( - (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); - - // Write the necessary number of messages to the stream - for (int i = 0; i < overflowMessageCount; i++) - { - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - } - - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - // Read the written messages from the stream - for (int i = 0; i < overflowMessageCount; i++) - { - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - } - - inputStream.Dispose(); - } - - [Fact] - public void ReaderResizesBufferForLargeMessages() - { - MemoryStream inputStream = new MemoryStream(); - MessageReader messageReader = - new MessageReader( - inputStream, - this.messageSerializer); - - // Get a message with content so large that the buffer will need - // to be resized to fit it all. - byte[] messageBuffer = - this.GetMessageBytes( - string.Format( - TestEventFormatString, - new String('X', (int)(MessageReader.DefaultBufferSize * 3)))); - - inputStream.Write(messageBuffer, 0, messageBuffer.Length); - inputStream.Flush(); - inputStream.Seek(0, SeekOrigin.Begin); - - HostingMessage messageResult = messageReader.ReadMessage().Result; - Assert.Equal("testEvent", messageResult.Method); - - inputStream.Dispose(); - } - - private byte[] GetMessageBytes(string messageString, Encoding encoding = null) - { - if (encoding == null) - { - encoding = Encoding.UTF8; - } - - byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); - byte[] headerBytes = - Encoding.ASCII.GetBytes( - string.Format( - Constants.ContentLengthFormatString, - messageBytes.Length)); - - // Copy the bytes into a single buffer - byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; - Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); - Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); - - return finalBytes; - } - } -} - diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs new file mode 100644 index 00000000..b575fe0c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs @@ -0,0 +1,17 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class Common + { + public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}"; + public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}"; + public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString); + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs new file mode 100644 index 00000000..0a12dc3e --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs @@ -0,0 +1,241 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageReaderTests + { + + private readonly IMessageSerializer messageSerializer; + + public MessageReaderTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public void ReadsMessage() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer); + + // Write a message to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length); + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + [Fact] + public void ReadsManyBufferedMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message to use for writing to the stream + byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString); + + // How many messages of this size should we write to overflow the buffer? + int overflowMessageCount = + (int)Math.Ceiling( + (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length); + + // Write the necessary number of messages to the stream + for (int i = 0; i < overflowMessageCount; i++) + { + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + } + + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Read the written messages from the stream + for (int i = 0; i < overflowMessageCount; i++) + { + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + } + + inputStream.Dispose(); + } + + [Fact] + public void ReadMalformedMissingHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthNonIntegerTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a non-integer content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMissingContentLengthHeaderTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... I create a new stream and pass it a a message without a content-length header + // ... and attempt to read a message from it + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedContentLengthTooShortTest() + { + using (MemoryStream inputStream = new MemoryStream()) + { + // If: + // ... Pass in an event that has an incorrect content length + // ... And pass in an event that is correct + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... The first read should fail with an exception while deserializing + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... The second read should fail with an exception while reading headers + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + } + } + + [Fact] + public void ReadMalformedThenValidTest() + { + // If: + // ... I create a new stream and pass it information that is malformed + // ... and attempt to read a message from it + // ... Then pass it information that is valid and attempt to read a message from it + using (MemoryStream inputStream = new MemoryStream()) + { + MessageReader messageReader = new MessageReader(inputStream, messageSerializer); + byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n"); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + messageBuffer = GetMessageBytes(Common.TestEventString); + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + // Then: + // ... An exception should be thrown while reading the first one + Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait(); + + // ... A test event should be successfully read from the second one + Message messageResult = messageReader.ReadMessage().Result; + Assert.NotNull(messageResult); + Assert.Equal("testEvent", messageResult.Method); + } + } + + [Fact] + public void ReaderResizesBufferForLargeMessages() + { + MemoryStream inputStream = new MemoryStream(); + MessageReader messageReader = + new MessageReader( + inputStream, + this.messageSerializer); + + // Get a message with content so large that the buffer will need + // to be resized to fit it all. + byte[] messageBuffer = this.GetMessageBytes( + string.Format( + Common.TestEventFormatString, + new String('X', (int) (MessageReader.DefaultBufferSize*3)))); + + inputStream.Write(messageBuffer, 0, messageBuffer.Length); + inputStream.Flush(); + inputStream.Seek(0, SeekOrigin.Begin); + + Message messageResult = messageReader.ReadMessage().Result; + Assert.Equal("testEvent", messageResult.Method); + + inputStream.Dispose(); + } + + private byte[] GetMessageBytes(string messageString, Encoding encoding = null) + { + if (encoding == null) + { + encoding = Encoding.UTF8; + } + + byte[] messageBytes = Encoding.UTF8.GetBytes(messageString); + byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length)); + + // Copy the bytes into a single buffer + byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length]; + Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length); + Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length); + + return finalBytes; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs new file mode 100644 index 00000000..3c007a85 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging +{ + public class MessageWriterTests + { + private readonly IMessageSerializer messageSerializer; + + public MessageWriterTests() + { + this.messageSerializer = new V8MessageSerializer(); + } + + [Fact] + public async Task WritesMessage() + { + MemoryStream outputStream = new MemoryStream(); + MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer); + + // Write the message and then roll back the stream to be read + // TODO: This will need to be redone! + await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null)); + outputStream.Seek(0, SeekOrigin.Begin); + + string expectedHeaderString = string.Format(Constants.ContentLengthFormatString, + Common.ExpectedMessageByteCount); + + byte[] buffer = new byte[128]; + await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length); + + Assert.Equal( + expectedHeaderString, + Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length)); + + // Read the message + await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount); + + Assert.Equal(Common.TestEventString, + Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount)); + + outputStream.Dispose(); + } + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs similarity index 94% rename from test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs rename to test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs index 0ba08056..89238098 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs @@ -6,7 +6,7 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -namespace Microsoft.SqlTools.ServiceLayer.Test.Message +namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging { #region Request Types diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs new file mode 100644 index 00000000..dbc344f8 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -0,0 +1,124 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class CancelTests + { + [Fact] + public void CancelInProgressQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and execute it + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri }; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a successful event (no messages) + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.Null(result.Messages); + + // ... The query should have been disposed as well + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void CancelExecutedQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and wait for execution + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri}; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + + // ... The query should not have been disposed + Assert.NotEmpty(queryService.ActiveQueries); + } + + [Fact] + public void CancelNonExistantTest() + { + // If: + // ... I request to cancel a query that doesn't exist + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + } + + #region Mocking + + private static Mock> GetQueryCancelResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private static void VerifyQueryCancelCallCount(Mock> mock, + Times sendResultCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 49b329a0..f887b50f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -37,7 +37,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Dictionary rowDictionary = new Dictionary(); for (int column = 0; column < columns; column++) { - rowDictionary.Add(String.Format("column{0}", column), String.Format("val{0}{1}", column, row)); + rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row)); } output[row] = rowDictionary; }