// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; using Microsoft.SqlTools.ResourceProvider.Core.Firewall; using Moq; using NUnit.Framework; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider { /// /// Tests to verify FirewallRuleService by mocking the azure authentication and resource managers /// public class FirewallRuleServiceTest { [Test] public async Task CreateShouldThrowExceptionGivenNullServerName() { string serverName = null; ServiceTestContext testContext = new ServiceTestContext(); Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName)); } [Test] public async Task CreateShouldThrowExceptionGivenNullStartIp() { string serverName = "serverName"; ServiceTestContext testContext = new ServiceTestContext(); testContext.StartIpAddress = null; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName)); } [Test] public async Task CreateShouldThrowExceptionGivenInvalidEndIp() { string serverName = "serverName"; ServiceTestContext testContext = new ServiceTestContext(); testContext.EndIpAddress = "invalid ip"; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName)); } [Test] public async Task CreateShouldThrowExceptionGivenInvalidStartIp() { string serverName = "serverName"; ServiceTestContext testContext = new ServiceTestContext(); testContext.StartIpAddress = "invalid ip"; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName)); } [Test] public async Task CreateShouldThrowExceptionGivenNullEndIp() { ServiceTestContext testContext = new ServiceTestContext(); testContext.EndIpAddress = null; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName)); } [Test] public async Task CreateShouldThrowExceptionIfUserIsNotLoggedIn() { var applicationAuthenticationManagerMock = new Mock(); applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()).Throws(new ApplicationException()); var azureResourceManagerMock = new Mock(); ServiceTestContext testContext = new ServiceTestContext(); testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock; testContext.AzureResourceManagerMock = azureResourceManagerMock; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName)); azureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync( It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); } [Test] public async Task CreateShouldThrowExceptionIfUserDoesNotHaveSubscriptions() { var applicationAuthenticationManagerMock = new Mock(); applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()).Returns(Task.FromResult(false)); applicationAuthenticationManagerMock.Setup(x => x.GetSubscriptionsAsync()) .Returns(Task.FromResult(Enumerable.Empty())); var azureResourceManagerMock = new Mock(); ServiceTestContext testContext = new ServiceTestContext(); testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock; testContext.AzureResourceManagerMock = azureResourceManagerMock; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName)); azureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync( It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); } [Test] public async Task CreateShouldThrowExceptionIfAuthenticationManagerFailsToReturnSubscription() { var applicationAuthenticationManagerMock = new Mock(); applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()).Returns(Task.FromResult(false)); applicationAuthenticationManagerMock.Setup(x => x.GetSubscriptionsAsync()).Throws(new Exception()); var azureResourceManagerMock = new Mock(); ServiceTestContext testContext = new ServiceTestContext(); testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock; testContext.AzureResourceManagerMock = azureResourceManagerMock; Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, "invalid server")); azureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync( It.IsAny(), It.IsAny(), It.IsAny()), Times.Never); } [Test] public async Task CreateShouldThrowExceptionGivenNoSubscriptionFound() { ServiceTestContext testContext = new ServiceTestContext(); testContext = CreateMocks(testContext); Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, "invalid server")); } [Test] public async Task CreateShouldCreateFirewallSuccessfullyGivenValidUserAccount() { ServiceTestContext testContext = new ServiceTestContext(); testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateShouldFindTheRightSubscriptionGivenValidSubscriptionInFirstPlace() { ServiceTestContext testContext = new ServiceTestContext(); testContext.Subscriptions = new List { testContext.ValidSubscription, ServiceTestContext.CreateSubscriptionContext(), ServiceTestContext.CreateSubscriptionContext(), }; testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateShouldFindTheRightSubscriptionGivenValidSubscriptionInSecondPlace() { ServiceTestContext testContext = new ServiceTestContext(); testContext.Subscriptions = new List { ServiceTestContext.CreateSubscriptionContext(), testContext.ValidSubscription, ServiceTestContext.CreateSubscriptionContext(), }; testContext.Initialize(); testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateShouldFindTheRightSubscriptionGivenValidSubscriptionInLastPlace() { ServiceTestContext testContext = new ServiceTestContext(); testContext.Subscriptions = new List { ServiceTestContext.CreateSubscriptionContext(), ServiceTestContext.CreateSubscriptionContext(), testContext.ValidSubscription }; testContext.Initialize(); testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateShouldFindTheRightResourceGivenValidResourceInLastPlace() { ServiceTestContext testContext = new ServiceTestContext(); var resources = new List { ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()), ServiceTestContext.CreateAzureSqlServer(testContext.ServerName), }; testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources; testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateShouldFindTheRightResourceGivenValidResourceInFirstPlace() { ServiceTestContext testContext = new ServiceTestContext(); var resources = new List { ServiceTestContext.CreateAzureSqlServer(testContext.ServerName), ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()), }; testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources; testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateShouldFindTheRightResourceGivenValidResourceInMiddle() { ServiceTestContext testContext = new ServiceTestContext(); var resources = new List { ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()), ServiceTestContext.CreateAzureSqlServer(testContext.ServerName), ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()) }; testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources; testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } [Test] public async Task CreateThrowExceptionIfResourceNotFound() { ServiceTestContext testContext = new ServiceTestContext(); var resources = new List { ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()), ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()), }; testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources; testContext = CreateMocks(testContext); Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName)); } [Test] public async Task CreateThrowExceptionIfResourcesIsEmpty() { ServiceTestContext testContext = new ServiceTestContext(); testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = new List(); testContext = CreateMocks(testContext); Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName, false)); } [Test] public async Task CreateShouldThrowExceptionIfThereIsNoSubscriptionForUser() { ServiceTestContext testContext = new ServiceTestContext(); testContext.Subscriptions = new List(); testContext = CreateMocks(testContext); Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName, false)); } [Test] public async Task CreateShouldThrowExceptionIfSubscriptionIsInAnotherAccount() { ServiceTestContext testContext = new ServiceTestContext(); testContext.Subscriptions = new List { ServiceTestContext.CreateSubscriptionContext(), ServiceTestContext.CreateSubscriptionContext(), }; testContext = CreateMocks(testContext); Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName, false)); } [Test] public async Task CreateShouldCreateFirewallForTheRightServerFullyQualifiedName() { ServiceTestContext testContext = new ServiceTestContext(); string serverNameWithDifferentDomain = testContext.ServerNameWithoutDomain + ".myaliased.domain.name"; testContext.ServerName = serverNameWithDifferentDomain; testContext.Initialize(); testContext = CreateMocks(testContext); await VerifyCreateAsync(testContext, testContext.ServerName); } private async Task VerifyCreateAsync(ServiceTestContext testContext, string serverName, bool verifyFirewallRuleCreated = true) { try { FirewallRuleService service = new FirewallRuleService(); service.AuthenticationManager = testContext.ApplicationAuthenticationManager; service.ResourceManager = testContext.AzureResourceManager; FirewallRuleResponse response = await service.CreateFirewallRuleAsync(serverName, testContext.StartIpAddress, testContext.EndIpAddress); if (verifyFirewallRuleCreated) { testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync( It.Is(s => s.SubscriptionContext.Subscription.SubscriptionId == testContext.ValidSubscription.Subscription.SubscriptionId), It.Is(r => r.FullyQualifiedDomainName == serverName), It.Is(y => y.EndIpAddress.ToString().Equals(testContext.EndIpAddress) && y.StartIpAddress.ToString().Equals(testContext.StartIpAddress))), Times.AtLeastOnce); } else { testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync( It.Is(s => s.SubscriptionContext.Subscription.SubscriptionId == testContext.ValidSubscription.Subscription.SubscriptionId), It.Is(r => r.FullyQualifiedDomainName == serverName), It.Is(y => y.EndIpAddress.ToString().Equals(testContext.EndIpAddress) && y.StartIpAddress.ToString().Equals(testContext.StartIpAddress))), Times.Never); } return response; } catch (Exception ex) { if (ex is FirewallRuleException) { Assert.True(ex.InnerException == null || !(ex.InnerException is FirewallRuleException)); } throw; } } private ServiceTestContext CreateMocks(ServiceTestContext testContext) { var accountMock = new Mock(); accountMock.Setup(x => x.UniqueId).Returns(Guid.NewGuid().ToString()); var applicationAuthenticationManagerMock = new Mock(); applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()) .Returns(Task.FromResult(false)); applicationAuthenticationManagerMock.Setup(x => x.GetCurrentAccountAsync()).Returns(Task.FromResult(accountMock.Object)); applicationAuthenticationManagerMock.Setup(x => x.GetSubscriptionsAsync()).Returns(Task.FromResult(testContext.Subscriptions as IEnumerable)); var azureResourceManagerMock = new Mock(); CreateMocksForResources(testContext, azureResourceManagerMock); testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock; testContext.AzureResourceManagerMock = azureResourceManagerMock; return testContext; } private void CreateMocksForResources( ServiceTestContext testContext, Mock azureResourceManagerMock) { foreach (IAzureUserAccountSubscriptionContext subscription in testContext.Subscriptions) { var sessionMock = new Mock(); sessionMock.Setup(x => x.SubscriptionContext).Returns(subscription); azureResourceManagerMock.Setup(x => x.CreateSessionAsync(subscription)).Returns(Task.FromResult(sessionMock.Object)); List resources; if (testContext.SubscriptionToResourcesMap.TryGetValue(subscription.Subscription.SubscriptionId, out resources)) { azureResourceManagerMock.Setup(x => x.GetSqlServerAzureResourcesAsync(It.Is( m => m.SubscriptionContext.Subscription.SubscriptionId == subscription.Subscription.SubscriptionId))) .Returns(Task.FromResult(resources as IEnumerable)); } else { azureResourceManagerMock.Setup(x => x.GetSqlServerAzureResourcesAsync( It.Is(m => m.SubscriptionContext.Subscription.SubscriptionId == subscription.Subscription.SubscriptionId))) .Returns(Task.FromResult>(null)); } } azureResourceManagerMock .Setup(x => x.CreateFirewallRuleAsync( It.IsAny(), It.IsAny(), It.Is( y => y.EndIpAddress.ToString().Equals(testContext.EndIpAddress) && y.StartIpAddress.ToString().Equals(testContext.StartIpAddress)))) .Returns(Task.FromResult(new FirewallRuleResponse() {Created = true})); } } internal class ServiceTestContext { private string _validServerName = "validServerName.database.windows.net"; private string _startIpAddressValue = "1.2.3.6"; private string _endIpAddressValue = "1.2.3.6"; private Dictionary> _subscriptionToResourcesMap; public ServiceTestContext() { StartIpAddress = _startIpAddressValue; EndIpAddress = _endIpAddressValue; ServerName = _validServerName; Initialize(); } internal void Initialize() { CreateSubscriptions(); CreateAzureResources(); } internal static IAzureUserAccountSubscriptionContext CreateSubscriptionContext() { var subscriptionContext = new Mock(); var subscriptionMock = new Mock(); subscriptionMock.Setup(x => x.SubscriptionId).Returns(Guid.NewGuid().ToString()); subscriptionContext.Setup(x => x.Subscription).Returns(subscriptionMock.Object); return subscriptionContext.Object; } private void CreateSubscriptions() { if (Subscriptions == null || Subscriptions.Count == 0) { ValidSubscriptionMock = new Mock(); var subscriptionMock = new Mock(); subscriptionMock.Setup(x => x.SubscriptionId).Returns(Guid.NewGuid().ToString()); ValidSubscriptionMock.Setup(x => x.Subscription).Returns(subscriptionMock.Object); Subscriptions = new List { ValidSubscription, CreateSubscriptionContext(), CreateSubscriptionContext() }; } } internal void CreateAzureResources(Dictionary> subscriptionToResourcesMap = null) { _subscriptionToResourcesMap = new Dictionary>(); if (subscriptionToResourcesMap == null) { foreach (var subscriptionDetails in Subscriptions) { if (subscriptionDetails.Subscription.SubscriptionId == ValidSubscription.Subscription.SubscriptionId) { var resources = new List(); resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString())); resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString())); resources.Add(CreateAzureSqlServer(ServerName)); _subscriptionToResourcesMap.Add(ValidSubscription.Subscription.SubscriptionId, resources); } else { var resources = new List(); resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString())); resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString())); _subscriptionToResourcesMap.Add(subscriptionDetails.Subscription.SubscriptionId, resources); } } } else { _subscriptionToResourcesMap = subscriptionToResourcesMap; } } internal static IAzureSqlServerResource CreateAzureSqlServer(string serverName) { var azureSqlServer = new Mock(); azureSqlServer.Setup(x => x.Name).Returns(GetServerNameWithoutDomain(serverName)); azureSqlServer.Setup(x => x.FullyQualifiedDomainName).Returns(serverName); return azureSqlServer.Object; } internal Dictionary> SubscriptionToResourcesMap { get { return _subscriptionToResourcesMap; } } internal static string GetServerNameWithoutDomain(string serverName) { int index = serverName.IndexOf('.'); if (index > 0) { return serverName.Substring(0, index); } return serverName; } internal string StartIpAddress { get; set; } internal string EndIpAddress { get; set; } internal IList Subscriptions { get; set; } internal Mock ValidSubscriptionMock { get; set; } internal IAzureUserAccountSubscriptionContext ValidSubscription { get { return ValidSubscriptionMock.Object; } } internal string ServerName { get; set; } internal string ServerNameWithoutDomain { get { return GetServerNameWithoutDomain(ServerName); } } internal Mock ApplicationAuthenticationManagerMock { get; set; } internal IAzureAuthenticationManager ApplicationAuthenticationManager { get { return ApplicationAuthenticationManagerMock?.Object; } } internal Mock AzureResourceManagerMock { get; set; } internal IAzureResourceManager AzureResourceManager { get { return AzureResourceManagerMock?.Object; } } } }