diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs index 9e2557cf..8e5213f4 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs @@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts /// Per-tenant token mappings. Ideally would be set independently of this call, but for /// now this allows us to get the tokens necessary to find a server and open a firewall rule /// - public Dictionary SecurityTokenMappings { get; set; } + public Dictionary SecurityTokenMappings { get; set; } /// /// Fully qualified name of the server to create a new firewall rule on @@ -48,6 +48,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts /// public string EndIpAddress { get; set; } + /// + /// Firewall rule name to set + /// + public string FirewallRuleName { get; set; } + } public class CreateFirewallRuleResponse : TokenReliantResponse @@ -86,16 +91,15 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts public class HandleFirewallRuleResponse { /// - /// Can this be handled? + /// Whether or not request can be handled. /// public bool Result { get; set; } /// - /// If not, why? + /// Contains error message, if request could not be handled. /// public string ErrorMessage { get; set; } /// - /// If it can be handled, is there a default IP address to send back so users - /// can tell what their blocked IP is? + /// If handled, the default IP address to send back; so users can tell what their blocked IP is. /// public string IpAddress { get; set; } } diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleRequest.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleRequest.cs index 0c8ab9d9..2f37a20e 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleRequest.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleRequest.cs @@ -2,9 +2,6 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // - -using System; -using System.Globalization; using System.Net; namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall @@ -27,15 +24,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall /// /// Firewall rule name /// - public string FirewallRuleName - { - get - { - DateTime now = DateTime.UtcNow; - - return string.Format(CultureInfo.InvariantCulture, "ClientIPAddress_{0}", - now.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture)); - } - } + public string FirewallRuleName { get; set; } } } diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs index 3e1229ee..17bf6ce5 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs @@ -12,6 +12,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; +using Microsoft.SqlTools.ResourceProvider.Core.Contracts; namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall { @@ -21,13 +22,13 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall /// /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// - Task CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue); + Task CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams); /// /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// - Task CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress); + Task CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest); /// @@ -55,31 +56,40 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall /// /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// - public async Task CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue) + public async Task CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams) { - IPAddress startIpAddress = ConvertToIpAddress(startIpAddressValue); - IPAddress endIpAddress = ConvertToIpAddress(endIpAddressValue); - return await CreateFirewallRuleAsync(serverName, startIpAddress, endIpAddress); + IPAddress startIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress); + IPAddress endIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress); + FirewallRuleRequest firewallRuleRequest = new FirewallRuleRequest() + { + StartIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress), + EndIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress), + FirewallRuleName = string.Format(CultureInfo.InvariantCulture, firewallRuleParams.FirewallRuleName ?? "ClientIPAddress_{0}", + DateTime.UtcNow.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture)) + }; + return await CreateFirewallRuleAsync(firewallRuleParams.ServerName, firewallRuleRequest); } /// /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// - public async Task CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress) + public async Task CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest) { try { FirewallRuleResponse firewallRuleResponse = new FirewallRuleResponse() { Created = false }; - CommonUtil.CheckStringForNullOrEmpty(serverName, "serverName"); - CommonUtil.CheckForNull(startIpAddress, "startIpAddress"); - CommonUtil.CheckForNull(endIpAddress, "endIpAddress"); + CommonUtil.CheckStringForNullOrEmpty(serverName, nameof(serverName)); + CommonUtil.CheckForNull(firewallRuleRequest, nameof(firewallRuleRequest)); + CommonUtil.CheckForNull(firewallRuleRequest.FirewallRuleName, nameof(firewallRuleRequest.FirewallRuleName)); + CommonUtil.CheckForNull(firewallRuleRequest.StartIpAddress, nameof(firewallRuleRequest.StartIpAddress)); + CommonUtil.CheckForNull(firewallRuleRequest.EndIpAddress, nameof(firewallRuleRequest.EndIpAddress)); IAzureAuthenticationManager authenticationManager = AuthenticationManager; if (authenticationManager != null && !await authenticationManager.GetUserNeedsReauthenticationAsync()) { FirewallRuleResource firewallRuleResource = await FindAzureResourceAsync(serverName); - firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, startIpAddress, endIpAddress); + firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, firewallRuleRequest); } if (firewallRuleResponse == null || !firewallRuleResponse.Created) { @@ -110,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall /// /// Creates firewall rule for given subscription and IP address range /// - private async Task CreateFirewallRule(FirewallRuleResource firewallRuleResource, IPAddress startIpAddress, IPAddress endIpAddress) + private async Task CreateFirewallRule(FirewallRuleResource firewallRuleResource, FirewallRuleRequest firewallRuleRequest) { CommonUtil.CheckForNull(firewallRuleResource, "firewallRuleResource"); @@ -118,18 +128,12 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall { if (firewallRuleResource.IsValid) { - - FirewallRuleRequest request = new FirewallRuleRequest() - { - StartIpAddress = startIpAddress, - EndIpAddress = endIpAddress - }; using (IAzureResourceManagementSession session = await ResourceManager.CreateSessionAsync(firewallRuleResource.SubscriptionContext)) { return await ResourceManager.CreateFirewallRuleAsync( session, firewallRuleResource.AzureResource, - request); + firewallRuleRequest); } } } @@ -162,8 +166,8 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall throw new FirewallRuleException(SR.NoSubscriptionsFound); } - ServiceResponse response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null, - subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync); + ServiceResponse response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null, + subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync); if (response != null) { diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs index 3d990b3b..c54d9d46 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs @@ -69,7 +69,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest"); } - private async Task DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRule) + private async Task DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRuleParams) { var result = new CreateFirewallRuleResponse(); // Note: currently not catching the exception. Expect the caller to this message to handle error cases by @@ -77,11 +77,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core try { AuthenticationService authService = ServiceProvider.GetService(); - IUserAccount account = await authService.SetCurrentAccountAsync(firewallRule.Account, firewallRule.SecurityTokenMappings); - FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRule.ServerName, firewallRule.StartIpAddress, firewallRule.EndIpAddress); + IUserAccount account = await authService.SetCurrentAccountAsync(firewallRuleParams.Account, firewallRuleParams.SecurityTokenMappings); + FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRuleParams); result.Result = true; } - catch(FirewallRuleException ex) + catch (FirewallRuleException ex) { result.Result = false; result.ErrorMessage = ex.Message; @@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core T result = await handler(); await requestContext.SendResult(result); } - catch(ExpiredTokenException ex) + catch (ExpiredTokenException ex) { if (expiredTokenHandler != null) { diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/FirewallRuleServiceTest.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/FirewallRuleServiceTest.cs index 284eaa05..a7f55bd4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/FirewallRuleServiceTest.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/FirewallRuleServiceTest.cs @@ -9,6 +9,7 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; +using Microsoft.SqlTools.ResourceProvider.Core.Contracts; using Microsoft.SqlTools.ResourceProvider.Core.Firewall; using Moq; using NUnit.Framework; @@ -309,9 +310,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider try { FirewallRuleService service = new FirewallRuleService(); + CreateFirewallRuleParams createFirewallRuleParams = new CreateFirewallRuleParams() + { + ServerName = serverName, + StartIpAddress = testContext.StartIpAddress, + EndIpAddress = testContext.EndIpAddress + }; service.AuthenticationManager = testContext.ApplicationAuthenticationManager; service.ResourceManager = testContext.AzureResourceManager; - FirewallRuleResponse response = await service.CreateFirewallRuleAsync(serverName, testContext.StartIpAddress, testContext.EndIpAddress); + FirewallRuleResponse response = await service.CreateFirewallRuleAsync(createFirewallRuleParams); if (verifyFirewallRuleCreated) { testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(