Do not update the DB name when connecting to DB pool (#1186)

* Do not update the DB name when connecting to DB pool

* Fix typo in the connection service
This commit is contained in:
Karl Burtram
2021-04-15 14:48:02 -07:00
committed by GitHub
parent 6bafed10eb
commit 6fe715d2d8
2 changed files with 66 additions and 45 deletions

View File

@@ -423,7 +423,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// Update with the actual database name in connectionInfo and result // Update with the actual database name in connectionInfo and result
// Doing this here as we know the connection is open - expect to do this only on connecting // Doing this here as we know the connection is open - expect to do this only on connecting
connectionInfo.ConnectionDetails.DatabaseName = connection.Database; // Do not update the DB name if it is a DB Pool database name (e.g. "db@pool")
if (!ConnectionService.IsDbPool(connectionInfo.ConnectionDetails.DatabaseName))
{
connectionInfo.ConnectionDetails.DatabaseName = connection.Database;
}
if (!string.IsNullOrEmpty(connectionInfo.ConnectionDetails.ConnectionString)) if (!string.IsNullOrEmpty(connectionInfo.ConnectionDetails.ConnectionString))
{ {
// If the connection was set up with a connection string, use the connection string to get the details // If the connection was set up with a connection string, use the connection string to get the details
@@ -1612,6 +1617,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
} }
} }
} }
public static bool IsDbPool(string databaseName)
{
return databaseName != null ? databaseName.IndexOf('@') != -1 : false;
}
} }
public class AzureAccessToken : IRenewableToken public class AzureAccessToken : IRenewableToken

View File

@@ -21,7 +21,7 @@ using System.Linq;
using System.Reflection; using System.Reflection;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{ {
[TestFixture] [TestFixture]
@@ -297,11 +297,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Connection = connectionDetails Connection = connectionDetails
}); });
Assert.Multiple(() => Assert.Multiple(() =>
{ {
Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "ConnectionId"); Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "ConnectionId");
Assert.NotNull(connectionResult.ConnectionSummary, "ConnectionSummary"); Assert.NotNull(connectionResult.ConnectionSummary, "ConnectionSummary");
Assert.AreEqual(expectedDbName, connectionResult.ConnectionSummary.DatabaseName, "I expect connection to succeed and the Summary to include the correct DB name"); Assert.AreEqual(expectedDbName, connectionResult.ConnectionSummary.DatabaseName, "I expect connection to succeed and the Summary to include the correct DB name");
}); });
} }
@@ -405,10 +405,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.That(connectionResult.Messages, Is.Not.Null.Or.Empty, "check that an error was caught"); Assert.That(connectionResult.Messages, Is.Not.Null.Or.Empty, "check that an error was caught");
} }
static readonly object[] invalidParameters = static readonly object[] invalidParameters =
{ {
new object[] { "SqlLogin", null, "my-server", "test", "sa", "123456" }, new object[] { "SqlLogin", null, "my-server", "test", "sa", "123456" },
new object[] { "SqlLogin", "file://my/sample/file.sql", null, "test", "sa", "123456" }, new object[] { "SqlLogin", "file://my/sample/file.sql", null, "test", "sa", "123456" },
new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", null, "123456"}, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", null, "123456"},
new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", null}, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", null},
new object[] {"SqlLogin", "", "my-server", "test", "sa", "123456" }, new object[] {"SqlLogin", "", "my-server", "test", "sa", "123456" },
@@ -418,7 +418,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
new object[] {"Integrated", null, "my-server", "test", "sa", "123456"}, new object[] {"Integrated", null, "my-server", "test", "sa", "123456"},
new object[] {"Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456"}, new object[] {"Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456"},
new object[] {"Integrated", "", "my-server", "test", "sa", "123456"}, new object[] {"Integrated", "", "my-server", "test", "sa", "123456"},
new object[] {"Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456"} new object[] {"Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456"}
}; };
/// <summary> /// <summary>
/// Verify that when connecting with invalid parameters, an error is thrown. /// Verify that when connecting with invalid parameters, an error is thrown.
@@ -445,16 +445,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.That(connectionResult.Messages, Is.Not.Null.Or.Empty, "check that an error was caught"); Assert.That(connectionResult.Messages, Is.Not.Null.Or.Empty, "check that an error was caught");
} }
static readonly object[] noUserNameOrPassword = static readonly object[] noUserNameOrPassword =
{ {
new object[] {null, null}, new object[] {null, null},
new object[] {null, ""}, new object[] {null, ""},
new object[] {"", null}, new object[] {"", null},
new object[] {"", ""}, new object[] {"", ""},
new object[] {"sa", null}, new object[] {"sa", null},
new object[] {"sa", ""}, new object[] {"sa", ""},
new object[] {null, "12345678"}, new object[] {null, "12345678"},
new object[] {"", "12345678"}, new object[] {"", "12345678"},
}; };
/// <summary> /// <summary>
/// Verify that when using integrated authentication, the username and/or password can be empty. /// Verify that when using integrated authentication, the username and/or password can be empty.
@@ -496,8 +496,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
} }
private static readonly object[] optionalParameters = private static readonly object[] optionalParameters =
{ {
new object[] {"AuthenticationType", "Integrated", "Integrated Security" }, new object[] {"AuthenticationType", "Integrated", "Integrated Security" },
new object[] {"AuthenticationType", "SqlLogin", ""}, new object[] {"AuthenticationType", "SqlLogin", ""},
new object[] {"Encrypt", true, "Encrypt"}, new object[] {"Encrypt", true, "Encrypt"},
@@ -536,13 +536,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
new object[] {"MultipleActiveResultSets", false, "Multiple Active Result Sets"}, new object[] {"MultipleActiveResultSets", false, "Multiple Active Result Sets"},
new object[] {"MultipleActiveResultSets", true, "Multiple Active Result Sets"}, new object[] {"MultipleActiveResultSets", true, "Multiple Active Result Sets"},
new object[] {"PacketSize", 8192, "Packet Size"}, new object[] {"PacketSize", 8192, "Packet Size"},
new object[] {"TypeSystemVersion", "Latest", "Type System Version"}, new object[] {"TypeSystemVersion", "Latest", "Type System Version"},
}; };
/// <summary> /// <summary>
/// Verify that optional parameters can be built into a connection string for connecting. /// Verify that optional parameters can be built into a connection string for connecting.
/// </summary> /// </summary>
[Test, TestCaseSource(nameof(optionalParameters))] [Test, TestCaseSource(nameof(optionalParameters))]
public void ConnectingWithOptionalParametersBuildsConnectionString(string propertyName, object propertyValue, string connectionStringMarker) public void ConnectingWithOptionalParametersBuildsConnectionString(string propertyName, object propertyValue, string connectionStringMarker)
{ {
// Create a test connection details object and set the property to a specific value // Create a test connection details object and set the property to a specific value
@@ -556,15 +556,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string"); Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string");
} }
private static readonly object[] optionalEnclaveParameters = private static readonly object[] optionalEnclaveParameters =
{ {
new object[] {"EnclaveAttestationProtocol", "AAS", "Attestation Protocol=AAS"}, new object[] {"EnclaveAttestationProtocol", "AAS", "Attestation Protocol=AAS"},
new object[] {"EnclaveAttestationProtocol", "HGS", "Attestation Protocol=HGS"}, new object[] {"EnclaveAttestationProtocol", "HGS", "Attestation Protocol=HGS"},
new object[] {"EnclaveAttestationProtocol", "aas", "Attestation Protocol=AAS"}, new object[] {"EnclaveAttestationProtocol", "aas", "Attestation Protocol=AAS"},
new object[] {"EnclaveAttestationProtocol", "hgs", "Attestation Protocol=HGS"}, new object[] {"EnclaveAttestationProtocol", "hgs", "Attestation Protocol=HGS"},
new object[] {"EnclaveAttestationProtocol", "AaS", "Attestation Protocol=AAS"}, new object[] {"EnclaveAttestationProtocol", "AaS", "Attestation Protocol=AAS"},
new object[] {"EnclaveAttestationProtocol", "hGs", "Attestation Protocol=HGS"}, new object[] {"EnclaveAttestationProtocol", "hGs", "Attestation Protocol=HGS"},
new object[] {"EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave" }, new object[] {"EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave" },
}; };
/// <summary> /// <summary>
@@ -586,11 +586,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string"); Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string");
} }
private static readonly object[] invalidOptions = private static readonly object[] invalidOptions =
{ {
new object[] {"AuthenticationType", "NotAValidAuthType" }, new object[] {"AuthenticationType", "NotAValidAuthType" },
new object[] {"ColumnEncryptionSetting", "NotAValidColumnEncryptionSetting" }, new object[] {"ColumnEncryptionSetting", "NotAValidColumnEncryptionSetting" },
new object[] {"EnclaveAttestationProtocol", "NotAValidEnclaveAttestationProtocol" }, new object[] {"EnclaveAttestationProtocol", "NotAValidEnclaveAttestationProtocol" },
}; };
/// <summary> /// <summary>
/// Build connection string with an invalid property type /// Build connection string with an invalid property type
@@ -604,8 +604,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.Throws<ArgumentException>(() => ConnectionService.BuildConnectionString(details)); Assert.Throws<ArgumentException>(() => ConnectionService.BuildConnectionString(details));
} }
private static readonly Tuple<string,object>[][] optionCombos = private static readonly Tuple<string,object>[][] optionCombos =
{ {
new [] new []
{ {
Tuple.Create<string, object>("ColumnEncryptionSetting", null), Tuple.Create<string, object>("ColumnEncryptionSetting", null),
@@ -623,7 +623,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Tuple.Create<string, object>("ColumnEncryptionSetting", ""), Tuple.Create<string, object>("ColumnEncryptionSetting", ""),
Tuple.Create<string, object>("EnclaveAttestationProtocol", "AAS"), Tuple.Create<string, object>("EnclaveAttestationProtocol", "AAS"),
Tuple.Create<string, object>("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") Tuple.Create<string, object>("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave")
} }
}; };
/// <summary> /// <summary>
@@ -633,14 +633,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
public void ConnStrWithInvalidOptions() public void ConnStrWithInvalidOptions()
{ {
ConnectionDetails details = TestObjects.GetTestConnectionDetails(); ConnectionDetails details = TestObjects.GetTestConnectionDetails();
foreach (var options in optionCombos) foreach (var options in optionCombos)
{ {
options.ToList().ForEach(tuple => options.ToList().ForEach(tuple =>
{ {
PropertyInfo info = details.GetType().GetProperty(tuple.Item1); PropertyInfo info = details.GetType().GetProperty(tuple.Item1);
info.SetValue(details, tuple.Item2); info.SetValue(details, tuple.Item2);
}); });
Assert.Throws<ArgumentException>(() => ConnectionService.BuildConnectionString(details)); Assert.Throws<ArgumentException>(() => ConnectionService.BuildConnectionString(details));
} }
} }
@@ -1663,5 +1663,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// Then the connection factory got called with details including an account token // Then the connection factory got called with details including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once()); mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
} }
/// <summary>
/// Test is IsDbPool method correctly works for various database names
/// </summary>
[Test]
public void CheckIsDbPool()
{
Assert.IsTrue(ConnectionService.IsDbPool("db@pool"));
Assert.IsFalse(ConnectionService.IsDbPool("db"));
Assert.IsFalse(ConnectionService.IsDbPool(null));
}
} }
} }