diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 2ffa682e..a1b5cd82 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -423,7 +423,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Update with the actual database name in connectionInfo and result // Doing this here as we know the connection is open - expect to do this only on connecting - connectionInfo.ConnectionDetails.DatabaseName = connection.Database; + // Do not update the DB name if it is a DB Pool database name (e.g. "db@pool") + if (!ConnectionService.IsDbPool(connectionInfo.ConnectionDetails.DatabaseName)) + { + connectionInfo.ConnectionDetails.DatabaseName = connection.Database; + } + if (!string.IsNullOrEmpty(connectionInfo.ConnectionDetails.ConnectionString)) { // If the connection was set up with a connection string, use the connection string to get the details @@ -1612,6 +1617,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } } + + public static bool IsDbPool(string databaseName) + { + return databaseName != null ? databaseName.IndexOf('@') != -1 : false; + } } public class AzureAccessToken : IRenewableToken diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 185de1fe..0570a1c6 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -21,7 +21,7 @@ using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; - + namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { [TestFixture] @@ -297,11 +297,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Connection = connectionDetails }); - Assert.Multiple(() => - { - Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "ConnectionId"); - Assert.NotNull(connectionResult.ConnectionSummary, "ConnectionSummary"); - Assert.AreEqual(expectedDbName, connectionResult.ConnectionSummary.DatabaseName, "I expect connection to succeed and the Summary to include the correct DB name"); + Assert.Multiple(() => + { + Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "ConnectionId"); + Assert.NotNull(connectionResult.ConnectionSummary, "ConnectionSummary"); + Assert.AreEqual(expectedDbName, connectionResult.ConnectionSummary.DatabaseName, "I expect connection to succeed and the Summary to include the correct DB name"); }); } @@ -405,10 +405,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.That(connectionResult.Messages, Is.Not.Null.Or.Empty, "check that an error was caught"); } - static readonly object[] invalidParameters = - { - new object[] { "SqlLogin", null, "my-server", "test", "sa", "123456" }, - new object[] { "SqlLogin", "file://my/sample/file.sql", null, "test", "sa", "123456" }, + static readonly object[] invalidParameters = + { + new object[] { "SqlLogin", null, "my-server", "test", "sa", "123456" }, + new object[] { "SqlLogin", "file://my/sample/file.sql", null, "test", "sa", "123456" }, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", null, "123456"}, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", null}, new object[] {"SqlLogin", "", "my-server", "test", "sa", "123456" }, @@ -418,7 +418,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection new object[] {"Integrated", null, "my-server", "test", "sa", "123456"}, new object[] {"Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456"}, new object[] {"Integrated", "", "my-server", "test", "sa", "123456"}, - new object[] {"Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456"} + new object[] {"Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456"} }; /// /// Verify that when connecting with invalid parameters, an error is thrown. @@ -445,16 +445,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.That(connectionResult.Messages, Is.Not.Null.Or.Empty, "check that an error was caught"); } - static readonly object[] noUserNameOrPassword = - { - new object[] {null, null}, - new object[] {null, ""}, - new object[] {"", null}, - new object[] {"", ""}, - new object[] {"sa", null}, - new object[] {"sa", ""}, - new object[] {null, "12345678"}, - new object[] {"", "12345678"}, + static readonly object[] noUserNameOrPassword = + { + new object[] {null, null}, + new object[] {null, ""}, + new object[] {"", null}, + new object[] {"", ""}, + new object[] {"sa", null}, + new object[] {"sa", ""}, + new object[] {null, "12345678"}, + new object[] {"", "12345678"}, }; /// /// Verify that when using integrated authentication, the username and/or password can be empty. @@ -496,8 +496,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection } - private static readonly object[] optionalParameters = - { + private static readonly object[] optionalParameters = + { new object[] {"AuthenticationType", "Integrated", "Integrated Security" }, new object[] {"AuthenticationType", "SqlLogin", ""}, new object[] {"Encrypt", true, "Encrypt"}, @@ -536,13 +536,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection new object[] {"MultipleActiveResultSets", false, "Multiple Active Result Sets"}, new object[] {"MultipleActiveResultSets", true, "Multiple Active Result Sets"}, new object[] {"PacketSize", 8192, "Packet Size"}, - new object[] {"TypeSystemVersion", "Latest", "Type System Version"}, - }; - + new object[] {"TypeSystemVersion", "Latest", "Type System Version"}, + }; + /// /// Verify that optional parameters can be built into a connection string for connecting. /// - [Test, TestCaseSource(nameof(optionalParameters))] + [Test, TestCaseSource(nameof(optionalParameters))] public void ConnectingWithOptionalParametersBuildsConnectionString(string propertyName, object propertyValue, string connectionStringMarker) { // Create a test connection details object and set the property to a specific value @@ -556,15 +556,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string"); } - private static readonly object[] optionalEnclaveParameters = - { + private static readonly object[] optionalEnclaveParameters = + { new object[] {"EnclaveAttestationProtocol", "AAS", "Attestation Protocol=AAS"}, new object[] {"EnclaveAttestationProtocol", "HGS", "Attestation Protocol=HGS"}, new object[] {"EnclaveAttestationProtocol", "aas", "Attestation Protocol=AAS"}, new object[] {"EnclaveAttestationProtocol", "hgs", "Attestation Protocol=HGS"}, new object[] {"EnclaveAttestationProtocol", "AaS", "Attestation Protocol=AAS"}, new object[] {"EnclaveAttestationProtocol", "hGs", "Attestation Protocol=HGS"}, - new object[] {"EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave" }, + new object[] {"EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave" }, }; /// @@ -586,11 +586,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string"); } - private static readonly object[] invalidOptions = - { + private static readonly object[] invalidOptions = + { new object[] {"AuthenticationType", "NotAValidAuthType" }, new object[] {"ColumnEncryptionSetting", "NotAValidColumnEncryptionSetting" }, - new object[] {"EnclaveAttestationProtocol", "NotAValidEnclaveAttestationProtocol" }, + new object[] {"EnclaveAttestationProtocol", "NotAValidEnclaveAttestationProtocol" }, }; /// /// Build connection string with an invalid property type @@ -604,8 +604,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.Throws(() => ConnectionService.BuildConnectionString(details)); } - private static readonly Tuple[][] optionCombos = - { + private static readonly Tuple[][] optionCombos = + { new [] { Tuple.Create("ColumnEncryptionSetting", null), @@ -623,7 +623,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Tuple.Create("ColumnEncryptionSetting", ""), Tuple.Create("EnclaveAttestationProtocol", "AAS"), Tuple.Create("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") - } + } }; /// @@ -633,14 +633,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection public void ConnStrWithInvalidOptions() { ConnectionDetails details = TestObjects.GetTestConnectionDetails(); - foreach (var options in optionCombos) - { - options.ToList().ForEach(tuple => - { - PropertyInfo info = details.GetType().GetProperty(tuple.Item1); - info.SetValue(details, tuple.Item2); - }); - Assert.Throws(() => ConnectionService.BuildConnectionString(details)); + foreach (var options in optionCombos) + { + options.ToList().ForEach(tuple => + { + PropertyInfo info = details.GetType().GetProperty(tuple.Item1); + info.SetValue(details, tuple.Item2); + }); + Assert.Throws(() => ConnectionService.BuildConnectionString(details)); } } @@ -1663,5 +1663,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // Then the connection factory got called with details including an account token mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny(), It.Is(accountToken => accountToken == azureAccountToken)), Times.Once()); } + + /// + /// Test is IsDbPool method correctly works for various database names + /// + [Test] + public void CheckIsDbPool() + { + Assert.IsTrue(ConnectionService.IsDbPool("db@pool")); + Assert.IsFalse(ConnectionService.IsDbPool("db")); + Assert.IsFalse(ConnectionService.IsDbPool(null)); + } } }