Allow the arm baseURI to be set dynamically (#920)

* Allow the arm baseURI to be set dynamically

* Defensive programming
This commit is contained in:
Amir Omidi
2020-03-23 14:41:45 -07:00
committed by GitHub
parent 6920c34570
commit 50a666f794
5 changed files with 111 additions and 21 deletions

View File

@@ -4,6 +4,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication
{ {
@@ -12,6 +13,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication
/// </summary> /// </summary>
public interface IAzureUserAccount : IEquatable<IAzureUserAccount>, IUserAccount public interface IAzureUserAccount : IEquatable<IAzureUserAccount>, IUserAccount
{ {
Account UnderlyingAccount
{
get;
set;
}
/// <summary> /// <summary>
/// User Account Display Info /// User Account Display Info
/// </summary> /// </summary>

View File

@@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
/// Indicates if the account needs refreshing /// Indicates if the account needs refreshing
/// </summary> /// </summary>
public bool IsStale { get; set; } public bool IsStale { get; set; }
} }
/// <summary> /// <summary>
@@ -58,9 +58,57 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
get; get;
set; set;
} }
/// <summary>
/// Information about the auth provider
/// </summary>
public ProviderSettings ProviderSettings;
} }
public class ProviderSettings
{
/// <summary>
/// Display name of the provider
/// </summary>
public string DisplayName;
/// <summary>
/// ID of the provider
/// </summary>
public string Id;
/// <summary>
/// Settings for the provider itself
/// </summary>
public ProviderSettingsObject Settings;
}
public class ProviderSettingsObject
{
public ResourceSetting ArmResource;
public ResourceSetting GraphResource;
public ResourceSetting OssRdbmsResource;
public ResourceSetting SqlResource;
/// <summary>
/// Actual sign in link
/// </summary>
public string Host;
/// <summary>
/// ClientID used
/// </summary>
public string ClientId;
}
public class ResourceSetting
{
/// <summary>
/// Endpoint of the resource
/// </summary>
public string Endpoint;
public string Id;
}
/// <summary> /// <summary>
/// Represents a key that identifies an account. /// Represents a key that identifies an account.
/// </summary> /// </summary>
@@ -87,7 +135,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
/// <summary> /// <summary>
/// A display name that offers context for the account, such as "Contoso". /// A display name that offers context for the account, such as "Contoso".
/// </summary> /// </summary>
public string ContextualDisplayName { get; set; } public string ContextualDisplayName { get; set; }
// Note: ignoring ContextualLogo as it's not needed // Note: ignoring ContextualLogo as it's not needed

View File

@@ -97,6 +97,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
userAccount.DisplayInfo = ToDisplayInfo(account); userAccount.DisplayInfo = ToDisplayInfo(account);
userAccount.NeedsReauthentication = account.IsStale; userAccount.NeedsReauthentication = account.IsStale;
userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account); userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account);
userAccount.UnderlyingAccount = account;
return userAccount; return userAccount;
} }

View File

@@ -39,7 +39,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
{ {
private readonly Uri _resourceManagementUri = new Uri("https://management.azure.com/"); private readonly Uri _resourceManagementUri = new Uri("https://management.azure.com/");
private const string ExpiredTokenCode = "ExpiredAuthenticationToken"; private const string ExpiredTokenCode = "ExpiredAuthenticationToken";
public AzureResourceManager() public AzureResourceManager()
{ {
// Duplicate the exportable attribute as at present we do not support filtering using extensiondescriptor. // Duplicate the exportable attribute as at present we do not support filtering using extensiondescriptor.
@@ -55,12 +55,26 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
CommonUtil.CheckForNull(subscriptionContext, "subscriptionContext"); CommonUtil.CheckForNull(subscriptionContext, "subscriptionContext");
try try
{ {
string armEndpoint = subscriptionContext.UserAccount.UnderlyingAccount.Properties.ProviderSettings?.Settings?.ArmResource?.Endpoint;
Uri armUri = null;
if (armEndpoint != null)
{
try
{
armUri = new Uri(armEndpoint);
}
catch (Exception e)
{
Console.WriteLine($"Exception while parsing URI: {e.Message}");
}
}
ServiceClientCredentials credentials = CreateCredentials(subscriptionContext); ServiceClientCredentials credentials = CreateCredentials(subscriptionContext);
SqlManagementClient sqlManagementClient = new SqlManagementClient(credentials) SqlManagementClient sqlManagementClient = new SqlManagementClient(armUri ?? _resourceManagementUri, credentials)
{ {
SubscriptionId = subscriptionContext.Subscription.SubscriptionId SubscriptionId = subscriptionContext.Subscription.SubscriptionId
}; };
ResourceManagementClient resourceManagementClient = new ResourceManagementClient(credentials)
ResourceManagementClient resourceManagementClient = new ResourceManagementClient(armUri ?? _resourceManagementUri, credentials)
{ {
SubscriptionId = subscriptionContext.Subscription.SubscriptionId SubscriptionId = subscriptionContext.Subscription.SubscriptionId
}; };
@@ -81,7 +95,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
/// <param name="serverName">Server name</param> /// <param name="serverName">Server name</param>
/// <returns>The list of databases</returns> /// <returns>The list of databases</returns>
public async Task<IEnumerable<IAzureResource>> GetAzureDatabasesAsync( public async Task<IEnumerable<IAzureResource>> GetAzureDatabasesAsync(
IAzureResourceManagementSession azureResourceManagementSession, IAzureResourceManagementSession azureResourceManagementSession,
string resourceGroupName, string resourceGroupName,
string serverName) string serverName)
{ {
@@ -115,12 +129,12 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
public async Task<IEnumerable<IAzureSqlServerResource>> GetSqlServerAzureResourcesAsync( public async Task<IEnumerable<IAzureSqlServerResource>> GetSqlServerAzureResourcesAsync(
IAzureResourceManagementSession azureResourceManagementSession) IAzureResourceManagementSession azureResourceManagementSession)
{ {
CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession"); CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession");
List<IAzureSqlServerResource> sqlServers = new List<IAzureSqlServerResource>(); List<IAzureSqlServerResource> sqlServers = new List<IAzureSqlServerResource>();
try try
{ {
AzureResourceManagementSession vsAzureResourceManagementSession = azureResourceManagementSession as AzureResourceManagementSession; AzureResourceManagementSession vsAzureResourceManagementSession = azureResourceManagementSession as AzureResourceManagementSession;
if(vsAzureResourceManagementSession != null) if (vsAzureResourceManagementSession != null)
{ {
IServersOperations serverOperations = vsAzureResourceManagementSession.SqlManagementClient.Servers; IServersOperations serverOperations = vsAzureResourceManagementSession.SqlManagementClient.Servers;
IPage<Server> servers = await ExecuteCloudRequest( IPage<Server> servers = await ExecuteCloudRequest(
@@ -128,7 +142,8 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
SR.FailedToGetAzureSqlServersWithError); SR.FailedToGetAzureSqlServersWithError);
if (servers != null) if (servers != null)
{ {
sqlServers.AddRange(servers.Select(server => { sqlServers.AddRange(servers.Select(server =>
{
var serverResource = new SqlAzureResource(server); var serverResource = new SqlAzureResource(server);
// TODO ResourceGroup name // TODO ResourceGroup name
return serverResource; return serverResource;
@@ -136,9 +151,9 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
} }
} }
} }
catch(Exception ex) catch (Exception ex)
{ {
TraceException(TraceEventType.Error, (int) TraceId.AzureResource, ex, "Failed to get servers"); TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to get servers");
throw; throw;
} }
@@ -147,7 +162,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync( public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(
IAzureResourceManagementSession azureResourceManagementSession, IAzureResourceManagementSession azureResourceManagementSession,
IAzureSqlServerResource azureSqlServer, IAzureSqlServerResource azureSqlServer,
FirewallRuleRequest firewallRuleRequest) FirewallRuleRequest firewallRuleRequest)
{ {
CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession"); CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession");
@@ -183,23 +198,23 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
}; };
} }
// else respond with failure case // else respond with failure case
return new FirewallRuleResponse() return new FirewallRuleResponse()
{ {
Created = false Created = false
}; };
} }
catch (Exception ex) catch (Exception ex)
{ {
TraceException(TraceEventType.Error, (int) TraceId.AzureResource, ex, "Failed to create firewall rule"); TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to create firewall rule");
throw; throw;
} }
} }
private Dictionary<string,List<string>> GetCustomHeaders() private Dictionary<string, List<string>> GetCustomHeaders()
{ {
// For some unknown reason the firewall rule method defaults to returning XML. Fixes this by adding an Accept header // For some unknown reason the firewall rule method defaults to returning XML. Fixes this by adding an Accept header
// ensuring it's always JSON // ensuring it's always JSON
var headers = new Dictionary<string,List<string>>(); var headers = new Dictionary<string, List<string>>();
headers["Accept"] = new List<string>() { headers["Accept"] = new List<string>() {
"application/json" "application/json"
}; };
@@ -217,7 +232,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
stopwatch.Start(); stopwatch.Start();
ServiceResponse<IAzureUserAccountSubscriptionContext> response = await AzureUtil.ExecuteGetAzureResourceAsParallel( ServiceResponse<IAzureUserAccountSubscriptionContext> response = await AzureUtil.ExecuteGetAzureResourceAsParallel(
userAccount, userAccount.AllTenants, string.Empty, CancellationToken.None, GetSubscriptionsForTentantAsync); userAccount, userAccount.AllTenants, string.Empty, CancellationToken.None, GetSubscriptionsForTentantAsync);
if (response.HasError) if (response.HasError)
{ {
var ex = response.Errors.First(); var ex = response.Errors.First();
@@ -239,12 +254,25 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
if (azureTenant != null) if (azureTenant != null)
{ {
ServiceClientCredentials credentials = CreateCredentials(azureTenant); ServiceClientCredentials credentials = CreateCredentials(azureTenant);
using (SubscriptionClient client = new SubscriptionClient(_resourceManagementUri, credentials)) string armEndpoint = userAccount.UnderlyingAccount.Properties.ProviderSettings?.Settings?.ArmResource?.Endpoint;
Uri armUri = null;
if (armEndpoint != null)
{
try
{
armUri = new Uri(armEndpoint);
}
catch (Exception e)
{
Console.WriteLine($"Exception while parsing URI: {e.Message}");
}
}
using (SubscriptionClient client = new SubscriptionClient(armUri ?? _resourceManagementUri, credentials))
{ {
IEnumerable<Subscription> subs = await GetSubscriptionsAsync(client); IEnumerable<Subscription> subs = await GetSubscriptionsAsync(client);
return new ServiceResponse<IAzureUserAccountSubscriptionContext>(subs.Select(sub => return new ServiceResponse<IAzureUserAccountSubscriptionContext>(subs.Select(sub =>
{ {
AzureSubscriptionIdentifier subId = new AzureSubscriptionIdentifier(userAccount, azureTenant.TenantId, sub.SubscriptionId, _resourceManagementUri); AzureSubscriptionIdentifier subId = new AzureSubscriptionIdentifier(userAccount, azureTenant.TenantId, sub.SubscriptionId, armUri ?? _resourceManagementUri);
AzureUserAccountSubscriptionContext context = new AzureUserAccountSubscriptionContext(subId, credentials); AzureUserAccountSubscriptionContext context = new AzureUserAccountSubscriptionContext(subId, credentials);
return context; return context;
})); }));
@@ -325,7 +353,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
{ {
return await operation(); return await operation();
} }
catch(CloudException ex) catch (CloudException ex)
{ {
if (ex.Body != null && string.Equals(ExpiredTokenCode, ex.Body.Code, StringComparison.OrdinalIgnoreCase)) if (ex.Body != null && string.Equals(ExpiredTokenCode, ex.Body.Code, StringComparison.OrdinalIgnoreCase))
{ {

View File

@@ -6,6 +6,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core;
using Microsoft.SqlTools.ResourceProvider.Core.Authentication; using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
{ {
@@ -39,6 +40,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
this.TenantId = azureUserAccount.TenantId; this.TenantId = azureUserAccount.TenantId;
this.AllTenants = azureUserAccount.AllTenants; this.AllTenants = azureUserAccount.AllTenants;
this.UniqueId = azureUserAccount.UniqueId; this.UniqueId = azureUserAccount.UniqueId;
this.UnderlyingAccount = azureUserAccount.UnderlyingAccount;
AzureUserAccount account = azureUserAccount as AzureUserAccount; AzureUserAccount account = azureUserAccount as AzureUserAccount;
} }
/// <summary> /// <summary>
@@ -99,5 +101,11 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
get; get;
set; set;
} }
public Account UnderlyingAccount
{
get;
set;
}
} }
} }