// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; using Microsoft.SqlTools.ResourceProvider.Core.Contracts; using Microsoft.SqlTools.ResourceProvider.Core.Extensibility; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { /// /// Implementation for . /// Provides functionality to authenticate to Azure and discover associated accounts and subscriptions /// [Exportable( ServerTypes.SqlServer, Categories.Azure, typeof(IAzureAuthenticationManager), "Microsoft.SqlTools.ResourceProvider.DefaultImpl.AzureAuthenticationManager", 1) ] public class AzureAuthenticationManager : ExportableBase, IAzureAuthenticationManager { private Dictionary accountsMap; private string currentAccountId = null; private IEnumerable _selectedSubscriptions = null; private readonly object _selectedSubscriptionsLockObject = new object(); private readonly ConcurrentCache> _subscriptionCache = new ConcurrentCache>(); public AzureAuthenticationManager() { Metadata = new ExportableMetadata( ServerTypes.SqlServer, Categories.Azure, "Microsoft.SqlTools.ResourceProvider.DefaultImpl.AzureAuthenticationManager", 1); accountsMap = new Dictionary(); } public IEnumerable UserAccounts { get { return accountsMap.Values; } } public bool HasLoginDialog { get { return false; } } /// /// Set current logged in user /// public async Task SetCurrentAccountAsync(object account) { CommonUtil.CheckForNull(account, nameof(account)); AccountTokenWrapper accountTokenWrapper = account as AccountTokenWrapper; if (accountTokenWrapper != null) { AzureUserAccount userAccount = CreateUserAccount(accountTokenWrapper); accountsMap[userAccount.UniqueId] = userAccount; currentAccountId = userAccount.UniqueId; } else { throw new ServiceFailedException(string.Format(CultureInfo.CurrentCulture, SR.UnsupportedAuthType, account.GetType().Name)); } OnCurrentAccountChanged(); return await GetCurrentAccountAsync(); } /// /// Public for testing purposes. Creates an Azure account with the correct set of mappings for tenants etc. /// /// /// public AzureUserAccount CreateUserAccount(AccountTokenWrapper accountTokenWrapper) { Account account = accountTokenWrapper.Account; CommonUtil.CheckForNull(accountTokenWrapper, nameof(accountTokenWrapper)); CommonUtil.CheckForNull(account, nameof(account)); CommonUtil.CheckForNull(account.Key, nameof(account) + ".Key"); CommonUtil.CheckForNull(accountTokenWrapper.SecurityTokenMappings, nameof(account) + ".SecurityTokenMappings"); AzureUserAccount userAccount = new AzureUserAccount(); userAccount.UniqueId = account.Key.AccountId; userAccount.DisplayInfo = ToDisplayInfo(account); userAccount.NeedsReauthentication = account.IsStale; userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account); userAccount.UnderlyingAccount = account; return userAccount; } private static IList ProcessTenants(AccountTokenWrapper accountTokenWrapper, Account account) { IList tenants = new List(); if (account.Properties != null && account.Properties.Tenants != null) { foreach (Tenant tenant in account.Properties.Tenants) { AccountSecurityToken token; if (accountTokenWrapper.SecurityTokenMappings.TryGetValue(tenant.Id, out token)) { AzureTenant azureTenant = new AzureTenant() { TenantId = tenant.Id, AccountDisplayableId = tenant.DisplayName, Resource = token.Resource, AccessToken = token.Token, TokenType = token.TokenType }; tenants.Add(azureTenant); } // else ignore for now as we can't handle a request to get a tenant without an access key } } return tenants; } private AzureUserAccountDisplayInfo ToDisplayInfo(Account account) { return new AzureUserAccountDisplayInfo() { AccountDisplayName = account.DisplayInfo != null ? account.DisplayInfo.DisplayName : account.Key.AccountId, ProviderDisplayName = account.Key.ProviderId }; } private void OnCurrentAccountChanged() { lock (_selectedSubscriptionsLockObject) { _selectedSubscriptions = null; } if (CurrentAccountChanged != null) { CurrentAccountChanged(this, new EventArgs()); } } /// /// The event to be raised when the current account is changed /// public event EventHandler CurrentAccountChanged; public Task AddUserAccountAsync() { throw new NotImplementedException(); } public Task AuthenticateAsync() { throw new NotImplementedException(); } public async Task GetCurrentAccountAsync() { var account = await GetCurrentAccountInternalAsync(); return account; } private Task GetCurrentAccountInternalAsync() { AzureUserAccount account = null; if (currentAccountId != null && accountsMap.TryGetValue(currentAccountId, out account)) { // TODO is there more needed here? } return Task.FromResult(account); } public async Task> GetSelectedSubscriptionsAsync() { return _selectedSubscriptions ?? await GetSubscriptionsAsync(); } /// /// Returns user's subscriptions /// public async Task> GetSubscriptionsAsync() { var result = Enumerable.Empty(); bool userNeedsAuthentication = await GetUserNeedsReauthenticationAsync(); if (!userNeedsAuthentication) { AzureUserAccount currentUser = await GetCurrentAccountInternalAsync(); if (currentUser != null) { try { result = await GetSubscriptionsFromCacheAsync(currentUser); } catch (ServiceExceptionBase) { throw; } catch (Exception ex) { throw new ServiceFailedException(SR.FailedToGetAzureSubscriptionsErrorMessage, ex); } } result = result ?? Enumerable.Empty(); } return result; } private async Task> GetSubscriptionsFromCacheAsync(AzureUserAccount user) { var result = Enumerable.Empty(); if (user != null) { if (user.UniqueId != "") { result = _subscriptionCache.Get(user.UniqueId); } if (result == Enumerable.Empty() ^ result == null) { result = await GetSubscriptionFromServiceAsync(user); _subscriptionCache.UpdateCache(user.UniqueId, result); } } result = result ?? Enumerable.Empty(); return result; } private async Task> GetSubscriptionFromServiceAsync(AzureUserAccount userAccount) { CommonUtil.CheckForNull(userAccount, nameof(userAccount)); List subscriptionList = new List(); if (userAccount.NeedsReauthentication) { throw new ExpiredTokenException(SR.UserNeedsAuthenticationError); } try { IAzureResourceManager resourceManager = ServiceProvider.GetService(); IEnumerable contexts = await resourceManager.GetSubscriptionContextsAsync(userAccount); subscriptionList = contexts.ToList(); } catch (ServiceExceptionBase) { throw; } catch (Exception ex) { throw new ServiceFailedException(SR.FailedToGetAzureSubscriptionsErrorMessage, ex); } return subscriptionList; } public Task GetUserNeedsReauthenticationAsync() { // for now, we don't support handling stale auth objects return Task.FromResult(false); } /// /// Stores the selected subscriptions given the ids /// public async Task SetSelectedSubscriptionsAsync(IEnumerable subscriptionIds) { IEnumerable subscriptions = await GetSubscriptionsAsync(); List subscriptionList = subscriptions.ToList(); List newSelectedSubscriptions = subscriptionIds == null ? subscriptionList : subscriptionList.Where(x => subscriptionIds.Contains(x.Subscription.SubscriptionId)).ToList(); //If the current account changes during setting selected subscription, none of the ids should be found //so we just reset the selected subscriptions if (subscriptionIds != null && subscriptionIds.Any() && newSelectedSubscriptions.Count == 0) { newSelectedSubscriptions = subscriptionList; } lock (_selectedSubscriptionsLockObject) { if (!SelectedSubscriptionsEquals(newSelectedSubscriptions)) { _selectedSubscriptions = newSelectedSubscriptions; return true; } } return false; } private bool SelectedSubscriptionsEquals(List newSelectedSubscriptions) { if (_selectedSubscriptions != null && _selectedSubscriptions.Count() == newSelectedSubscriptions.Count) { return newSelectedSubscriptions.All(subscription => _selectedSubscriptions.Contains(subscription)); } return false; } /// /// Tries to find a subscription given subscription id /// public bool TryParseSubscriptionIdentifier(string value, out IAzureSubscriptionIdentifier subscription) { // TODO can port this over from the VS implementation if needed, but for now disabling as we don't serialize / deserialize subscriptions throw new NotImplementedException(); } } }