Add support for firewall rule name in request (#1791)

This commit is contained in:
Cheena Malhotra
2022-12-16 09:04:14 -08:00
committed by GitHub
parent e20f64fa9a
commit eca0cc484c
5 changed files with 48 additions and 45 deletions

View File

@@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
/// Per-tenant token mappings. Ideally would be set independently of this call, but for /// Per-tenant token mappings. Ideally would be set independently of this call, but for
/// now this allows us to get the tokens necessary to find a server and open a firewall rule /// now this allows us to get the tokens necessary to find a server and open a firewall rule
/// </summary> /// </summary>
public Dictionary<string,AccountSecurityToken> SecurityTokenMappings { get; set; } public Dictionary<string, AccountSecurityToken> SecurityTokenMappings { get; set; }
/// <summary> /// <summary>
/// Fully qualified name of the server to create a new firewall rule on /// Fully qualified name of the server to create a new firewall rule on
@@ -48,6 +48,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
/// </summary> /// </summary>
public string EndIpAddress { get; set; } public string EndIpAddress { get; set; }
/// <summary>
/// Firewall rule name to set
/// </summary>
public string FirewallRuleName { get; set; }
} }
public class CreateFirewallRuleResponse : TokenReliantResponse public class CreateFirewallRuleResponse : TokenReliantResponse
@@ -86,16 +91,15 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
public class HandleFirewallRuleResponse public class HandleFirewallRuleResponse
{ {
/// <summary> /// <summary>
/// Can this be handled? /// Whether or not request can be handled.
/// </summary> /// </summary>
public bool Result { get; set; } public bool Result { get; set; }
/// <summary> /// <summary>
/// If not, why? /// Contains error message, if request could not be handled.
/// </summary> /// </summary>
public string ErrorMessage { get; set; } public string ErrorMessage { get; set; }
/// <summary> /// <summary>
/// If it can be handled, is there a default IP address to send back so users /// If handled, the default IP address to send back; so users can tell what their blocked IP is.
/// can tell what their blocked IP is?
/// </summary> /// </summary>
public string IpAddress { get; set; } public string IpAddress { get; set; }
} }

View File

@@ -2,9 +2,6 @@
// Copyright (c) Microsoft. All rights reserved. // Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
using System;
using System.Globalization;
using System.Net; using System.Net;
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
@@ -27,15 +24,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary> /// <summary>
/// Firewall rule name /// Firewall rule name
/// </summary> /// </summary>
public string FirewallRuleName public string FirewallRuleName { get; set; }
{
get
{
DateTime now = DateTime.UtcNow;
return string.Format(CultureInfo.InvariantCulture, "ClientIPAddress_{0}",
now.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture));
}
}
} }
} }

View File

@@ -12,6 +12,7 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.ResourceProvider.Core.Authentication; using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
{ {
@@ -21,13 +22,13 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary> /// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary> /// </summary>
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue); Task<FirewallRuleResponse> CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams);
/// <summary> /// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary> /// </summary>
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress); Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest);
/// <summary> /// <summary>
@@ -55,31 +56,40 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary> /// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary> /// </summary>
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue) public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams)
{ {
IPAddress startIpAddress = ConvertToIpAddress(startIpAddressValue); IPAddress startIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress);
IPAddress endIpAddress = ConvertToIpAddress(endIpAddressValue); IPAddress endIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress);
return await CreateFirewallRuleAsync(serverName, startIpAddress, endIpAddress); FirewallRuleRequest firewallRuleRequest = new FirewallRuleRequest()
{
StartIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress),
EndIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress),
FirewallRuleName = string.Format(CultureInfo.InvariantCulture, firewallRuleParams.FirewallRuleName ?? "ClientIPAddress_{0}",
DateTime.UtcNow.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture))
};
return await CreateFirewallRuleAsync(firewallRuleParams.ServerName, firewallRuleRequest);
} }
/// <summary> /// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails /// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary> /// </summary>
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress) public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest)
{ {
try try
{ {
FirewallRuleResponse firewallRuleResponse = new FirewallRuleResponse() { Created = false }; FirewallRuleResponse firewallRuleResponse = new FirewallRuleResponse() { Created = false };
CommonUtil.CheckStringForNullOrEmpty(serverName, "serverName"); CommonUtil.CheckStringForNullOrEmpty(serverName, nameof(serverName));
CommonUtil.CheckForNull(startIpAddress, "startIpAddress"); CommonUtil.CheckForNull(firewallRuleRequest, nameof(firewallRuleRequest));
CommonUtil.CheckForNull(endIpAddress, "endIpAddress"); CommonUtil.CheckForNull(firewallRuleRequest.FirewallRuleName, nameof(firewallRuleRequest.FirewallRuleName));
CommonUtil.CheckForNull(firewallRuleRequest.StartIpAddress, nameof(firewallRuleRequest.StartIpAddress));
CommonUtil.CheckForNull(firewallRuleRequest.EndIpAddress, nameof(firewallRuleRequest.EndIpAddress));
IAzureAuthenticationManager authenticationManager = AuthenticationManager; IAzureAuthenticationManager authenticationManager = AuthenticationManager;
if (authenticationManager != null && !await authenticationManager.GetUserNeedsReauthenticationAsync()) if (authenticationManager != null && !await authenticationManager.GetUserNeedsReauthenticationAsync())
{ {
FirewallRuleResource firewallRuleResource = await FindAzureResourceAsync(serverName); FirewallRuleResource firewallRuleResource = await FindAzureResourceAsync(serverName);
firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, startIpAddress, endIpAddress); firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, firewallRuleRequest);
} }
if (firewallRuleResponse == null || !firewallRuleResponse.Created) if (firewallRuleResponse == null || !firewallRuleResponse.Created)
{ {
@@ -110,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary> /// <summary>
/// Creates firewall rule for given subscription and IP address range /// Creates firewall rule for given subscription and IP address range
/// </summary> /// </summary>
private async Task<FirewallRuleResponse> CreateFirewallRule(FirewallRuleResource firewallRuleResource, IPAddress startIpAddress, IPAddress endIpAddress) private async Task<FirewallRuleResponse> CreateFirewallRule(FirewallRuleResource firewallRuleResource, FirewallRuleRequest firewallRuleRequest)
{ {
CommonUtil.CheckForNull(firewallRuleResource, "firewallRuleResource"); CommonUtil.CheckForNull(firewallRuleResource, "firewallRuleResource");
@@ -118,18 +128,12 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
{ {
if (firewallRuleResource.IsValid) if (firewallRuleResource.IsValid)
{ {
FirewallRuleRequest request = new FirewallRuleRequest()
{
StartIpAddress = startIpAddress,
EndIpAddress = endIpAddress
};
using (IAzureResourceManagementSession session = await ResourceManager.CreateSessionAsync(firewallRuleResource.SubscriptionContext)) using (IAzureResourceManagementSession session = await ResourceManager.CreateSessionAsync(firewallRuleResource.SubscriptionContext))
{ {
return await ResourceManager.CreateFirewallRuleAsync( return await ResourceManager.CreateFirewallRuleAsync(
session, session,
firewallRuleResource.AzureResource, firewallRuleResource.AzureResource,
request); firewallRuleRequest);
} }
} }
} }
@@ -162,8 +166,8 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
throw new FirewallRuleException(SR.NoSubscriptionsFound); throw new FirewallRuleException(SR.NoSubscriptionsFound);
} }
ServiceResponse<FirewallRuleResource> response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null, ServiceResponse<FirewallRuleResource> response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null,
subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync); subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync);
if (response != null) if (response != null)
{ {

View File

@@ -69,7 +69,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest"); await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest");
} }
private async Task<CreateFirewallRuleResponse> DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRule) private async Task<CreateFirewallRuleResponse> DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRuleParams)
{ {
var result = new CreateFirewallRuleResponse(); var result = new CreateFirewallRuleResponse();
// Note: currently not catching the exception. Expect the caller to this message to handle error cases by // Note: currently not catching the exception. Expect the caller to this message to handle error cases by
@@ -77,11 +77,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
try try
{ {
AuthenticationService authService = ServiceProvider.GetService<AuthenticationService>(); AuthenticationService authService = ServiceProvider.GetService<AuthenticationService>();
IUserAccount account = await authService.SetCurrentAccountAsync(firewallRule.Account, firewallRule.SecurityTokenMappings); IUserAccount account = await authService.SetCurrentAccountAsync(firewallRuleParams.Account, firewallRuleParams.SecurityTokenMappings);
FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRule.ServerName, firewallRule.StartIpAddress, firewallRule.EndIpAddress); FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRuleParams);
result.Result = true; result.Result = true;
} }
catch(FirewallRuleException ex) catch (FirewallRuleException ex)
{ {
result.Result = false; result.Result = false;
result.ErrorMessage = ex.Message; result.ErrorMessage = ex.Message;
@@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
T result = await handler(); T result = await handler();
await requestContext.SendResult(result); await requestContext.SendResult(result);
} }
catch(ExpiredTokenException ex) catch (ExpiredTokenException ex)
{ {
if (expiredTokenHandler != null) if (expiredTokenHandler != null)
{ {

View File

@@ -9,6 +9,7 @@ using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core;
using Microsoft.SqlTools.ResourceProvider.Core.Authentication; using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
using Microsoft.SqlTools.ResourceProvider.Core.Firewall; using Microsoft.SqlTools.ResourceProvider.Core.Firewall;
using Moq; using Moq;
using NUnit.Framework; using NUnit.Framework;
@@ -309,9 +310,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider
try try
{ {
FirewallRuleService service = new FirewallRuleService(); FirewallRuleService service = new FirewallRuleService();
CreateFirewallRuleParams createFirewallRuleParams = new CreateFirewallRuleParams()
{
ServerName = serverName,
StartIpAddress = testContext.StartIpAddress,
EndIpAddress = testContext.EndIpAddress
};
service.AuthenticationManager = testContext.ApplicationAuthenticationManager; service.AuthenticationManager = testContext.ApplicationAuthenticationManager;
service.ResourceManager = testContext.AzureResourceManager; service.ResourceManager = testContext.AzureResourceManager;
FirewallRuleResponse response = await service.CreateFirewallRuleAsync(serverName, testContext.StartIpAddress, testContext.EndIpAddress); FirewallRuleResponse response = await service.CreateFirewallRuleAsync(createFirewallRuleParams);
if (verifyFirewallRuleCreated) if (verifyFirewallRuleCreated)
{ {
testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync( testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(