//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ResourceProvider.Core;
using Microsoft.SqlTools.ResourceProvider.Core.Authentication;
using Microsoft.SqlTools.ResourceProvider.Core.Contracts;
using Microsoft.SqlTools.ResourceProvider.Core.Extensibility;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl
{
///
/// Implementation for .
/// Provides functionality to authenticate to Azure and discover associated accounts and subscriptions
///
[Exportable(
ServerTypes.SqlServer,
Categories.Azure,
typeof(IAzureAuthenticationManager),
"Microsoft.SqlTools.ResourceProvider.DefaultImpl.AzureAuthenticationManager",
1)
]
public class AzureAuthenticationManager : ExportableBase, IAzureAuthenticationManager
{
private Dictionary accountsMap;
private string currentAccountId = null;
private IEnumerable _selectedSubscriptions = null;
private readonly object _selectedSubscriptionsLockObject = new object();
private readonly ConcurrentCache> _subscriptionCache =
new ConcurrentCache>();
public AzureAuthenticationManager()
{
Metadata = new ExportableMetadata(
ServerTypes.SqlServer,
Categories.Azure,
"Microsoft.SqlTools.ResourceProvider.DefaultImpl.AzureAuthenticationManager",
1);
accountsMap = new Dictionary();
}
public IEnumerable UserAccounts
{
get { return accountsMap.Values; }
}
public bool HasLoginDialog
{
get { return false; }
}
///
/// Set current logged in user
///
public async Task SetCurrentAccountAsync(object account)
{
CommonUtil.CheckForNull(account, nameof(account));
AccountTokenWrapper accountTokenWrapper = account as AccountTokenWrapper;
if (accountTokenWrapper != null)
{
AzureUserAccount userAccount = CreateUserAccount(accountTokenWrapper);
accountsMap[userAccount.UniqueId] = userAccount;
currentAccountId = userAccount.UniqueId;
}
else
{
throw new ServiceFailedException(string.Format(CultureInfo.CurrentCulture, SR.UnsupportedAuthType, account.GetType().Name));
}
OnCurrentAccountChanged();
return await GetCurrentAccountAsync();
}
///
/// Public for testing purposes. Creates an Azure account with the correct set of mappings for tenants etc.
///
///
///
public AzureUserAccount CreateUserAccount(AccountTokenWrapper accountTokenWrapper)
{
Account account = accountTokenWrapper.Account;
CommonUtil.CheckForNull(accountTokenWrapper, nameof(accountTokenWrapper));
CommonUtil.CheckForNull(account, nameof(account));
CommonUtil.CheckForNull(account.Key, nameof(account) + ".Key");
CommonUtil.CheckForNull(accountTokenWrapper.SecurityTokenMappings, nameof(account) + ".SecurityTokenMappings");
AzureUserAccount userAccount = new AzureUserAccount();
userAccount.UniqueId = account.Key.AccountId;
userAccount.DisplayInfo = ToDisplayInfo(account);
userAccount.NeedsReauthentication = account.IsStale;
userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account);
userAccount.UnderlyingAccount = account;
return userAccount;
}
private static IList ProcessTenants(AccountTokenWrapper accountTokenWrapper, Account account)
{
IList tenants = new List();
if (account.Properties != null && account.Properties.Tenants != null)
{
foreach (Tenant tenant in account.Properties.Tenants)
{
AccountSecurityToken token;
if (accountTokenWrapper.SecurityTokenMappings.TryGetValue(tenant.Id, out token))
{
AzureTenant azureTenant = new AzureTenant()
{
TenantId = tenant.Id,
AccountDisplayableId = tenant.DisplayName,
Resource = token.Resource,
AccessToken = token.Token,
TokenType = token.TokenType
};
tenants.Add(azureTenant);
}
// else ignore for now as we can't handle a request to get a tenant without an access key
}
}
return tenants;
}
private AzureUserAccountDisplayInfo ToDisplayInfo(Account account)
{
return new AzureUserAccountDisplayInfo()
{
AccountDisplayName = account.DisplayInfo != null ? account.DisplayInfo.DisplayName : account.Key.AccountId,
ProviderDisplayName = account.Key.ProviderId
};
}
private void OnCurrentAccountChanged()
{
lock (_selectedSubscriptionsLockObject)
{
_selectedSubscriptions = null;
}
if (CurrentAccountChanged != null)
{
CurrentAccountChanged(this, new EventArgs());
}
}
///
/// The event to be raised when the current account is changed
///
public event EventHandler CurrentAccountChanged;
public Task AddUserAccountAsync()
{
throw new NotImplementedException();
}
public Task AuthenticateAsync()
{
throw new NotImplementedException();
}
public async Task GetCurrentAccountAsync()
{
var account = await GetCurrentAccountInternalAsync();
return account;
}
private Task GetCurrentAccountInternalAsync()
{
AzureUserAccount account = null;
if (currentAccountId != null
&& accountsMap.TryGetValue(currentAccountId, out account))
{
// TODO is there more needed here?
}
return Task.FromResult(account);
}
public async Task> GetSelectedSubscriptionsAsync()
{
return _selectedSubscriptions ?? await GetSubscriptionsAsync();
}
///
/// Returns user's subscriptions
///
public async Task> GetSubscriptionsAsync()
{
var result = Enumerable.Empty();
bool userNeedsAuthentication = await GetUserNeedsReauthenticationAsync();
if (!userNeedsAuthentication)
{
AzureUserAccount currentUser = await GetCurrentAccountInternalAsync();
if (currentUser != null)
{
try
{
result = await GetSubscriptionsFromCacheAsync(currentUser);
}
catch (ServiceExceptionBase)
{
throw;
}
catch (Exception ex)
{
throw new ServiceFailedException(SR.FailedToGetAzureSubscriptionsErrorMessage, ex);
}
}
result = result ?? Enumerable.Empty();
}
return result;
}
private async Task> GetSubscriptionsFromCacheAsync(AzureUserAccount user)
{
var result = Enumerable.Empty();
if (user != null)
{
if (user.UniqueId != "") {
result = _subscriptionCache.Get(user.UniqueId);
}
if (result == Enumerable.Empty() ^ result == null) {
result = await GetSubscriptionFromServiceAsync(user);
_subscriptionCache.UpdateCache(user.UniqueId, result);
}
}
result = result ?? Enumerable.Empty();
return result;
}
private async Task> GetSubscriptionFromServiceAsync(AzureUserAccount userAccount)
{
CommonUtil.CheckForNull(userAccount, nameof(userAccount));
List subscriptionList = new List();
if (userAccount.NeedsReauthentication)
{
throw new ExpiredTokenException(SR.UserNeedsAuthenticationError);
}
try
{
IAzureResourceManager resourceManager = ServiceProvider.GetService();
IEnumerable contexts = await resourceManager.GetSubscriptionContextsAsync(userAccount);
subscriptionList = contexts.ToList();
}
catch (ServiceExceptionBase)
{
throw;
}
catch (Exception ex)
{
throw new ServiceFailedException(SR.FailedToGetAzureSubscriptionsErrorMessage, ex);
}
return subscriptionList;
}
public Task GetUserNeedsReauthenticationAsync()
{
// for now, we don't support handling stale auth objects
return Task.FromResult(false);
}
///
/// Stores the selected subscriptions given the ids
///
public async Task SetSelectedSubscriptionsAsync(IEnumerable subscriptionIds)
{
IEnumerable subscriptions = await GetSubscriptionsAsync();
List subscriptionList = subscriptions.ToList();
List newSelectedSubscriptions = subscriptionIds == null
? subscriptionList
: subscriptionList.Where(x => subscriptionIds.Contains(x.Subscription.SubscriptionId)).ToList();
//If the current account changes during setting selected subscription, none of the ids should be found
//so we just reset the selected subscriptions
if (subscriptionIds != null && subscriptionIds.Any() && newSelectedSubscriptions.Count == 0)
{
newSelectedSubscriptions = subscriptionList;
}
lock (_selectedSubscriptionsLockObject)
{
if (!SelectedSubscriptionsEquals(newSelectedSubscriptions))
{
_selectedSubscriptions = newSelectedSubscriptions;
return true;
}
}
return false;
}
private bool SelectedSubscriptionsEquals(List newSelectedSubscriptions)
{
if (_selectedSubscriptions != null && _selectedSubscriptions.Count() == newSelectedSubscriptions.Count)
{
return newSelectedSubscriptions.All(subscription => _selectedSubscriptions.Contains(subscription));
}
return false;
}
///
/// Tries to find a subscription given subscription id
///
public bool TryParseSubscriptionIdentifier(string value, out IAzureSubscriptionIdentifier subscription)
{
// TODO can port this over from the VS implementation if needed, but for now disabling as we don't serialize / deserialize subscriptions
throw new NotImplementedException();
}
}
}