diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index b11aa168..83b860a8 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -34,7 +34,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public string OwnerUri { get; private set; }
- private ISqlConnectionFactory Factory {get; set;}
+ public ISqlConnectionFactory Factory {get; private set;}
public ConnectionDetails ConnectionDetails { get; private set; }
@@ -123,16 +123,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
// Attempts to link a URI to an actively used connection for this URI
- public bool TryFindConnection(string ownerUri, out ConnectionSummary connectionSummary)
+ public bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo)
{
- connectionSummary = null;
- ConnectionInfo connectionInfo;
- if (this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo))
- {
- connectionSummary = CopySummary(connectionInfo.ConnectionDetails);
- return true;
- }
- return false;
+ return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo);
}
private static ConnectionSummary CopySummary(ConnectionSummary summary)
@@ -151,16 +144,33 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
///
public ConnectResponse Connect(ConnectParams connectionParams)
{
+ // Validate parameters
+ if(connectionParams == null || !connectionParams.IsValid())
+ {
+ return new ConnectResponse()
+ {
+ Messages = "Error: Invalid connection parameters provided."
+ };
+ }
+
ConnectionInfo connectionInfo;
if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) )
{
// TODO disconnect
}
- connectionInfo = new ConnectionInfo(this.connectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
+ connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
// try to connect
- connectionInfo.OpenConnection();
- // TODO: check that connection worked
+ var response = new ConnectResponse();
+ try
+ {
+ connectionInfo.OpenConnection();
+ }
+ catch(Exception ex)
+ {
+ response.Messages = ex.Message;
+ return response;
+ }
ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
@@ -171,10 +181,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
// return the connection result
- return new ConnectResponse()
- {
- ConnectionId = connectionInfo.ConnectionId.ToString()
- };
+ response.ConnectionId = connectionInfo.ConnectionId.ToString();
+ return response;
}
public void InitializeService(IProtocolEndpoint serviceHost)
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionMessagesExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionMessagesExtensions.cs
new file mode 100644
index 00000000..b9e73e09
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionMessagesExtensions.cs
@@ -0,0 +1,30 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ ///
+ /// Extension methods to ConnectParams
+ ///
+ public static class ConnectParamsExtensions
+ {
+ ///
+ /// Check that the fields in ConnectParams are all valid
+ ///
+ public static bool IsValid(this ConnectParams parameters)
+ {
+ return !(
+ String.IsNullOrEmpty(parameters.OwnerUri) ||
+ parameters.Connection == null ||
+ String.IsNullOrEmpty(parameters.Connection.DatabaseName) ||
+ String.IsNullOrEmpty(parameters.Connection.Password) ||
+ String.IsNullOrEmpty(parameters.Connection.ServerName) ||
+ String.IsNullOrEmpty(parameters.Connection.UserName)
+ );
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
index 5f62c37b..7ecc7106 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/AutoCompleteService.cs
@@ -265,10 +265,10 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
// If we have a connection but no cache, we don't care - assuming the OnConnect and OnDisconnect listeners
// behave well, there should be a cache for any actively connected document. This also helps skip documents
// that are not backed by a SQL connection
- ConnectionSummary connectionSummary;
+ ConnectionInfo connectionInfo;
IntellisenseCache cache;
- if (ConnectionService.Instance.TryFindConnection(textDocumentPosition.Uri, out connectionSummary)
- && caches.TryGetValue(connectionSummary, out cache))
+ if (ConnectionService.Instance.TryFindConnection(textDocumentPosition.Uri, out connectionInfo)
+ && caches.TryGetValue(connectionInfo.ConnectionDetails, out cache))
{
return cache.GetAutoCompleteItems(textDocumentPosition).ToArray();
}
@@ -278,3 +278,4 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
}
}
+
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs
index 95190752..d0c991f8 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs
@@ -19,6 +19,83 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
///
public class ConnectionServiceTests
{
+ ///
+ /// Verify that when connecting with invalid credentials, an error is thrown.
+ ///
+ [Fact]
+ public void ConnectingWithInvalidCredentialsYieldsErrorMessage()
+ {
+ var testConnectionDetails = TestObjects.GetTestConnectionDetails();
+ var invalidConnectionDetails = new ConnectionDetails();
+ invalidConnectionDetails.ServerName = testConnectionDetails.ServerName;
+ invalidConnectionDetails.DatabaseName = testConnectionDetails.DatabaseName;
+ invalidConnectionDetails.UserName = "invalidUsername"; // triggers exception when opening mock connection
+ invalidConnectionDetails.Password = "invalidPassword";
+
+ // Connect to test db with invalid credentials
+ var connectionResult =
+ TestObjects.GetTestConnectionService()
+ .Connect(new ConnectParams()
+ {
+ OwnerUri = "file://my/sample/file.sql",
+ Connection = invalidConnectionDetails
+ });
+
+ // check that an error was caught
+ Assert.NotNull(connectionResult.Messages);
+ Assert.NotEqual(String.Empty, connectionResult.Messages);
+ }
+
+ ///
+ /// Verify that when connecting with invalid parameters, an error is thrown.
+ ///
+ [Theory]
+ [InlineDataAttribute(null, "my-server", "test", "sa", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", null, "test", "sa", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "my-server", null, "sa", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "my-server", "test", null, "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "my-server", "test", "sa", null)]
+ [InlineDataAttribute("", "my-server", "test", "sa", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "", "test", "sa", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "my-server", "", "sa", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "my-server", "test", "", "123456")]
+ [InlineDataAttribute("file://my/sample/file.sql", "my-server", "test", "sa", "")]
+ public void ConnectingWithInvalidParametersYieldsErrorMessage(string ownerUri, string server, string database, string userName, string password)
+ {
+ // Connect with invalid parameters
+ var connectionResult =
+ TestObjects.GetTestConnectionService()
+ .Connect(new ConnectParams()
+ {
+ OwnerUri = ownerUri,
+ Connection = new ConnectionDetails() {
+ ServerName = server,
+ DatabaseName = database,
+ UserName = userName,
+ Password = password
+ }
+ });
+
+ // check that an error was caught
+ Assert.NotNull(connectionResult.Messages);
+ Assert.NotEqual(String.Empty, connectionResult.Messages);
+ }
+
+ ///
+ /// Verify that when connecting with a null parameters object, an error is thrown.
+ ///
+ [Fact]
+ public void ConnectingWithNullParametersObjectYieldsErrorMessage()
+ {
+ // Connect with null parameters
+ var connectionResult =
+ TestObjects.GetTestConnectionService()
+ .Connect(null);
+
+ // check that an error was caught
+ Assert.NotNull(connectionResult.Messages);
+ Assert.NotEqual(String.Empty, connectionResult.Messages);
+ }
///
/// Verify that the SQL parser correctly detects errors in text
@@ -64,43 +141,45 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
Assert.True(callbackInvoked);
}
- //[Fact]
- //public void TestConnectRequestRegistersOwner()
- //{
- // // Given a request to connect to a database
- // var service = new ConnectionService(new TestSqlConnectionFactory());
- // ConnectionDetails connectionDetails = TestObjects.GetTestConnectionDetails();
- // var connectParams = new ConnectParams()
- // {
- // OwnerUri = "file://path/to/my.sql",
- // Connection = connectionDetails
- // };
+ ///
+ /// Verify when a connection is created that the URI -> Connection mapping is created in the connection service.
+ ///
+ [Fact]
+ public void TestConnectRequestRegistersOwner()
+ {
+ // Given a request to connect to a database
+ var service = TestObjects.GetTestConnectionService();
+ var connectParams = TestObjects.GetTestConnectionParams();
- // var endpoint = new Mock();
- // Func, Task> connectRequestHandler = null;
- // endpoint.Setup(e => e.SetRequestHandler(ConnectionRequest.Type, It.IsAny, Task>>()))
- // .Callback, Task>>(handler => connectRequestHandler = handler);
+ //var endpoint = new Mock();
+ //Func, Task> connectRequestHandler = null;
+ //endpoint.Setup(e => e.SetRequestHandler(ConnectionRequest.Type, It.IsAny, Task>>()))
+ // .Callback, Task>>(handler => connectRequestHandler = handler);
- // // when I initialize the service
- // service.InitializeService(endpoint.Object);
+ // when I initialize the service
+ //service.InitializeService(endpoint.Object);
- // // then I expect the handler to be captured
- // Assert.NotNull(connectRequestHandler);
+ // then I expect the handler to be captured
+ //Assert.NotNull(connectRequestHandler);
- // // when I call the service
- // var requestContext = new Mock>();
+ // when I call the service
+ //var requestContext = new Mock>();
- // connectRequestHandler(connectParams, requestContext);
- // // then I should get a live connection
+ //connectRequestHandler(connectParams, requestContext.Object);
+ // then I should get a live connection
- // // and then I should have
- // // connect to a database instance
- // var connectionResult =
- // TestObjects.GetTestConnectionService()
- // .Connect(TestObjects.GetTestConnectionDetails());
+ // and then I should have
+ // connect to a database instance
+ var connectionResult = service.Connect(connectParams);
- // // verify that a valid connection id was returned
- // Assert.True(connectionResult.ConnectionId > 0);
- //}
+ // verify that a valid connection id was returned
+ Assert.NotNull(connectionResult.ConnectionId);
+ Assert.NotEqual(String.Empty, connectionResult.ConnectionId);
+ Assert.NotNull(new Guid(connectionResult.ConnectionId));
+
+ // verify that the (URI -> connection) mapping was created
+ ConnectionInfo info;
+ Assert.True(service.TryFindConnection(connectParams.OwnerUri, out info));
+ }
}
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs
index 669b830b..cda0ed5a 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs
@@ -41,13 +41,13 @@ namespace Microsoft.SqlTools.Test.Utility
#endif
}
- public static ConnectParams GetTestConnectionParams()
- {
+ public static ConnectParams GetTestConnectionParams()
+ {
return new ConnectParams()
{
OwnerUri = "file://some/file.sql",
Connection = GetTestConnectionDetails()
- };
+ };
}
///
@@ -327,7 +327,11 @@ namespace Microsoft.SqlTools.Test.Utility
public override void Open()
{
- // No Op
+ // No Op, unless credentials are bad
+ if(ConnectionString.Contains("invalidUsername"))
+ {
+ throw new Exception("Invalid credentials provided");
+ }
}
public override string ConnectionString { get; set; }