//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ResourceProvider.Core;
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
using Microsoft.SqlTools.ResourceProvider.Core.Firewall;
using Moq;
using NUnit.Framework;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider
{
///
/// Tests to verify FirewallRuleService by mocking the azure authentication and resource managers
///
public class FirewallRuleServiceTest
{
[Test]
public async Task CreateShouldThrowExceptionGivenNullServerName()
{
string serverName = null;
ServiceTestContext testContext = new ServiceTestContext();
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName));
}
[Test]
public async Task CreateShouldThrowExceptionGivenNullStartIp()
{
string serverName = "serverName";
ServiceTestContext testContext = new ServiceTestContext();
testContext.StartIpAddress = null;
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName));
}
[Test]
public async Task CreateShouldThrowExceptionGivenInvalidEndIp()
{
string serverName = "serverName";
ServiceTestContext testContext = new ServiceTestContext();
testContext.EndIpAddress = "invalid ip";
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName));
}
[Test]
public async Task CreateShouldThrowExceptionGivenInvalidStartIp()
{
string serverName = "serverName";
ServiceTestContext testContext = new ServiceTestContext();
testContext.StartIpAddress = "invalid ip";
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, serverName));
}
[Test]
public async Task CreateShouldThrowExceptionGivenNullEndIp()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.EndIpAddress = null;
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName));
}
[Test]
public async Task CreateShouldThrowExceptionIfUserIsNotLoggedIn()
{
var applicationAuthenticationManagerMock = new Mock();
applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()).Throws(new ApplicationException());
var azureResourceManagerMock = new Mock();
ServiceTestContext testContext = new ServiceTestContext();
testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock;
testContext.AzureResourceManagerMock = azureResourceManagerMock;
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName));
azureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
It.IsAny(), It.IsAny(), It.IsAny()),
Times.Never);
}
[Test]
public async Task CreateShouldThrowExceptionIfUserDoesNotHaveSubscriptions()
{
var applicationAuthenticationManagerMock =
new Mock();
applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()).Returns(Task.FromResult(false));
applicationAuthenticationManagerMock.Setup(x => x.GetSubscriptionsAsync())
.Returns(Task.FromResult(Enumerable.Empty()));
var azureResourceManagerMock = new Mock();
ServiceTestContext testContext = new ServiceTestContext();
testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock;
testContext.AzureResourceManagerMock = azureResourceManagerMock;
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName));
azureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
It.IsAny(), It.IsAny(), It.IsAny()),
Times.Never);
}
[Test]
public async Task CreateShouldThrowExceptionIfAuthenticationManagerFailsToReturnSubscription()
{
var applicationAuthenticationManagerMock = new Mock();
applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync()).Returns(Task.FromResult(false));
applicationAuthenticationManagerMock.Setup(x => x.GetSubscriptionsAsync()).Throws(new Exception());
var azureResourceManagerMock = new Mock();
ServiceTestContext testContext = new ServiceTestContext();
testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock;
testContext.AzureResourceManagerMock = azureResourceManagerMock;
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, "invalid server"));
azureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
It.IsAny(), It.IsAny(), It.IsAny()),
Times.Never);
}
[Test]
public async Task CreateShouldThrowExceptionGivenNoSubscriptionFound()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext = CreateMocks(testContext);
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, "invalid server"));
}
[Test]
public async Task CreateShouldCreateFirewallSuccessfullyGivenValidUserAccount()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateShouldFindTheRightSubscriptionGivenValidSubscriptionInFirstPlace()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.Subscriptions = new List
{
testContext.ValidSubscription,
ServiceTestContext.CreateSubscriptionContext(),
ServiceTestContext.CreateSubscriptionContext(),
};
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateShouldFindTheRightSubscriptionGivenValidSubscriptionInSecondPlace()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.Subscriptions = new List
{
ServiceTestContext.CreateSubscriptionContext(),
testContext.ValidSubscription,
ServiceTestContext.CreateSubscriptionContext(),
};
testContext.Initialize();
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateShouldFindTheRightSubscriptionGivenValidSubscriptionInLastPlace()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.Subscriptions = new List
{
ServiceTestContext.CreateSubscriptionContext(),
ServiceTestContext.CreateSubscriptionContext(),
testContext.ValidSubscription
};
testContext.Initialize();
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateShouldFindTheRightResourceGivenValidResourceInLastPlace()
{
ServiceTestContext testContext = new ServiceTestContext();
var resources = new List
{
ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()),
ServiceTestContext.CreateAzureSqlServer(testContext.ServerName),
};
testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources;
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateShouldFindTheRightResourceGivenValidResourceInFirstPlace()
{
ServiceTestContext testContext = new ServiceTestContext();
var resources = new List
{
ServiceTestContext.CreateAzureSqlServer(testContext.ServerName),
ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()),
};
testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources;
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateShouldFindTheRightResourceGivenValidResourceInMiddle()
{
ServiceTestContext testContext = new ServiceTestContext();
var resources = new List
{
ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()),
ServiceTestContext.CreateAzureSqlServer(testContext.ServerName),
ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString())
};
testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources;
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
[Test]
public async Task CreateThrowExceptionIfResourceNotFound()
{
ServiceTestContext testContext = new ServiceTestContext();
var resources = new List
{
ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()),
ServiceTestContext.CreateAzureSqlServer(Guid.NewGuid().ToString()),
};
testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = resources;
testContext = CreateMocks(testContext);
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName));
}
[Test]
public async Task CreateThrowExceptionIfResourcesIsEmpty()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.SubscriptionToResourcesMap[testContext.ValidSubscription.Subscription.SubscriptionId] = new List();
testContext = CreateMocks(testContext);
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName, false));
}
[Test]
public async Task CreateShouldThrowExceptionIfThereIsNoSubscriptionForUser()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.Subscriptions = new List();
testContext = CreateMocks(testContext);
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName, false));
}
[Test]
public async Task CreateShouldThrowExceptionIfSubscriptionIsInAnotherAccount()
{
ServiceTestContext testContext = new ServiceTestContext();
testContext.Subscriptions = new List
{
ServiceTestContext.CreateSubscriptionContext(),
ServiceTestContext.CreateSubscriptionContext(),
};
testContext = CreateMocks(testContext);
Assert.ThrowsAsync(() => VerifyCreateAsync(testContext, testContext.ServerName, false));
}
[Test]
public async Task CreateShouldCreateFirewallForTheRightServerFullyQualifiedName()
{
ServiceTestContext testContext = new ServiceTestContext();
string serverNameWithDifferentDomain = testContext.ServerNameWithoutDomain + ".myaliased.domain.name";
testContext.ServerName = serverNameWithDifferentDomain;
testContext.Initialize();
testContext = CreateMocks(testContext);
await VerifyCreateAsync(testContext, testContext.ServerName);
}
private async Task VerifyCreateAsync(ServiceTestContext testContext, string serverName, bool verifyFirewallRuleCreated = true)
{
try
{
FirewallRuleService service = new FirewallRuleService();
service.AuthenticationManager = testContext.ApplicationAuthenticationManager;
service.ResourceManager = testContext.AzureResourceManager;
FirewallRuleResponse response = await service.CreateFirewallRuleAsync(serverName, testContext.StartIpAddress, testContext.EndIpAddress);
if (verifyFirewallRuleCreated)
{
testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
It.Is(s => s.SubscriptionContext.Subscription.SubscriptionId == testContext.ValidSubscription.Subscription.SubscriptionId),
It.Is(r => r.FullyQualifiedDomainName == serverName),
It.Is(y => y.EndIpAddress.ToString().Equals(testContext.EndIpAddress) && y.StartIpAddress.ToString().Equals(testContext.StartIpAddress))),
Times.AtLeastOnce);
}
else
{
testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
It.Is(s => s.SubscriptionContext.Subscription.SubscriptionId == testContext.ValidSubscription.Subscription.SubscriptionId),
It.Is(r => r.FullyQualifiedDomainName == serverName),
It.Is(y => y.EndIpAddress.ToString().Equals(testContext.EndIpAddress) && y.StartIpAddress.ToString().Equals(testContext.StartIpAddress))),
Times.Never);
}
return response;
}
catch (Exception ex)
{
if (ex is FirewallRuleException)
{
Assert.True(ex.InnerException == null || !(ex.InnerException is FirewallRuleException));
}
throw;
}
}
private ServiceTestContext CreateMocks(ServiceTestContext testContext)
{
var accountMock = new Mock();
accountMock.Setup(x => x.UniqueId).Returns(Guid.NewGuid().ToString());
var applicationAuthenticationManagerMock = new Mock();
applicationAuthenticationManagerMock.Setup(x => x.GetUserNeedsReauthenticationAsync())
.Returns(Task.FromResult(false));
applicationAuthenticationManagerMock.Setup(x => x.GetCurrentAccountAsync()).Returns(Task.FromResult(accountMock.Object));
applicationAuthenticationManagerMock.Setup(x => x.GetSubscriptionsAsync()).Returns(Task.FromResult(testContext.Subscriptions as IEnumerable));
var azureResourceManagerMock = new Mock();
CreateMocksForResources(testContext, azureResourceManagerMock);
testContext.ApplicationAuthenticationManagerMock = applicationAuthenticationManagerMock;
testContext.AzureResourceManagerMock = azureResourceManagerMock;
return testContext;
}
private void CreateMocksForResources(
ServiceTestContext testContext,
Mock azureResourceManagerMock)
{
foreach (IAzureUserAccountSubscriptionContext subscription in testContext.Subscriptions)
{
var sessionMock = new Mock();
sessionMock.Setup(x => x.SubscriptionContext).Returns(subscription);
azureResourceManagerMock.Setup(x => x.CreateSessionAsync(subscription)).Returns(Task.FromResult(sessionMock.Object));
List resources;
if (testContext.SubscriptionToResourcesMap.TryGetValue(subscription.Subscription.SubscriptionId,
out resources))
{
azureResourceManagerMock.Setup(x => x.GetSqlServerAzureResourcesAsync(It.Is(
m => m.SubscriptionContext.Subscription.SubscriptionId == subscription.Subscription.SubscriptionId)))
.Returns(Task.FromResult(resources as IEnumerable));
}
else
{
azureResourceManagerMock.Setup(x => x.GetSqlServerAzureResourcesAsync(
It.Is(m => m.SubscriptionContext.Subscription.SubscriptionId == subscription.Subscription.SubscriptionId)))
.Returns(Task.FromResult>(null));
}
}
azureResourceManagerMock
.Setup(x => x.CreateFirewallRuleAsync(
It.IsAny(),
It.IsAny(),
It.Is(
y => y.EndIpAddress.ToString().Equals(testContext.EndIpAddress)
&& y.StartIpAddress.ToString().Equals(testContext.StartIpAddress))))
.Returns(Task.FromResult(new FirewallRuleResponse() {Created = true}));
}
}
internal class ServiceTestContext
{
private string _validServerName = "validServerName.database.windows.net";
private string _startIpAddressValue = "1.2.3.6";
private string _endIpAddressValue = "1.2.3.6";
private Dictionary> _subscriptionToResourcesMap;
public ServiceTestContext()
{
StartIpAddress = _startIpAddressValue;
EndIpAddress = _endIpAddressValue;
ServerName = _validServerName;
Initialize();
}
internal void Initialize()
{
CreateSubscriptions();
CreateAzureResources();
}
internal static IAzureUserAccountSubscriptionContext CreateSubscriptionContext()
{
var subscriptionContext = new Mock();
var subscriptionMock = new Mock();
subscriptionMock.Setup(x => x.SubscriptionId).Returns(Guid.NewGuid().ToString());
subscriptionContext.Setup(x => x.Subscription).Returns(subscriptionMock.Object);
return subscriptionContext.Object;
}
private void CreateSubscriptions()
{
if (Subscriptions == null || Subscriptions.Count == 0)
{
ValidSubscriptionMock = new Mock();
var subscriptionMock = new Mock();
subscriptionMock.Setup(x => x.SubscriptionId).Returns(Guid.NewGuid().ToString());
ValidSubscriptionMock.Setup(x => x.Subscription).Returns(subscriptionMock.Object);
Subscriptions = new List
{
ValidSubscription,
CreateSubscriptionContext(),
CreateSubscriptionContext()
};
}
}
internal void CreateAzureResources(Dictionary> subscriptionToResourcesMap = null)
{
_subscriptionToResourcesMap = new Dictionary>();
if (subscriptionToResourcesMap == null)
{
foreach (var subscriptionDetails in Subscriptions)
{
if (subscriptionDetails.Subscription.SubscriptionId == ValidSubscription.Subscription.SubscriptionId)
{
var resources = new List();
resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString()));
resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString()));
resources.Add(CreateAzureSqlServer(ServerName));
_subscriptionToResourcesMap.Add(ValidSubscription.Subscription.SubscriptionId, resources);
}
else
{
var resources = new List();
resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString()));
resources.Add(CreateAzureSqlServer(Guid.NewGuid().ToString()));
_subscriptionToResourcesMap.Add(subscriptionDetails.Subscription.SubscriptionId, resources);
}
}
}
else
{
_subscriptionToResourcesMap = subscriptionToResourcesMap;
}
}
internal static IAzureSqlServerResource CreateAzureSqlServer(string serverName)
{
var azureSqlServer =
new Mock();
azureSqlServer.Setup(x => x.Name).Returns(GetServerNameWithoutDomain(serverName));
azureSqlServer.Setup(x => x.FullyQualifiedDomainName).Returns(serverName);
return azureSqlServer.Object;
}
internal Dictionary> SubscriptionToResourcesMap
{
get { return _subscriptionToResourcesMap; }
}
internal static string GetServerNameWithoutDomain(string serverName)
{
int index = serverName.IndexOf('.');
if (index > 0)
{
return serverName.Substring(0, index);
}
return serverName;
}
internal string StartIpAddress { get; set; }
internal string EndIpAddress { get; set; }
internal IList Subscriptions { get; set; }
internal Mock ValidSubscriptionMock { get; set; }
internal IAzureUserAccountSubscriptionContext ValidSubscription { get { return ValidSubscriptionMock.Object; } }
internal string ServerName { get; set; }
internal string ServerNameWithoutDomain
{
get { return GetServerNameWithoutDomain(ServerName); }
}
internal Mock ApplicationAuthenticationManagerMock { get; set; }
internal IAzureAuthenticationManager ApplicationAuthenticationManager { get { return ApplicationAuthenticationManagerMock?.Object; } }
internal Mock AzureResourceManagerMock { get; set; }
internal IAzureResourceManager AzureResourceManager { get { return AzureResourceManagerMock?.Object; } }
}
}