diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Authentication/IAzureUserAccount.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Authentication/IAzureUserAccount.cs index e82befa0..eeebd9cc 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Authentication/IAzureUserAccount.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Authentication/IAzureUserAccount.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using Microsoft.SqlTools.ResourceProvider.Core.Contracts; namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication { @@ -12,6 +13,11 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Authentication /// public interface IAzureUserAccount : IEquatable, IUserAccount { + Account UnderlyingAccount + { + get; + set; + } /// /// User Account Display Info /// diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/Account.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/Account.cs index 6f85964d..1f0a4312 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/Account.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/Account.cs @@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts /// Indicates if the account needs refreshing /// public bool IsStale { get; set; } - + } /// @@ -58,9 +58,57 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts get; set; } + /// + /// Information about the auth provider + /// + public ProviderSettings ProviderSettings; } + public class ProviderSettings + { + /// + /// Display name of the provider + /// + public string DisplayName; + + /// + /// ID of the provider + /// + public string Id; + /// + /// Settings for the provider itself + /// + public ProviderSettingsObject Settings; + } + + public class ProviderSettingsObject + { + public ResourceSetting ArmResource; + public ResourceSetting GraphResource; + public ResourceSetting OssRdbmsResource; + public ResourceSetting SqlResource; + + /// + /// Actual sign in link + /// + public string Host; + /// + /// ClientID used + /// + public string ClientId; + } + + public class ResourceSetting + { + + /// + /// Endpoint of the resource + /// + public string Endpoint; + public string Id; + } + /// /// Represents a key that identifies an account. /// @@ -87,7 +135,6 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts /// /// A display name that offers context for the account, such as "Contoso". /// - public string ContextualDisplayName { get; set; } // Note: ignoring ContextualLogo as it's not needed diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs index 7702f6fb..43a07b05 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs @@ -97,6 +97,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl userAccount.DisplayInfo = ToDisplayInfo(account); userAccount.NeedsReauthentication = account.IsStale; userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account); + userAccount.UnderlyingAccount = account; return userAccount; } diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs index 824de076..32ec7e3d 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs @@ -39,7 +39,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { private readonly Uri _resourceManagementUri = new Uri("https://management.azure.com/"); private const string ExpiredTokenCode = "ExpiredAuthenticationToken"; - + public AzureResourceManager() { // Duplicate the exportable attribute as at present we do not support filtering using extensiondescriptor. @@ -55,12 +55,26 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl CommonUtil.CheckForNull(subscriptionContext, "subscriptionContext"); try { + string armEndpoint = subscriptionContext.UserAccount.UnderlyingAccount.Properties.ProviderSettings?.Settings?.ArmResource?.Endpoint; + Uri armUri = null; + if (armEndpoint != null) + { + try + { + armUri = new Uri(armEndpoint); + } + catch (Exception e) + { + Console.WriteLine($"Exception while parsing URI: {e.Message}"); + } + } ServiceClientCredentials credentials = CreateCredentials(subscriptionContext); - SqlManagementClient sqlManagementClient = new SqlManagementClient(credentials) + SqlManagementClient sqlManagementClient = new SqlManagementClient(armUri ?? _resourceManagementUri, credentials) { SubscriptionId = subscriptionContext.Subscription.SubscriptionId }; - ResourceManagementClient resourceManagementClient = new ResourceManagementClient(credentials) + + ResourceManagementClient resourceManagementClient = new ResourceManagementClient(armUri ?? _resourceManagementUri, credentials) { SubscriptionId = subscriptionContext.Subscription.SubscriptionId }; @@ -81,7 +95,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl /// Server name /// The list of databases public async Task> GetAzureDatabasesAsync( - IAzureResourceManagementSession azureResourceManagementSession, + IAzureResourceManagementSession azureResourceManagementSession, string resourceGroupName, string serverName) { @@ -115,12 +129,12 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl public async Task> GetSqlServerAzureResourcesAsync( IAzureResourceManagementSession azureResourceManagementSession) { - CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession"); + CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession"); List sqlServers = new List(); try { AzureResourceManagementSession vsAzureResourceManagementSession = azureResourceManagementSession as AzureResourceManagementSession; - if(vsAzureResourceManagementSession != null) + if (vsAzureResourceManagementSession != null) { IServersOperations serverOperations = vsAzureResourceManagementSession.SqlManagementClient.Servers; IPage servers = await ExecuteCloudRequest( @@ -128,7 +142,8 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl SR.FailedToGetAzureSqlServersWithError); if (servers != null) { - sqlServers.AddRange(servers.Select(server => { + sqlServers.AddRange(servers.Select(server => + { var serverResource = new SqlAzureResource(server); // TODO ResourceGroup name return serverResource; @@ -136,9 +151,9 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl } } } - catch(Exception ex) + catch (Exception ex) { - TraceException(TraceEventType.Error, (int) TraceId.AzureResource, ex, "Failed to get servers"); + TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to get servers"); throw; } @@ -147,7 +162,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl public async Task CreateFirewallRuleAsync( IAzureResourceManagementSession azureResourceManagementSession, - IAzureSqlServerResource azureSqlServer, + IAzureSqlServerResource azureSqlServer, FirewallRuleRequest firewallRuleRequest) { CommonUtil.CheckForNull(azureResourceManagementSession, "azureResourceManagerSession"); @@ -183,23 +198,23 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl }; } // else respond with failure case - return new FirewallRuleResponse() - { + return new FirewallRuleResponse() + { Created = false }; } catch (Exception ex) { - TraceException(TraceEventType.Error, (int) TraceId.AzureResource, ex, "Failed to create firewall rule"); + TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to create firewall rule"); throw; } } - private Dictionary> GetCustomHeaders() + private Dictionary> GetCustomHeaders() { // For some unknown reason the firewall rule method defaults to returning XML. Fixes this by adding an Accept header // ensuring it's always JSON - var headers = new Dictionary>(); + var headers = new Dictionary>(); headers["Accept"] = new List() { "application/json" }; @@ -217,7 +232,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl stopwatch.Start(); ServiceResponse response = await AzureUtil.ExecuteGetAzureResourceAsParallel( userAccount, userAccount.AllTenants, string.Empty, CancellationToken.None, GetSubscriptionsForTentantAsync); - + if (response.HasError) { var ex = response.Errors.First(); @@ -239,12 +254,25 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl if (azureTenant != null) { ServiceClientCredentials credentials = CreateCredentials(azureTenant); - using (SubscriptionClient client = new SubscriptionClient(_resourceManagementUri, credentials)) + string armEndpoint = userAccount.UnderlyingAccount.Properties.ProviderSettings?.Settings?.ArmResource?.Endpoint; + Uri armUri = null; + if (armEndpoint != null) + { + try + { + armUri = new Uri(armEndpoint); + } + catch (Exception e) + { + Console.WriteLine($"Exception while parsing URI: {e.Message}"); + } + } + using (SubscriptionClient client = new SubscriptionClient(armUri ?? _resourceManagementUri, credentials)) { IEnumerable subs = await GetSubscriptionsAsync(client); return new ServiceResponse(subs.Select(sub => { - AzureSubscriptionIdentifier subId = new AzureSubscriptionIdentifier(userAccount, azureTenant.TenantId, sub.SubscriptionId, _resourceManagementUri); + AzureSubscriptionIdentifier subId = new AzureSubscriptionIdentifier(userAccount, azureTenant.TenantId, sub.SubscriptionId, armUri ?? _resourceManagementUri); AzureUserAccountSubscriptionContext context = new AzureUserAccountSubscriptionContext(subId, credentials); return context; })); @@ -325,7 +353,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return await operation(); } - catch(CloudException ex) + catch (CloudException ex) { if (ex.Body != null && string.Equals(ExpiredTokenCode, ex.Body.Code, StringComparison.OrdinalIgnoreCase)) { diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureUserAccount.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureUserAccount.cs index 7b559954..774a5fd0 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureUserAccount.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureUserAccount.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; +using Microsoft.SqlTools.ResourceProvider.Core.Contracts; namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { @@ -39,6 +40,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl this.TenantId = azureUserAccount.TenantId; this.AllTenants = azureUserAccount.AllTenants; this.UniqueId = azureUserAccount.UniqueId; + this.UnderlyingAccount = azureUserAccount.UnderlyingAccount; AzureUserAccount account = azureUserAccount as AzureUserAccount; } /// @@ -99,5 +101,11 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl get; set; } + + public Account UnderlyingAccount + { + get; + set; + } } }