diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
new file mode 100644
index 00000000..31d0026d
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
@@ -0,0 +1,53 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Data.Common;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection
+{
+ ///
+ /// Information pertaining to a unique connection instance.
+ ///
+ public class ConnectionInfo
+ {
+ ///
+ /// Constructor
+ ///
+ public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details)
+ {
+ Factory = factory;
+ OwnerUri = ownerUri;
+ ConnectionDetails = details;
+ ConnectionId = Guid.NewGuid();
+ }
+
+ ///
+ /// Unique Id, helpful to identify a connection info object
+ ///
+ public Guid ConnectionId { get; private set; }
+
+ ///
+ /// URI identifying the owner/user of the connection. Could be a file, service, resource, etc.
+ ///
+ public string OwnerUri { get; private set; }
+
+ ///
+ /// Factory used for creating the SQL connection associated with the connection info.
+ ///
+ public ISqlConnectionFactory Factory {get; private set;}
+
+ ///
+ /// Properties used for creating/opening the SQL connection.
+ ///
+ public ConnectionDetails ConnectionDetails { get; private set; }
+
+ ///
+ /// The connection to the SQL database that commands will be run against.
+ ///
+ public DbConnection SqlConnection { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index 8f430e29..8905ab92 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -5,42 +5,16 @@
using System;
using System.Collections.Generic;
-using System.Data.Common;
using System.Data.SqlClient;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
-using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
- public class ConnectionInfo
- {
- public ConnectionInfo(ISqlConnectionFactory factory, string ownerUri, ConnectionDetails details)
- {
- Factory = factory;
- OwnerUri = ownerUri;
- ConnectionDetails = details;
- ConnectionId = Guid.NewGuid();
- }
-
- ///
- /// Unique Id, helpful to identify a connection info object
- ///
- public Guid ConnectionId { get; private set; }
-
- public string OwnerUri { get; private set; }
-
- public ISqlConnectionFactory Factory {get; private set;}
-
- public ConnectionDetails ConnectionDetails { get; private set; }
-
- public DbConnection SqlConnection { get; set; }
- }
-
///
/// Main class for the Connection Management services
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs
deleted file mode 100644
index 543b18f5..00000000
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessages.cs
+++ /dev/null
@@ -1,89 +0,0 @@
-//
-// Copyright (c) Microsoft. All rights reserved.
-// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-//
-
-using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
-
-namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
-{
- ///
- /// Parameters for the Connect Request.
- ///
- public class ConnectParams
- {
- ///
- /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace
- /// or a virtual file representing an object in a database.
- ///
- public string OwnerUri { get; set; }
- ///
- /// Contains the required parameters to initialize a connection to a database.
- /// A connection will identified by its server name, database name and user name.
- /// This may be changed in the future to support multiple connections with different
- /// connection properties to the same database.
- ///
- public ConnectionDetails Connection { get; set; }
- }
-
- ///
- /// Message format for the connection result response
- ///
- public class ConnectResponse
- {
- ///
- /// A GUID representing a unique connection ID
- ///
- public string ConnectionId { get; set; }
-
- ///
- /// Gets or sets any connection error messages
- ///
- public string Messages { get; set; }
- }
-
- ///
- /// Provides high level information about a connection.
- ///
- public class ConnectionSummary
- {
- ///
- /// Gets or sets the connection server name
- ///
- public string ServerName { get; set; }
-
- ///
- /// Gets or sets the connection database name
- ///
- public string DatabaseName { get; set; }
-
- ///
- /// Gets or sets the connection user name
- ///
- public string UserName { get; set; }
- }
-
- ///
- /// Message format for the initial connection request
- ///
- public class ConnectionDetails : ConnectionSummary
- {
- ///
- /// Gets or sets the connection password
- ///
- ///
- public string Password { get; set; }
-
- // TODO Handle full set of properties
- }
-
- ///
- /// Connect request mapping entry
- ///
- public class ConnectionRequest
- {
- public static readonly
- RequestType Type =
- RequestType.Create("connection/connect");
- }
-}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
new file mode 100644
index 00000000..31dad8c5
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
@@ -0,0 +1,26 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Parameters for the Connect Request.
+ ///
+ public class ConnectParams
+ {
+ ///
+ /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace
+ /// or a virtual file representing an object in a database.
+ ///
+ public string OwnerUri { get; set; }
+ ///
+ /// Contains the required parameters to initialize a connection to a database.
+ /// A connection will identified by its server name, database name and user name.
+ /// This may be changed in the future to support multiple connections with different
+ /// connection properties to the same database.
+ ///
+ public ConnectionDetails Connection { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
similarity index 100%
rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectMessagesExtensions.cs
rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs
new file mode 100644
index 00000000..c325c64f
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs
@@ -0,0 +1,23 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Message format for the connection result response
+ ///
+ public class ConnectResponse
+ {
+ ///
+ /// A GUID representing a unique connection ID
+ ///
+ public string ConnectionId { get; set; }
+
+ ///
+ /// Gets or sets any connection error messages
+ ///
+ public string Messages { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs
new file mode 100644
index 00000000..c0daee6d
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedNotification.cs
@@ -0,0 +1,19 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// ConnectionChanged notification mapping entry
+ ///
+ public class ConnectionChangedNotification
+ {
+ public static readonly
+ EventType Type =
+ EventType.Create("connection/connectionchanged");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs
similarity index 68%
rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs
rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs
index 94454bc5..3db86f34 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedMessages.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionChangedParams.cs
@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
-
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
///
@@ -22,15 +20,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
///
public ConnectionSummary Connection { get; set; }
}
-
- ///
- /// ConnectionChanged notification mapping entry
- ///
- public class ConnectionChangedNotification
- {
- public static readonly
- EventType Type =
- EventType.Create("connection/connectionchanged");
- }
-
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
new file mode 100644
index 00000000..0acac867
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
@@ -0,0 +1,21 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Message format for the initial connection request
+ ///
+ public class ConnectionDetails : ConnectionSummary
+ {
+ ///
+ /// Gets or sets the connection password
+ ///
+ ///
+ public string Password { get; set; }
+
+ // TODO Handle full set of properties
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs
new file mode 100644
index 00000000..50251e12
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs
@@ -0,0 +1,19 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Connect request mapping entry
+ ///
+ public class ConnectionRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("connection/connect");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs
new file mode 100644
index 00000000..11549e85
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummary.cs
@@ -0,0 +1,28 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Provides high level information about a connection.
+ ///
+ public class ConnectionSummary
+ {
+ ///
+ /// Gets or sets the connection server name
+ ///
+ public string ServerName { get; set; }
+
+ ///
+ /// Gets or sets the connection database name
+ ///
+ public string DatabaseName { get; set; }
+
+ ///
+ /// Gets or sets the connection user name
+ ///
+ public string UserName { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs
new file mode 100644
index 00000000..dfeb0ab4
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryComparer.cs
@@ -0,0 +1,53 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+
+ ///
+ /// Treats connections as the same if their server, db and usernames all match
+ ///
+ public class ConnectionSummaryComparer : IEqualityComparer
+ {
+ public bool Equals(ConnectionSummary x, ConnectionSummary y)
+ {
+ if(x == y) { return true; }
+ else if(x != null)
+ {
+ if(y == null) { return false; }
+
+ // Compare server, db, username. Note: server is case-insensitive in the driver
+ return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0
+ && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0
+ && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0;
+ }
+ return false;
+ }
+
+ public int GetHashCode(ConnectionSummary obj)
+ {
+ int hashcode = 31;
+ if(obj != null)
+ {
+ if(obj.ServerName != null)
+ {
+ hashcode ^= obj.ServerName.GetHashCode();
+ }
+ if (obj.DatabaseName != null)
+ {
+ hashcode ^= obj.DatabaseName.GetHashCode();
+ }
+ if (obj.UserName != null)
+ {
+ hashcode ^= obj.UserName.GetHashCode();
+ }
+ }
+ return hashcode;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs
new file mode 100644
index 00000000..02bc7623
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionSummaryExtensions.cs
@@ -0,0 +1,26 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Extension methods to ConnectionSummary
+ ///
+ public static class ConnectionSummaryExtensions
+ {
+ ///
+ /// Create a copy of a ConnectionSummary object
+ ///
+ public static ConnectionSummary Clone(this ConnectionSummary summary)
+ {
+ return new ConnectionSummary()
+ {
+ ServerName = summary.ServerName,
+ DatabaseName = summary.DatabaseName,
+ UserName = summary.UserName
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs
similarity index 63%
rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs
rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs
index c078b308..91bc7faf 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectMessages.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs
@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
-
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
///
@@ -18,14 +16,4 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
///
public string OwnerUri { get; set; }
}
-
- ///
- /// Disconnect request mapping entry
- ///
- public class DisconnectRequest
- {
- public static readonly
- RequestType Type =
- RequestType.Create("connection/disconnect");
- }
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs
new file mode 100644
index 00000000..cbf67ef2
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectRequest.cs
@@ -0,0 +1,19 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Disconnect request mapping entry
+ ///
+ public class DisconnectRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("connection/disconnect");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs
new file mode 100644
index 00000000..d6e65801
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Contracts/HostingErrorEvent.cs
@@ -0,0 +1,28 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Hosting.Contracts
+{
+ ///
+ /// Parameters to be used for reporting hosting-level errors, such as protocol violations
+ ///
+ public class HostingErrorParams
+ {
+ ///
+ /// The message of the error
+ ///
+ public string Message { get; set; }
+ }
+
+ public class HostingErrorEvent
+ {
+ public static readonly
+ EventType Type =
+ EventType.Create("hosting/error");
+
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs
index a18fa806..c4cf5365 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageDispatcher.cs
@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Channel;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
using Microsoft.SqlTools.EditorServices.Utility;
@@ -198,10 +199,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
this.SynchronizationContext = SynchronizationContext.Current;
// Run the message loop
- bool isRunning = true;
- while (isRunning && !cancellationToken.IsCancellationRequested)
+ while (!cancellationToken.IsCancellationRequested)
{
- Message newMessage = null;
+ Message newMessage;
try
{
@@ -210,12 +210,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
catch (MessageParseException e)
{
- // TODO: Write an error response
-
- Logger.Write(
- LogLevel.Error,
- "Could not parse a message that was received:\r\n\r\n" +
- e.ToString());
+ string message = string.Format("Exception occurred while parsing message: {0}", e.Message);
+ Logger.Write(LogLevel.Error, message);
+ await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
+ {
+ Message = message
+ });
// Continue the loop
continue;
@@ -227,18 +227,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
catch (Exception e)
{
- var b = e.Message;
- newMessage = null;
+ // Log the error and send an error event to the client
+ string message = string.Format("Exception occurred while receiving message: {0}", e.Message);
+ Logger.Write(LogLevel.Error, message);
+ await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams
+ {
+ Message = message
+ });
+
+ // Continue the loop
+ continue;
}
// The message could be null if there was an error parsing the
// previous message. In this case, do not try to dispatch it.
if (newMessage != null)
{
+ // Verbose logging
+ string logMessage = string.Format("Received message of type[{0}] and method[{1}]",
+ newMessage.MessageType, newMessage.Method);
+ Logger.Write(LogLevel.Verbose, logMessage);
+
// Process the message
- await this.DispatchMessage(
- newMessage,
- this.MessageWriter);
+ await this.DispatchMessage(newMessage, this.MessageWriter);
}
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs
index f3857710..17d4b5e0 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/MessageReader.cs
@@ -25,22 +25,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
private const int CR = 0x0D;
private const int LF = 0x0A;
- private static string[] NewLineDelimiters = new string[] { Environment.NewLine };
+ private static readonly string[] NewLineDelimiters = { Environment.NewLine };
- private Stream inputStream;
- private IMessageSerializer messageSerializer;
- private Encoding messageEncoding;
+ private readonly Stream inputStream;
+ private readonly IMessageSerializer messageSerializer;
+ private readonly Encoding messageEncoding;
private ReadState readState;
private bool needsMoreData = true;
private int readOffset;
private int bufferEndOffset;
- private byte[] messageBuffer = new byte[DefaultBufferSize];
+ private byte[] messageBuffer;
private int expectedContentLength;
private Dictionary messageHeaders;
- enum ReadState
+ private enum ReadState
{
Headers,
Content
@@ -85,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
this.needsMoreData = false;
// Do we need to look for message headers?
- if (this.readState == ReadState.Headers &&
+ if (this.readState == ReadState.Headers &&
!this.TryReadMessageHeaders())
{
// If we don't have enough data to read headers yet, keep reading
@@ -94,7 +94,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
// Do we need to look for message content?
- if (this.readState == ReadState.Content &&
+ if (this.readState == ReadState.Content &&
!this.TryReadMessageContent(out messageContent))
{
// If we don't have enough data yet to construct the content, keep reading
@@ -106,16 +106,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
break;
}
+ // Now that we have a message, reset the buffer's state
+ ShiftBufferBytesAndShrink(readOffset);
+
// Get the JObject for the JSON content
JObject messageObject = JObject.Parse(messageContent);
- // Load the message
- Logger.Write(
- LogLevel.Verbose,
- string.Format(
- "READ MESSAGE:\r\n\r\n{0}",
- messageObject.ToString(Formatting.Indented)));
-
// Return the parsed message
return this.messageSerializer.DeserializeMessage(messageObject);
}
@@ -162,8 +158,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
{
int scanOffset = this.readOffset;
- // Scan for the final double-newline that marks the
- // end of the header lines
+ // Scan for the final double-newline that marks the end of the header lines
while (scanOffset + 3 < this.bufferEndOffset &&
(this.messageBuffer[scanOffset] != CR ||
this.messageBuffer[scanOffset + 1] != LF ||
@@ -173,45 +168,51 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
scanOffset++;
}
- // No header or body separator found (e.g CRLFCRLF)
+ // Make sure we haven't reached the end of the buffer without finding a separator (e.g CRLFCRLF)
if (scanOffset + 3 >= this.bufferEndOffset)
{
return false;
}
- this.messageHeaders = new Dictionary();
+ // Convert the header block into a array of lines
+ var headers = Encoding.ASCII.GetString(this.messageBuffer, this.readOffset, scanOffset)
+ .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
- var headers =
- Encoding.ASCII
- .GetString(this.messageBuffer, this.readOffset, scanOffset)
- .Split(NewLineDelimiters, StringSplitOptions.RemoveEmptyEntries);
-
- // Read each header and store it in the dictionary
- foreach (var header in headers)
+ try
{
- int currentLength = header.IndexOf(':');
- if (currentLength == -1)
+ // Read each header and store it in the dictionary
+ this.messageHeaders = new Dictionary();
+ foreach (var header in headers)
{
- throw new ArgumentException("Message header must separate key and value using :");
+ int currentLength = header.IndexOf(':');
+ if (currentLength == -1)
+ {
+ throw new ArgumentException("Message header must separate key and value using :");
+ }
+
+ var key = header.Substring(0, currentLength);
+ var value = header.Substring(currentLength + 1).Trim();
+ this.messageHeaders[key] = value;
}
- var key = header.Substring(0, currentLength);
- var value = header.Substring(currentLength + 1).Trim();
- this.messageHeaders[key] = value;
- }
+ // Parse out the content length as an int
+ string contentLengthString;
+ if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
+ {
+ throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
+ }
- // Make sure a Content-Length header was present, otherwise it
- // is a fatal error
- string contentLengthString = null;
- if (!this.messageHeaders.TryGetValue("Content-Length", out contentLengthString))
- {
- throw new MessageParseException("", "Fatal error: Content-Length header must be provided.");
+ // Parse the content length to an integer
+ if (!int.TryParse(contentLengthString, out this.expectedContentLength))
+ {
+ throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
+ }
}
-
- // Parse the content length to an integer
- if (!int.TryParse(contentLengthString, out this.expectedContentLength))
+ catch (Exception)
{
- throw new MessageParseException("", "Fatal error: Content-Length value is not an integer.");
+ // The content length was invalid or missing. Trash the buffer we've read
+ ShiftBufferBytesAndShrink(scanOffset + 4);
+ throw;
}
// Skip past the headers plus the newline characters
@@ -234,31 +235,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol
}
// Convert the message contents to a string using the specified encoding
- messageContent =
- this.messageEncoding.GetString(
- this.messageBuffer,
- this.readOffset,
- this.expectedContentLength);
+ messageContent = this.messageEncoding.GetString(
+ this.messageBuffer,
+ this.readOffset,
+ this.expectedContentLength);
- // Move the remaining bytes to the front of the buffer for the next message
- var remainingByteCount = this.bufferEndOffset - (this.expectedContentLength + this.readOffset);
- Buffer.BlockCopy(
- this.messageBuffer,
- this.expectedContentLength + this.readOffset,
- this.messageBuffer,
- 0,
- remainingByteCount);
+ readOffset += expectedContentLength;
- // Reset the offsets for the next read
- this.readOffset = 0;
- this.bufferEndOffset = remainingByteCount;
-
- // Done reading content, now look for headers
+ // Done reading content, now look for headers for the next message
this.readState = ReadState.Headers;
return true;
}
+ private void ShiftBufferBytesAndShrink(int bytesToRemove)
+ {
+ // Create a new buffer that is shrunken by the number of bytes to remove
+ // Note: by using Max, we can guarantee a buffer of at least default buffer size
+ byte[] newBuffer = new byte[Math.Max(messageBuffer.Length - bytesToRemove, DefaultBufferSize)];
+
+ // If we need to do shifting, do the shifting
+ if (bytesToRemove <= messageBuffer.Length)
+ {
+ // Copy the existing buffer starting at the offset to remove
+ Buffer.BlockCopy(messageBuffer, bytesToRemove, newBuffer, 0, bufferEndOffset - bytesToRemove);
+ }
+
+ // Make the new buffer the message buffer
+ messageBuffer = newBuffer;
+
+ // Reset the read offset and the end offset
+ readOffset = 0;
+ bufferEndOffset -= bytesToRemove;
+ }
+
#endregion
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
index 14148778..a390eae2 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
@@ -5,8 +5,6 @@
using System;
using System.Collections.Generic;
-using System.Data;
-using System.Data.Common;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
@@ -16,159 +14,6 @@ using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{
- internal class IntellisenseCache
- {
- // connection used to query for intellisense info
- private DbConnection connection;
-
- // number of documents (URI's) that are using the cache for the same database
- // the autocomplete service uses this to remove unreferenced caches
- public int ReferenceCount { get; set; }
-
- public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails)
- {
- ReferenceCount = 0;
- DatabaseInfo = CopySummary(connectionDetails);
-
- // TODO error handling on this. Intellisense should catch or else the service should handle
- connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails));
- connection.Open();
- }
-
- ///
- /// Used to identify a database for which this cache is used
- ///
- public ConnectionSummary DatabaseInfo
- {
- get;
- private set;
- }
- ///
- /// Gets the current autocomplete candidate list
- ///
- public IEnumerable AutoCompleteList { get; private set; }
-
- public async Task UpdateCache()
- {
- DbCommand command = connection.CreateCommand();
- command.CommandText = "SELECT name FROM sys.tables";
- command.CommandTimeout = 15;
- command.CommandType = CommandType.Text;
- var reader = await command.ExecuteReaderAsync();
-
- List results = new List();
- while (await reader.ReadAsync())
- {
- results.Add(reader[0].ToString());
- }
-
- AutoCompleteList = results;
- await Task.FromResult(0);
- }
-
- public List GetAutoCompleteItems(TextDocumentPosition textDocumentPosition)
- {
- List completions = new List();
-
- int i = 0;
-
- // Take a reference to the list at a point in time in case we update and replace the list
- var suggestions = AutoCompleteList;
- // the completion list will be null is user not connected to server
- if (this.AutoCompleteList != null)
- {
-
- foreach (var autoCompleteItem in suggestions)
- {
- // convert the completion item candidates into CompletionItems
- completions.Add(new CompletionItem()
- {
- Label = autoCompleteItem,
- Kind = CompletionItemKind.Keyword,
- Detail = autoCompleteItem + " details",
- Documentation = autoCompleteItem + " documentation",
- TextEdit = new TextEdit
- {
- NewText = autoCompleteItem,
- Range = new Range
- {
- Start = new Position
- {
- Line = textDocumentPosition.Position.Line,
- Character = textDocumentPosition.Position.Character
- },
- End = new Position
- {
- Line = textDocumentPosition.Position.Line,
- Character = textDocumentPosition.Position.Character + 5
- }
- }
- }
- });
-
- // only show 50 items
- if (++i == 50)
- {
- break;
- }
- }
- }
-
- return completions;
- }
-
- private static ConnectionSummary CopySummary(ConnectionSummary summary)
- {
- return new ConnectionSummary()
- {
- ServerName = summary.ServerName,
- DatabaseName = summary.DatabaseName,
- UserName = summary.UserName
- };
- }
- }
-
- ///
- /// Treats connections as the same if their server, db and usernames all match
- ///
- public class ConnectionSummaryComparer : IEqualityComparer
- {
- public bool Equals(ConnectionSummary x, ConnectionSummary y)
- {
- if(x == y) { return true; }
- else if(x != null)
- {
- if(y == null) { return false; }
-
- // Compare server, db, username. Note: server is case-insensitive in the driver
- return string.Compare(x.ServerName, y.ServerName, StringComparison.OrdinalIgnoreCase) == 0
- && string.Compare(x.DatabaseName, y.DatabaseName, StringComparison.Ordinal) == 0
- && string.Compare(x.UserName, y.UserName, StringComparison.Ordinal) == 0;
- }
- return false;
- }
-
- public int GetHashCode(ConnectionSummary obj)
- {
- int hashcode = 31;
- if(obj != null)
- {
- if(obj.ServerName != null)
- {
- hashcode ^= obj.ServerName.GetHashCode();
- }
- if (obj.DatabaseName != null)
- {
- hashcode ^= obj.DatabaseName.GetHashCode();
- }
- if (obj.UserName != null)
- {
- hashcode ^= obj.UserName.GetHashCode();
- }
- }
- return hashcode;
- }
- }
///
/// Main class for Autocomplete functionality
///
@@ -235,13 +80,36 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
}
}
}
+
+ private ConnectionService connectionService = null;
+
+ ///
+ /// Internal for testing purposes only
+ ///
+ internal ConnectionService ConnectionServiceInstance
+ {
+ get
+ {
+ if(connectionService == null)
+ {
+ connectionService = ConnectionService.Instance;
+ }
+ return connectionService;
+ }
+
+ set
+ {
+ connectionService = value;
+ }
+ }
+
public void InitializeService(ServiceHost serviceHost)
{
// Register a callback for when a connection is created
- ConnectionService.Instance.RegisterOnConnectionTask(UpdateAutoCompleteCache);
+ ConnectionServiceInstance.RegisterOnConnectionTask(UpdateAutoCompleteCache);
// Register a callback for when a connection is closed
- ConnectionService.Instance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference);
+ ConnectionServiceInstance.RegisterOnDisconnectTask(RemoveAutoCompleteCacheUriReference);
}
private async Task UpdateAutoCompleteCache(ConnectionInfo connectionInfo)
@@ -252,6 +120,14 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
}
}
+ ///
+ /// Intellisense cache count access for testing.
+ ///
+ internal int GetCacheCount()
+ {
+ return caches.Count;
+ }
+
///
/// Remove a reference to an autocomplete cache from a URI. If
/// it is the last URI connected to a particular connection,
@@ -312,7 +188,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
// that are not backed by a SQL connection
ConnectionInfo info;
IntellisenseCache cache;
- if (ConnectionService.Instance.TryFindConnection(textDocumentPosition.Uri, out info)
+ if (ConnectionServiceInstance.TryFindConnection(textDocumentPosition.Uri, out info)
&& caches.TryGetValue((ConnectionSummary)info.ConnectionDetails, out cache))
{
return cache.GetAutoCompleteItems(textDocumentPosition).ToArray();
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs
new file mode 100644
index 00000000..eea72771
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/IntellisenseCache.cs
@@ -0,0 +1,122 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
+{
+ internal class IntellisenseCache
+ {
+ ///
+ /// connection used to query for intellisense info
+ ///
+ private DbConnection connection;
+
+ ///
+ /// Number of documents (URI's) that are using the cache for the same database.
+ /// The autocomplete service uses this to remove unreferenced caches.
+ ///
+ public int ReferenceCount { get; set; }
+
+ public IntellisenseCache(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails)
+ {
+ ReferenceCount = 0;
+ DatabaseInfo = connectionDetails.Clone();
+
+ // TODO error handling on this. Intellisense should catch or else the service should handle
+ connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails));
+ connection.Open();
+ }
+
+ ///
+ /// Used to identify a database for which this cache is used
+ ///
+ public ConnectionSummary DatabaseInfo
+ {
+ get;
+ private set;
+ }
+ ///
+ /// Gets the current autocomplete candidate list
+ ///
+ public IEnumerable AutoCompleteList { get; private set; }
+
+ public async Task UpdateCache()
+ {
+ DbCommand command = connection.CreateCommand();
+ command.CommandText = "SELECT name FROM sys.tables";
+ command.CommandTimeout = 15;
+ command.CommandType = CommandType.Text;
+ var reader = await command.ExecuteReaderAsync();
+
+ List results = new List();
+ while (await reader.ReadAsync())
+ {
+ results.Add(reader[0].ToString());
+ }
+
+ AutoCompleteList = results;
+ await Task.FromResult(0);
+ }
+
+ public List GetAutoCompleteItems(TextDocumentPosition textDocumentPosition)
+ {
+ List completions = new List();
+
+ int i = 0;
+
+ // Take a reference to the list at a point in time in case we update and replace the list
+ var suggestions = AutoCompleteList;
+ // the completion list will be null is user not connected to server
+ if (this.AutoCompleteList != null)
+ {
+
+ foreach (var autoCompleteItem in suggestions)
+ {
+ // convert the completion item candidates into CompletionItems
+ completions.Add(new CompletionItem()
+ {
+ Label = autoCompleteItem,
+ Kind = CompletionItemKind.Keyword,
+ Detail = autoCompleteItem + " details",
+ Documentation = autoCompleteItem + " documentation",
+ TextEdit = new TextEdit
+ {
+ NewText = autoCompleteItem,
+ Range = new Range
+ {
+ Start = new Position
+ {
+ Line = textDocumentPosition.Position.Line,
+ Character = textDocumentPosition.Position.Character
+ },
+ End = new Position
+ {
+ Line = textDocumentPosition.Position.Line,
+ Character = textDocumentPosition.Position.Character + 5
+ }
+ }
+ }
+ });
+
+ // only show 50 items
+ if (++i == 50)
+ {
+ break;
+ }
+ }
+ }
+
+ return completions;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
index 6cdbd745..f380f1bb 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
@@ -337,7 +337,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{
Logger.Write(
LogLevel.Error,
- String.Format(
+ string.Format(
"Exception while cancelling analysis task:\n\n{0}",
e.ToString()));
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs
new file mode 100644
index 00000000..3eb87f4f
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs
@@ -0,0 +1,36 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts
+{
+ ///
+ /// Parameters for the query cancellation request
+ ///
+ public class QueryCancelParams
+ {
+ public string OwnerUri { get; set; }
+ }
+
+ ///
+ /// Parameters to return as the result of a query dispose request
+ ///
+ public class QueryCancelResult
+ {
+ ///
+ /// Any error messages that occurred during disposing the result set. Optional, can be set
+ /// to null if there were no errors.
+ ///
+ public string Messages { get; set; }
+ }
+
+ public class QueryCancelRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("query/cancel");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
index 887bbbaf..d9a886d4 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
@@ -177,12 +177,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
HasError = true;
UnwrapDbException(dbe);
- conn?.Dispose();
}
catch (Exception)
{
HasError = true;
- conn?.Dispose();
throw;
}
finally
@@ -233,6 +231,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
};
}
+ ///
+ /// Cancels the query by issuing the cancellation token
+ ///
+ public void Cancel()
+ {
+ // Make sure that the query hasn't completed execution
+ if (HasExecuted)
+ {
+ throw new InvalidOperationException("The query has already completed, it cannot be cancelled.");
+ }
+
+ // Issue the cancellation token for the query
+ cancellationSource.Cancel();
+ }
+
///
/// Delegate handler for storing messages that are returned from the server
/// NOTE: Only messages that are below a certain severity will be returned via this
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
index 389be092..5b26eafc 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
@@ -73,6 +73,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
serviceHost.SetRequestHandler(QueryExecuteRequest.Type, HandleExecuteRequest);
serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest);
serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest);
+ serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest);
// Register handler for shutdown event
serviceHost.RegisterShutdownTask((shutdownParams, requestContext) =>
@@ -178,6 +179,52 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
}
+ public async Task HandleCancelRequest(QueryCancelParams cancelParams,
+ RequestContext requestContext)
+ {
+ try
+ {
+ // Attempt to find the query for the owner uri
+ Query result;
+ if (!ActiveQueries.TryGetValue(cancelParams.OwnerUri, out result))
+ {
+ await requestContext.SendResult(new QueryCancelResult
+ {
+ Messages = "Failed to cancel query, ID not found."
+ });
+ return;
+ }
+
+ // Cancel the query
+ result.Cancel();
+
+ // Attempt to dispose the query
+ if (!ActiveQueries.TryRemove(cancelParams.OwnerUri, out result))
+ {
+ // It really shouldn't be possible to get to this scenario, but we'll cover it anyhow
+ await requestContext.SendResult(new QueryCancelResult
+ {
+ Messages = "Query successfully cancelled, failed to dispose query. ID not found."
+ });
+ return;
+ }
+
+ await requestContext.SendResult(new QueryCancelResult());
+ }
+ catch (InvalidOperationException e)
+ {
+ // If this exception occurred, we most likely were trying to cancel a completed query
+ await requestContext.SendResult(new QueryCancelResult
+ {
+ Messages = e.Message
+ });
+ }
+ catch (Exception e)
+ {
+ await requestContext.SendError(e.Message);
+ }
+ }
+
#endregion
#region Private Helpers
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs
index 701fa6f5..f47cacb9 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs
@@ -182,7 +182,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace
foreach (var textChange in textChangeParams.ContentChanges)
{
string fileUri = textChangeParams.Uri ?? textChangeParams.TextDocument.Uri;
- msg.AppendLine(String.Format(" File: {0}", fileUri));
+ msg.AppendLine(string.Format(" File: {0}", fileUri));
ScriptFile changedFile = Workspace.GetFile(fileUri);
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs
index 80ea3ec9..ddf82059 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs
@@ -3,11 +3,18 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
+using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
using Microsoft.SqlTools.Test.Utility;
+using Moq;
+using Moq.Protected;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
@@ -19,6 +26,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
{
#region "Diagnostics tests"
+ ///
+ /// Verify that the latest SqlParser (2016 as of this writing) is used by default
+ ///
+ [Fact]
+ public void LatestSqlParserIsUsedByDefault()
+ {
+ // This should only parse correctly on SQL server 2016 or newer
+ const string sql2016Text =
+ @"CREATE SECURITY POLICY [FederatedSecurityPolicy]" + "\r\n" +
+ @"ADD FILTER PREDICATE [rls].[fn_securitypredicate]([CustomerId])" + "\r\n" +
+ @"ON [dbo].[Customer];";
+
+ LanguageService service = TestObjects.GetTestLanguageService();
+
+ // parse
+ var scriptFile = new ScriptFile();
+ scriptFile.SetFileContents(sql2016Text);
+ ScriptFileMarker[] fileMarkers = service.GetSemanticMarkers(scriptFile);
+
+ // verify that no errors are detected
+ Assert.Equal(0, fileMarkers.Length);
+ }
+
///
/// Verify that the SQL parser correctly detects errors in text
///
@@ -108,24 +138,179 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices
#region "Autocomplete Tests"
///
- /// Verify that the SQL parser correctly detects errors in text
+ /// Creates a mock db command that returns a predefined result set
+ ///
+ public static DbCommand CreateTestCommand(Dictionary[][] data)
+ {
+ var commandMock = new Mock { CallBase = true };
+ var commandMockSetup = commandMock.Protected()
+ .Setup("ExecuteDbDataReader", It.IsAny());
+
+ commandMockSetup.Returns(new TestDbDataReader(data));
+
+ return commandMock.Object;
+ }
+
+ ///
+ /// Creates a mock db connection that returns predefined data when queried for a result set
+ ///
+ public DbConnection CreateMockDbConnection(Dictionary[][] data)
+ {
+ var connectionMock = new Mock { CallBase = true };
+ connectionMock.Protected()
+ .Setup("CreateDbCommand")
+ .Returns(CreateTestCommand(data));
+
+ return connectionMock.Object;
+ }
+
+ ///
+ /// Verify that the autocomplete service returns tables for the current connection as suggestions
///
[Fact]
- public async Task AutocompleteTest()
+ public void TablesAreReturnedAsAutocompleteSuggestions()
{
- // TODO Re-enable this test once we have a way to hook up the right auto-complete and connection services.
- // Probably need a service provider channel so that we can mock service access. Otherwise everything accesses
- // static instances and cannot be properly tested.
-
- //var autocompleteService = TestObjects.GetAutoCompleteService();
- //var connectionService = TestObjects.GetTestConnectionService();
+ // Result set for the query of database tables
+ Dictionary[] data =
+ {
+ new Dictionary { {"name", "master" } },
+ new Dictionary { {"name", "model" } }
+ };
- //ConnectParams connectionRequest = TestObjects.GetTestConnectionParams();
- //var connectionResult = connectionService.Connect(connectionRequest);
+ var mockFactory = new Mock();
+ mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny()))
+ .Returns(CreateMockDbConnection(new[] {data}));
+
+ var connectionService = TestObjects.GetTestConnectionService();
+ var autocompleteService = new AutoCompleteService();
+ autocompleteService.ConnectionServiceInstance = connectionService;
+ autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance);
+
+ autocompleteService.ConnectionFactory = mockFactory.Object;
- //var sqlConnection = connectionService.ActiveConnections[connectionResult.ConnectionId];
- //await autocompleteService.UpdateAutoCompleteCache(sqlConnection);
- await Task.Run(() => { return; });
+ // Open a connection
+ // The cache should get updated as part of this
+ ConnectParams connectionRequest = TestObjects.GetTestConnectionParams();
+ var connectionResult = connectionService.Connect(connectionRequest);
+ Assert.NotEmpty(connectionResult.ConnectionId);
+
+ // Check that there is one cache created in the auto complete service
+ Assert.Equal(1, autocompleteService.GetCacheCount());
+
+ // Check that we get table suggestions for an autocomplete request
+ TextDocumentPosition position = new TextDocumentPosition();
+ position.Uri = connectionRequest.OwnerUri;
+ position.Position = new Position();
+ position.Position.Line = 1;
+ position.Position.Character = 1;
+ var items = autocompleteService.GetCompletionItems(position);
+ Assert.Equal(2, items.Length);
+ Assert.Equal("master", items[0].Label);
+ Assert.Equal("model", items[1].Label);
+ }
+
+ ///
+ /// Verify that only one intellisense cache is created for two documents using
+ /// the autocomplete service when they share a common connection.
+ ///
+ [Fact]
+ public void OnlyOneCacheIsCreatedForTwoDocumentsWithSameConnection()
+ {
+ var connectionService = TestObjects.GetTestConnectionService();
+ var autocompleteService = new AutoCompleteService();
+ autocompleteService.ConnectionServiceInstance = connectionService;
+ autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance);
+
+ // Open two connections
+ ConnectParams connectionRequest1 = TestObjects.GetTestConnectionParams();
+ connectionRequest1.OwnerUri = "file:///my/first/file.sql";
+ ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams();
+ connectionRequest2.OwnerUri = "file:///my/second/file.sql";
+ var connectionResult1 = connectionService.Connect(connectionRequest1);
+ Assert.NotEmpty(connectionResult1.ConnectionId);
+ var connectionResult2 = connectionService.Connect(connectionRequest2);
+ Assert.NotEmpty(connectionResult2.ConnectionId);
+
+ // Verify that only one intellisense cache is created to service both URI's
+ Assert.Equal(1, autocompleteService.GetCacheCount());
+ }
+
+ ///
+ /// Verify that two different intellisense caches and corresponding autocomplete
+ /// suggestions are provided for two documents with different connections.
+ ///
+ [Fact]
+ public void TwoCachesAreCreatedForTwoDocumentsWithDifferentConnections()
+ {
+ // Result set for the query of database tables
+ Dictionary[] data1 =
+ {
+ new Dictionary { {"name", "master" } },
+ new Dictionary { {"name", "model" } }
+ };
+
+ Dictionary[] data2 =
+ {
+ new Dictionary { {"name", "master" } },
+ new Dictionary { {"name", "my_table" } },
+ new Dictionary { {"name", "my_other_table" } }
+ };
+
+ var mockFactory = new Mock();
+ mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny()))
+ .Returns(CreateMockDbConnection(new[] {data1}))
+ .Returns(CreateMockDbConnection(new[] {data2}));
+
+ var connectionService = TestObjects.GetTestConnectionService();
+ var autocompleteService = new AutoCompleteService();
+ autocompleteService.ConnectionServiceInstance = connectionService;
+ autocompleteService.InitializeService(Microsoft.SqlTools.ServiceLayer.Hosting.ServiceHost.Instance);
+
+ autocompleteService.ConnectionFactory = mockFactory.Object;
+
+ // Open connections
+ // The cache should get updated as part of this
+ ConnectParams connectionRequest = TestObjects.GetTestConnectionParams();
+ connectionRequest.OwnerUri = "file:///my/first/sql/file.sql";
+ var connectionResult = connectionService.Connect(connectionRequest);
+ Assert.NotEmpty(connectionResult.ConnectionId);
+
+ // Check that there is one cache created in the auto complete service
+ Assert.Equal(1, autocompleteService.GetCacheCount());
+
+ // Open second connection
+ ConnectParams connectionRequest2 = TestObjects.GetTestConnectionParams();
+ connectionRequest2.OwnerUri = "file:///my/second/sql/file.sql";
+ connectionRequest2.Connection.DatabaseName = "my_other_db";
+ var connectionResult2 = connectionService.Connect(connectionRequest2);
+ Assert.NotEmpty(connectionResult2.ConnectionId);
+
+ // Check that there are now two caches in the auto complete service
+ Assert.Equal(2, autocompleteService.GetCacheCount());
+
+ // Check that we get 2 different table suggestions for autocomplete requests
+ TextDocumentPosition position = new TextDocumentPosition();
+ position.Uri = connectionRequest.OwnerUri;
+ position.Position = new Position();
+ position.Position.Line = 1;
+ position.Position.Character = 1;
+
+ var items = autocompleteService.GetCompletionItems(position);
+ Assert.Equal(2, items.Length);
+ Assert.Equal("master", items[0].Label);
+ Assert.Equal("model", items[1].Label);
+
+ TextDocumentPosition position2 = new TextDocumentPosition();
+ position2.Uri = connectionRequest2.OwnerUri;
+ position2.Position = new Position();
+ position2.Position.Line = 1;
+ position2.Position.Character = 1;
+
+ var items2 = autocompleteService.GetCompletionItems(position2);
+ Assert.Equal(3, items2.Length);
+ Assert.Equal("master", items2[0].Label);
+ Assert.Equal("my_table", items2[1].Label);
+ Assert.Equal("my_other_table", items2[2].Label);
}
#endregion
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs
deleted file mode 100644
index 54fbf01f..00000000
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/MessageReaderWriterTests.cs
+++ /dev/null
@@ -1,178 +0,0 @@
-//
-// Copyright (c) Microsoft. All rights reserved.
-// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-//
-
-using System;
-using System.IO;
-using System.Text;
-using System.Threading.Tasks;
-using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
-using HostingMessage = Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts.Message;
-using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers;
-using Xunit;
-
-namespace Microsoft.SqlTools.ServiceLayer.Test.Message
-{
- public class MessageReaderWriterTests
- {
- const string TestEventString = "{\"type\":\"event\",\"event\":\"testEvent\",\"body\":null}";
- const string TestEventFormatString = "{{\"event\":\"testEvent\",\"body\":{{\"someString\":\"{0}\"}},\"seq\":0,\"type\":\"event\"}}";
- readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString);
-
- private IMessageSerializer messageSerializer;
-
- public MessageReaderWriterTests()
- {
- this.messageSerializer = new V8MessageSerializer();
- }
-
- [Fact]
- public async Task WritesMessage()
- {
- MemoryStream outputStream = new MemoryStream();
-
- MessageWriter messageWriter =
- new MessageWriter(
- outputStream,
- this.messageSerializer);
-
- // Write the message and then roll back the stream to be read
- // TODO: This will need to be redone!
- await messageWriter.WriteMessage(HostingMessage.Event("testEvent", null));
- outputStream.Seek(0, SeekOrigin.Begin);
-
- string expectedHeaderString =
- string.Format(
- Constants.ContentLengthFormatString,
- ExpectedMessageByteCount);
-
- byte[] buffer = new byte[128];
- await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length);
-
- Assert.Equal(
- expectedHeaderString,
- Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length));
-
- // Read the message
- await outputStream.ReadAsync(buffer, 0, ExpectedMessageByteCount);
-
- Assert.Equal(
- TestEventString,
- Encoding.UTF8.GetString(buffer, 0, ExpectedMessageByteCount));
-
- outputStream.Dispose();
- }
-
- [Fact]
- public void ReadsMessage()
- {
- MemoryStream inputStream = new MemoryStream();
- MessageReader messageReader =
- new MessageReader(
- inputStream,
- this.messageSerializer);
-
- // Write a message to the stream
- byte[] messageBuffer = this.GetMessageBytes(TestEventString);
- inputStream.Write(
- this.GetMessageBytes(TestEventString),
- 0,
- messageBuffer.Length);
-
- inputStream.Flush();
- inputStream.Seek(0, SeekOrigin.Begin);
-
- HostingMessage messageResult = messageReader.ReadMessage().Result;
- Assert.Equal("testEvent", messageResult.Method);
-
- inputStream.Dispose();
- }
-
- [Fact]
- public void ReadsManyBufferedMessages()
- {
- MemoryStream inputStream = new MemoryStream();
- MessageReader messageReader =
- new MessageReader(
- inputStream,
- this.messageSerializer);
-
- // Get a message to use for writing to the stream
- byte[] messageBuffer = this.GetMessageBytes(TestEventString);
-
- // How many messages of this size should we write to overflow the buffer?
- int overflowMessageCount =
- (int)Math.Ceiling(
- (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length);
-
- // Write the necessary number of messages to the stream
- for (int i = 0; i < overflowMessageCount; i++)
- {
- inputStream.Write(messageBuffer, 0, messageBuffer.Length);
- }
-
- inputStream.Flush();
- inputStream.Seek(0, SeekOrigin.Begin);
-
- // Read the written messages from the stream
- for (int i = 0; i < overflowMessageCount; i++)
- {
- HostingMessage messageResult = messageReader.ReadMessage().Result;
- Assert.Equal("testEvent", messageResult.Method);
- }
-
- inputStream.Dispose();
- }
-
- [Fact]
- public void ReaderResizesBufferForLargeMessages()
- {
- MemoryStream inputStream = new MemoryStream();
- MessageReader messageReader =
- new MessageReader(
- inputStream,
- this.messageSerializer);
-
- // Get a message with content so large that the buffer will need
- // to be resized to fit it all.
- byte[] messageBuffer =
- this.GetMessageBytes(
- string.Format(
- TestEventFormatString,
- new String('X', (int)(MessageReader.DefaultBufferSize * 3))));
-
- inputStream.Write(messageBuffer, 0, messageBuffer.Length);
- inputStream.Flush();
- inputStream.Seek(0, SeekOrigin.Begin);
-
- HostingMessage messageResult = messageReader.ReadMessage().Result;
- Assert.Equal("testEvent", messageResult.Method);
-
- inputStream.Dispose();
- }
-
- private byte[] GetMessageBytes(string messageString, Encoding encoding = null)
- {
- if (encoding == null)
- {
- encoding = Encoding.UTF8;
- }
-
- byte[] messageBytes = Encoding.UTF8.GetBytes(messageString);
- byte[] headerBytes =
- Encoding.ASCII.GetBytes(
- string.Format(
- Constants.ContentLengthFormatString,
- messageBytes.Length));
-
- // Copy the bytes into a single buffer
- byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length];
- Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length);
- Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length);
-
- return finalBytes;
- }
- }
-}
-
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs
new file mode 100644
index 00000000..b575fe0c
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/Common.cs
@@ -0,0 +1,17 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Text;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
+{
+ public class Common
+ {
+ public const string TestEventString = @"{""type"":""event"",""event"":""testEvent"",""body"":null}";
+ public const string TestEventFormatString = @"{{""event"":""testEvent"",""body"":{{""someString"":""{0}""}},""seq"":0,""type"":""event""}}";
+ public static readonly int ExpectedMessageByteCount = Encoding.UTF8.GetByteCount(TestEventString);
+
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs
new file mode 100644
index 00000000..0a12dc3e
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageReaderTests.cs
@@ -0,0 +1,241 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.IO;
+using System.Text;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers;
+using Newtonsoft.Json;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
+{
+ public class MessageReaderTests
+ {
+
+ private readonly IMessageSerializer messageSerializer;
+
+ public MessageReaderTests()
+ {
+ this.messageSerializer = new V8MessageSerializer();
+ }
+
+ [Fact]
+ public void ReadsMessage()
+ {
+ MemoryStream inputStream = new MemoryStream();
+ MessageReader messageReader = new MessageReader(inputStream, this.messageSerializer);
+
+ // Write a message to the stream
+ byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString);
+ inputStream.Write(this.GetMessageBytes(Common.TestEventString), 0, messageBuffer.Length);
+
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ Message messageResult = messageReader.ReadMessage().Result;
+ Assert.Equal("testEvent", messageResult.Method);
+
+ inputStream.Dispose();
+ }
+
+ [Fact]
+ public void ReadsManyBufferedMessages()
+ {
+ MemoryStream inputStream = new MemoryStream();
+ MessageReader messageReader =
+ new MessageReader(
+ inputStream,
+ this.messageSerializer);
+
+ // Get a message to use for writing to the stream
+ byte[] messageBuffer = this.GetMessageBytes(Common.TestEventString);
+
+ // How many messages of this size should we write to overflow the buffer?
+ int overflowMessageCount =
+ (int)Math.Ceiling(
+ (MessageReader.DefaultBufferSize * 1.5) / messageBuffer.Length);
+
+ // Write the necessary number of messages to the stream
+ for (int i = 0; i < overflowMessageCount; i++)
+ {
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ }
+
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ // Read the written messages from the stream
+ for (int i = 0; i < overflowMessageCount; i++)
+ {
+ Message messageResult = messageReader.ReadMessage().Result;
+ Assert.Equal("testEvent", messageResult.Method);
+ }
+
+ inputStream.Dispose();
+ }
+
+ [Fact]
+ public void ReadMalformedMissingHeaderTest()
+ {
+ using (MemoryStream inputStream = new MemoryStream())
+ {
+ // If:
+ // ... I create a new stream and pass it information that is malformed
+ // ... and attempt to read a message from it
+ MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
+ byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n");
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ // Then:
+ // ... An exception should be thrown while reading
+ Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait();
+ }
+ }
+
+ [Fact]
+ public void ReadMalformedContentLengthNonIntegerTest()
+ {
+ using (MemoryStream inputStream = new MemoryStream())
+ {
+ // If:
+ // ... I create a new stream and pass it a non-integer content-length header
+ // ... and attempt to read a message from it
+ MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
+ byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: asdf\r\n\r\n");
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ // Then:
+ // ... An exception should be thrown while reading
+ Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait();
+ }
+ }
+
+ [Fact]
+ public void ReadMissingContentLengthHeaderTest()
+ {
+ using (MemoryStream inputStream = new MemoryStream())
+ {
+ // If:
+ // ... I create a new stream and pass it a a message without a content-length header
+ // ... and attempt to read a message from it
+ MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
+ byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Type: asdf\r\n\r\n");
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ // Then:
+ // ... An exception should be thrown while reading
+ Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait();
+ }
+ }
+
+ [Fact]
+ public void ReadMalformedContentLengthTooShortTest()
+ {
+ using (MemoryStream inputStream = new MemoryStream())
+ {
+ // If:
+ // ... Pass in an event that has an incorrect content length
+ // ... And pass in an event that is correct
+ MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
+ byte[] messageBuffer = Encoding.ASCII.GetBytes("Content-Length: 10\r\n\r\n");
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ messageBuffer = Encoding.UTF8.GetBytes(Common.TestEventString);
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ messageBuffer = Encoding.ASCII.GetBytes("\r\n\r\n");
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ // Then:
+ // ... The first read should fail with an exception while deserializing
+ Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait();
+
+ // ... The second read should fail with an exception while reading headers
+ Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait();
+ }
+ }
+
+ [Fact]
+ public void ReadMalformedThenValidTest()
+ {
+ // If:
+ // ... I create a new stream and pass it information that is malformed
+ // ... and attempt to read a message from it
+ // ... Then pass it information that is valid and attempt to read a message from it
+ using (MemoryStream inputStream = new MemoryStream())
+ {
+ MessageReader messageReader = new MessageReader(inputStream, messageSerializer);
+ byte[] messageBuffer = Encoding.ASCII.GetBytes("This is an invalid header\r\n\r\n");
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ messageBuffer = GetMessageBytes(Common.TestEventString);
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ // Then:
+ // ... An exception should be thrown while reading the first one
+ Assert.ThrowsAsync(() => messageReader.ReadMessage()).Wait();
+
+ // ... A test event should be successfully read from the second one
+ Message messageResult = messageReader.ReadMessage().Result;
+ Assert.NotNull(messageResult);
+ Assert.Equal("testEvent", messageResult.Method);
+ }
+ }
+
+ [Fact]
+ public void ReaderResizesBufferForLargeMessages()
+ {
+ MemoryStream inputStream = new MemoryStream();
+ MessageReader messageReader =
+ new MessageReader(
+ inputStream,
+ this.messageSerializer);
+
+ // Get a message with content so large that the buffer will need
+ // to be resized to fit it all.
+ byte[] messageBuffer = this.GetMessageBytes(
+ string.Format(
+ Common.TestEventFormatString,
+ new String('X', (int) (MessageReader.DefaultBufferSize*3))));
+
+ inputStream.Write(messageBuffer, 0, messageBuffer.Length);
+ inputStream.Flush();
+ inputStream.Seek(0, SeekOrigin.Begin);
+
+ Message messageResult = messageReader.ReadMessage().Result;
+ Assert.Equal("testEvent", messageResult.Method);
+
+ inputStream.Dispose();
+ }
+
+ private byte[] GetMessageBytes(string messageString, Encoding encoding = null)
+ {
+ if (encoding == null)
+ {
+ encoding = Encoding.UTF8;
+ }
+
+ byte[] messageBytes = Encoding.UTF8.GetBytes(messageString);
+ byte[] headerBytes = Encoding.ASCII.GetBytes(string.Format(Constants.ContentLengthFormatString, messageBytes.Length));
+
+ // Copy the bytes into a single buffer
+ byte[] finalBytes = new byte[headerBytes.Length + messageBytes.Length];
+ Buffer.BlockCopy(headerBytes, 0, finalBytes, 0, headerBytes.Length);
+ Buffer.BlockCopy(messageBytes, 0, finalBytes, headerBytes.Length, messageBytes.Length);
+
+ return finalBytes;
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs
new file mode 100644
index 00000000..3c007a85
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/MessageWriterTests.cs
@@ -0,0 +1,55 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.IO;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Serializers;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
+{
+ public class MessageWriterTests
+ {
+ private readonly IMessageSerializer messageSerializer;
+
+ public MessageWriterTests()
+ {
+ this.messageSerializer = new V8MessageSerializer();
+ }
+
+ [Fact]
+ public async Task WritesMessage()
+ {
+ MemoryStream outputStream = new MemoryStream();
+ MessageWriter messageWriter = new MessageWriter(outputStream, this.messageSerializer);
+
+ // Write the message and then roll back the stream to be read
+ // TODO: This will need to be redone!
+ await messageWriter.WriteMessage(Hosting.Protocol.Contracts.Message.Event("testEvent", null));
+ outputStream.Seek(0, SeekOrigin.Begin);
+
+ string expectedHeaderString = string.Format(Constants.ContentLengthFormatString,
+ Common.ExpectedMessageByteCount);
+
+ byte[] buffer = new byte[128];
+ await outputStream.ReadAsync(buffer, 0, expectedHeaderString.Length);
+
+ Assert.Equal(
+ expectedHeaderString,
+ Encoding.ASCII.GetString(buffer, 0, expectedHeaderString.Length));
+
+ // Read the message
+ await outputStream.ReadAsync(buffer, 0, Common.ExpectedMessageByteCount);
+
+ Assert.Equal(Common.TestEventString,
+ Encoding.UTF8.GetString(buffer, 0, Common.ExpectedMessageByteCount));
+
+ outputStream.Dispose();
+ }
+
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs
similarity index 94%
rename from test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs
rename to test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs
index 0ba08056..89238098 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/Message/TestMessageTypes.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Messaging/TestMessageTypes.cs
@@ -6,7 +6,7 @@
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
-namespace Microsoft.SqlTools.ServiceLayer.Test.Message
+namespace Microsoft.SqlTools.ServiceLayer.Test.Messaging
{
#region Request Types
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs
new file mode 100644
index 00000000..dbc344f8
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs
@@ -0,0 +1,124 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+using Moq;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
+{
+ public class CancelTests
+ {
+ [Fact]
+ public void CancelInProgressQueryTest()
+ {
+ // If:
+ // ... I request a query (doesn't matter what kind) and execute it
+ var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true);
+ var executeParams = new QueryExecuteParams { QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri };
+ var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
+ queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
+ queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution
+
+ // ... And then I request to cancel the query
+ var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri};
+ QueryCancelResult result = null;
+ var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null);
+ queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait();
+
+ // Then:
+ // ... I should have seen a successful event (no messages)
+ VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never());
+ Assert.Null(result.Messages);
+
+ // ... The query should have been disposed as well
+ Assert.Empty(queryService.ActiveQueries);
+ }
+
+ [Fact]
+ public void CancelExecutedQueryTest()
+ {
+ // If:
+ // ... I request a query (doesn't matter what kind) and wait for execution
+ var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true);
+ var executeParams = new QueryExecuteParams {QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri};
+ var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null);
+ queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait();
+
+ // ... And then I request to cancel the query
+ var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri};
+ QueryCancelResult result = null;
+ var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null);
+ queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait();
+
+ // Then:
+ // ... I should have seen a result event with an error message
+ VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never());
+ Assert.NotNull(result.Messages);
+
+ // ... The query should not have been disposed
+ Assert.NotEmpty(queryService.ActiveQueries);
+ }
+
+ [Fact]
+ public void CancelNonExistantTest()
+ {
+ // If:
+ // ... I request to cancel a query that doesn't exist
+ var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false);
+ var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"};
+ QueryCancelResult result = null;
+ var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null);
+ queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait();
+
+ // Then:
+ // ... I should have seen a result event with an error message
+ VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never());
+ Assert.NotNull(result.Messages);
+ }
+
+ #region Mocking
+
+ private static Mock> GetQueryCancelResultContextMock(
+ Action resultCallback,
+ Action