// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ResourceProvider; using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; using Microsoft.SqlTools.ResourceProvider.Core.Contracts; using Microsoft.SqlTools.ResourceProvider.Core.Firewall; using Microsoft.SqlTools.ResourceProvider.DefaultImpl; using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; using Moq; using NUnit.Framework; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Formatter { public class ResourceProviderServiceTests { private const int SqlAzureFirewallBlockedErrorNumber = 40615; private const int SqlAzureLoginFailedErrorNumber = 18456; private string errorMessageWithIp = "error Message with 1.2.3.4 as IP address"; public ResourceProviderServiceTests() { HostMock = new Mock(); AuthenticationManagerMock = new Mock(); ResourceManagerMock = new Mock(); ServiceProvider = ExtensionServiceProvider.CreateFromAssembliesInDirectory(ResourceProviderHostLoader.GetResourceProviderExtensionDlls()); ServiceProvider.RegisterSingleService(AuthenticationManagerMock.Object); ServiceProvider.RegisterSingleService(ResourceManagerMock.Object); HostLoader.InitializeHostedServices(ServiceProvider, HostMock.Object); ResourceProviderService = ServiceProvider.GetService(); } protected RegisteredServiceProvider ServiceProvider { get; private set; } protected Mock HostMock { get; private set; } protected Mock AuthenticationManagerMock { get; set; } protected Mock ResourceManagerMock { get; set; } protected ResourceProviderService ResourceProviderService { get; private set; } [Test] public async Task TestHandleFirewallRuleIgnoresNonMssqlProvider() { // Given a non-MSSQL provider var handleFirewallParams = new HandleFirewallRuleParams() { ErrorCode = SqlAzureFirewallBlockedErrorNumber, ErrorMessage = errorMessageWithIp, ConnectionTypeId = "Other" }; // When I ask whether the service can process an error as a firewall rule request await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => { // Then I expect the response to be false and no IP information to be sent Assert.NotNull(response); Assert.False(response.Result); Assert.Null(response.IpAddress); Assert.AreEqual(Microsoft.SqlTools.ResourceProvider.Core.SR.FirewallRuleUnsupportedConnectionType, response.ErrorMessage); }); } [Test] public async Task TestHandleFirewallRuleSupportsMssqlProvider() { // Given a firewall error for the MSSQL provider var handleFirewallParams = new HandleFirewallRuleParams() { ErrorCode = SqlAzureFirewallBlockedErrorNumber, ErrorMessage = errorMessageWithIp, ConnectionTypeId = "MSSQL" }; // When I ask whether the service can process an error as a firewall rule request await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => { // Then I expect the response to be true and the IP address to be extracted Assert.NotNull(response); Assert.True(response.Result); Assert.AreEqual("1.2.3.4", response.IpAddress); Assert.Null(response.ErrorMessage); }); } [Test] public async Task TestHandleFirewallRuleIgnoresNonFirewallErrors() { // Given a login error for the MSSQL provider var handleFirewallParams = new HandleFirewallRuleParams() { ErrorCode = SqlAzureLoginFailedErrorNumber, ErrorMessage = errorMessageWithIp, ConnectionTypeId = "MSSQL" }; // When I ask whether the service can process an error as a firewall rule request await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => { // Then I expect the response to be false and no IP address to be defined Assert.NotNull(response); Assert.False(response.Result); Assert.AreEqual(string.Empty, response.IpAddress); Assert.Null(response.ErrorMessage); }); } [Test] public async Task TestHandleFirewallRuleDoesntBreakWithoutIp() { // Given a firewall error with no IP address in the error message var handleFirewallParams = new HandleFirewallRuleParams() { ErrorCode = SqlAzureFirewallBlockedErrorNumber, ErrorMessage = "No IP here", ConnectionTypeId = "MSSQL" }; // When I ask whether the service can process an error as a firewall rule request await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => { // Then I expect the response to be OK as we require the known IP address to function Assert.NotNull(response); Assert.False(response.Result); Assert.AreEqual(string.Empty, response.IpAddress); Assert.Null(response.ErrorMessage); }); } [Test] public async Task TestCreateFirewallRuleBasicRequest() { // Given a firewall request for a valid subscription string serverName = "myserver.database.windows.net"; var sub1Mock = new Mock(); var sub2Mock = new Mock(); var server = new SqlAzureResource(new Azure.Management.Sql.Models.Server("Somewhere", "1234", "myserver", "SQLServer", null, null, null, null, null, null, null, fullyQualifiedDomainName: serverName)); var subsToServers = new List>>() { Tuple.Create(sub1Mock.Object, Enumerable.Empty()), Tuple.Create(sub2Mock.Object, new IAzureSqlServerResource[] { server }.AsEnumerable()) }; var azureRmResponse = new FirewallRuleResponse() { Created = true, StartIpAddress = null, EndIpAddress = null }; SetupDependencies(subsToServers, azureRmResponse); // When I request the firewall be created var createFirewallParams = new CreateFirewallRuleParams() { ServerName = serverName, StartIpAddress = "1.1.1.1", EndIpAddress = "1.1.1.255", Account = CreateAccount(), SecurityTokenMappings = new Dictionary() }; await TestUtils.RunAndVerify( (context) => ResourceProviderService.HandleCreateFirewallRuleRequest(createFirewallParams, context), (response) => { // Then I expect the response to be OK as we require the known IP address to function Assert.NotNull(response); Assert.Null(response.ErrorMessage); Assert.True(response.Result); Assert.False(response.IsTokenExpiredFailure); }); } [Test] public async Task TestCreateFirewallRuleHandlesTokenExpiration() { // Given the token has expired string serverName = "myserver.database.windows.net"; var sub1Mock = new Mock(); SetupCreateSession(); string expectedErrorMsg = "Token is expired"; AuthenticationManagerMock.Setup(a => a.GetSubscriptionsAsync()).ThrowsAsync(new ExpiredTokenException(expectedErrorMsg)); // When I request the firewall be created var createFirewallParams = new CreateFirewallRuleParams() { ServerName = serverName, StartIpAddress = "1.1.1.1", EndIpAddress = "1.1.1.255", Account = CreateAccount(), SecurityTokenMappings = new Dictionary() }; await TestUtils.RunAndVerify( (context) => ResourceProviderService.HandleCreateFirewallRuleRequest(createFirewallParams, context), (response) => { // Then I expect the response to indicate that we failed due to token expiration Assert.NotNull(response); Assert.AreEqual(expectedErrorMsg, response.ErrorMessage); Assert.True(response.IsTokenExpiredFailure); Assert.False(response.Result); }); } private void SetupDependencies( IList>> subsToServers, FirewallRuleResponse response) { SetupCreateSession(); SetupReturnsSubscriptions(subsToServers.Select(s => s.Item1)); foreach(var s in subsToServers) { SetupAzureServers(s.Item1, s.Item2); } SetupFirewallResponse(response); } private void SetupReturnsSubscriptions(IEnumerable subs) { AuthenticationManagerMock.Setup(a => a.GetSubscriptionsAsync()).Returns(() => Task.FromResult(subs)); } private void SetupCreateSession() { ResourceManagerMock.Setup(r => r.CreateSessionAsync(It.IsAny())) .Returns((IAzureUserAccountSubscriptionContext sub) => { var sessionMock = new Mock(); sessionMock.SetupProperty(s => s.SubscriptionContext, sub); return Task.FromResult(sessionMock.Object); }); } private void SetupAzureServers(IAzureSubscriptionContext sub, IEnumerable servers) { Func isExpectedSub = (session) => { return session.SubscriptionContext == sub; }; ResourceManagerMock.Setup(r => r.GetSqlServerAzureResourcesAsync( It.Is((session) => isExpectedSub(session)) )).Returns(() => Task.FromResult(servers)); } private void SetupFirewallResponse(FirewallRuleResponse response) { ResourceManagerMock.Setup(r => r.CreateFirewallRuleAsync( It.IsAny(), It.IsAny(), It.IsAny()) ).Returns(() => Task.FromResult(response)); } private Account CreateAccount(bool needsReauthentication = false) { return new Account() { Key = new AccountKey() { AccountId = "MyAccount", ProviderId = "MSSQL" }, IsStale = needsReauthentication }; } } }