mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-02-16 10:58:30 -05:00
Add support for firewall rule name in request (#1791)
This commit is contained in:
@@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
|
|||||||
/// Per-tenant token mappings. Ideally would be set independently of this call, but for
|
/// Per-tenant token mappings. Ideally would be set independently of this call, but for
|
||||||
/// now this allows us to get the tokens necessary to find a server and open a firewall rule
|
/// now this allows us to get the tokens necessary to find a server and open a firewall rule
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Dictionary<string,AccountSecurityToken> SecurityTokenMappings { get; set; }
|
public Dictionary<string, AccountSecurityToken> SecurityTokenMappings { get; set; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Fully qualified name of the server to create a new firewall rule on
|
/// Fully qualified name of the server to create a new firewall rule on
|
||||||
@@ -48,6 +48,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
public string EndIpAddress { get; set; }
|
public string EndIpAddress { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Firewall rule name to set
|
||||||
|
/// </summary>
|
||||||
|
public string FirewallRuleName { get; set; }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public class CreateFirewallRuleResponse : TokenReliantResponse
|
public class CreateFirewallRuleResponse : TokenReliantResponse
|
||||||
@@ -86,16 +91,15 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
|
|||||||
public class HandleFirewallRuleResponse
|
public class HandleFirewallRuleResponse
|
||||||
{
|
{
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Can this be handled?
|
/// Whether or not request can be handled.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public bool Result { get; set; }
|
public bool Result { get; set; }
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// If not, why?
|
/// Contains error message, if request could not be handled.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string ErrorMessage { get; set; }
|
public string ErrorMessage { get; set; }
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// If it can be handled, is there a default IP address to send back so users
|
/// If handled, the default IP address to send back; so users can tell what their blocked IP is.
|
||||||
/// can tell what their blocked IP is?
|
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string IpAddress { get; set; }
|
public string IpAddress { get; set; }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,6 @@
|
|||||||
// Copyright (c) Microsoft. All rights reserved.
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
//
|
//
|
||||||
|
|
||||||
using System;
|
|
||||||
using System.Globalization;
|
|
||||||
using System.Net;
|
using System.Net;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
||||||
@@ -27,15 +24,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Firewall rule name
|
/// Firewall rule name
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public string FirewallRuleName
|
public string FirewallRuleName { get; set; }
|
||||||
{
|
|
||||||
get
|
|
||||||
{
|
|
||||||
DateTime now = DateTime.UtcNow;
|
|
||||||
|
|
||||||
return string.Format(CultureInfo.InvariantCulture, "ClientIPAddress_{0}",
|
|
||||||
now.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ using System.Threading;
|
|||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.SqlTools.Extensibility;
|
using Microsoft.SqlTools.Extensibility;
|
||||||
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
|
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
|
||||||
|
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
||||||
{
|
{
|
||||||
@@ -21,13 +22,13 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
||||||
/// </summary>
|
/// </summary>
|
||||||
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue);
|
Task<FirewallRuleResponse> CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams);
|
||||||
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
||||||
/// </summary>
|
/// </summary>
|
||||||
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress);
|
Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest);
|
||||||
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -55,31 +56,40 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, string startIpAddressValue, string endIpAddressValue)
|
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(CreateFirewallRuleParams firewallRuleParams)
|
||||||
{
|
{
|
||||||
IPAddress startIpAddress = ConvertToIpAddress(startIpAddressValue);
|
IPAddress startIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress);
|
||||||
IPAddress endIpAddress = ConvertToIpAddress(endIpAddressValue);
|
IPAddress endIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress);
|
||||||
return await CreateFirewallRuleAsync(serverName, startIpAddress, endIpAddress);
|
FirewallRuleRequest firewallRuleRequest = new FirewallRuleRequest()
|
||||||
|
{
|
||||||
|
StartIpAddress = ConvertToIpAddress(firewallRuleParams.StartIpAddress),
|
||||||
|
EndIpAddress = ConvertToIpAddress(firewallRuleParams.EndIpAddress),
|
||||||
|
FirewallRuleName = string.Format(CultureInfo.InvariantCulture, firewallRuleParams.FirewallRuleName ?? "ClientIPAddress_{0}",
|
||||||
|
DateTime.UtcNow.ToString("yyyy-MM-dd_hh:mm:ss", CultureInfo.CurrentCulture))
|
||||||
|
};
|
||||||
|
return await CreateFirewallRuleAsync(firewallRuleParams.ServerName, firewallRuleRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
/// Creates firewall rule for given server name and IP address range. Throws exception if operation fails
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, IPAddress startIpAddress, IPAddress endIpAddress)
|
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(string serverName, FirewallRuleRequest firewallRuleRequest)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
FirewallRuleResponse firewallRuleResponse = new FirewallRuleResponse() { Created = false };
|
FirewallRuleResponse firewallRuleResponse = new FirewallRuleResponse() { Created = false };
|
||||||
CommonUtil.CheckStringForNullOrEmpty(serverName, "serverName");
|
CommonUtil.CheckStringForNullOrEmpty(serverName, nameof(serverName));
|
||||||
CommonUtil.CheckForNull(startIpAddress, "startIpAddress");
|
CommonUtil.CheckForNull(firewallRuleRequest, nameof(firewallRuleRequest));
|
||||||
CommonUtil.CheckForNull(endIpAddress, "endIpAddress");
|
CommonUtil.CheckForNull(firewallRuleRequest.FirewallRuleName, nameof(firewallRuleRequest.FirewallRuleName));
|
||||||
|
CommonUtil.CheckForNull(firewallRuleRequest.StartIpAddress, nameof(firewallRuleRequest.StartIpAddress));
|
||||||
|
CommonUtil.CheckForNull(firewallRuleRequest.EndIpAddress, nameof(firewallRuleRequest.EndIpAddress));
|
||||||
|
|
||||||
IAzureAuthenticationManager authenticationManager = AuthenticationManager;
|
IAzureAuthenticationManager authenticationManager = AuthenticationManager;
|
||||||
|
|
||||||
if (authenticationManager != null && !await authenticationManager.GetUserNeedsReauthenticationAsync())
|
if (authenticationManager != null && !await authenticationManager.GetUserNeedsReauthenticationAsync())
|
||||||
{
|
{
|
||||||
FirewallRuleResource firewallRuleResource = await FindAzureResourceAsync(serverName);
|
FirewallRuleResource firewallRuleResource = await FindAzureResourceAsync(serverName);
|
||||||
firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, startIpAddress, endIpAddress);
|
firewallRuleResponse = await CreateFirewallRule(firewallRuleResource, firewallRuleRequest);
|
||||||
}
|
}
|
||||||
if (firewallRuleResponse == null || !firewallRuleResponse.Created)
|
if (firewallRuleResponse == null || !firewallRuleResponse.Created)
|
||||||
{
|
{
|
||||||
@@ -110,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates firewall rule for given subscription and IP address range
|
/// Creates firewall rule for given subscription and IP address range
|
||||||
/// </summary>
|
/// </summary>
|
||||||
private async Task<FirewallRuleResponse> CreateFirewallRule(FirewallRuleResource firewallRuleResource, IPAddress startIpAddress, IPAddress endIpAddress)
|
private async Task<FirewallRuleResponse> CreateFirewallRule(FirewallRuleResource firewallRuleResource, FirewallRuleRequest firewallRuleRequest)
|
||||||
{
|
{
|
||||||
CommonUtil.CheckForNull(firewallRuleResource, "firewallRuleResource");
|
CommonUtil.CheckForNull(firewallRuleResource, "firewallRuleResource");
|
||||||
|
|
||||||
@@ -118,18 +128,12 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
|||||||
{
|
{
|
||||||
if (firewallRuleResource.IsValid)
|
if (firewallRuleResource.IsValid)
|
||||||
{
|
{
|
||||||
|
|
||||||
FirewallRuleRequest request = new FirewallRuleRequest()
|
|
||||||
{
|
|
||||||
StartIpAddress = startIpAddress,
|
|
||||||
EndIpAddress = endIpAddress
|
|
||||||
};
|
|
||||||
using (IAzureResourceManagementSession session = await ResourceManager.CreateSessionAsync(firewallRuleResource.SubscriptionContext))
|
using (IAzureResourceManagementSession session = await ResourceManager.CreateSessionAsync(firewallRuleResource.SubscriptionContext))
|
||||||
{
|
{
|
||||||
return await ResourceManager.CreateFirewallRuleAsync(
|
return await ResourceManager.CreateFirewallRuleAsync(
|
||||||
session,
|
session,
|
||||||
firewallRuleResource.AzureResource,
|
firewallRuleResource.AzureResource,
|
||||||
request);
|
firewallRuleRequest);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -162,8 +166,8 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall
|
|||||||
throw new FirewallRuleException(SR.NoSubscriptionsFound);
|
throw new FirewallRuleException(SR.NoSubscriptionsFound);
|
||||||
}
|
}
|
||||||
|
|
||||||
ServiceResponse<FirewallRuleResource> response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null,
|
ServiceResponse<FirewallRuleResource> response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null,
|
||||||
subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync);
|
subscriptions, serverName, new CancellationToken(), TryFindAzureResourceForSubscriptionAsync);
|
||||||
|
|
||||||
if (response != null)
|
if (response != null)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
|
|||||||
await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest");
|
await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest");
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task<CreateFirewallRuleResponse> DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRule)
|
private async Task<CreateFirewallRuleResponse> DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRuleParams)
|
||||||
{
|
{
|
||||||
var result = new CreateFirewallRuleResponse();
|
var result = new CreateFirewallRuleResponse();
|
||||||
// Note: currently not catching the exception. Expect the caller to this message to handle error cases by
|
// Note: currently not catching the exception. Expect the caller to this message to handle error cases by
|
||||||
@@ -77,11 +77,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
AuthenticationService authService = ServiceProvider.GetService<AuthenticationService>();
|
AuthenticationService authService = ServiceProvider.GetService<AuthenticationService>();
|
||||||
IUserAccount account = await authService.SetCurrentAccountAsync(firewallRule.Account, firewallRule.SecurityTokenMappings);
|
IUserAccount account = await authService.SetCurrentAccountAsync(firewallRuleParams.Account, firewallRuleParams.SecurityTokenMappings);
|
||||||
FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRule.ServerName, firewallRule.StartIpAddress, firewallRule.EndIpAddress);
|
FirewallRuleResponse response = await firewallRuleService.CreateFirewallRuleAsync(firewallRuleParams);
|
||||||
result.Result = true;
|
result.Result = true;
|
||||||
}
|
}
|
||||||
catch(FirewallRuleException ex)
|
catch (FirewallRuleException ex)
|
||||||
{
|
{
|
||||||
result.Result = false;
|
result.Result = false;
|
||||||
result.ErrorMessage = ex.Message;
|
result.ErrorMessage = ex.Message;
|
||||||
@@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core
|
|||||||
T result = await handler();
|
T result = await handler();
|
||||||
await requestContext.SendResult(result);
|
await requestContext.SendResult(result);
|
||||||
}
|
}
|
||||||
catch(ExpiredTokenException ex)
|
catch (ExpiredTokenException ex)
|
||||||
{
|
{
|
||||||
if (expiredTokenHandler != null)
|
if (expiredTokenHandler != null)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ using System.Linq;
|
|||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.SqlTools.ResourceProvider.Core;
|
using Microsoft.SqlTools.ResourceProvider.Core;
|
||||||
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
|
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
|
||||||
|
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
|
||||||
using Microsoft.SqlTools.ResourceProvider.Core.Firewall;
|
using Microsoft.SqlTools.ResourceProvider.Core.Firewall;
|
||||||
using Moq;
|
using Moq;
|
||||||
using NUnit.Framework;
|
using NUnit.Framework;
|
||||||
@@ -309,9 +310,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
FirewallRuleService service = new FirewallRuleService();
|
FirewallRuleService service = new FirewallRuleService();
|
||||||
|
CreateFirewallRuleParams createFirewallRuleParams = new CreateFirewallRuleParams()
|
||||||
|
{
|
||||||
|
ServerName = serverName,
|
||||||
|
StartIpAddress = testContext.StartIpAddress,
|
||||||
|
EndIpAddress = testContext.EndIpAddress
|
||||||
|
};
|
||||||
service.AuthenticationManager = testContext.ApplicationAuthenticationManager;
|
service.AuthenticationManager = testContext.ApplicationAuthenticationManager;
|
||||||
service.ResourceManager = testContext.AzureResourceManager;
|
service.ResourceManager = testContext.AzureResourceManager;
|
||||||
FirewallRuleResponse response = await service.CreateFirewallRuleAsync(serverName, testContext.StartIpAddress, testContext.EndIpAddress);
|
FirewallRuleResponse response = await service.CreateFirewallRuleAsync(createFirewallRuleParams);
|
||||||
if (verifyFirewallRuleCreated)
|
if (verifyFirewallRuleCreated)
|
||||||
{
|
{
|
||||||
testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
|
testContext.AzureResourceManagerMock.Verify(x => x.CreateFirewallRuleAsync(
|
||||||
|
|||||||
Reference in New Issue
Block a user