diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs index 52f55816..f9714bab 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Firewall/FirewallRuleService.cs @@ -158,7 +158,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core.Firewall if (subscriptions == null) { - throw new FirewallRuleException(SR.FirewallRuleCreationFailed); + throw new FirewallRuleException(SR.NoSubscriptionsFound); } ServiceResponse response = await AzureUtil.ExecuteGetAzureResourceAsParallel((object)null, diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs index 0064be60..c04bdeee 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs @@ -27,7 +27,15 @@ namespace Microsoft.SqlTools.ResourceProvider.Core Keys.Culture = value; } } - + + + public static string NoSubscriptionsFound + { + get + { + return Keys.GetString(Keys.NoSubscriptionsFound); + } + } public static string AzureServerNotFound { @@ -35,7 +43,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.AzureServerNotFound); } - } + } public static string AzureSubscriptionFailedErrorMessage { @@ -43,7 +51,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.AzureSubscriptionFailedErrorMessage); } - } + } public static string DatabaseDiscoveryFailedErrorMessage { @@ -51,7 +59,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.DatabaseDiscoveryFailedErrorMessage); } - } + } public static string FirewallRuleAccessForbidden { @@ -59,7 +67,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.FirewallRuleAccessForbidden); } - } + } public static string FirewallRuleCreationFailed { @@ -67,7 +75,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.FirewallRuleCreationFailed); } - } + } public static string FirewallRuleCreationFailedWithError { @@ -75,7 +83,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.FirewallRuleCreationFailedWithError); } - } + } public static string InvalidIpAddress { @@ -83,7 +91,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.InvalidIpAddress); } - } + } public static string InvalidServerTypeErrorMessage { @@ -91,7 +99,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.InvalidServerTypeErrorMessage); } - } + } public static string LoadingExportableFailedGeneralErrorMessage { @@ -99,7 +107,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.LoadingExportableFailedGeneralErrorMessage); } - } + } public static string FirewallRuleUnsupportedConnectionType { @@ -107,7 +115,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return Keys.GetString(Keys.FirewallRuleUnsupportedConnectionType); } - } + } [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys @@ -115,37 +123,40 @@ namespace Microsoft.SqlTools.ResourceProvider.Core static ResourceManager resourceManager = new ResourceManager("Microsoft.SqlTools.ResourceProvider.Core.Localization.SR", typeof(SR).GetTypeInfo().Assembly); static CultureInfo _culture = null; - - - public const string AzureServerNotFound = "AzureServerNotFound"; - - - public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; - - - public const string DatabaseDiscoveryFailedErrorMessage = "DatabaseDiscoveryFailedErrorMessage"; - - - public const string FirewallRuleAccessForbidden = "FirewallRuleAccessForbidden"; - - - public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; - - - public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; - - - public const string InvalidIpAddress = "InvalidIpAddress"; - - - public const string InvalidServerTypeErrorMessage = "InvalidServerTypeErrorMessage"; - - - public const string LoadingExportableFailedGeneralErrorMessage = "LoadingExportableFailedGeneralErrorMessage"; - - - public const string FirewallRuleUnsupportedConnectionType = "FirewallRuleUnsupportedConnectionType"; - + + + public const string NoSubscriptionsFound = "NoSubscriptionsFound"; + + + public const string AzureServerNotFound = "AzureServerNotFound"; + + + public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; + + + public const string DatabaseDiscoveryFailedErrorMessage = "DatabaseDiscoveryFailedErrorMessage"; + + + public const string FirewallRuleAccessForbidden = "FirewallRuleAccessForbidden"; + + + public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; + + + public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; + + + public const string InvalidIpAddress = "InvalidIpAddress"; + + + public const string InvalidServerTypeErrorMessage = "InvalidServerTypeErrorMessage"; + + + public const string LoadingExportableFailedGeneralErrorMessage = "LoadingExportableFailedGeneralErrorMessage"; + + + public const string FirewallRuleUnsupportedConnectionType = "FirewallRuleUnsupportedConnectionType"; + private Keys() { } @@ -166,7 +177,7 @@ namespace Microsoft.SqlTools.ResourceProvider.Core { return resourceManager.GetString(key, _culture); } - - } - } -} + + } + } +} diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx index 1f16faa4..55bc2567 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx @@ -117,44 +117,48 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + No subscriptions were found for the currently logged in user account. + + The server you specified {0} does not exist in any subscription in {1}. Either you have signed in with an incorrect account or your server was removed from subscription(s) in this account. Please check your account and try again. - + An error occurred while getting Azure subscriptions - + An error occurred while getting databases from servers of type {0} from {1} - + {0} does not have permission to change the server firewall rule. Try again with a different account that is an Owner or Contributor of the Azure subscription or the server. - + An error occurred while creating a new firewall rule. - + An error occurred while creating a new firewall rule: '{0}' - + Invalid IP address - + Server Type is invalid. - + A required dll cannot be loaded. Please repair your application. - + Cannot open a firewall rule for the specified connection type - - + + diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings index d0b41403..eab975f1 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings @@ -22,6 +22,7 @@ ############################################################################ # Azure Core DLL +NoSubscriptionsFound = No subscriptions were found for the currently logged in user account. AzureServerNotFound = The server you specified {0} does not exist in any subscription in {1}. Either you have signed in with an incorrect account or your server was removed from subscription(s) in this account. Please check your account and try again. AzureSubscriptionFailedErrorMessage = An error occurred while getting Azure subscriptions DatabaseDiscoveryFailedErrorMessage = An error occurred while getting databases from servers of type {0} from {1} diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf index 3e061f47..8479b2e8 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf @@ -52,6 +52,11 @@ An error occurred while creating a new firewall rule: '{0}' + + No subscriptions were found for the currently logged in user account. + No subscriptions were found for the currently logged in user account. + + \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs index 3c212593..7dd2d8ee 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs @@ -27,7 +27,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl "Microsoft.SqlTools.ResourceProvider.DefaultImpl.AzureAuthenticationManager", 1) ] - class AzureAuthenticationManager : ExportableBase, IAzureAuthenticationManager + public class AzureAuthenticationManager : ExportableBase, IAzureAuthenticationManager { private Dictionary accountsMap; private string currentAccountId = null; @@ -88,38 +88,49 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl public AzureUserAccount CreateUserAccount(AccountTokenWrapper accountTokenWrapper) { Account account = accountTokenWrapper.Account; - CommonUtil.CheckForNull(accountTokenWrapper.Account, nameof(account)); + CommonUtil.CheckForNull(accountTokenWrapper, nameof(accountTokenWrapper)); + CommonUtil.CheckForNull(account, nameof(account)); + CommonUtil.CheckForNull(account.Key, nameof(account) + ".Key"); CommonUtil.CheckForNull(accountTokenWrapper.SecurityTokenMappings, nameof(account) + ".SecurityTokenMappings"); AzureUserAccount userAccount = new AzureUserAccount(); userAccount.UniqueId = account.Key.AccountId; userAccount.DisplayInfo = ToDisplayInfo(account); - IList tenants = new List(); - foreach (Tenant tenant in account.Properties.Tenants) - { - AccountSecurityToken token; - if (accountTokenWrapper.SecurityTokenMappings.TryGetValue(tenant.Id, out token)) - { - AzureTenant azureTenant = new AzureTenant() - { - TenantId = tenant.Id, - AccountDisplayableId = tenant.DisplayName, - Resource = token.Resource, - AccessToken = token.Token, - TokenType = token.TokenType - }; - tenants.Add(azureTenant); - } - // else ignore for now as we can't handle a request to get a tenant without an access key - } - userAccount.AllTenants = tenants; + userAccount.NeedsReauthentication = account.IsStale; + userAccount.AllTenants = ProcessTenants(accountTokenWrapper, account); return userAccount; } + private static IList ProcessTenants(AccountTokenWrapper accountTokenWrapper, Account account) + { + IList tenants = new List(); + if (account.Properties != null && account.Properties.Tenants != null) + { + foreach (Tenant tenant in account.Properties.Tenants) + { + AccountSecurityToken token; + if (accountTokenWrapper.SecurityTokenMappings.TryGetValue(tenant.Id, out token)) + { + AzureTenant azureTenant = new AzureTenant() + { + TenantId = tenant.Id, + AccountDisplayableId = tenant.DisplayName, + Resource = token.Resource, + AccessToken = token.Token, + TokenType = token.TokenType + }; + tenants.Add(azureTenant); + } + // else ignore for now as we can't handle a request to get a tenant without an access key + } + } + return tenants; + } + private AzureUserAccountDisplayInfo ToDisplayInfo(Account account) { return new AzureUserAccountDisplayInfo() { - AccountDisplayName = account.DisplayInfo.DisplayName, + AccountDisplayName = account.DisplayInfo != null ? account.DisplayInfo.DisplayName : account.Key.AccountId, ProviderDisplayName = account.Key.ProviderId }; } diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs index 06e5a8bd..54c06d5d 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs @@ -20,7 +20,6 @@ using Microsoft.Rest; using System.Globalization; using Microsoft.Rest.Azure; using Microsoft.SqlTools.ResourceProvider.Core; -using System.Collections; using System.Threading; namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceWrapper.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceWrapper.cs index eda2ec1f..8943dcf9 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceWrapper.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceWrapper.cs @@ -81,7 +81,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { this.resourceGroupName = ParseResourceGroupNameFromId(); } - return this.resourceGroupName; + return this.resourceGroupName ?? string.Empty; } set { diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs index a7abd08c..44f60907 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs @@ -27,7 +27,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl Keys.Culture = value; } } - + public static string FailedToGetAzureDatabasesErrorMessage { @@ -35,7 +35,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FailedToGetAzureDatabasesErrorMessage); } - } + } public static string FailedToGetAzureSubscriptionsErrorMessage { @@ -43,7 +43,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FailedToGetAzureSubscriptionsErrorMessage); } - } + } public static string FailedToGetAzureResourceGroupsErrorMessage { @@ -51,7 +51,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FailedToGetAzureResourceGroupsErrorMessage); } - } + } public static string FailedToGetAzureSqlServersErrorMessage { @@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FailedToGetAzureSqlServersErrorMessage); } - } + } public static string FailedToGetAzureSqlServersWithError { @@ -67,7 +67,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FailedToGetAzureSqlServersWithError); } - } + } public static string FirewallRuleCreationFailed { @@ -75,7 +75,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FirewallRuleCreationFailed); } - } + } public static string FirewallRuleCreationFailedWithError { @@ -83,7 +83,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.FirewallRuleCreationFailedWithError); } - } + } public static string AzureSubscriptionFailedErrorMessage { @@ -91,7 +91,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.AzureSubscriptionFailedErrorMessage); } - } + } public static string UnsupportedAuthType { @@ -99,7 +99,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return Keys.GetString(Keys.UnsupportedAuthType); } - } + } [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys @@ -107,34 +107,34 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl static ResourceManager resourceManager = new ResourceManager("Microsoft.SqlTools.ResourceProvider.DefaultImpl.Localization.SR", typeof(SR).GetTypeInfo().Assembly); static CultureInfo _culture = null; - - - public const string FailedToGetAzureDatabasesErrorMessage = "FailedToGetAzureDatabasesErrorMessage"; - - - public const string FailedToGetAzureSubscriptionsErrorMessage = "FailedToGetAzureSubscriptionsErrorMessage"; - - - public const string FailedToGetAzureResourceGroupsErrorMessage = "FailedToGetAzureResourceGroupsErrorMessage"; - - - public const string FailedToGetAzureSqlServersErrorMessage = "FailedToGetAzureSqlServersErrorMessage"; - - - public const string FailedToGetAzureSqlServersWithError = "FailedToGetAzureSqlServersWithError"; - - - public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; - - - public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; - - - public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; - - - public const string UnsupportedAuthType = "UnsupportedAuthType"; - + + + public const string FailedToGetAzureDatabasesErrorMessage = "FailedToGetAzureDatabasesErrorMessage"; + + + public const string FailedToGetAzureSubscriptionsErrorMessage = "FailedToGetAzureSubscriptionsErrorMessage"; + + + public const string FailedToGetAzureResourceGroupsErrorMessage = "FailedToGetAzureResourceGroupsErrorMessage"; + + + public const string FailedToGetAzureSqlServersErrorMessage = "FailedToGetAzureSqlServersErrorMessage"; + + + public const string FailedToGetAzureSqlServersWithError = "FailedToGetAzureSqlServersWithError"; + + + public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; + + + public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; + + + public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; + + + public const string UnsupportedAuthType = "UnsupportedAuthType"; + private Keys() { } @@ -155,7 +155,7 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl { return resourceManager.GetString(key, _culture); } - - } - } -} + + } + } +} diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx index 2130cb0b..81c293a4 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx @@ -120,37 +120,37 @@ An error occurred while getting Azure databases - + An error occurred while getting Azure subscriptions: {0} - + An error occurred while getting Azure resource groups: {0} - + An error occurred while getting Azure Sql Servers - + An error occurred while getting Azure Sql Servers: '{0}' - + An error occurred while creating a new firewall rule. - + An error occurred while creating a new firewall rule: '{0}' - + An error occurred while getting Azure subscriptions - + Unsupported account type '{0}' for this provider - - + + diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs new file mode 100644 index 00000000..a685ee32 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs @@ -0,0 +1,113 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.ResourceProvider.Core; +using Microsoft.SqlTools.ResourceProvider.Core.Authentication; +using Microsoft.SqlTools.ResourceProvider.Core.Contracts; +using Microsoft.SqlTools.ResourceProvider.DefaultImpl; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider.Azure +{ + public class AzureAuthenticationManagerTest + { + private Mock resourceManagerMock; + private RegisteredServiceProvider serviceProvider; + + public AzureAuthenticationManagerTest() + { + resourceManagerMock = new Mock(); + serviceProvider = new RegisteredServiceProvider(); + serviceProvider.RegisterSingleService(resourceManagerMock.Object); + } + + [Fact] + public async Task CurrentUserShouldBeNullWhenUserIsNotSignedIn() + { + IAzureAuthenticationManager accountManager = await CreateAccountManager(null, null); + Assert.Null(await accountManager.GetCurrentAccountAsync()); + } + + [Fact] + public async Task GetSubscriptionShouldReturnEmptyWhenUserIsNotSignedIn() + { + IAzureAuthenticationManager accountManager = await CreateAccountManager(null, null); + IEnumerable result = + await accountManager.GetSelectedSubscriptionsAsync(); + Assert.False(result.Any()); + } + + [Fact] + public async Task GetSubscriptionShouldThrowWhenUserNeedsAuthentication() + { + var currentUserAccount = CreateAccount(); + currentUserAccount.Account.IsStale = true; + IAzureAuthenticationManager accountManager = await CreateAccountManager(currentUserAccount, null); + await Assert.ThrowsAsync(() => accountManager.GetSelectedSubscriptionsAsync()); + } + + [Fact] + public async Task GetSubscriptionShouldThrowIfFailed() + { + var currentUserAccount = CreateAccount(); + IAzureAuthenticationManager accountManager = await CreateAccountManager(currentUserAccount, null, true); + await Assert.ThrowsAsync(() => accountManager.GetSelectedSubscriptionsAsync()); + } + + [Fact] + public async Task GetSubscriptionShouldReturnTheListSuccessfully() + { + List subscriptions = new List { + new Mock().Object + }; + var currentUserAccount = CreateAccount(); + IAzureAuthenticationManager accountManager = await CreateAccountManager(currentUserAccount, subscriptions, false); + IEnumerable result = + await accountManager.GetSelectedSubscriptionsAsync(); + Assert.True(result.Any()); + } + + private AccountTokenWrapper CreateAccount(bool needsReauthentication = false) + { + return new AccountTokenWrapper(new Account() + { + Key = new AccountKey() + { + AccountId = "MyAccount", + ProviderId = "MSSQL" + }, + IsStale = needsReauthentication + }, + new Dictionary()); + } + private async Task CreateAccountManager(AccountTokenWrapper currentAccount, + IEnumerable subscriptions, bool shouldFail = false) + { + AzureAuthenticationManager azureAuthenticationManager = new AzureAuthenticationManager(); + azureAuthenticationManager.SetServiceProvider(serviceProvider); + if (currentAccount != null) + { + await azureAuthenticationManager.SetCurrentAccountAsync(currentAccount); + } + + if (!shouldFail) + { + resourceManagerMock.Setup(x => x.GetSubscriptionContextsAsync(It.IsAny())).Returns(Task.FromResult(subscriptions)); + } + else + { + resourceManagerMock.Setup(x => x.GetSubscriptionContextsAsync(It.IsAny())).Throws(new Exception()); + } + + return azureAuthenticationManager; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureResourceWrapperTest.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureResourceWrapperTest.cs new file mode 100644 index 00000000..5b0df2a1 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureResourceWrapperTest.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.Azure.Management.Sql.Models; +using Microsoft.SqlTools.ResourceProvider.DefaultImpl; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider.Azure +{ + public class AzureResourceWrapperTest + { + [Fact] + public void ShouldParseResourceGroupFromId() + { + // Given a resource with a known resource group + TrackedResource trackedResource = CreateMockResource( + "/subscriptions/aaaaaaaa-1234-cccc-dddd-a1234v12c23/resourceGroups/myresourcegroup/providers/Microsoft.Sql/servers/my-server", + "my-server", + "Microsoft.Sql"); + + // When I get the resource group name + AzureResourceWrapper resource = new AzureResourceWrapper(trackedResource); + string rgName = resource.ResourceGroupName; + + // then I get it as expected + Assert.Equal("myresourcegroup", rgName); + } + + [Fact] + public void ShouldHandleMissingResourceGroup() + { + // Given a resource without resource group in the ID + TrackedResource trackedResource = CreateMockResource( + "/subscriptions/aaaaaaaa-1234-cccc-dddd-a1234v12c23", + "my-server", + "Microsoft.Sql"); + + // When I get the resource group name + AzureResourceWrapper resource = new AzureResourceWrapper(trackedResource); + string rgName = resource.ResourceGroupName; + + // then I get string.Empty + Assert.Equal(string.Empty, rgName); + } + + private TrackedResource CreateMockResource(string id = null, string name = null, string type = null) + { + return new TrackedResource("Somewhere", id, name, type); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureSqlServerDiscoveryProviderTest.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureSqlServerDiscoveryProviderTest.cs index 820041b4..83b993dd 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureSqlServerDiscoveryProviderTest.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureSqlServerDiscoveryProviderTest.cs @@ -45,7 +45,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider.Azure } [Fact] - public async Task GetShouldReturnEmptyGivenNotSubscriptionFound() + public async Task GetShouldReturnEmptyGivenNoSubscriptionFound() { Dictionary> subscriptionToDatabaseMap = new Dictionary>(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs index 17082c65..832fd91a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs @@ -3,17 +3,30 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using Microsoft.SqlTools.Hosting.Protocol; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; using Microsoft.SqlTools.Extensibility; -using Microsoft.SqlTools.ResourceProvider.Core; -using Moq; +using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ResourceProvider; +using Microsoft.SqlTools.ResourceProvider.Core; using Microsoft.SqlTools.ResourceProvider.Core.Authentication; +using Microsoft.SqlTools.ResourceProvider.Core.Contracts; +using Microsoft.SqlTools.ResourceProvider.Core.Firewall; +using Microsoft.SqlTools.ResourceProvider.DefaultImpl; +using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using Moq; +using Xunit; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Formatter { public class ResourceProviderServiceTests { + private const int SqlAzureFirewallBlockedErrorNumber = 40615; + private const int SqlAzureLoginFailedErrorNumber = 18456; + private string errorMessageWithIp = "error Message with 1.2.3.4 as IP address"; + public ResourceProviderServiceTests() { HostMock = new Mock(); @@ -34,7 +47,194 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Formatter protected ResourceProviderService ResourceProviderService { get; private set; } - + [Fact] + public async Task TestHandleFirewallRuleIgnoresNonMssqlProvider() + { + // Given a non-MSSQL provider + var handleFirewallParams = new HandleFirewallRuleParams() + { + ErrorCode = SqlAzureFirewallBlockedErrorNumber, + ErrorMessage = errorMessageWithIp, + ConnectionTypeId = "Other" + }; + // When I ask whether the service can process an error as a firewall rule request + await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => + { + // Then I expect the response to be false and no IP information to be sent + Assert.NotNull(response); + Assert.False(response.Result); + Assert.Null(response.IpAddress); + Assert.Equal(Microsoft.SqlTools.ResourceProvider.Core.SR.FirewallRuleUnsupportedConnectionType, response.ErrorMessage); + }); + } + [Fact] + public async Task TestHandleFirewallRuleSupportsMssqlProvider() + { + // Given a firewall error for the MSSQL provider + var handleFirewallParams = new HandleFirewallRuleParams() + { + ErrorCode = SqlAzureFirewallBlockedErrorNumber, + ErrorMessage = errorMessageWithIp, + ConnectionTypeId = "MSSQL" + }; + // When I ask whether the service can process an error as a firewall rule request + await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => + { + // Then I expect the response to be true and the IP address to be extracted + Assert.NotNull(response); + Assert.True(response.Result); + Assert.Equal("1.2.3.4", response.IpAddress); + Assert.Null(response.ErrorMessage); + }); + } + + [Fact] + public async Task TestHandleFirewallRuleIgnoresNonFirewallErrors() + { + // Given a login error for the MSSQL provider + var handleFirewallParams = new HandleFirewallRuleParams() + { + ErrorCode = SqlAzureLoginFailedErrorNumber, + ErrorMessage = errorMessageWithIp, + ConnectionTypeId = "MSSQL" + }; + // When I ask whether the service can process an error as a firewall rule request + await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => + { + // Then I expect the response to be false and no IP address to be defined + Assert.NotNull(response); + Assert.False(response.Result); + Assert.Equal(string.Empty, response.IpAddress); + Assert.Null(response.ErrorMessage); + }); + } + + [Fact] + public async Task TestHandleFirewallRuleDoesntBreakWithoutIp() + { + // Given a firewall error with no IP address in the error message + var handleFirewallParams = new HandleFirewallRuleParams() + { + ErrorCode = SqlAzureFirewallBlockedErrorNumber, + ErrorMessage = "No IP here", + ConnectionTypeId = "MSSQL" + }; + // When I ask whether the service can process an error as a firewall rule request + await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => + { + // Then I expect the response to be fakse as we require the known IP address to function + Assert.NotNull(response); + Assert.False(response.Result); + Assert.Equal(string.Empty, response.IpAddress); + Assert.Null(response.ErrorMessage); + }); + } + + [Fact] + public async Task TestCreateFirewallRuleBasicRequest() + { + // Given a firewall request for a valid subscription + string serverName = "myserver.database.windows.net"; + var sub1Mock = new Mock(); + var sub2Mock = new Mock(); + var server = new SqlAzureResource(new Azure.Management.Sql.Models.Server("Somewhere", + "1234", "myserver", "SQLServer", + null, null, null, null, null, null, null, + fullyQualifiedDomainName: serverName)); + var subsToServers = new List>>() + { + Tuple.Create(sub1Mock.Object, Enumerable.Empty()), + Tuple.Create(sub2Mock.Object, new IAzureSqlServerResource[] { server }.AsEnumerable()) + }; + var azureRmResponse = new FirewallRuleResponse() + { + Created = true, + StartIpAddress = null, + EndIpAddress = null + }; + SetupDependencies(subsToServers, azureRmResponse); + + // When I request the firewall be created + var createFirewallParams = new CreateFirewallRuleParams() + { + ServerName = serverName, + StartIpAddress = "1.1.1.1", + EndIpAddress = "1.1.1.255", + Account = CreateAccount(), + SecurityTokenMappings = new Dictionary() + }; + await TestUtils.RunAndVerify( + (context) => ResourceProviderService.HandleCreateFirewallRuleRequest(createFirewallParams, context), + (response) => + { + // Then I expect the response to be fakse as we require the known IP address to function + Assert.NotNull(response); + Assert.Null(response.ErrorMessage); + Assert.True(response.Result); + }); + } + + private void SetupDependencies( + IList>> subsToServers, + FirewallRuleResponse response) + { + SetupCreateSession(); + SetupReturnsSubscriptions(subsToServers.Select(s => s.Item1)); + foreach(var s in subsToServers) + { + SetupAzureServers(s.Item1, s.Item2); + } + SetupFirewallResponse(response); + } + + private void SetupReturnsSubscriptions(IEnumerable subs) + { + AuthenticationManagerMock.Setup(a => a.GetSubscriptionsAsync()).Returns(() => Task.FromResult(subs)); + } + + private void SetupCreateSession() + { + ResourceManagerMock.Setup(r => r.CreateSessionAsync(It.IsAny())) + .Returns((IAzureUserAccountSubscriptionContext sub) => + { + var sessionMock = new Mock(); + sessionMock.SetupProperty(s => s.SubscriptionContext, sub); + return Task.FromResult(sessionMock.Object); + }); + } + + private void SetupAzureServers(IAzureSubscriptionContext sub, IEnumerable servers) + { + Func isExpectedSub = (session) => + { + return session.SubscriptionContext == sub; + }; + ResourceManagerMock.Setup(r => r.GetSqlServerAzureResourcesAsync( + It.Is((session) => isExpectedSub(session)) + )).Returns(() => Task.FromResult(servers)); + } + + private void SetupFirewallResponse(FirewallRuleResponse response) + { + ResourceManagerMock.Setup(r => r.CreateFirewallRuleAsync( + It.IsAny(), + It.IsAny(), + It.IsAny()) + ).Returns(() => Task.FromResult(response)); + } + + private Account CreateAccount(bool needsReauthentication = false) + { + return new Account() + { + Key = new AccountKey() + { + AccountId = "MyAccount", + ProviderId = "MSSQL" + }, + IsStale = needsReauthentication + }; + } } } \ No newline at end of file