Add support for firewall rule name in request (#1791)

This commit is contained in:
Cheena Malhotra
2022-12-16 09:04:14 -08:00
committed by GitHub
parent e20f64fa9a
commit eca0cc484c
5 changed files with 48 additions and 45 deletions

View File

@@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
/// Per-tenant token mappings. Ideally would be set independently of this call, but for
/// now this allows us to get the tokens necessary to find a server and open a firewall rule
/// </summary>
public Dictionary<string,AccountSecurityToken> SecurityTokenMappings { get; set; }
public Dictionary<string, AccountSecurityToken> SecurityTokenMappings { get; set; }
/// <summary>
/// Fully qualified name of the server to create a new firewall rule on
@@ -48,6 +48,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
/// </summary>
public string EndIpAddress { get; set; }
/// <summary>
/// Firewall rule name to set
/// </summary>
public string FirewallRuleName { get; set; }
}
public class CreateFirewallRuleResponse : TokenReliantResponse
@@ -86,16 +91,15 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
public class HandleFirewallRuleResponse
{
/// <summary>
/// Can this be handled?
/// Whether or not request can be handled.
/// </summary>
public bool Result { get; set; }
/// <summary>
/// If not, why?
/// Contains error message, if request could not be handled.
/// </summary>
public string ErrorMessage { get; set; }
/// <summary>
/// If it can be handled, is there a default IP address to send back so users
/// can tell what their blocked IP is?
/// If handled, the default IP address to send back; so users can tell what their blocked IP is.
/// </summary>
public string IpAddress { get; set; }
}

View File

@@ -2,9 +2,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Globalization;
using System.Net;
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
@@ -27,15 +24,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary>
/// Firewall rule name
/// </summary>
public string FirewallRuleName
{
get
{
DateTime now = DateTime.UtcNow;
return string.Format(CultureInfo.InvariantCulture, "ClientIPAddress_{0}",
now.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture));
}
}
public string FirewallRuleName { get; set; }
}
}

View File

@@ -12,6 +12,7 @@ using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
{
@@ -21,13 +22,13 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary>
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue);
Task<FirewallRuleResponse> CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams);
/// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary>
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress);
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest);
/// <summary>
@@ -55,31 +56,40 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary>
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue)
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams)
{
IPAddress startIpAddress = ConvertToIpAddress(startIpAddressValue);
IPAddress endIpAddress = ConvertToIpAddress(endIpAddressValue);
return await CreateFirewallRuleAsync(serverName, startIpAddress, endIpAddress);
IPAddress startIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress);
IPAddress endIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress);
FirewallRuleRequest firewallRuleRequest = new FirewallRuleRequest()
{
StartIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress),
EndIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress),
FirewallRuleName = string.Format(CultureInfo.InvariantCulture, firewallRuleParams.FirewallRuleName ?? "ClientIPAddress_{0}",
DateTime.UtcNow.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture))
};
return await CreateFirewallRuleAsync(firewallRuleParams.ServerName, firewallRuleRequest);
}
/// <summary>
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
/// </summary>
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress)
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest)
{
try
{
FirewallRuleResponse firewallRuleResponse = new FirewallRuleResponse() { Created = false };
CommonUtil.CheckStringForNullOrEmpty(serverName, "serverName");
CommonUtil.CheckForNull(startIpAddress, "startIpAddress");
CommonUtil.CheckForNull(endIpAddress, "endIpAddress");
CommonUtil.CheckStringForNullOrEmpty(serverName, nameof(serverName));
CommonUtil.CheckForNull(firewallRuleRequest, nameof(firewallRuleRequest));
CommonUtil.CheckForNull(firewallRuleRequest.FirewallRuleName, nameof(firewallRuleRequest.FirewallRuleName));
CommonUtil.CheckForNull(firewallRuleRequest.StartIpAddress, nameof(firewallRuleRequest.StartIpAddress));
CommonUtil.CheckForNull(firewallRuleRequest.EndIpAddress, nameof(firewallRuleRequest.EndIpAddress));
IAzureAuthenticationManager authenticationManager = AuthenticationManager;
if (authenticationManager != null && !await authenticationManager.GetUserNeedsReauthenticationAsync())
{
FirewallRuleResource firewallRuleResource = await FindAzureResourceAsync(serverName);
firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, startIpAddress, endIpAddress);
firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, firewallRuleRequest);
}
if (firewallRuleResponse == null || !firewallRuleResponse.Created)
{
@@ -110,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
/// <summary>
/// Creates firewall rule for given subscription and IP address range
/// </summary>
private async Task<FirewallRuleResponse> CreateFirewallRule(FirewallRuleResource firewallRuleResource, IPAddress startIpAddress, IPAddress endIpAddress)
private async Task<FirewallRuleResponse> CreateFirewallRule(FirewallRuleResource firewallRuleResource, FirewallRuleRequest firewallRuleRequest)
{
CommonUtil.CheckForNull(firewallRuleResource, "firewallRuleResource");
@@ -118,18 +128,12 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
{
if (firewallRuleResource.IsValid)
{
FirewallRuleRequest request = new FirewallRuleRequest()
{
StartIpAddress = startIpAddress,
EndIpAddress = endIpAddress
};
using (IAzureResourceManagementSession session = await ResourceManager.CreateSessionAsync(firewallRuleResource.SubscriptionContext))
{
return await ResourceManager.CreateFirewallRuleAsync(
session,
firewallRuleResource.AzureResource,
request);
firewallRuleRequest);
}
}
}
@@ -162,8 +166,8 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
throw new FirewallRuleException(SR.NoSubscriptionsFound);
}
ServiceResponse<FirewallRuleResource> response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null,
subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync);
ServiceResponse<FirewallRuleResource> response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null,
subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync);
if (response != null)
{

View File

@@ -69,7 +69,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest");
}
private async Task<CreateFirewallRuleResponse> DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRule)
private async Task<CreateFirewallRuleResponse> DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRuleParams)
{
var result = new CreateFirewallRuleResponse();
// Note: currently not catching the exception. Expect the caller to this message to handle error cases by
@@ -77,11 +77,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
try
{
AuthenticationService authService = ServiceProvider.GetService<AuthenticationService>();
IUserAccount account = await authService.SetCurrentAccountAsync(firewallRule.Account, firewallRule.SecurityTokenMappings);
FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRule.ServerName, firewallRule.StartIpAddress, firewallRule.EndIpAddress);
IUserAccount account = await authService.SetCurrentAccountAsync(firewallRuleParams.Account, firewallRuleParams.SecurityTokenMappings);
FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRuleParams);
result.Result = true;
}
catch(FirewallRuleException ex)
catch (FirewallRuleException ex)
{
result.Result = false;
result.ErrorMessage = ex.Message;
@@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
T result = await handler();
await requestContext.SendResult(result);
}
catch(ExpiredTokenException ex)
catch (ExpiredTokenException ex)
{
if (expiredTokenHandler != null)
{