diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 86050aed..dd4d434e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -388,25 +388,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// creates a new connection. This cannot be used to create a default connection or to create a /// connection if a default connection does not exist. /// + /// A DB connection for the connection type requested public async Task GetOrOpenConnection(string ownerUri, string connectionType) { - if (string.IsNullOrEmpty(ownerUri) || string.IsNullOrEmpty(connectionType)) - { - return null; - } + Validate.IsNotNullOrEmptyString(nameof(ownerUri), ownerUri); + Validate.IsNotNullOrEmptyString(nameof(connectionType), connectionType); // Try to get the ConnectionInfo, if it exists - ConnectionInfo connectionInfo = ownerToConnectionMap[ownerUri]; - if (connectionInfo == null) + ConnectionInfo connectionInfo; + if (!ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo)) { - return null; + throw new ArgumentOutOfRangeException(SR.ConnectionServiceListDbErrorNotConnected(ownerUri)); } // Make sure a default connection exists DbConnection defaultConnection; if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection)) { - return null; + throw new InvalidOperationException(SR.ConnectionServiceDbErrorDefaultNotConnected(ownerUri)); } // Try to get the DbConnection @@ -416,7 +415,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // If the DbConnection does not exist and is not the default connection, create one. // We can't create the default (initial) connection here because we won't have a ConnectionDetails // if Connect() has not yet been called. - ConnectParams connectParams = new ConnectParams() + ConnectParams connectParams = new ConnectParams { OwnerUri = ownerUri, Connection = connectionInfo.ConnectionDetails, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs index 221bfe95..3fe5880f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs @@ -730,6 +730,11 @@ namespace Microsoft.SqlTools.ServiceLayer return Keys.GetString(Keys.ConnectionServiceListDbErrorNotConnected, uri); } + public static string ConnectionServiceDbErrorDefaultNotConnected(string uri) + { + return Keys.GetString(Keys.ConnectionServiceDbErrorDefaultNotConnected, uri); + } + public static string ConnectionServiceConnStringInvalidAuthType(string authType) { return Keys.GetString(Keys.ConnectionServiceConnStringInvalidAuthType, authType); @@ -802,6 +807,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string ConnectionServiceListDbErrorNotConnected = "ConnectionServiceListDbErrorNotConnected"; + public const string ConnectionServiceDbErrorDefaultNotConnected = "ConnectionServiceDbErrorDefaultNotConnected"; + + public const string ConnectionServiceConnStringInvalidAuthType = "ConnectionServiceConnStringInvalidAuthType"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx index d9af1add..3cdcb15a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx @@ -128,6 +128,11 @@ SpecifiedUri '{0}' does not have existing connection . + Parameters: 0 - uri (string) + + + Specified URI '{0}' does not have a default connection + . Parameters: 0 - uri (string) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings index 12108e6a..be49988a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings @@ -29,6 +29,8 @@ ConnectionServiceListDbErrorNullOwnerUri = OwnerUri cannot be null or empty ConnectionServiceListDbErrorNotConnected(string uri) = SpecifiedUri '{0}' does not have existing connection +ConnectionServiceDbErrorDefaultNotConnected(string uri) = Specified URI '{0}' does not have a default connection + ConnectionServiceConnStringInvalidAuthType(string authType) = Invalid value '{0}' for AuthenticationType. Valid values are 'Integrated' and 'SqlLogin'. ConnectionServiceConnStringInvalidIntent(string intent) = Invalid value '{0}' for ApplicationIntent. Valid values are 'ReadWrite' and 'ReadOnly'. diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs index 9703603f..15567b40 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs @@ -1139,5 +1139,53 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection await service.Connect(connectParamsQuery); Assert.Equal(2, connectionInfo.CountConnections); } + + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task GetOrOpenNullOwnerUri(string ownerUri) + { + // If: I have a connection service and I ask for a connection with an invalid ownerUri + // Then: An exception should be thrown + var service = TestObjects.GetTestConnectionService(); + await Assert.ThrowsAsync( + () => service.GetOrOpenConnection(ownerUri, ConnectionType.Default)); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task GetOrOpenNullConnectionType(string connType) + { + // If: I have a connection service and I ask for a connection with an invalid connectionType + // Then: An exception should be thrown + var service = TestObjects.GetTestConnectionService(); + await Assert.ThrowsAsync( + () => service.GetOrOpenConnection(TestObjects.ScriptUri, connType)); + } + + [Fact] + public async Task GetOrOpenNoConnection() + { + // If: I have a connection service and I ask for a connection for an unconnected uri + // Then: An exception should be thrown + var service = TestObjects.GetTestConnectionService(); + await Assert.ThrowsAsync( + () => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query)); + } + + [Fact] + public async Task GetOrOpenNoDefaultConnection() + { + // Setup: Create a connection service with an empty connection info obj + var service = TestObjects.GetTestConnectionService(); + var connInfo = new ConnectionInfo(null, null, null); + service.OwnerToConnectionMap[TestObjects.ScriptUri] = connInfo; + + // If: I ask for a connection on a connection that doesn't have a default connection + // Then: An exception should be thrown + await Assert.ThrowsAsync( + () => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query)); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index ba4228bf..53fe76ec 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -206,7 +206,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public static ConnectionInfo CreateTestConnectionInfo(TestResultSet[] data, bool throwOnRead) { - return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails); + // Create a connection info and add the default connection to it + ISqlConnectionFactory factory = CreateMockFactory(data, throwOnRead); + ConnectionInfo ci = new ConnectionInfo(factory, OwnerUri, StandardConnectionDetails); + ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null); + return ci; } public static ConnectionInfo CreateConnectedConnectionInfo(TestResultSet[] data, bool throwOnRead, string type = ConnectionType.Default)