mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-02-04 01:25:43 -05:00
Allow the arm baseURI to be set dynamically (#920)
* Allow the arm baseURI to be set dynamically * Defensive programming
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
|
||||
|
||||
namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication
|
||||
{
|
||||
@@ -12,6 +13,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication
|
||||
/// </summary>
|
||||
public interface IAzureUserAccount : IEquatable<IAzureUserAccount>, IUserAccount
|
||||
{
|
||||
Account UnderlyingAccount
|
||||
{
|
||||
get;
|
||||
set;
|
||||
}
|
||||
/// <summary>
|
||||
/// User Account Display Info
|
||||
/// </summary>
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
|
||||
/// Indicates if the account needs refreshing
|
||||
/// </summary>
|
||||
public bool IsStale { get; set; }
|
||||
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -58,9 +58,57 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
|
||||
get;
|
||||
set;
|
||||
}
|
||||
/// <summary>
|
||||
/// Information about the auth provider
|
||||
/// </summary>
|
||||
public ProviderSettings ProviderSettings;
|
||||
|
||||
}
|
||||
|
||||
public class ProviderSettings
|
||||
{
|
||||
/// <summary>
|
||||
/// Display name of the provider
|
||||
/// </summary>
|
||||
public string DisplayName;
|
||||
|
||||
/// <summary>
|
||||
/// ID of the provider
|
||||
/// </summary>
|
||||
public string Id;
|
||||
/// <summary>
|
||||
/// Settings for the provider itself
|
||||
/// </summary>
|
||||
public ProviderSettingsObject Settings;
|
||||
}
|
||||
|
||||
public class ProviderSettingsObject
|
||||
{
|
||||
public ResourceSetting ArmResource;
|
||||
public ResourceSetting GraphResource;
|
||||
public ResourceSetting OssRdbmsResource;
|
||||
public ResourceSetting SqlResource;
|
||||
|
||||
/// <summary>
|
||||
/// Actual sign in link
|
||||
/// </summary>
|
||||
public string Host;
|
||||
/// <summary>
|
||||
/// ClientID used
|
||||
/// </summary>
|
||||
public string ClientId;
|
||||
}
|
||||
|
||||
public class ResourceSetting
|
||||
{
|
||||
|
||||
/// <summary>
|
||||
/// Endpoint of the resource
|
||||
/// </summary>
|
||||
public string Endpoint;
|
||||
public string Id;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Represents a key that identifies an account.
|
||||
/// </summary>
|
||||
@@ -87,7 +135,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts
|
||||
/// <summary>
|
||||
/// A display name that offers context for the account, such as "Contoso".
|
||||
/// </summary>
|
||||
|
||||
public string ContextualDisplayName { get; set; }
|
||||
|
||||
// Note: ignoring ContextualLogo as it's not needed
|
||||
|
||||
@@ -97,6 +97,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
userAccount.DisplayInfo = ToDisplayInfo(account);
|
||||
userAccount.NeedsReauthentication = account.IsStale;
|
||||
userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account);
|
||||
userAccount.UnderlyingAccount = account;
|
||||
return userAccount;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
{
|
||||
private readonly Uri _resourceManagementUri = new Uri("https://management.azure.com/");
|
||||
private const string ExpiredTokenCode = "ExpiredAuthenticationToken";
|
||||
|
||||
|
||||
public AzureResourceManager()
|
||||
{
|
||||
// Duplicate the exportable attribute as at present we do not support filtering using extensiondescriptor.
|
||||
@@ -55,12 +55,26 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
CommonUtil.CheckForNull(subscriptionContext, "subscriptionContext");
|
||||
try
|
||||
{
|
||||
string armEndpoint = subscriptionContext.UserAccount.UnderlyingAccount.Properties.ProviderSettings?.Settings?.ArmResource?.Endpoint;
|
||||
Uri armUri = null;
|
||||
if (armEndpoint != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
armUri = new Uri(armEndpoint);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Console.WriteLine($"Exception while parsing URI: {e.Message}");
|
||||
}
|
||||
}
|
||||
ServiceClientCredentials credentials = CreateCredentials(subscriptionContext);
|
||||
SqlManagementClient sqlManagementClient = new SqlManagementClient(credentials)
|
||||
SqlManagementClient sqlManagementClient = new SqlManagementClient(armUri ?? _resourceManagementUri, credentials)
|
||||
{
|
||||
SubscriptionId = subscriptionContext.Subscription.SubscriptionId
|
||||
};
|
||||
ResourceManagementClient resourceManagementClient = new ResourceManagementClient(credentials)
|
||||
|
||||
ResourceManagementClient resourceManagementClient = new ResourceManagementClient(armUri ?? _resourceManagementUri, credentials)
|
||||
{
|
||||
SubscriptionId = subscriptionContext.Subscription.SubscriptionId
|
||||
};
|
||||
@@ -81,7 +95,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
/// <param name="serverName">Server name</param>
|
||||
/// <returns>The list of databases</returns>
|
||||
public async Task<IEnumerable<IAzureResource>> GetAzureDatabasesAsync(
|
||||
IAzureResourceManagementSession azureResourceManagementSession,
|
||||
IAzureResourceManagementSession azureResourceManagementSession,
|
||||
string resourceGroupName,
|
||||
string serverName)
|
||||
{
|
||||
@@ -115,12 +129,12 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
public async Task<IEnumerable<IAzureSqlServerResource>> GetSqlServerAzureResourcesAsync(
|
||||
IAzureResourceManagementSession azureResourceManagementSession)
|
||||
{
|
||||
CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession");
|
||||
CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession");
|
||||
List<IAzureSqlServerResource> sqlServers = new List<IAzureSqlServerResource>();
|
||||
try
|
||||
{
|
||||
AzureResourceManagementSession vsAzureResourceManagementSession = azureResourceManagementSession as AzureResourceManagementSession;
|
||||
if(vsAzureResourceManagementSession != null)
|
||||
if (vsAzureResourceManagementSession != null)
|
||||
{
|
||||
IServersOperations serverOperations = vsAzureResourceManagementSession.SqlManagementClient.Servers;
|
||||
IPage<Server> servers = await ExecuteCloudRequest(
|
||||
@@ -128,7 +142,8 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
SR.FailedToGetAzureSqlServersWithError);
|
||||
if (servers != null)
|
||||
{
|
||||
sqlServers.AddRange(servers.Select(server => {
|
||||
sqlServers.AddRange(servers.Select(server =>
|
||||
{
|
||||
var serverResource = new SqlAzureResource(server);
|
||||
// TODO ResourceGroup name
|
||||
return serverResource;
|
||||
@@ -136,9 +151,9 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
}
|
||||
}
|
||||
}
|
||||
catch(Exception ex)
|
||||
catch (Exception ex)
|
||||
{
|
||||
TraceException(TraceEventType.Error, (int) TraceId.AzureResource, ex, "Failed to get servers");
|
||||
TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to get servers");
|
||||
throw;
|
||||
}
|
||||
|
||||
@@ -147,7 +162,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
|
||||
public async Task<FirewallRuleResponse> CreateFirewallRuleAsync(
|
||||
IAzureResourceManagementSession azureResourceManagementSession,
|
||||
IAzureSqlServerResource azureSqlServer,
|
||||
IAzureSqlServerResource azureSqlServer,
|
||||
FirewallRuleRequest firewallRuleRequest)
|
||||
{
|
||||
CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession");
|
||||
@@ -183,23 +198,23 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
};
|
||||
}
|
||||
// else respond with failure case
|
||||
return new FirewallRuleResponse()
|
||||
{
|
||||
return new FirewallRuleResponse()
|
||||
{
|
||||
Created = false
|
||||
};
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
TraceException(TraceEventType.Error, (int) TraceId.AzureResource, ex, "Failed to create firewall rule");
|
||||
TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to create firewall rule");
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
private Dictionary<string,List<string>> GetCustomHeaders()
|
||||
private Dictionary<string, List<string>> GetCustomHeaders()
|
||||
{
|
||||
// For some unknown reason the firewall rule method defaults to returning XML. Fixes this by adding an Accept header
|
||||
// ensuring it's always JSON
|
||||
var headers = new Dictionary<string,List<string>>();
|
||||
var headers = new Dictionary<string, List<string>>();
|
||||
headers["Accept"] = new List<string>() {
|
||||
"application/json"
|
||||
};
|
||||
@@ -217,7 +232,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
stopwatch.Start();
|
||||
ServiceResponse<IAzureUserAccountSubscriptionContext> response = await AzureUtil.ExecuteGetAzureResourceAsParallel(
|
||||
userAccount, userAccount.AllTenants, string.Empty, CancellationToken.None, GetSubscriptionsForTentantAsync);
|
||||
|
||||
|
||||
if (response.HasError)
|
||||
{
|
||||
var ex = response.Errors.First();
|
||||
@@ -239,12 +254,25 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
if (azureTenant != null)
|
||||
{
|
||||
ServiceClientCredentials credentials = CreateCredentials(azureTenant);
|
||||
using (SubscriptionClient client = new SubscriptionClient(_resourceManagementUri, credentials))
|
||||
string armEndpoint = userAccount.UnderlyingAccount.Properties.ProviderSettings?.Settings?.ArmResource?.Endpoint;
|
||||
Uri armUri = null;
|
||||
if (armEndpoint != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
armUri = new Uri(armEndpoint);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Console.WriteLine($"Exception while parsing URI: {e.Message}");
|
||||
}
|
||||
}
|
||||
using (SubscriptionClient client = new SubscriptionClient(armUri ?? _resourceManagementUri, credentials))
|
||||
{
|
||||
IEnumerable<Subscription> subs = await GetSubscriptionsAsync(client);
|
||||
return new ServiceResponse<IAzureUserAccountSubscriptionContext>(subs.Select(sub =>
|
||||
{
|
||||
AzureSubscriptionIdentifier subId = new AzureSubscriptionIdentifier(userAccount, azureTenant.TenantId, sub.SubscriptionId, _resourceManagementUri);
|
||||
AzureSubscriptionIdentifier subId = new AzureSubscriptionIdentifier(userAccount, azureTenant.TenantId, sub.SubscriptionId, armUri ?? _resourceManagementUri);
|
||||
AzureUserAccountSubscriptionContext context = new AzureUserAccountSubscriptionContext(subId, credentials);
|
||||
return context;
|
||||
}));
|
||||
@@ -325,7 +353,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
{
|
||||
return await operation();
|
||||
}
|
||||
catch(CloudException ex)
|
||||
catch (CloudException ex)
|
||||
{
|
||||
if (ex.Body != null && string.Equals(ExpiredTokenCode, ex.Body.Code, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@ using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.SqlTools.ResourceProvider.Core;
|
||||
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
|
||||
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
|
||||
|
||||
namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
{
|
||||
@@ -39,6 +40,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
this.TenantId = azureUserAccount.TenantId;
|
||||
this.AllTenants = azureUserAccount.AllTenants;
|
||||
this.UniqueId = azureUserAccount.UniqueId;
|
||||
this.UnderlyingAccount = azureUserAccount.UnderlyingAccount;
|
||||
AzureUserAccount account = azureUserAccount as AzureUserAccount;
|
||||
}
|
||||
/// <summary>
|
||||
@@ -99,5 +101,11 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
|
||||
get;
|
||||
set;
|
||||
}
|
||||
|
||||
public Account UnderlyingAccount
|
||||
{
|
||||
get;
|
||||
set;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user