// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // #nullable disable using Microsoft.Data.SqlClient; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; using static Microsoft.SqlTools.Utility.SqlConstants; using Moq; using Moq.Protected; using NUnit.Framework; using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { [TestFixture] /// /// Tests for the ServiceHost Connection Service tests /// public class ConnectionServiceTests { /// /// Creates a mock db command that returns a predefined result set /// public static DbCommand CreateTestCommand(TestResultSet[] data) { var commandMock = new Mock { CallBase = true }; var commandMockSetup = commandMock.Protected() .Setup("ExecuteDbDataReader", It.IsAny()); commandMockSetup.Returns(() => new TestDbDataReader(data, false)); return commandMock.Object; } /// /// Creates a mock db connection that returns predefined data when queried for a result set /// public DbConnection CreateMockDbConnection(TestResultSet[] data) { var connectionMock = new Mock { CallBase = true }; connectionMock.Protected() .Setup("CreateDbCommand") .Returns(CreateTestCommand(data)); return connectionMock.Object; } [Test] public void CanCancelConnectRequest() { const string testFile = "file:///my/test/file.sql"; // Given a connection that times out and responds to cancellation var mockConnection = new Mock { CallBase = true }; CancellationToken token = new CancellationToken(); bool ready = false; mockConnection.Setup(x => x.OpenAsync(It.IsAny())) .Callback(t => { // Pass the token to the return handler and signal the main thread to cancel token = t; ready = true; }) .Returns(() => { if (TestUtils.WaitFor(() => token.IsCancellationRequested)) { throw new OperationCanceledException(); } return Task.FromResult(true); }); var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(mockConnection.Object); var connectionService = new ConnectionService(mockFactory.Object); // Connect the connection asynchronously in a background thread var connectionDetails = TestObjects.GetTestConnectionDetails(); var connectTask = Task.Run(async () => await connectionService .Connect(new ConnectParams { OwnerUri = testFile, Connection = connectionDetails })); // Wait for the connection to call OpenAsync() Assert.True(TestUtils.WaitFor(() => ready)); // Send a cancellation request var cancelResult = connectionService .CancelConnect(new CancelConnectParams() { OwnerUri = testFile }); // Wait for the connection task to finish connectTask.Wait(); // Verify that the connection was cancelled (no connection was created) Assert.Null(connectTask.Result.ConnectionId); // Verify that the cancel succeeded Assert.True(cancelResult); } [Test] public async Task CanCancelConnectRequestByConnecting() { const string testFile = "file:///my/test/file.sql"; // Given a connection that times out and responds to cancellation var mockConnection = new Mock { CallBase = true }; CancellationToken token = new CancellationToken(); bool ready = false; mockConnection.Setup(x => x.OpenAsync(It.IsAny())) .Callback(t => { // Pass the token to the return handler and signal the main thread to cancel token = t; ready = true; }) .Returns(() => { if (TestUtils.WaitFor(() => token.IsCancellationRequested)) { throw new OperationCanceledException(); } return Task.FromResult(true); }); // Given a second connection that succeeds var mockConnection2 = new Mock { CallBase = true }; mockConnection2.Setup(x => x.OpenAsync(It.IsAny())) .Returns(() => Task.Run(() => { })); var mockFactory = new Mock(); mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(mockConnection.Object) .Returns(mockConnection2.Object); var connectionService = new ConnectionService(mockFactory.Object); // Connect the first connection asynchronously in a background thread var connectionDetails = TestObjects.GetTestConnectionDetails(); var connectTask = Task.Run(async () => await connectionService .Connect(new ConnectParams() { OwnerUri = testFile, Connection = connectionDetails })); // Wait for the connection to call OpenAsync() Assert.True(TestUtils.WaitFor(() => ready)); // Send a cancellation by trying to connect again var connectResult = await connectionService .Connect(new ConnectParams() { OwnerUri = testFile, Connection = connectionDetails }); // Wait for the first connection task to finish connectTask.Wait(); // Verify that the first connection was cancelled (no connection was created) Assert.Null(connectTask.Result.ConnectionId); // Verify that the second connection succeeded Assert.That(connectResult.ConnectionId, Is.Not.Empty); } [Test] public void CanCancelConnectRequestByDisconnecting() { const string testFile = "file:///my/test/file.sql"; // Given a connection that times out and responds to cancellation var mockConnection = new Mock { CallBase = true }; CancellationToken token = new CancellationToken(); bool ready = false; mockConnection.Setup(x => x.OpenAsync(It.IsAny())) .Callback(t => { // Pass the token to the return handler and signal the main thread to cancel token = t; ready = true; }) .Returns(() => { if (TestUtils.WaitFor(() => token.IsCancellationRequested)) { throw new OperationCanceledException(); } return Task.FromResult(true); }); var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(mockConnection.Object); var connectionService = new ConnectionService(mockFactory.Object); // Connect the first connection asynchronously in a background thread var connectionDetails = TestObjects.GetTestConnectionDetails(); var connectTask = Task.Run(async () => await connectionService .Connect(new ConnectParams { OwnerUri = testFile, Connection = connectionDetails })); // Wait for the connection to call OpenAsync() Assert.True(TestUtils.WaitFor(() => ready)); // Send a cancellation by trying to disconnect var disconnectResult = connectionService .Disconnect(new DisconnectParams { OwnerUri = testFile }); // Wait for the first connection task to finish connectTask.Wait(); // Verify that the first connection was cancelled (no connection was created) Assert.Null(connectTask.Result.ConnectionId); // Verify that the disconnect failed (since it caused a cancellation) Assert.False(disconnectResult); } /// /// Verify that we cannot connect to the default database when no username /// is provided and the authentication type is basic auth. /// [Test] public async Task CantConnectWithoutUsername() { // Connect var connectionDetails = TestObjects.GetTestConnectionDetails(); connectionDetails.UserName = ""; var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", Connection = connectionDetails }); Assert.That(connectionResult.ErrorMessage, Is.EqualTo(SR.ConnectionParamsValidateNullSqlAuth("UserName"))); } /// /// Verify that we can connect to the default database when no database name is /// provided as a parameter. /// [Test] public async Task CanConnectWithEmptyDatabaseName([Values(null, "")] string databaseName) { // Connect var connectionDetails = TestObjects.GetTestConnectionDetails(); connectionDetails.DatabaseName = databaseName; var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", Connection = connectionDetails }); Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "check that a connection was created"); } /// /// Verify that we can connect to the default database when no database name is /// provided as a parameter. /// [Test] public async Task ConnectToDefaultDatabaseRespondsWithActualDbName([Values("master", "nonMasterDb")] string expectedDbName) { // Given connecting with empty database name will return the expected DB name var connectionMock = new Mock { CallBase = true }; connectionMock.Setup(c => c.Database).Returns(expectedDbName); var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(connectionMock.Object); var connectionService = new ConnectionService(mockFactory.Object); // When I connect with an empty DB name var connectionDetails = TestObjects.GetTestConnectionDetails(); connectionDetails.DatabaseName = string.Empty; var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", Connection = connectionDetails }); Assert.Multiple(() => { Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "ConnectionId"); Assert.NotNull(connectionResult.ConnectionSummary, "ConnectionSummary"); Assert.AreEqual(expectedDbName, connectionResult.ConnectionSummary.DatabaseName, "I expect connection to succeed and the Summary to include the correct DB name"); }); } /// /// Verify that when a connection is started for a URI with an already existing /// connection, we disconnect first before connecting. /// [Test] public async Task ConnectingWhenConnectionExistCausesDisconnectThenConnect() { bool callbackInvoked = false; string ownerUri = "file://my/sample/file.sql"; const string masterDbName = "master"; const string otherDbName = "other"; // Given a connection that returns the database name var dummySqlConnection = new TestSqlConnection(null); var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns((string connString, string azureAccountToken, SqlRetryLogicBaseProvider retryProvider) => { dummySqlConnection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); // Database name is respected. Follow heuristic where empty DB name really means Master var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; dummySqlConnection.SetDatabase(dbName); return dummySqlConnection; }); var connectionService = new ConnectionService(mockFactory.Object); // register disconnect callback connectionService.RegisterOnDisconnectTask( (result, uri) => { callbackInvoked = true; Assert.True(uri.Equals(ownerUri)); return Task.FromResult(true); } ); // When I connect to default var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); // Then I expect to be connected to master Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); // And when I then connect to another DB var updatedConnectionDetails = TestObjects.GetTestConnectionDetails(); updatedConnectionDetails.DatabaseName = otherDbName; connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = updatedConnectionDetails }); // Then I expect to be disconnected from master, and connected to the new DB // verify that the event was fired (we disconnected first before connecting) Assert.True(callbackInvoked); // verify that we connected again Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); Assert.AreEqual(otherDbName, connectionResult.ConnectionSummary.DatabaseName); } /// /// Verify that when connecting with invalid credentials, an error is thrown. /// [Test] public async Task ConnectingWithInvalidCredentialsYieldsErrorMessage() { var testConnectionDetails = TestObjects.GetTestConnectionDetails(); var invalidConnectionDetails = new ConnectionDetails { ServerName = testConnectionDetails.ServerName, DatabaseName = testConnectionDetails.DatabaseName, UserName = "invalidUsername", Password = Guid.NewGuid().ToString() }; // triggers exception when opening mock connection // Connect to test db with invalid credentials var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams { OwnerUri = "file://my/sample/file.sql", Connection = invalidConnectionDetails }); Assert.That(connectionResult.ErrorMessage, Is.Not.Null.Or.Empty, "check that an error was caught"); } static readonly object[] invalidParameters = { new object[] { "SqlLogin", null, "my-server", "test", "sa", "123456" }, new object[] { "SqlLogin", "file://my/sample/file.sql", null, "test", "sa", "123456" }, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", null, "123456"}, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", null}, new object[] {"SqlLogin", "", "my-server", "test", "sa", "123456" }, new object[] {"SqlLogin", "file://my/sample/file.sql", "", "test", "sa", "123456"}, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", "", "123456"}, new object[] {"SqlLogin", "file://my/sample/file.sql", "my-server", "test", "sa", ""}, new object[] {"Integrated", null, "my-server", "test", "sa", "123456"}, new object[] {"Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456"}, new object[] {"Integrated", "", "my-server", "test", "sa", "123456"}, new object[] {"Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456"} }; /// /// Verify that when connecting with invalid parameters, an error is thrown. /// [Test, TestCaseSource(nameof(invalidParameters))] public async Task ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password) { // Connect with invalid parameters var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = new ConnectionDetails() { ServerName = server, DatabaseName = database, UserName = userName, Password = password, AuthenticationType = authType } }); Assert.That(connectionResult.ErrorMessage, Is.Not.Null.Or.Empty, "check that an error was caught"); } static readonly object[] noPassword = { new object[] {"sa", null}, new object[] {"sa", ""}, }; /// /// Verify that when using sql logins, the password can be empty. /// [Test, TestCaseSource(nameof(noPassword))] public void ConnectingWithNoPasswordWorksForSqlLogin(string userName, string password) { // Connect var connectionResult = ConnectionService.CreateConnectionStringBuilder(new ConnectionDetails() { ServerName = "my-server", DatabaseName = "test", UserName = userName, Password = password, AuthenticationType = SqlLogin }); Assert.That(connectionResult, Is.Not.Null.Or.Empty, "check that the connection was successful"); } static readonly object[] noUserNameOrPassword = { new object[] {null, null}, new object[] {null, ""}, new object[] {"", null}, new object[] {"", ""}, new object[] {"sa", null}, new object[] {"sa", ""}, new object[] {null, "12345678"}, new object[] {"", "12345678"}, }; /// /// Verify that when using integrated authentication, the username and/or password can be empty. /// [Test, TestCaseSource(nameof(noUserNameOrPassword))] public async Task ConnectingWithNoUsernameOrPasswordWorksForIntegratedAuth(string userName, string password) { // Connect var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", Connection = new ConnectionDetails() { ServerName = "my-server", DatabaseName = "test", UserName = userName, Password = password, AuthenticationType = Integrated } }); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); } /// /// Verify that username is required when using Active Directory Interactive authentication. /// Both AzureMFA and ActiveDirectoryInteractive should work same way, when SqlAuthenticationProvider is enabled. /// [TestCase(null, AzureMFA)] [TestCase(null, ActiveDirectoryInteractive)] public async Task ConnectingWitNoUsernameFailsForAADInteractiveAuth(string userName, string authMode) { // This is an exception scenario test, therefore using ConnectionService instance instead of TestConnectionService directly. ConnectionService.Instance.EnableSqlAuthenticationProvider = true; // Connect var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", Connection = new ConnectionDetails() { ServerName = "my-server", DatabaseName = "test", UserName = userName, AuthenticationType = authMode } }); Assert.That(connectionResult.ConnectionId, Is.Null.Or.Empty, "Connection should not be successful"); } /// /// Verify that password is ignored when using Microsoft Entra Interactive authentication. /// [TestCase("user", "anything", AzureMFA)] [TestCase("user", "anything", ActiveDirectoryInteractive)] public async Task ConnectingWitPasswordIsIgnoredForAADInteractiveAuth(string username, string password, string authMode) { // Connect var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", Connection = new ConnectionDetails() { ServerName = "my-server", DatabaseName = "test", UserName = username, Password = password, AuthenticationType = authMode } }); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "Connection should be successful"); } /// /// Verify that when connecting with a null parameters object, an error is thrown. /// [Test] public async Task ConnectingWithNullParametersObjectYieldsErrorMessage() { // Connect with null parameters var connectionResult = await TestObjects.GetTestConnectionService() .Connect(null); Assert.That(connectionResult.ErrorMessage, Is.Not.Null.Or.Empty, "check that an error was caught"); } private static readonly object[] optionalParameters = { new object[] {"AuthenticationType", "Integrated", "Integrated Security" }, new object[] {"AuthenticationType", "SqlLogin", ""}, new object[] {"Encrypt", "Mandatory", "Encrypt"}, new object[] {"Encrypt", "Optional", "Encrypt"}, new object[] {"Encrypt", "Strict", "Encrypt"}, new object[] {"ColumnEncryptionSetting", "Enabled", "Column Encryption Setting=Enabled"}, new object[] {"ColumnEncryptionSetting", "Disabled", "Column Encryption Setting=Disabled"}, new object[] {"ColumnEncryptionSetting", "enabled", "Column Encryption Setting=Enabled"}, new object[] {"ColumnEncryptionSetting", "disabled", "Column Encryption Setting=Disabled"}, new object[] {"ColumnEncryptionSetting", "ENABLED", "Column Encryption Setting=Enabled"}, new object[] {"ColumnEncryptionSetting", "DISABLED", "Column Encryption Setting=Disabled"}, new object[] {"ColumnEncryptionSetting", "eNaBlEd", "Column Encryption Setting=Enabled"}, new object[] {"ColumnEncryptionSetting", "DiSaBlEd", "Column Encryption Setting=Disabled"}, new object[] {"TrustServerCertificate", true, "Trust Server Certificate"}, new object[] {"TrustServerCertificate", false, "Trust Server Certificate"}, new object[] {"HostNameInCertificate", "hostname", "Host Name In Certificate"}, new object[] {"PersistSecurityInfo", true, "Persist Security Info"}, new object[] {"PersistSecurityInfo", false, "Persist Security Info"}, new object[] {"ConnectTimeout", 15, "Connect Timeout"}, new object[] {"CommandTimeout", 30, "Command Timeout"}, new object[] {"ConnectRetryCount", 1, "Connect Retry Count"}, new object[] {"ConnectRetryInterval", 10, "Connect Retry Interval"}, new object[] {"ApplicationName", "vscode-mssql", "Application Name"}, new object[] {"WorkstationId", "mycomputer", "Workstation ID"}, new object[] {"ApplicationIntent", "ReadWrite", "Application Intent"}, new object[] {"ApplicationIntent", "ReadOnly", "Application Intent"}, new object[] {"CurrentLanguage", "test", "Current Language"}, new object[] {"Pooling", false, "Pooling"}, new object[] {"Pooling", true, "Pooling"}, new object[] {"MaxPoolSize", 100, "Max Pool Size"}, new object[] {"MinPoolSize", 0, "Min Pool Size"}, new object[] {"LoadBalanceTimeout", 0, "Load Balance Timeout"}, new object[] {"Replication", true, "Replication"}, new object[] {"Replication", false, "Replication"}, new object[] {"AttachDbFilename", "myfile", "AttachDbFilename"}, new object[] {"FailoverPartner", "partner", "Failover Partner"}, new object[] {"MultiSubnetFailover", true, "Multi Subnet Failover"}, new object[] {"MultiSubnetFailover", false, "Multi Subnet Failover"}, new object[] {"MultipleActiveResultSets", false, "Multiple Active Result Sets"}, new object[] {"MultipleActiveResultSets", true, "Multiple Active Result Sets"}, new object[] {"PacketSize", 8192, "Packet Size"}, new object[] {"TypeSystemVersion", "Latest", "Type System Version"}, }; /// /// Verify that optional parameters can be built into a connection string for connecting. /// [Test, TestCaseSource(nameof(optionalParameters))] public void ConnectingWithOptionalParametersBuildsConnectionString(string propertyName, object propertyValue, string connectionStringMarker) { // Create a test connection details object and set the property to a specific value ConnectionDetails details = TestObjects.GetTestConnectionDetails(); PropertyInfo info = details.GetType().GetProperty(propertyName); info.SetValue(details, propertyValue); // Test that a connection string can be created without exceptions string connectionString = ConnectionService.BuildConnectionString(details); Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string"); } private static readonly object[] optionalEnclaveParameters = { new object[] {"AAS", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave;Attestation Protocol=AAS"}, new object[] {"HGS", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave;Attestation Protocol=HGS"}, new object[] {"aas", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave;Attestation Protocol=AAS"}, new object[] {"hgs", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave;Attestation Protocol=HGS"}, new object[] {"AaS", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave;Attestation Protocol=AAS"}, new object[] {"hGs", "https://attestation.us.attest.azure.net/attest/SgxEnclave", "Enclave Attestation Url=https://attestation.us.attest.azure.net/attest/SgxEnclave;Attestation Protocol=HGS"}, new object[] {"NONE", null, "Attestation Protocol=None"}, new object[] {"None", null, "Attestation Protocol=None" }, }; /// /// Verify that optional parameters which require ColumnEncryptionSetting to be enabled /// can be built into a connection string for connecting. /// [Test, TestCaseSource(nameof(optionalEnclaveParameters))] public void ConnectingWithOptionalEnclaveParametersBuildsConnectionString(string attestationProtocol, string attestationUrl, string connectionStringMarker) { // Create a test connection details object ConnectionDetails details = TestObjects.GetTestConnectionDetails(); //Enable Secure Enclaves details.ColumnEncryptionSetting = "Enabled"; details.SecureEnclaves = "Enabled"; // Set Attestation Protocol details.GetType() .GetProperty("EnclaveAttestationProtocol") .SetValue(details, attestationProtocol); // Set Attestation URL details.GetType() .GetProperty("EnclaveAttestationUrl") .SetValue(details, attestationUrl); // Test that a connection string can be created without exceptions with provided combinations. string connectionString = ConnectionService.BuildConnectionString(details); Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameters are in the connection string"); } private static readonly object[] invalidOptions = { new object[] {"AuthenticationType", "NotAValidAuthType" }, new object[] {"ColumnEncryptionSetting", "NotAValidColumnEncryptionSetting" }, new object[] {"EnclaveAttestationProtocol", "NotAValidEnclaveAttestationProtocol" }, new object[] {"EnclaveAttestationProtocol", "AAS" }, // Without Attestation Url new object[] {"EnclaveAttestationProtocol", "hgs" }, // Without Attestation Url new object[] { "EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave" }, // Without Attestation Protocol }; /// /// Build connection string with an invalid property type /// [Test, TestCaseSource(nameof(invalidOptions))] public void BuildConnectionStringWithInvalidOptions(string propertyName, object propertyValue) { ConnectionDetails details = TestObjects.GetTestConnectionDetails(); PropertyInfo info = details.GetType().GetProperty(propertyName); info.SetValue(details, propertyValue); Assert.Throws(() => ConnectionService.BuildConnectionString(details)); } private static readonly Tuple[][] optionCombos = { new [] { Tuple.Create("ColumnEncryptionSetting", null), Tuple.Create("EnclaveAttestationProtocol", "AAS"), Tuple.Create("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") }, new [] { Tuple.Create("ColumnEncryptionSetting", "Enabled"), Tuple.Create("SecureEnclaves", null), Tuple.Create("EnclaveAttestationProtocol", "AAS"), Tuple.Create("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") }, new [] { Tuple.Create("ColumnEncryptionSetting", "Disabled"), Tuple.Create("EnclaveAttestationProtocol", "AAS"), Tuple.Create("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") }, new [] { Tuple.Create("ColumnEncryptionSetting", "Enabled"), Tuple.Create("SecureEnclaves", "Disabled"), Tuple.Create("EnclaveAttestationProtocol", "AAS"), Tuple.Create("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") }, new [] { Tuple.Create("ColumnEncryptionSetting", ""), Tuple.Create("EnclaveAttestationProtocol", "AAS"), Tuple.Create("EnclaveAttestationUrl", "https://attestation.us.attest.azure.net/attest/SgxEnclave") } }; private static readonly object[] EncryptionCombinations = { new object[] { SqlConnectionEncryptOption.Optional, SqlConnectionEncryptOption.Optional }, new object[] { SqlConnectionEncryptOption.Mandatory, SqlConnectionEncryptOption.Mandatory }, new object[] { SqlConnectionEncryptOption.Strict, SqlConnectionEncryptOption.Strict }, }; /// /// Verify that Strict Encryption parameters can be built into a connection string for connecting. /// [Test, TestCaseSource(nameof(EncryptionCombinations))] public void ConnectingWithStrictEncryptionBuildsConnectionString(SqlConnectionEncryptOption encryptValue, SqlConnectionEncryptOption expected) { // Create a test connection details object and set the property to a specific value ConnectionDetails details = TestObjects.GetTestConnectionDetails(); details.Encrypt = encryptValue.ToString(); // Test that a connection string can be created without exceptions string connectionString = ConnectionService.BuildConnectionString(details); Assert.That(connectionString, Contains.Substring("Encrypt=" + expected.ToString()), "Encrypt not as expected."); } /// /// Verify that Strict Encryption parameters can be built into a connection string for connecting. /// [Test] [TestCase(true, "True")] [TestCase(false, "False")] public void ConnectingWithBoolEncryptBuildsConnectionString(bool encryptValue, string expected) { // Create a test connection details object and set the property to a specific value ConnectionDetails details = TestObjects.GetTestConnectionDetails(); details.Options["encrypt"] = encryptValue; // Test that a connection string can be created without exceptions string connectionString = ConnectionService.BuildConnectionString(details); Assert.That(connectionString, Contains.Substring("Encrypt=" + expected), "Encrypt not as expected."); } /// /// Build connection string with an invalid property combinations /// [Test] public void ConnStrWithInvalidOptions() { ConnectionDetails details = TestObjects.GetTestConnectionDetails(); foreach (var options in optionCombos) { options.ToList().ForEach(tuple => { PropertyInfo info = details.GetType().GetProperty(tuple.Item1); info.SetValue(details, tuple.Item2); }); Assert.Throws(() => ConnectionService.BuildConnectionString(details)); } } /// /// Verify that a connection changed event is fired when the database context changes. /// [Test] public async Task ConnectionChangedEventIsFiredWhenDatabaseContextChanges() { var serviceHostMock = new Mock(); var connectionService = TestObjects.GetTestConnectionService(); connectionService.ServiceHost = serviceHostMock.Object; // Set up an initial connection const string ownerUri = "file://my/sample/file.sql"; var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); Assert.That(connectionResult.ConnectionId, Is.Not.Empty, "verify that a valid connection id was returned"); ConnectionInfo info; Assert.True(connectionService.TryFindConnection(ownerUri, out info)); // Tell the connection manager that the database change occurred connectionService.ChangeConnectionDatabaseContext(ownerUri, "myOtherDb"); // Verify that the connection changed event was fired serviceHostMock.Verify(x => x.SendEvent(ConnectionChangedNotification.Type, It.IsAny()), Times.Once()); } /// /// Verify that the SQL parser correctly detects errors in text /// [Test] public async Task ConnectToDatabaseTest() { // connect to a database instance string ownerUri = "file://my/sample/file.sql"; var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); } /// /// Verify that we can disconnect from an active connection successfully /// [Test] public async Task DisconnectFromDatabaseTest() { // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); // send disconnect request var disconnectResult = connectionService .Disconnect(new DisconnectParams() { OwnerUri = ownerUri }); Assert.True(disconnectResult); } /// /// Test that when a disconnect is performed, the callback event is fired /// [Test] public async Task DisconnectFiresCallbackEvent() { bool callbackInvoked = false; // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); // register disconnect callback connectionService.RegisterOnDisconnectTask( (result, uri) => { callbackInvoked = true; Assert.True(uri.Equals(ownerUri)); return Task.FromResult(true); } ); // send disconnect request var disconnectResult = connectionService .Disconnect(new DisconnectParams() { OwnerUri = ownerUri }); Assert.True(disconnectResult); // verify that the event was fired Assert.True(callbackInvoked); } /// /// Test that disconnecting an active connection removes the Owner URI -> ConnectionInfo mapping /// [Test] public async Task DisconnectRemovesOwnerMapping() { // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); // verify that we are connected Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); // check that the owner mapping exists ConnectionInfo info; Assert.True(connectionService.TryFindConnection(ownerUri, out info)); // send disconnect request var disconnectResult = connectionService .Disconnect(new DisconnectParams() { OwnerUri = ownerUri }); Assert.True(disconnectResult); // check that the owner mapping no longer exists Assert.False(connectionService.TryFindConnection(ownerUri, out info)); } /// /// Test that disconnecting validates parameters and doesn't succeed when they are invalid /// [Test] public async Task DisconnectValidatesParameters([Values("", null)] string disconnectUri) { // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); // verify that we are connected Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); // send disconnect request var disconnectResult = connectionService .Disconnect(new DisconnectParams() { OwnerUri = disconnectUri }); // verify that disconnect failed Assert.False(disconnectResult); } private async Task RunListDatabasesRequestHandler(TestResultSet testdata, bool? includeDetails) { // Setup mock connection factory to inject query results var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(CreateMockDbConnection(new[] { testdata })); var connectionService = new ConnectionService(mockFactory.Object); // connect to a database instance string ownerUri = "file://my/sample/file.sql"; var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); // verify that a valid connection id was returned Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); // list databases for the connection ListDatabasesParams parameters = new ListDatabasesParams { OwnerUri = ownerUri, IncludeDetails = includeDetails }; return connectionService.ListDatabases(parameters); } /// /// Verifies the the list databases operation lists database names for the server used by a connection. /// [Test] public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames() { // Result set for the query of database names TestDbColumn[] cols = { new TestDbColumn("name") }; object[][] rows = { new object[] {"mydatabase"}, new object[] {"master"}, new object[] {"model"}, new object[] {"msdb"}, new object[] {"tempdb"} }; TestResultSet data = new TestResultSet(cols, rows); var response = await RunListDatabasesRequestHandler(testdata: data, includeDetails: null); string[] databaseNames = response.DatabaseNames; Assert.AreEqual(5, databaseNames.Length); Assert.AreEqual("mydatabase", databaseNames[0]); Assert.AreEqual("master", databaseNames[1]); Assert.AreEqual("model", databaseNames[2]); Assert.AreEqual("msdb", databaseNames[3]); Assert.AreEqual("tempdb", databaseNames[4]); } /// /// Verifies the the list databases operation lists database names for the server used by a connection. /// [Test] public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseDetails() { // Result set for the query of database names TestDbColumn[] cols = { new TestDbColumn("name"), new TestDbColumn("state"), new TestDbColumn("size"), new TestDbColumn("last_backup") }; object[][] rows = { new object[] {"mydatabase", "Online", "10", "2010-01-01 11:11:11"}, new object[] {"master", "Online", "11", "2010-01-01 11:11:12"}, new object[] {"model", "Offline", "12", "2010-01-01 11:11:13"}, new object[] {"msdb", "Online", "13", "2010-01-01 11:11:14"}, new object[] {"tempdb", "Online", "14", "2010-01-01 11:11:15"} }; TestResultSet data = new TestResultSet(cols, rows); var response = await RunListDatabasesRequestHandler(testdata: data, includeDetails: true); Assert.AreEqual(5, response.Databases.Length); VerifyDatabaseDetail(rows[0], response.Databases[0]); VerifyDatabaseDetail(rows[1], response.Databases[1]); VerifyDatabaseDetail(rows[2], response.Databases[2]); VerifyDatabaseDetail(rows[3], response.Databases[3]); VerifyDatabaseDetail(rows[4], response.Databases[4]); } private void VerifyDatabaseDetail(object[] expected, DatabaseInfo actual) { Assert.AreEqual(expected[0], actual.Options[ListDatabasesRequestDatabaseProperties.Name]); Assert.AreEqual(expected[1], actual.Options[ListDatabasesRequestDatabaseProperties.State]); Assert.AreEqual(expected[2], actual.Options[ListDatabasesRequestDatabaseProperties.SizeInMB]); Assert.AreEqual(expected[3], actual.Options[ListDatabasesRequestDatabaseProperties.LastBackup]); } /// /// Verify that the factory is returning DatabaseNamesHandler /// [Test] public void ListDatabaseRequestFactoryReturnsDatabaseNamesHandler() { var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: false, isSqlDB: true); Assert.That(handler, Is.InstanceOf()); handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: false, isSqlDB: false); Assert.That(handler, Is.InstanceOf()); } /// /// Verify that the factory is returning SqlDBDatabaseDetailHandler /// [Test] public void ListDatabaseRequestFactoryReturnsSqlDBHandler() { var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: true, isSqlDB: true); Assert.That(handler, Is.InstanceOf()); } /// /// Verify that the factory is returning SqlServerDatabaseDetailHandler /// [Test] public void ListDatabaseRequestFactoryReturnsSqlServerHandler() { var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: true, isSqlDB: false); Assert.That(handler, Is.InstanceOf()); } /// /// Verify that the SQL parser correctly detects errors in text /// [Test] public async Task OnConnectionCallbackHandlerTest() { bool callbackInvoked = false; // setup connection service with callback var connectionService = TestObjects.GetTestConnectionService(); connectionService.RegisterOnConnectionTask( (sqlConnection) => { callbackInvoked = true; return Task.FromResult(true); } ); // connect to a database instance await connectionService.Connect(TestObjects.GetTestConnectionParams()); // verify that a valid connection id was returned Assert.True(callbackInvoked); } /// /// Test ConnectionSummaryComparer /// [Test] public void TestConnectionSummaryComparer() { var summary1 = new ConnectionSummary() { ServerName = "localhost", DatabaseName = "master", UserName = "user" }; var summary2 = new ConnectionSummary() { ServerName = "localhost", DatabaseName = "master", UserName = "user" }; var comparer = new ConnectionSummaryComparer(); Assert.True(comparer.Equals(summary1, summary2)); summary2.DatabaseName = "tempdb"; Assert.False(comparer.Equals(summary1, summary2)); Assert.False(comparer.Equals(null, summary2)); Assert.False(summary1.GetHashCode() == summary2.GetHashCode()); } /// /// Verify when a connection is created that the URI -> Connection mapping is created in the connection service. /// [Test] public async Task TestConnectRequestRegistersOwner() { // Given a request to connect to a database var service = TestObjects.GetTestConnectionService(); var connectParams = TestObjects.GetTestConnectionParams(); // connect to a database instance var connectionResult = await service.Connect(connectParams); // verify that a valid connection id was returned Assert.NotNull(connectionResult.ConnectionId); Assert.That(connectionResult.ConnectionId, Is.Not.EqualTo(string.Empty)); Assert.NotNull(new Guid(connectionResult.ConnectionId)); // verify that the (URI -> connection) mapping was created ConnectionInfo info; Assert.True(service.TryFindConnection(connectParams.OwnerUri, out info)); } /// /// Verify that Linux/OSX SqlExceptions thrown do not contain an error code. /// This is a bug in .NET core (see https://github.com/dotnet/corefx/issues/12472). /// If this test ever fails, it means that this bug has been fixed. When this is /// the case, look at RetryPolicyUtils.cs in IsRetryableNetworkConnectivityError(), /// and remove the code block specific to Linux/OSX. /// [Test] public void TestThatLinuxAndOsxSqlExceptionHasNoErrorCode() { RunIfWrapper.RunIfLinuxOrOSX(() => { try { SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder { DataSource = "bad-server-name", UserID = "sa", Password = "bad password" }; SqlConnection connection = new SqlConnection(builder.ConnectionString); connection.Open(); // This should fail } catch (SqlException ex) { // Error code should be 0 due to bug Assert.AreEqual(0, ex.Number); } }); } /// /// Test that cancel connection with a null connection parameter /// [Test] public void TestCancelConnectionNullParam() { var service = TestObjects.GetTestConnectionService(); Assert.False(service.CancelConnect(null)); } /// /// Test that cancel connection with a null connection parameter /// [Test] public void TestListDatabasesInvalidParams() { var service = TestObjects.GetTestConnectionService(); var listParams = new ListDatabasesParams(); Assert.Throws(() => service.ListDatabases(listParams)); listParams.OwnerUri = "file://notmyfile.sql"; Assert.Throws(() => service.ListDatabases(listParams)); } /// /// Test that the connection complete notification type can be created. /// [Test] public void TestConnectionCompleteNotificationIsCreated() { Assert.NotNull(ConnectionCompleteNotification.Type); } /// /// Test that the connection summary comparer creates a hash code correctly /// [Test] public void TestConnectionSummaryComparerHashCode([Values] bool objectNull, [Values(null, "server")] string serverName, [Values(null, "test")] string databaseName, [Values(null, "sa")] string userName) { // Given a connection summary and comparer object ConnectionSummary summary = null; if (!objectNull) { summary = new ConnectionSummary() { ServerName = serverName, DatabaseName = databaseName, UserName = userName }; } ConnectionSummaryComparer comparer = new ConnectionSummaryComparer(); // If I compute a hash code int hashCode = comparer.GetHashCode(summary); if (summary == null || (serverName == null && databaseName == null && userName == null)) { Assert.AreEqual(31, hashCode, "I expect it to be 31 for a null summary"); } else { Assert.That(hashCode, Is.Not.EqualTo(31), "And not 31 otherwise"); } } [Test] public void ConnectParamsAreInvalidIfConnectionIsNull() { // Given connection parameters where the connection property is null ConnectParams parameters = new ConnectParams { OwnerUri = "my/sql/file.sql", Connection = null }; string errorMessage; // If I check if the parameters are valid Assert.False(parameters.IsValid(out errorMessage)); Assert.That(errorMessage, Is.Not.Null.Or.Empty, "Then I expect an error message"); } [Test] public async Task ConnectingTwiceWithTheSameUriDoesNotCreateAnotherDbConnection() { // Setup the connect and disconnect params var connectParamsSame1 = new ConnectParams() { OwnerUri = "connectParamsSame", Connection = TestObjects.GetTestConnectionDetails() }; var connectParamsSame2 = new ConnectParams() { OwnerUri = "connectParamsSame", Connection = TestObjects.GetTestConnectionDetails() }; var disconnectParamsSame = new DisconnectParams() { OwnerUri = connectParamsSame1.OwnerUri }; var connectParamsDifferent = new ConnectParams() { OwnerUri = "connectParamsDifferent", Connection = TestObjects.GetTestConnectionDetails() }; var disconnectParamsDifferent = new DisconnectParams() { OwnerUri = connectParamsDifferent.OwnerUri }; // Given a request to connect to a database, there should be no initial connections in the map var service = TestObjects.GetTestConnectionService(); Dictionary ownerToConnectionMap = service.OwnerToConnectionMap; Assert.AreEqual(0, ownerToConnectionMap.Count); // If we connect to the service, there should be 1 connection await service.Connect(connectParamsSame1); Assert.AreEqual(1, ownerToConnectionMap.Count); // If we connect again with the same URI, there should still be 1 connection await service.Connect(connectParamsSame2); Assert.AreEqual(1, ownerToConnectionMap.Count); // If we connect with a different URI, there should be 2 connections await service.Connect(connectParamsDifferent); Assert.AreEqual(2, ownerToConnectionMap.Count); // If we disconnect with the unique URI, there should be 1 connection service.Disconnect(disconnectParamsDifferent); Assert.AreEqual(1, ownerToConnectionMap.Count); // If we disconnect with the duplicate URI, there should be 0 connections service.Disconnect(disconnectParamsSame); Assert.AreEqual(0, ownerToConnectionMap.Count); } [Test] public async Task DbConnectionDoesntLeakUponDisconnect() { // If we connect with a single URI and 2 connection types var connectParamsDefault = new ConnectParams() { OwnerUri = "connectParams", Connection = TestObjects.GetTestConnectionDetails(), Type = ConnectionType.Default }; var connectParamsQuery = new ConnectParams() { OwnerUri = "connectParams", Connection = TestObjects.GetTestConnectionDetails(), Type = ConnectionType.Query }; var disconnectParams = new DisconnectParams() { OwnerUri = connectParamsDefault.OwnerUri }; var service = TestObjects.GetTestConnectionService(); await service.Connect(connectParamsDefault); await service.Connect(connectParamsQuery); // We should have one ConnectionInfo and 2 DbConnections ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri]; Assert.AreEqual(2, connectionInfo.CountConnections); Assert.AreEqual(1, service.OwnerToConnectionMap.Count); // If we record when the Default connecton calls Close() bool defaultDisconnectCalled = false; var mockDefaultConnection = new Mock { CallBase = true }; mockDefaultConnection.Setup(x => x.Close()) .Callback(() => { defaultDisconnectCalled = true; }); connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Default] = mockDefaultConnection.Object; // And when the Query connecton calls Close() bool queryDisconnectCalled = false; var mockQueryConnection = new Mock { CallBase = true }; mockQueryConnection.Setup(x => x.Close()) .Callback(() => { queryDisconnectCalled = true; }); connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Query] = mockQueryConnection.Object; // If we disconnect all open connections with the same URI as used above service.Disconnect(disconnectParams); // Close() should have gotten called for both DbConnections Assert.True(defaultDisconnectCalled); Assert.True(queryDisconnectCalled); // And the maps that hold connection data should be empty Assert.AreEqual(0, connectionInfo.CountConnections); Assert.AreEqual(0, service.OwnerToConnectionMap.Count); } [Test] public async Task ClosingQueryConnectionShouldLeaveDefaultConnectionOpen() { // Setup the connect and disconnect params var connectParamsDefault = new ConnectParams() { OwnerUri = "connectParamsSame", Connection = TestObjects.GetTestConnectionDetails(), Type = ConnectionType.Default }; var connectParamsQuery = new ConnectParams() { OwnerUri = connectParamsDefault.OwnerUri, Connection = TestObjects.GetTestConnectionDetails(), Type = ConnectionType.Query }; var disconnectParamsQuery = new DisconnectParams() { OwnerUri = connectParamsDefault.OwnerUri, Type = connectParamsQuery.Type }; // If I connect a Default and a Query connection var service = TestObjects.GetTestConnectionService(); await service.Connect(connectParamsDefault); await service.Connect(connectParamsQuery); ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri]; // There should be 2 connections in the map Assert.AreEqual(2, connectionInfo.CountConnections); // If I Disconnect only the Query connection, there should be 1 connection in the map service.Disconnect(disconnectParamsQuery); Assert.AreEqual(1, connectionInfo.CountConnections); // If I reconnect, there should be 2 again await service.Connect(connectParamsQuery); Assert.AreEqual(2, connectionInfo.CountConnections); } [Test] public void GetOrOpenNullOwnerUri([Values(null, "")] string ownerUri) { // If: I have a connection service and I ask for a connection with an invalid ownerUri // Then: An exception should be thrown var service = TestObjects.GetTestConnectionService(); Assert.ThrowsAsync( () => service.GetOrOpenConnection(ownerUri, ConnectionType.Default)); } [Test] public void GetOrOpenNullConnectionType([Values(null, "")] string connType) { // If: I have a connection service and I ask for a connection with an invalid connectionType // Then: An exception should be thrown var service = TestObjects.GetTestConnectionService(); Assert.ThrowsAsync( () => service.GetOrOpenConnection(TestObjects.ScriptUri, connType)); } [Test] public void GetOrOpenNoConnection() { // If: I have a connection service and I ask for a connection for an unconnected uri // Then: An exception should be thrown var service = TestObjects.GetTestConnectionService(); Assert.ThrowsAsync( () => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query)); } [Test] public void GetOrOpenNoDefaultConnection() { // Setup: Create a connection service with an empty connection info obj var service = TestObjects.GetTestConnectionService(); var connInfo = new ConnectionInfo(null, null, null); service.OwnerToConnectionMap[TestObjects.ScriptUri] = connInfo; // If: I ask for a connection on a connection that doesn't have a default connection // Then: An exception should be thrown Assert.ThrowsAsync( () => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query)); } [Test] public void GetOrOpenAdminDefaultConnection() { // Setup: Create a connection service with an empty connection info obj var service = TestObjects.GetTestConnectionService(); var connInfo = new ConnectionInfo(null, null, null); service.OwnerToConnectionMap[TestObjects.ScriptUri] = connInfo; // If: I ask for a connection on a connection that doesn't have a default connection // Then: An exception should be thrown Assert.ThrowsAsync( () => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query)); } [Test] public async Task ConnectionWithAdminConnectionEnsuresOnlyOneConnectionCreated() { // If I try to connect using a connection string, it overrides the server name and username for the connection ConnectParams connectionParameters = TestObjects.GetTestConnectionParams(); string serverName = "ADMIN:overriddenServerName"; string userName = "overriddenUserName"; connectionParameters.Connection.ServerName = serverName; connectionParameters.Connection.UserName = userName; // Connect ConnectionService service = TestObjects.GetTestConnectionService(); var connectionResult = await service.Connect(connectionParameters); // Verify you can get the connection for default DbConnection defaultConn = await service.GetOrOpenConnection(connectionParameters.OwnerUri, ConnectionType.Default); ConnectionInfo connInfo = service.OwnerToConnectionMap[connectionParameters.OwnerUri]; Assert.NotNull(defaultConn); Assert.AreEqual(1, connInfo.AllConnections.Count); // Verify that for the Query, no new connection is created DbConnection queryConn = await service.GetOrOpenConnection(connectionParameters.OwnerUri, ConnectionType.Query); connInfo = service.OwnerToConnectionMap[connectionParameters.OwnerUri]; Assert.NotNull(defaultConn); Assert.AreEqual(1, connInfo.AllConnections.Count); // Verify that if the query connection was closed, it will be reopened on requesting the connection again Assert.AreEqual(ConnectionState.Open, queryConn.State); queryConn.Close(); Assert.AreEqual(ConnectionState.Closed, queryConn.State); queryConn = await service.GetOrOpenConnection(connectionParameters.OwnerUri, ConnectionType.Query); Assert.AreEqual(ConnectionState.Open, queryConn.State); } [Test] public async Task ConnectionWithConnectionStringSucceeds() { // If I connect using a connection string instead of the normal parameters, the connection succeeds var connectionParameters = TestObjects.GetTestConnectionParams(true); var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); } [Test] public async Task ConnectionWithBadConnectionStringFails() { // If I try to connect using an invalid connection string, the connection fails var connectionParameters = TestObjects.GetTestConnectionParams(true); connectionParameters.Connection.ConnectionString = "thisisnotavalidconnectionstring"; var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); Assert.That(connectionResult.ErrorMessage, Is.Not.Empty); } [Test] public async Task ConnectionWithConnectionStringOverridesServerInfo() { // If I try to connect using a connection string, it overrides the server name and username for the connection var connectionParameters = TestObjects.GetTestConnectionParams(); var serverName = "overriddenServerName"; var userName = "overriddenUserName"; connectionParameters.Connection.ServerName = serverName; connectionParameters.Connection.UserName = userName; var connectionString = TestObjects.GetTestConnectionParams(true).Connection.ConnectionString; connectionParameters.Connection.ConnectionString = connectionString; // Connect and verify that the connectionParameters object's server name and username have been overridden var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); Assert.That(connectionResult.ConnectionSummary.ServerName, Is.Not.EqualTo(serverName)); Assert.That(connectionResult.ConnectionSummary.UserName, Is.Not.EqualTo(userName)); } [Test] public async Task OtherParametersOverrideConnectionString() { // If I try to connect using a connection string, and set parameters other than the server name, username, or password, // they override the values in the connection string. var connectionParameters = TestObjects.GetTestConnectionParams(); var databaseName = "overriddenDatabaseName"; connectionParameters.Connection.DatabaseName = databaseName; var connectionString = TestObjects.GetTestConnectionParams(true).Connection.ConnectionString; connectionParameters.Connection.ConnectionString = connectionString; // Connect and verify that the connection string's database name has been overridden var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); Assert.AreEqual(databaseName, connectionResult.ConnectionSummary.DatabaseName); } [Test] public async Task CanChangeDatabase() { string ownerUri = "file://my/sample/file.sql"; const string masterDbName = "master"; const string otherDbName = "other"; // Given a connection that returns the database name var connection = new TestSqlConnection(null); var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns((string connString, string azureAccountToken, SqlRetryLogicBaseProvider retryProvider) => { connection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); // Database name is respected. Follow heuristic where empty DB name really means Master var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; connection.SetDatabase(dbName); return connection; }); var connectionService = new ConnectionService(mockFactory.Object); // When I connect to default var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); connection.SetState(ConnectionState.Open); connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName); Assert.AreEqual(otherDbName, connection.Database); } [Test] public async Task CanChangeDatabaseAzure() { string ownerUri = "file://my/sample/file.sql"; const string masterDbName = "master"; const string otherDbName = "other"; string dbName = masterDbName; // Given a connection that returns the database name var mockConnection = new Mock(); mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny())) .Throws(new Exception()); mockConnection.SetupGet(conn => conn.Database).Returns(dbName); mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open); mockConnection.Setup(conn => conn.Close()); mockConnection.Setup(conn => conn.Open()); var connection = mockConnection.Object; var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns((string connString, string azureAccountToken, SqlRetryLogicBaseProvider retryProvider) => { connection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); // Database name is respected. Follow heuristic where empty DB name really means Master dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; return connection; }); var connectionService = new ConnectionService(mockFactory.Object); // When I connect to default var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); ConnectionInfo testInfo; connectionService.TryFindConnection(ownerUri, out testInfo); Assert.NotNull(testInfo); testInfo.IsCloud = true; connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName, true); Assert.AreEqual(otherDbName, dbName); } [Test] public async Task ReturnsFalseIfNotForced() { string ownerUri = "file://my/sample/file.sql"; const string defaultDbName = "databaseName"; const string otherDbName = "other"; string dbName = defaultDbName; // Given a connection that returns the database name var mockConnection = new Mock(); mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny())) .Throws(new Exception()); mockConnection.SetupGet(conn => conn.Database).Returns(dbName); mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open); mockConnection.Setup(conn => conn.Close()); mockConnection.Setup(conn => conn.Open()); var connection = mockConnection.Object; var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns((string connString, string azureAccountToken, SqlRetryLogicBaseProvider retryProvider) => { connection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); // Database name is respected. Follow heuristic where empty DB name really means Master dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? defaultDbName : scsb.InitialCatalog; return connection; }); var connectionService = new ConnectionService(mockFactory.Object); // When I connect to default var connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails() }); ConnectionInfo testInfo; connectionService.TryFindConnection(ownerUri, out testInfo); Assert.NotNull(testInfo); testInfo.IsCloud = true; Assert.False(connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName)); Assert.AreEqual(defaultDbName, dbName); } /// /// Test ParseConnectionString /// [Test] public void ParseConnectionStringTest() { // If we make a connection to a live database ConnectionService service = ConnectionService.Instance; var connectionString = "Server=tcp:{servername},1433;Initial Catalog={databasename};Persist Security Info=False;User ID={your_username};Password={your_password};MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;Command Timeout=30;HostNameInCertificate={servername}"; var details = service.ParseConnectionString(connectionString); Assert.That(details.ServerName, Is.EqualTo("tcp:{servername},1433"), "Unexpected server name"); Assert.That(details.DatabaseName, Is.EqualTo("{databasename}"), "Unexpected database name"); Assert.That(details.UserName, Is.EqualTo("{your_username}"), "Unexpected username"); Assert.That(details.Password, Is.EqualTo("{your_password}"), "Unexpected password"); Assert.That(details.PersistSecurityInfo, Is.Null, "Unexpected Persist Security Info"); Assert.That(details.MultipleActiveResultSets, Is.Null, "Unexpected Multiple Active Result Sets value"); Assert.That(details.Encrypt, Is.EqualTo("true"), "Unexpected Encrypt value"); Assert.That(details.TrustServerCertificate, Is.False, "Unexpected database name value"); Assert.That(details.HostNameInCertificate, Is.EqualTo("{servername}"), "Unexpected Host Name in Certificate value"); Assert.That(details.ConnectTimeout, Is.EqualTo(30), "Unexpected Connect Timeout value"); Assert.That(details.CommandTimeout, Is.EqualTo(30), "Unexpected CommandTimeout Timeout value"); } /// /// Test ParseConnectionString /// [Test] public void ParseConnectionStringTest_StrictEncryption() { // If we make a connection to a live database ConnectionService service = ConnectionService.Instance; var connectionString = "Server=tcp:{servername},1433;Initial Catalog={databasename};Persist Security Info=False;User ID={your_username};Password={your_password};MultipleActiveResultSets=False;Encrypt=Strict;TrustServerCertificate=False;Connection Timeout=30;Command Timeout=30;HostNameInCertificate={servername}"; var details = service.ParseConnectionString(connectionString); Assert.That(details.ServerName, Is.EqualTo("tcp:{servername},1433"), "Unexpected server name"); Assert.That(details.DatabaseName, Is.EqualTo("{databasename}"), "Unexpected database name"); Assert.That(details.UserName, Is.EqualTo("{your_username}"), "Unexpected username"); Assert.That(details.Password, Is.EqualTo("{your_password}"), "Unexpected password"); Assert.That(details.PersistSecurityInfo, Is.Null, "Unexpected Persist Security Info"); Assert.That(details.MultipleActiveResultSets, Is.Null, "Unexpected Multiple Active Result Sets value"); Assert.That(details.Encrypt, Is.EqualTo(SqlConnectionEncryptOption.Strict.ToString()), "Unexpected Encrypt value"); Assert.That(details.TrustServerCertificate, Is.False, "Unexpected database name value"); Assert.That(details.HostNameInCertificate, Is.EqualTo("{servername}"), "Unexpected Host Name in Certificate value"); Assert.That(details.ConnectTimeout, Is.EqualTo(30), "Unexpected Connect Timeout value"); Assert.That(details.CommandTimeout, Is.EqualTo(30), "Unexpected Command Timeout value"); } [Test] public async Task ConnectingWithAzureAccountUsesToken() { // Set up mock connection factory var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(new TestSqlConnection(null)); var connectionService = new ConnectionService(mockFactory.Object); var details = TestObjects.GetTestConnectionDetails(); var azureAccountToken = "testAzureAccountToken"; details.AzureAccountToken = azureAccountToken; details.UserName = ""; details.Password = ""; details.AuthenticationType = AzureMFA; // Open a connection using connection details that include an account token await connectionService.Connect(new ConnectParams { OwnerUri = "testURI", Connection = details }); // Validate that the connection factory gets called with details NOT including an account token mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny(), It.Is(accountToken => accountToken == azureAccountToken), It.IsAny()), Times.Once()); } /// /// Test is IsDbPool method correctly works for various database names /// [Test] public void CheckIsDbPool() { Assert.IsTrue(ConnectionService.IsDbPool("db@pool")); Assert.IsFalse(ConnectionService.IsDbPool("db")); Assert.IsFalse(ConnectionService.IsDbPool(null)); } /// /// Verify that providing an empty password to change password will fire an error. /// [Test] public void ConnectionEmptyPasswordChange() { var serviceHostMock = new Mock(); var connectionService = ConnectionService.Instance; connectionService.ServiceHost = serviceHostMock.Object; // Set up an initial connection const string ownerUri = "file://my/sample/file.sql"; ChangePasswordParams testConnectionParams = new ChangePasswordParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails(), NewPassword = "" }; Assert.Throws(() => connectionService.ChangePassword(testConnectionParams)); } /// /// Verify that providing an invalid connection parameter value to change password will fire an error. /// [Test] public void ConnectionInvalidParamPasswordChange() { var serviceHostMock = new Mock(); var connectionService = ConnectionService.Instance; connectionService.ServiceHost = serviceHostMock.Object; // Set up an initial connection const string ownerUri = "file://my/sample/file.sql"; ChangePasswordParams testConnectionParams = new ChangePasswordParams() { OwnerUri = ownerUri, Connection = { }, NewPassword = "TestPassword" }; Assert.Throws(() => connectionService.ChangePassword(testConnectionParams)); } /// /// Verify that providing a non actual connection and a fake password to change password will throw an error. /// [Test] public void InvalidConnectionPasswordChange() { var serviceHostMock = new Mock(); var connectionService = ConnectionService.Instance; connectionService.ServiceHost = serviceHostMock.Object; // Set up an initial connection const string ownerUri = "file://my/sample/file.sql"; ChangePasswordParams testConnectionParams = new ChangePasswordParams() { OwnerUri = ownerUri, Connection = TestObjects.GetTestConnectionDetails(), NewPassword = "TestPassword" }; Assert.Throws( () => connectionService.ChangePassword(testConnectionParams)); } } }