Introduce AAD interactive auth mode (#1860)

This commit is contained in:
Cheena Malhotra
2023-03-02 09:39:54 -08:00
committed by GitHub
parent 98e50c98fe
commit 187b6ecc14
47 changed files with 918 additions and 151 deletions

View File

@@ -15,6 +15,8 @@
<PackageReference Update="Microsoft.Extensions.DependencyModel" Version="3.1.4" />
<PackageReference Update="Microsoft.Extensions.FileSystemGlobbing" Version="7.0.0" />
<PackageReference Update="Microsoft.Identity.Client" Version="4.49.1" />
<PackageReference Update="Microsoft.Identity.Client.Extensions.Msal" Version="2.25.3" />
<PackageReference Update="Microsoft.Azure.Management.ResourceManager" Version="3.7.1-preview" />
<PackageReference Update="Microsoft.Azure.Management.Sql" Version="1.41.0-preview" />

View File

@@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.29409.12
# Visual Studio Version 17
VisualStudioVersion = 17.5.33209.295
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{2BBD7364-054F-4693-97CD-1C395E3E84A9}"
ProjectSection(SolutionItems) = preProject
@@ -96,6 +96,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migratio
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migration.IntegrationTests", "test\Microsoft.SqlTools.Migration.IntegrationTests\Microsoft.SqlTools.Migration.IntegrationTests.csproj", "{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Authentication", "src\Microsoft.SqlTools.Authentication\Microsoft.SqlTools.Authentication.csproj", "{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Shared", "src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj", "{531EC0E0-F400-42C5-BCFA-6691313B5F3E}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -229,6 +233,18 @@ Global
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Integration|Any CPU.Build.0 = Debug|Any CPU
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.Build.0 = Release|Any CPU
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.ActiveCfg = Debug|Any CPU
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.Build.0 = Debug|Any CPU
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.Build.0 = Release|Any CPU
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.ActiveCfg = Debug|Any CPU
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.Build.0 = Debug|Any CPU
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -256,6 +272,8 @@ Global
{07296730-DAB7-4B0B-9D09-ABD9A5025D68} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
{22DB0C12-6848-4503-AD1C-DAD6A1D631AE} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4}
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
{531EC0E0-F400-42C5-BCFA-6691313B5F3E} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {B31CDF4B-2851-45E5-8C5F-BE97125D9DD8}

View File

@@ -4,6 +4,7 @@
//
using Microsoft.SqlTools.Hosting.Contracts;
using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.Connection
{
@@ -50,7 +51,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection
CategoryValues = new CategoryValue[]
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA }
},
IsIdentity = true,
IsRequired = true,

View File

@@ -19,6 +19,7 @@ using System.Diagnostics;
using Microsoft.Kusto.ServiceLayer.DataSource;
using Microsoft.Kusto.ServiceLayer.DataSource.Metadata;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.Connection
{
@@ -916,7 +917,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection
return new ConnectionDetails
{
ApplicationName = builder.ApplicationNameForTracing,
AuthenticationType = "AzureMFA",
AuthenticationType = AzureMFA,
DatabaseName = builder.InitialCatalog,
ServerName = builder.DataSource,
UserName = builder.UserID,

View File

@@ -20,6 +20,7 @@ using Microsoft.Kusto.ServiceLayer.LanguageServices;
using Microsoft.Kusto.ServiceLayer.Utility;
using Microsoft.Kusto.ServiceLayer.Workspace.Contracts;
using CompletionItem = Microsoft.Kusto.ServiceLayer.LanguageServices.Contracts.CompletionItem;
using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.DataSource
{
@@ -63,7 +64,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource
private DataSourceConnectionDetails MapKustoConnectionDetails(ConnectionDetails connectionDetails)
{
if (connectionDetails.AuthenticationType == "dstsAuth" || connectionDetails.AuthenticationType == "AzureMFA")
if (connectionDetails.AuthenticationType == dstsAuth || connectionDetails.AuthenticationType == AzureMFA)
{
ValidationUtils.IsTrue<ArgumentException>(!string.IsNullOrWhiteSpace(connectionDetails.AccountToken),
$"The Kusto User Token is not specified - set {nameof(connectionDetails.AccountToken)}");

View File

@@ -21,6 +21,7 @@ using Kusto.Language.Editor;
using Microsoft.Kusto.ServiceLayer.Connection;
using Microsoft.Kusto.ServiceLayer.DataSource.Contracts;
using Microsoft.Kusto.ServiceLayer.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
{
@@ -74,7 +75,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
ServerName = ClusterName,
DatabaseName = DatabaseName,
UserToken = accountToken,
AuthenticationType = "AzureMFA"
AuthenticationType = AzureMFA
};
Initialize(connectionDetails);
@@ -96,8 +97,8 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
switch (connectionDetails.AuthenticationType)
{
case "AzureMFA": return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken);
case "dstsAuth": return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken);
case AzureMFA: return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken);
case dstsAuth: return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken);
default:
return string.IsNullOrWhiteSpace(connectionDetails.UserName) && string.IsNullOrWhiteSpace(connectionDetails.Password)
? stringBuilder

View File

@@ -51,6 +51,7 @@
<ProjectReference Include="../Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj" />
<ProjectReference Include="../Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj" />
<ProjectReference Include="../Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj" />
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
<ItemGroup>
<EmbeddedResource Include="ObjectExplorer\DataSourceModel\TreeNodeDefinition.xml" />

View File

@@ -41,7 +41,7 @@ namespace Microsoft.Kusto.ServiceLayer
logFilePath = Logger.GenerateLogFilePath("kustoservice");
}
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog);
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog);
// set up the host details and profile paths
var hostDetails = new HostDetails(version: new Version(1, 0));

View File

@@ -0,0 +1,33 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Represents an access token data object.
/// </summary>
public class AccessToken
{
/// <summary>
/// OAuth 2.0 JWT encoded access token string
/// </summary>
public string Token { get; set; }
/// <summary>
/// Expiry date of token
/// </summary>
public DateTimeOffset ExpiresOn { get; set; }
/// <summary>
/// Default constructor for Access Token object
/// </summary>
/// <param name="token">Access token as string</param>
/// <param name="expiresOn">Expiry date</param>
public AccessToken(string token, DateTimeOffset expiresOn) {
this.Token = token;
this.ExpiresOn = expiresOn;
}
}
}

View File

@@ -0,0 +1,15 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Supported Active Directory authentication modes
/// </summary>
public enum AuthenticationMethod
{
ActiveDirectoryInteractive
}
}

View File

@@ -0,0 +1,82 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Parameters to be passed to <see cref="Authenticator"/> to request an access token
/// </summary>
public class AuthenticationParams
{
/// <summary>
/// Authentication method to be used by <see cref="Authenticator"/>.
/// </summary>
public AuthenticationMethod AuthenticationMethod { get; set; }
/// <summary>
/// Authority URL, e.g. https://login.microsoftonline.com/
/// </summary>
public string Authority { get; set; }
/// <summary>
/// Audience for which access token should be acquired, e.g. common, organizations, consumers.
/// It can also be a tenant Id when authenticating multi-tenant application accounts.
/// </summary>
public string Audience { get; set; }
/// <summary>
/// Resource URL, e.g. https://database.windows.net/
/// </summary>
public string Resource { get; set; }
/// <summary>
/// Array of scopes for which access token is requested.
/// </summary>
public string[] Scopes { get; set; }
/// <summary>
/// <see cref="Guid"/> Connection Id, that will be passed to Azure AD when requesting access token.
/// It can be used for tracking AAD request status if needed.
/// </summary>
public Guid ConnectionId { get; set; }
/// <summary>
/// User name to be provided as userhint when acquiring access token.
/// </summary>
public string UserName { get; set; }
/// <summary>
/// Default constructor
/// </summary>
/// <param name="authMethod">Authentication Method to be used.</param>
/// <param name="authority">Authority URL</param>
/// <param name="audience">Audience</param>
/// <param name="resource">Resource for which token is requested.</param>
/// <param name="scopes">Scopes for access token</param>
/// <param name="userName">User hint information</param>
/// <param name="connectionId">Connection Id for tracing AAD request</param>
public AuthenticationParams(AuthenticationMethod authMethod, string authority, string audience,
string resource, string[] scopes, string userName, Guid connectionId) {
this.AuthenticationMethod = authMethod;
this.Authority = authority;
this.Audience = audience;
this.Resource = resource;
this.Scopes = scopes;
this.UserName = userName;
this.ConnectionId = connectionId;
}
public string ToLogString(bool piiEnabled)
{
return $"\tAuthenticationMethod: {AuthenticationMethod}" +
$"\n\tAuthority: {Authority}" +
$"\n\tAudience: {Audience}" +
$"\n\tResource: {Resource}" +
$"\n\tScopes: {Scopes.ToString()}" +
(piiEnabled ? $"\n\tUsername: {UserName}" : "") +
$"\n\tConnection Id: {ConnectionId}";
}
}
}

View File

@@ -0,0 +1,158 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Collections.Concurrent;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Extensions.Msal;
using Microsoft.SqlTools.Authentication.Utility;
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Provides APIs to acquire access token using MSAL.NET v4 with provided <see cref="AuthenticationParams"/>.
/// </summary>
public class Authenticator
{
private string applicationClientId;
private string applicationName;
private string cacheFolderPath;
private string cacheFileName;
private MsalCacheHelper cacheHelper;
private static ConcurrentDictionary<string, IPublicClientApplication> PublicClientAppMap
= new ConcurrentDictionary<string, IPublicClientApplication>();
#region Public APIs
public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName)
{
this.applicationClientId = appClientId;
this.applicationName = appName;
this.cacheFolderPath = cacheFolderPath;
this.cacheFileName = cacheFileName;
// Storage creation properties are used to enable file system caching with MSAL
var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath)
.WithUnprotectedFile().Build();
// This hooks up the cross-platform cache into MSAL
this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult();
}
/// <summary>
/// Acquires access token synchronously.
/// </summary>
/// <param name="params">Authentication parameters to be used for access token request.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Access Token with expiry date</returns>
public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken) =>
GetTokenAsync(@params, cancellationToken).GetAwaiter().GetResult();
/// <summary>
/// Acquires access token asynchronously.
/// </summary>
/// <param name="params">Authentication parameters to be used for access token request.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Access Token with expiry date</returns>
/// <exception cref="Exception"></exception>
public async Task<AccessToken?> GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken)
{
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Received @params: {@params.ToLogString(SqlToolsLogger.IsPiiEnabled)}");
IPublicClientApplication publicClientApplication = GetPublicClientAppInstance(@params.Authority, @params.Audience);
AccessToken? accessToken;
if (@params.AuthenticationMethod == AuthenticationMethod.ActiveDirectoryInteractive)
{
// Find account
IEnumerator<IAccount>? accounts = (await publicClientApplication.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
IAccount? account = default;
if (!string.IsNullOrEmpty(@params.UserName) && accounts.MoveNext())
{
// Handle username format to extract email: "John Doe - johndoe@constoso.com"
string username = @params.UserName.Contains(" - ") ? @params.UserName.Split(" - ")[1] : @params.UserName;
if (!Utils.isValidEmail(username))
{
SqlToolsLogger.Pii($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Unexpected username format, email not retreived: {@params.UserName}. " +
$"Accepted formats are: 'johndoe@org.com' or 'John Doe - johndoe@org.com'.");
throw new Exception($"Invalid email address format for user: [{username}] received for Azure Active Directory authentication.");
}
do
{
IAccount? currentVal = accounts.Current;
if (string.Compare(username, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
{
account = currentVal;
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account found in MSAL Cache: {account.HomeAccountId}");
break;
}
}
while (accounts.MoveNext());
if (null != account)
{
try
{
// Fetch token silently
var result = await publicClientApplication.AcquireTokenSilent(@params.Scopes, account)
.ExecuteAsync(cancellationToken: cancellationToken)
.ConfigureAwait(false);
accessToken = new AccessToken(result!.AccessToken, result!.ExpiresOn);
}
catch (Exception e)
{
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Silent authentication failed for resource {@params.Resource} for ConnectionId {@params.ConnectionId}.");
SqlToolsLogger.Error(e);
throw;
}
}
else
{
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Account not found in MSAL cache for user.");
throw new Exception($"User account '{username}' not found in MSAL cache, please add linked account or refresh account credentials.");
}
}
else
{
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account not received.");
throw new Exception($"User account not received.");
}
}
else
{
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Authentication Method ${@params.AuthenticationMethod} is not supported.");
throw new Exception($"Authentication Method ${@params.AuthenticationMethod} is not supported.");
}
return accessToken;
}
#endregion
#region Private methods
private IPublicClientApplication GetPublicClientAppInstance(string authority, string audience)
{
string authorityUrl = authority + '/' + audience;
if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance))
{
clientApplicationInstance = CreatePublicClientAppInstance(authority, audience);
this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache);
PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance);
}
return clientApplicationInstance;
}
private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) =>
PublicClientApplicationBuilder.Create(this.applicationClientId)
.WithAuthority(authority, audience)
.WithClientName(this.applicationName)
.WithLogging(Utils.MSALLogCallback)
.WithDefaultRedirectUri()
.Build();
#endregion
}
}

View File

@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient" />
<PackageReference Include="Microsoft.Identity.Client" />
<PackageReference Include="Microsoft.Identity.Client.Extensions.Msal" />
<PackageReference Include="System.Configuration.ConfigurationManager" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,110 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.Data.SqlClient;
using Microsoft.SqlTools.Authentication.Utility;
namespace Microsoft.SqlTools.Authentication.Sql
{
/// <summary>
/// Provides an implementation of <see cref="SqlAuthenticationProvider"/> for SQL Tools to be able to perform Federated authentication
/// silently with Microsoft.Data.SqlClient integration only for "ActiveDirectory" authentication modes.
/// When registered, the SqlClient driver calls the <see cref="AcquireTokenAsync(SqlAuthenticationParameters)"/> API
/// with server-sent authority information to request access token when needed.
/// </summary>
public class AuthenticationProvider : SqlAuthenticationProvider
{
private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255";
private const string AzureTokenFolder = "Azure Accounts";
private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud";
private const string s_defaultScopeSuffix = "/.default";
private Authenticator authenticator;
/// <summary>
/// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated.
/// </summary>
/// <param name="applicationName">Application Name that identifies user folder path location for reading/writing to shared cache.</param>
/// <param name="authCallback">Callback that handles AAD authentication when user interaction is needed.</param>
public AuthenticationProvider(string applicationName)
{
if(string.IsNullOrEmpty(applicationName)) {
applicationName = nameof(SqlTools);
}
var cachePath = Path.Combine(Utils.BuildAppDirectoryPath(), applicationName, AzureTokenFolder);
this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName);
}
/// <summary>
/// Acquires access token with provided <paramref name="parameters"/>
/// </summary>
/// <param name="parameters">Authentication parameters</param>
/// <returns>Authentication token containing access token and expiry date.</returns>
public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters)
{
// Setup scope
string resource = parameters.Resource;
string scope;
if (parameters.Resource.EndsWith(s_defaultScopeSuffix))
{
scope = parameters.Resource;
resource = parameters.Resource.Substring(0, parameters.Resource.LastIndexOf('/') + 1);
}
else
{
scope = parameters.Resource + s_defaultScopeSuffix;
}
string[] scopes = new string[] { scope };
CancellationTokenSource cts = new CancellationTokenSource();
// Use Connection timeout value to cancel token acquire request after certain period of time.
cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds
/* We split audience from Authority URL here. Audience can be one of the following:
* The Azure AD authority audience enumeration
* The tenant ID, which can be:
* - A GUID (the ID of your Azure AD instance), for single-tenant applications
* - A domain name associated with your Azure AD instance (also for single-tenant applications)
* One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
* - `organizations` for a multitenant application
* - `consumers` to sign in users only with their personal accounts
* - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
*
* MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
* If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
**/
int seperatorIndex = parameters.Authority.LastIndexOf('/');
string authority = parameters.Authority.Remove(seperatorIndex + 1);
string audience = parameters.Authority.Substring(seperatorIndex + 1);
string? userName = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;
AuthenticationParams @params = new AuthenticationParams(
AuthenticationMethod.ActiveDirectoryInteractive,
authority,
audience,
resource,
scopes,
userName!,
parameters.ConnectionId);
AccessToken? result = await authenticator.GetTokenAsync(@params, cts.Token).ConfigureAwait(false);
return new SqlAuthenticationToken(result!.Token, result!.ExpiresOn);
}
/// <summary>
/// Whether or not provided <paramref name="authenticationMethod"/> is supported.
/// </summary>
/// <param name="authenticationMethod">SQL Authentication method</param>
/// <returns></returns>
public override bool IsSupported(SqlAuthenticationMethod authenticationMethod)
=> authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive;
}
}

View File

@@ -0,0 +1,119 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Net.Mail;
using System.Runtime.InteropServices;
using Microsoft.Identity.Client;
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
namespace Microsoft.SqlTools.Authentication.Utility
{
internal sealed class Utils
{
/// <summary>
/// Validates provided <paramref name="userEmail"/> follows email format.
/// </summary>
/// <param name="useremail">Email address</param>
/// <returns>Whether email is in correct format.</returns>
public static bool isValidEmail(string userEmail)
{
try
{
new MailAddress(userEmail);
return true;
}
catch (FormatException)
{
return false;
}
}
/// <summary>
/// Builds directory path based on environment settings.
/// </summary>
/// <returns>Application directory path</returns>
/// <exception cref="Exception">When called on unsupported platform.</exception>
public static string BuildAppDirectoryPath()
{
var homedir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
// Windows
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
var appData = Environment.GetEnvironmentVariable("APPDATA");
var userProfile = Environment.GetEnvironmentVariable("USERPROFILE");
if (appData != null)
{
return appData;
}
else if (userProfile != null)
{
return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming");
}
else
{
throw new Exception("Not able to find APPDATA or USERPROFILE");
}
}
// Mac
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
return string.Join(homedir, "Library", "Application Support");
}
// Linux
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME");
if (xdgConfigHome != null)
{
return xdgConfigHome;
}
else
{
return string.Join(homedir, ".config");
}
}
else
{
throw new Exception("Platform not supported");
}
}
/// <summary>
/// Log callback handler used for MSAL Client applications.
/// </summary>
/// <param name="logLevel">Log level</param>
/// <param name="message">Log message</param>
/// <param name="pii">Whether message contains PII information.</param>
public static void MSALLogCallback(LogLevel logLevel, string message, bool pii)
{
switch (logLevel)
{
case LogLevel.Error:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Error(message);
break;
case LogLevel.Warning:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Warning(message);
break;
case LogLevel.Info:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Information(message);
break;
case LogLevel.Verbose:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Verbose(message);
break;
case LogLevel.Always:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Critical(message);
break;
}
}
}
}

View File

@@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Credentials
logFilePath = Logger.GenerateLogFilePath("credentials");
}
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog);
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog);
// set up the host details and profile paths
var hostDetails = new HostDetails(

View File

@@ -25,4 +25,7 @@
<EmbeddedResource Include="Localization\*.resx" />
<None Include="Localization\sr.strings" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
</Project>

View File

@@ -28,6 +28,7 @@ namespace Microsoft.SqlTools.Utility
ServiceName = serviceName;
ErrorMessage = string.Empty;
Locale = string.Empty;
ApplicationName = string.Empty;
try
{
@@ -42,12 +43,18 @@ namespace Microsoft.SqlTools.Utility
switch (argName)
{
case "-application-name":
ApplicationName = args[++i];
break;
case "-autoflush-log":
AutoFlushLog = true;
break;
case "-tracing-level":
TracingLevel = args[++i];
break;
case "-pii-logging":
PiiLogging = true;
break;
case "-log-file":
LogFilePath = args[++i];
break;
@@ -66,6 +73,9 @@ namespace Microsoft.SqlTools.Utility
case "-parallel-message-processing":
ParallelMessageProcessing = true;
break;
case "-enable-sql-authentication-provider":
EnableSqlAuthenticationProvider = true;
break;
case "-parent-pid":
string nextArg = args[++i];
if (Int32.TryParse(nextArg, out int parsedInt))
@@ -95,6 +105,11 @@ namespace Microsoft.SqlTools.Utility
}
}
/// <summary>
/// Name of application that is sending command options
/// </summary>
public string ApplicationName { get; private set; }
/// <summary>
/// Contains any error messages during execution
/// </summary>
@@ -137,6 +152,8 @@ namespace Microsoft.SqlTools.Utility
public string TracingLevel { get; private set; }
public bool PiiLogging { get; private set; }
public string LogFilePath { get; private set; }
public bool AutoFlushLog { get; private set; } = false;
@@ -147,6 +164,13 @@ namespace Microsoft.SqlTools.Utility
/// </summary>
public bool ParallelMessageProcessing { get; private set; } = false;
/// <summary>
/// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
/// when user chooses 'Azure MFA'. This setting enables MSAL.NET to acquire token with SqlClient integration.
/// Currently this option is disabled by default, it's planned to be enabled by default in future releases.
/// </summary>
public bool EnableSqlAuthenticationProvider { get; private set; } = false;
/// <summary>
/// The ID of the process that started this service. This is used to check when the parent
/// process exits so that the service process can exit at the same time.

View File

@@ -70,7 +70,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
if (azureAccountToken != null)
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(_underlyingConnection.ConnectionString);
if (builder.Authentication is SqlAuthenticationMethod.NotSpecified && azureAccountToken is not null)
{
_underlyingConnection.AccessToken = azureAccountToken;
}

View File

@@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Migration
logFilePath = Logger.GenerateLogFilePath(logFilePath);
}
Logger.Initialize(SourceLevels.Verbose, logFilePath, "Migration", commandOptions.AutoFlushLog);
Logger.Initialize(SourceLevels.Verbose, piiEnabled: commandOptions.PiiLogging, logFilePath, "Migration", commandOptions.AutoFlushLog);
Logger.Verbose("Starting SqlTools Migration Server...");

View File

@@ -41,7 +41,7 @@ namespace Microsoft.SqlTools.ResourceProvider
}
// we need to switch to Information when preparing for public preview
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog);
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog);
Logger.Write(TraceEventType.Information, "Starting SqlTools Resource Provider");
// set up the host details and profile paths

View File

@@ -6,6 +6,7 @@
#nullable disable
using Microsoft.SqlTools.Hosting.Contracts;
using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -52,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
CategoryValues = new CategoryValue[]
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA }
},
IsIdentity = true,
IsRequired = true,

View File

@@ -23,7 +23,9 @@ using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Diagnostics;
using Microsoft.SqlTools.Authentication.Sql;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -143,17 +145,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
RequestSecurityTokenParams message = new RequestSecurityTokenParams()
{
Authority = authority,
Provider = "Azure",
Authority = authority,
Resource = resource,
Scope = scope
Scopes = new string[] { scope }
};
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true);
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true).ConfigureAwait(false);
return response.Token;
}
/// <summary>
/// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
/// when user chooses 'Azure MFA'.
/// </summary>
public bool EnableSqlAuthenticationProvider { get; set; }
/// <summary>
/// Returns a connection queue for given type
/// </summary>
@@ -248,7 +256,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
if (this.TryFindConnection(ownerUri, out connInfo))
{
// If not an azure connection, no need to refresh token
if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
if (connInfo.ConnectionDetails.AuthenticationType != AzureMFA)
{
return false;
}
@@ -594,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionInfo.IsSqlDb = serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDatabase;
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
// Determines that access token is used for creating connection.
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == "AzureMFA";
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == AzureMFA;
connectionInfo.EngineEdition = (DatabaseEngineEdition)serverInfo.EngineEditionId;
// Azure Data Studio supports SQL Server 2014 and later releases.
response.IsSupportedVersion = serverInfo.IsCloud || serverInfo.ServerMajorVersion >= 12;
@@ -1061,10 +1069,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return handler.HandleRequest(this.connectionFactory, info);
}
public void InitializeService(IProtocolEndpoint serviceHost)
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
{
this.ServiceHost = serviceHost;
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
{
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
var provider = new AuthenticationProvider(commandOptions.ApplicationName);
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
this.EnableSqlAuthenticationProvider = true;
Logger.Information("Registering implementation of SQL Authentication provider for 'Active Directory Interactive' authentication mode.");
}
// Register request and event handlers with the Service Host
serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest, true);
serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest, true);
@@ -1304,31 +1322,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionBuilder = new SqlConnectionStringBuilder
{
["Data Source"] = dataSource,
["User Id"] = connectionDetails.UserName,
["Password"] = connectionDetails.Password
DataSource = dataSource
};
}
// Check for any optional parameters
if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
{
connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
connectionBuilder.InitialCatalog = connectionDetails.DatabaseName;
}
if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
{
switch (connectionDetails.AuthenticationType)
{
case "Integrated":
case Integrated:
connectionBuilder.IntegratedSecurity = true;
break;
case "SqlLogin":
case SqlLogin:
connectionBuilder.UserID = connectionDetails.UserName;
connectionBuilder.Password = connectionDetails.Password;
connectionBuilder.Authentication = SqlAuthenticationMethod.SqlPassword;
break;
case "AzureMFA":
connectionBuilder.UserID = "";
connectionBuilder.Password = "";
case AzureMFA:
if (Instance.EnableSqlAuthenticationProvider)
{
connectionBuilder.UserID = connectionDetails.UserName;
connectionDetails.AuthenticationType = ActiveDirectoryInteractive;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
}
else
{
connectionBuilder.UserID = "";
}
break;
case "ActiveDirectoryPassword":
case ActiveDirectoryInteractive:
connectionBuilder.UserID = connectionDetails.UserName;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
break;
case ActiveDirectoryPassword:
connectionBuilder.UserID = connectionDetails.UserName;
connectionBuilder.Password = connectionDetails.Password;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryPassword;
break;
default:
@@ -1337,16 +1370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting))
{
switch (connectionDetails.ColumnEncryptionSetting.ToUpper())
if (Enum.TryParse<SqlConnectionColumnEncryptionSetting>(connectionDetails.ColumnEncryptionSetting, true, out var value))
{
case "ENABLED":
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Enabled;
break;
case "DISABLED":
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
connectionBuilder.ColumnEncryptionSetting = value;
}
else
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
}
}
if (!string.IsNullOrEmpty(connectionDetails.SecureEnclaves))
@@ -1365,36 +1395,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationProtocol))
{
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
}
switch (connectionDetails.EnclaveAttestationProtocol.ToUpper())
if (Enum.TryParse<SqlConnectionAttestationProtocol>(connectionDetails.EnclaveAttestationProtocol, true, out var value))
{
case "AAS":
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.AAS;
break;
case "HGS":
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.HGS;
break;
case "NONE":
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.None;
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
connectionBuilder.AttestationProtocol = value;
}
else
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
}
}
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationUrl))
{
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
}
if(connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
if (connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAttestationProtocolNoneWithUrl);
}
@@ -1456,19 +1480,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
{
ApplicationIntent intent;
switch (connectionDetails.ApplicationIntent)
if (Enum.TryParse<ApplicationIntent>(connectionDetails.ApplicationIntent, true, out ApplicationIntent value))
{
case "ReadOnly":
intent = ApplicationIntent.ReadOnly;
break;
case "ReadWrite":
intent = ApplicationIntent.ReadWrite;
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
connectionBuilder.ApplicationIntent = value;
}
else
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
}
connectionBuilder.ApplicationIntent = intent;
}
if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
{
@@ -1585,7 +1604,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ApplicationIntent = builder.ApplicationIntent.ToString(),
ApplicationName = builder.ApplicationName,
AttachDbFilename = builder.AttachDBFilename,
AuthenticationType = builder.IntegratedSecurity ? "Integrated" : "SqlLogin",
AuthenticationType = builder.IntegratedSecurity ? "Integrated" :
(builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
? "ActiveDirectoryInteractive" : "SqlLogin"),
ConnectRetryCount = builder.ConnectRetryCount,
ConnectRetryInterval = builder.ConnectRetryInterval,
ConnectTimeout = builder.ConnectTimeout,
@@ -1793,7 +1814,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
SqlConnection sqlConn = new SqlConnection(connectionString);
// Fill in Azure authentication token if needed
if (connInfo.ConnectionDetails.AzureAccountToken != null)
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
@@ -1824,7 +1845,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
ServerConnection serverConnection;
if (connInfo.ConnectionDetails.AzureAccountToken != null)
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
}

View File

@@ -11,25 +11,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
class RequestSecurityTokenParams
{
/// <summary>
/// Gets or sets the address of the authority to issue token.
/// </summary>
public string Authority { get; set; }
/// <summary>
/// Gets or sets the provider that indicates the type of linked account to query.
/// </summary>
public string Provider { get; set; }
/// <summary>
/// Gets or sets the authority URL from where token is requested.
/// </summary>
public string Authority { get; set; }
/// <summary>
/// Gets or sets the identifier of the target resource that is the recipient of the requested token.
/// </summary>
public string Resource { get; set; }
/// <summary>
/// Gets or sets the scope of the authentication request.
/// Gets or sets the scope array of the authentication request.
/// </summary>
public string Scope { get; set; }
public string [] Scopes { get; set; }
}
class RequestSecurityTokenResponse

View File

@@ -8,6 +8,7 @@ using Microsoft.SqlServer.Dac;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.TaskServices;
using Microsoft.SqlTools.ServiceLayer.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using Microsoft.SqlTools.Utility;
using System;
using System.Diagnostics;
@@ -86,7 +87,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DacFx
try
{
// Pass in Azure authentication token if needed
this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null ? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken)) : new DacServices(this.ConnectionString);
this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null && this.ConnInfo.ConnectionDetails.AuthenticationType == AzureMFA
? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken))
: new DacServices(this.ConnectionString);
Execute();
}
catch (Exception e)

View File

@@ -40,6 +40,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.SqlProjects;
using Microsoft.SqlTools.ServiceLayer.TableDesigner;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace;
namespace Microsoft.SqlTools.ServiceLayer
@@ -53,7 +54,7 @@ namespace Microsoft.SqlTools.ServiceLayer
private static object lockObject = new object();
private static bool isLoaded;
internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, Stream? inputStream = null, Stream? outputStream = null)
internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions, Stream? inputStream = null, Stream? outputStream = null)
{
ServiceHost serviceHost = ServiceHost.Instance;
lock (lockObject)
@@ -63,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer
// Grab the instance of the service host
serviceHost.Initialize(inputStream, outputStream);
InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext);
InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext, commandOptions);
// Start the service only after all request handlers are setup. This is vital
// as otherwise the Initialize event can be lost - it's processed and discarded before the handler
@@ -75,7 +76,7 @@ namespace Microsoft.SqlTools.ServiceLayer
return serviceHost;
}
private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext)
private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions)
{
// Load extension provider, which currently finds all exports in current DLL. Can be changed to find based
// on directory or assembly list quite easily in the future
@@ -96,7 +97,7 @@ namespace Microsoft.SqlTools.ServiceLayer
LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext);
serviceProvider.RegisterSingleService(LanguageService.Instance);
ConnectionService.Instance.InitializeService(serviceHost);
ConnectionService.Instance.InitializeService(serviceHost, commandOptions);
serviceProvider.RegisterSingleService(ConnectionService.Instance);
CredentialService.Instance.InitializeService(serviceHost);

View File

@@ -57,6 +57,8 @@
<ProjectReference Include="../Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj" />
<ProjectReference Include="../Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj" />
<ProjectReference Include="../Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj" />
<ProjectReference Include="..\Microsoft.SqlTools.Authentication\Microsoft.SqlTools.Authentication.csproj" />
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
<ItemGroup>
<Content Include="..\..\Notice.txt">

View File

@@ -400,7 +400,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
var builder = ConnectionService.CreateConnectionStringBuilder(session.ConnectionInfo.ConnectionDetails);
builder.InitialCatalog = node.NodeValue;
builder.ApplicationName = TableDesignerService.TableDesignerApplicationName;
var azureToken = session.ConnectionInfo.ConnectionDetails.AzureAccountToken;
// Set Access Token only when authentication mode is not specified.
var azureToken = builder.Authentication == SqlAuthenticationMethod.NotSpecified
? session.ConnectionInfo.ConnectionDetails.AzureAccountToken : null;
TableDesignerCacheManager.StartDatabaseModelInitialization(builder.ToString(), azureToken);
}
catch (Exception ex)

View File

@@ -41,7 +41,16 @@ namespace Microsoft.SqlTools.ServiceLayer
logFilePath = Logger.GenerateLogFilePath("sqltools");
}
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog);
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog);
// Register PII Logging configuration change callback
Workspace.WorkspaceService<SqlToolsSettings>.Instance.RegisterConfigChangeCallback((newSettings, oldSettings, context) =>
{
Logger.IsPiiEnabled = newSettings?.MssqlTools?.PiiLogging ?? false;
Logger.Information(Logger.IsPiiEnabled ? "PII Logging enabled" : "PII Logging disabled");
return Task.FromResult(true);
});
// Only enable SQL Client logging when verbose or higher to avoid extra overhead when the
// detailed logging it provides isn't needed
if (Logger.TracingLevel.HasFlag(SourceLevels.Verbose))
@@ -53,7 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer
var hostDetails = new HostDetails(version: new Version(1, 0));
SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails);
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext);
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, commandOptions);
serviceHost.MessageDispatcher.ParallelMessageProcessing = commandOptions.ParallelMessageProcessing;
// If this service was started by another process, then it should shutdown when that parent process does.

View File

@@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.DacFx.Contracts;
using Microsoft.SqlTools.ServiceLayer.SchemaCompare.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare
@@ -187,7 +188,9 @@ namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare
case SchemaCompareEndpointType.Database:
{
string connectionString = GetConnectionString(connInfo, endpointInfo.DatabaseName);
return connInfo.ConnectionDetails?.AzureAccountToken != null
// Set Access Token only when authentication mode is not specified.
return connInfo.ConnectionDetails?.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA
? new SchemaCompareDatabaseEndpoint(connectionString, new AccessTokenProvider(connInfo.ConnectionDetails.AzureAccountToken))
: new SchemaCompareDatabaseEndpoint(connectionString);
}

View File

@@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using Microsoft.SqlTools.ServiceLayer.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Linq;
namespace Microsoft.SqlTools.ServiceLayer.Scripting
@@ -109,7 +110,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
if (connInfo != null)
{
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
// Set Access Token only when authentication type is AzureMFA.
if (connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
}
else
{

View File

@@ -18,7 +18,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// group such as Intellisense is defined on a serialized setting it's used in the order of mssql, then sql, then
/// falls back to a default value.
/// </summary>
public class CompoundToolsSettingsValues: ISqlToolsSettingsValues
public class CompoundToolsSettingsValues : ISqlToolsSettingsValues
{
private List<ISqlToolsSettingsValues> priorityList = new List<ISqlToolsSettingsValues>();
private SqlToolsSettingsValues defaultValues;
@@ -44,11 +44,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the detailed IntelliSense settings
/// </summary>
public IntelliSenseSettings IntelliSense
{
{
get
{
return GetSettingOrDefault((settings) => settings.IntelliSense);
}
}
set
{
priorityList[0].IntelliSense = value;
@@ -59,11 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the query execution settings
/// </summary>
public QueryExecutionSettings QueryExecutionSettings
{
{
get
{
return GetSettingOrDefault((settings) => settings.QueryExecutionSettings);
}
}
set
{
priorityList[0].QueryExecutionSettings = value;
@@ -74,11 +74,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the formatter settings
/// </summary>
public FormatterSettings Format
{
{
get
{
return GetSettingOrDefault((settings) => settings.Format);
}
}
set
{
priorityList[0].Format = value;
@@ -89,15 +89,24 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the object explorer settings
/// </summary>
public ObjectExplorerSettings ObjectExplorer
{
{
get
{
return GetSettingOrDefault((settings) => settings.ObjectExplorer);
}
}
set
{
priorityList[0].ObjectExplorer = value;
}
}
/// <summary>
/// Gets or sets PII Logging setting.
/// </summary>
public bool PiiLogging
{
get => GetSettingOrDefault((settings) => settings.PiiLogging);
set => priorityList[0].PiiLogging = value;
}
}
}

View File

@@ -5,6 +5,8 @@
#nullable disable
using System;
namespace Microsoft.SqlTools.ServiceLayer.SqlContext
{
/// <summary>
@@ -31,5 +33,10 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Object Explorer specific settings
/// </summary>
ObjectExplorerSettings ObjectExplorer { get; set; }
/// <summary>
/// PII Logging setting
/// </summary>
Boolean PiiLogging { get; set; }
}
}

View File

@@ -5,6 +5,7 @@
#nullable disable
using System;
using Newtonsoft.Json;
namespace Microsoft.SqlTools.ServiceLayer.SqlContext
@@ -25,6 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
QueryExecutionSettings = new QueryExecutionSettings();
Format = new FormatterSettings();
TableDesigner = new TableDesignerSettings();
PiiLogging = false;
}
}
@@ -57,5 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// </summary>
[JsonProperty("tableDesigner")]
public TableDesignerSettings TableDesigner { get; set; }
/// <summary>
/// Gets or sets the setting to enable PII Logging.
/// </summary>
[JsonProperty("piiLogging")]
public Boolean PiiLogging { get; set; }
}
}

View File

@@ -1800,7 +1800,11 @@ namespace Microsoft.SqlTools.ServiceLayer.TableDesigner
connectionStringBuilder.InitialCatalog = tableInfo.Database;
connectionStringBuilder.ApplicationName = TableDesignerService.TableDesignerApplicationName;
var connectionString = connectionStringBuilder.ToString();
tableDesigner = new Dac.TableDesigner(connectionString, tableInfo.AccessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable);
// Set Access Token only when authentication mode is not specified.
var accessToken = connectionStringBuilder.Authentication == SqlAuthenticationMethod.NotSpecified
? tableInfo.AccessToken : null;
tableDesigner = new Dac.TableDesigner(connectionString, accessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable);
}
else
{

View File

@@ -12,7 +12,7 @@ using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Utility
{
class ServiceLayerCommandOptions : CommandOptions
public class ServiceLayerCommandOptions : CommandOptions
{
internal const string ServiceLayerServiceName = "MicrosoftSqlToolsServiceLayer.exe";

View File

@@ -0,0 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
</Project>

View File

@@ -0,0 +1,18 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Shared.Utility
{
public class Constants
{
// Authentication Types
public const string Integrated = "Integrated";
public const string SqlLogin = "SqlLogin";
public const string AzureMFA = "AzureMFA";
public const string dstsAuth = "dstsAuth";
public const string ActiveDirectoryInteractive = "ActiveDirectoryInteractive";
public const string ActiveDirectoryPassword = "ActiveDirectoryPassword";
}
}

View File

@@ -3,11 +3,8 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Threading;
namespace Microsoft.SqlTools.Utility
{
@@ -29,41 +26,45 @@ namespace Microsoft.SqlTools.Utility
/// </summary>
public static class Logger
{
internal const SourceLevels defaultTracingLevel = SourceLevels.Critical;
internal const string defaultTraceSource = "sqltools";
public const SourceLevels defaultTracingLevel = SourceLevels.Critical;
public const string defaultTraceSource = "sqltools";
private static SourceLevels tracingLevel = defaultTracingLevel;
private static string logFileFullPath;
private static string? logFileFullPath;
public static TraceSource? TraceSource { get; set; }
internal static TraceSource TraceSource { get; set; }
public static string LogFileFullPath
{
get => logFileFullPath;
get => logFileFullPath!;
private set
{
//If the log file path has a directory component then ensure that the directory exists.
if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value)))
if (value != null)
{
Directory.CreateDirectory(Path.GetDirectoryName(value));
}
//If the log file path has a directory component then ensure that the directory exists.
if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value)))
{
Directory.CreateDirectory(Path.GetDirectoryName(value)!);
}
logFileFullPath = value;
logFileFullPath = value;
}
ConfigureListener();
}
}
private static SqlToolsTraceListener Listener { get; set; }
private static SqlToolsTraceListener? Listener { get; set; }
private static void ConfigureLogFile(string logFilePrefix) => LogFileFullPath = GenerateLogFilePath(logFilePrefix);
/// <summary>
/// Calling this method will turn on inclusion CallStack in the log for all future traces
/// </summary>
public static void StartCallStack() => Listener.TraceOutputOptions |= TraceOptions.Callstack;
public static void StartCallStack() => Listener!.TraceOutputOptions |= TraceOptions.Callstack;
/// <summary>
/// Calling this method will turn off inclusion of CallStack in the log for all future traces
/// </summary>
public static void StopCallStack() => Listener.TraceOutputOptions &= ~TraceOptions.Callstack;
public static void StopCallStack() => Listener!.TraceOutputOptions &= ~TraceOptions.Callstack;
/// <summary>
/// Calls flush on defaultTracingLevel configured listeners.
@@ -86,17 +87,19 @@ namespace Microsoft.SqlTools.Utility
get => tracingLevel;
set
{
if(TraceSource != null)
if (TraceSource != null)
{
// configures the source level filter. This alone is not enough for tracing that is done via "Trace" class instead of "TraceSource" object
TraceSource.Switch = new SourceSwitch(TraceSource.Name, value.ToString());
}
// configure the listener level filter
tracingLevel = value;
Listener.Filter = new EventTypeFilter(tracingLevel);
Listener!.Filter = new EventTypeFilter(tracingLevel);
}
}
public static bool IsPiiEnabled { get; set; } = false;
public static bool AutoFlush { get; set; } = false;
/// <summary>
@@ -115,11 +118,13 @@ namespace Microsoft.SqlTools.Utility
/// </param>
public static void Initialize(
SourceLevels tracingLevel = defaultTracingLevel,
string logFilePath = null,
bool piiEnabled = false,
string? logFilePath = null,
string traceSource = defaultTraceSource,
bool autoFlush = false)
{
Logger.tracingLevel = tracingLevel;
Logger.IsPiiEnabled = piiEnabled;
Logger.AutoFlush = autoFlush;
TraceSource = new TraceSource(traceSource, Logger.tracingLevel);
if (string.IsNullOrWhiteSpace(logFilePath))
@@ -146,11 +151,12 @@ namespace Microsoft.SqlTools.Utility
/// <param name="autoFlush">
/// Optional. Specifies whether the log is flushed after every message
/// </param>
public static void Initialize(string tracingLevel, string logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false)
public static void Initialize(string tracingLevel, bool piiEnabled, string? logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false)
{
Initialize(Enum.TryParse<SourceLevels>(tracingLevel, out SourceLevels sourceTracingLevel)
? sourceTracingLevel
: defaultTracingLevel
, piiEnabled
, logFilePath
, traceSource
, autoFlush);
@@ -169,7 +175,7 @@ namespace Microsoft.SqlTools.Utility
throw new ArgumentOutOfRangeException(nameof(logFilePrefix), $"LogfilePath cannot be configured if argument {nameof(logFilePrefix)} has not been set");
}
// Create the log directory
string logDir = Path.GetDirectoryName(logFilePrefix);
string? logDir = Path.GetDirectoryName(logFilePrefix);
if (!string.IsNullOrWhiteSpace(logDir))
{
if (!Directory.Exists(logDir))
@@ -199,7 +205,7 @@ namespace Microsoft.SqlTools.Utility
}
string fileName;
try
try
{
var now = DateTime.Now;
fileName = string.Format(CultureInfo.InvariantCulture,
@@ -239,6 +245,16 @@ namespace Microsoft.SqlTools.Utility
/// <param name="logMessage">The message text to be written.</param>
public static void Write(TraceEventType eventType, string logMessage) => Write(eventType, LogEvent.Default, logMessage);
/// <summary>
/// Writes a PII message to the log file with the Verbose event level when PII flag is enabled.
/// </summary>
/// <param name="logMessage">The message text to be written.</param>
public static void Pii(string logMessage) {
if (IsPiiEnabled) {
Write(TraceEventType.Verbose, logMessage);
}
}
/// <summary>
/// Writes a message to the log file with the Verbose event level
/// </summary>
@@ -386,9 +402,9 @@ namespace Microsoft.SqlTools.Utility
);
}
#region forward actual write/close/flush/dispose calls to the underlying listener.
public override void Write(string message) => Listener.Write(message);
public override void Write(string? message) => Listener.Write(message);
public override void WriteLine(string message) => Listener.WriteLine(message);
public override void WriteLine(string? message) => Listener.WriteLine(message);
/// <Summary>
/// Closes the <see cref="System.Diagnostics.TextWriterTraceListener.Writer"> so that it no longer
@@ -430,41 +446,41 @@ namespace Microsoft.SqlTools.Utility
}
#endregion
public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id)
public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id)
{
TraceEvent(eventCache, source, eventType, id, String.Empty);
}
// All other TraceEvent methods come through this one.
public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string message)
public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? message)
{
if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, message, null, null, null))
{
return;
}
WriteHeader(eventCache, source, eventType, id);
WriteLine(message);
WriteFooter(eventCache);
WriteHeader(eventCache!, source, eventType, id);
WriteLine(message!);
WriteFooter(eventCache!);
}
public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string format, params object[] args)
public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? format, params object?[]? args)
{
if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, format, args, null, null))
{
return;
}
WriteHeader(eventCache, source, eventType, id);
WriteHeader(eventCache!, source, eventType, id);
if (args != null)
{
WriteLine(String.Format(CultureInfo.InvariantCulture, format, args));
WriteLine(String.Format(CultureInfo.InvariantCulture, format!, args));
}
else
{
WriteLine(format);
WriteLine(format!);
}
WriteFooter(eventCache);
WriteFooter(eventCache!);
}
private void WriteHeader(TraceEventCache eventCache, String source, TraceEventType eventType, int id)

View File

@@ -12,14 +12,15 @@ using Microsoft.Kusto.ServiceLayer.DataSource;
using Microsoft.Kusto.ServiceLayer.DataSource.Intellisense;
using Microsoft.Kusto.ServiceLayer.LanguageServices;
using Microsoft.Kusto.ServiceLayer.Workspace.Contracts;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using NUnit.Framework;
namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
{
public class DataSourceFactoryTests
{
[TestCase(typeof(ArgumentException), "ConnectionString", "", "AzureMFA")]
[TestCase(typeof(ArgumentException), "ConnectionString", "", "dstsAuth")]
[TestCase(typeof(ArgumentException), "ConnectionString", "", AzureMFA)]
[TestCase(typeof(ArgumentException), "ConnectionString", "", dstsAuth)]
public void Create_Throws_Exceptions_For_InvalidAzureAccountToken(Type exceptionType, string connectionString, string azureAccountToken, string authType)
{
Program.ServiceName = "Kusto";

View File

@@ -8,14 +8,15 @@
using System;
using Microsoft.Kusto.ServiceLayer.DataSource.Contracts;
using Microsoft.Kusto.ServiceLayer.DataSource.Kusto;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using NUnit.Framework;
namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
{
public class KustoClientTests
{
[TestCase("dstsAuth")]
[TestCase("AzureMFA")]
[TestCase(dstsAuth)]
[TestCase(AzureMFA)]
public void Constructor_Throws_ArgumentException_For_MissingToken(string authType)
{
var connectionDetails = new DataSourceConnectionDetails
@@ -37,7 +38,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
UserToken = "UserToken",
ServerName = clusterName,
DatabaseName = "",
AuthenticationType = "AzureMFA"
AuthenticationType = AzureMFA
};
var client = new KustoClient(connectionDetails, "ownerUri");
@@ -46,10 +47,10 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
Assert.AreEqual("NetDefaultDB", client.DatabaseName);
}
[TestCase("dstsAuth")]
[TestCase("AzureMFA")]
[TestCase(dstsAuth)]
[TestCase(AzureMFA)]
[TestCase("NoAuth")]
[TestCase("SqlLogin")]
[TestCase(SqlLogin)]
public void Constructor_Creates_Client_With_Valid_AuthenticationType(string authenticationType)
{
string clusterName = "https://fake.url.com";
@@ -59,8 +60,8 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
ServerName = clusterName,
DatabaseName = "FakeDatabaseName",
AuthenticationType = authenticationType,
UserName = authenticationType == "SqlLogin" ? "username": null,
Password = authenticationType == "SqlLogin" ? "password": null
UserName = authenticationType == SqlLogin ? "username": null,
Password = authenticationType == SqlLogin ? "password": null
};
var client = new KustoClient(connectionDetails, "ownerUri");

View File

@@ -7,8 +7,8 @@
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Moq" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Moq" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="nunit" />
<PackageReference Include="nunit3testadapter" />
<PackageReference Include="nunit.console" />
@@ -25,5 +25,6 @@
<ProjectReference Include="../../src/Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj" />
<ProjectReference Include="../../src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj" />
<ProjectReference Include="../../test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj" />
<ProjectReference Include="..\..\src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
</Project>

View File

@@ -32,6 +32,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
public SourceLevels TracingLevel { get; set; } = SourceLevels.Critical;
public bool DoNotUseTraceSource { get; set; } = false;
public bool IsPiiEnabled { get; set; } = false;
public bool AutoFlush { get; set; } = false;
private List<Action> pendingVerifications;
@@ -43,7 +45,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
public string LogFileName { get => logFileName ?? Logger.LogFileFullPath; set => logFileName = value; }
public void Initialize() =>
Logger.Initialize(TracingLevel, LogFilePath, TraceSource, AutoFlush); // initialize the logger
Logger.Initialize(TracingLevel, IsPiiEnabled, LogFilePath, TraceSource, AutoFlush); // initialize the logger
public string LogContents
{
get

View File

@@ -195,7 +195,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
// Initialize the ServiceHost, using a MemoryStream for the output stream so that we don't fill up the logs
// with a bunch of outgoing messages (which aren't used for anything during tests)
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, new MemoryStream());
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, null, new MemoryStream());
// Set up our logger to write to Console for tests to help debug issues
Logger.Initialize(autoFlush: true);

View File

@@ -12,6 +12,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using Moq;
using Moq.Protected;
using NUnit.Framework;
@@ -479,6 +480,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
new object[] {null, "12345678"},
new object[] {"", "12345678"},
};
/// <summary>
/// Verify that when using integrated authentication, the username and/or password can be empty.
/// </summary>
@@ -497,13 +499,67 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
DatabaseName = "test",
UserName = userName,
Password = password,
AuthenticationType = "Integrated"
AuthenticationType = Integrated
}
});
Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful");
}
/// <summary>
/// Verify that username is required when using Active Directory Interactive authentication.
/// Both AzureMFA and ActiveDirectoryInteractive should work same way, when SqlAuthenticationProvider is enabled.
/// </summary>
[TestCase(null, AzureMFA)]
[TestCase(null, ActiveDirectoryInteractive)]
public async Task ConnectingWitNoUsernameFailsForAADInteractiveAuth(string userName, string authMode)
{
// This is an exception scenario test, therefore using ConnectionService instance instead of TestConnectionService directly.
ConnectionService.Instance.EnableSqlAuthenticationProvider = true;
// Connect
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
OwnerUri = "file:///my/test/file.sql",
Connection = new ConnectionDetails()
{
ServerName = "my-server",
DatabaseName = "test",
UserName = userName,
AuthenticationType = authMode
}
});
Assert.That(connectionResult.ConnectionId, Is.Null.Or.Empty, "Connection should not be successful");
}
/// <summary>
/// Verify that password is ignored when using Active Directory Interactive authentication.
/// </summary>
[TestCase("user", "anything", AzureMFA)]
[TestCase("user", "anything", ActiveDirectoryInteractive)]
public async Task ConnectingWitPasswordIsIgnoredForAADInteractiveAuth(string username, string password, string authMode)
{
// Connect
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
OwnerUri = "file:///my/test/file.sql",
Connection = new ConnectionDetails()
{
ServerName = "my-server",
DatabaseName = "test",
UserName = username,
Password = password,
AuthenticationType = authMode
}
});
Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "Connection should be successful");
}
/// <summary>
/// Verify that when connecting with a null parameters object, an error is thrown.
/// </summary>
@@ -1773,16 +1829,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
details.AzureAccountToken = azureAccountToken;
details.UserName = "";
details.Password = "";
details.AuthenticationType = "AzureMFA";
details.AuthenticationType = AzureMFA;
// If I open a connection using connection details that include an account token
// Open a connection using connection details that include an account token
await connectionService.Connect(new ConnectParams
{
OwnerUri = "testURI",
Connection = details
});
// Then the connection factory got called with details including an account token
// Validate that the connection factory gets called with details NOT including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
}

View File

@@ -34,6 +34,7 @@
<ProjectReference Include="../../src/Microsoft.SqlTools.ResourceProvider.Core/Microsoft.SqlTools.ResourceProvider.Core.csproj" />
<ProjectReference Include="../../src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Microsoft.SqlTools.ResourceProvider.DefaultImpl.csproj" />
<ProjectReference Include="../../test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj" />
<ProjectReference Include="..\..\src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
<ItemGroup>
<Service Include="{82a7f48d-3b50-4b1e-b82e-3ada8210c358}" />

View File

@@ -109,7 +109,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
string tracingLevel = null;
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
string expectedTraceSource = Logger.defaultTraceSource;
Logger.Initialize(tracingLevel);
Logger.Initialize(tracingLevel, false);
bool isLogFileExpectedToExist = false;
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
TestLogger.Cleanup(Logger.LogFileFullPath);
@@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
string tracingLevel = null;
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
string expectedTraceSource = Logger.defaultTraceSource;
Logger.Initialize(tracingLevel);
Logger.Initialize(tracingLevel, false);
bool isLogFileExpectedToExist = false;
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
TestLogger.Cleanup(Logger.LogFileFullPath);
@@ -131,7 +131,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
string tracingLevel = "invalid";
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
string expectedTraceSource = Logger.defaultTraceSource;
Logger.Initialize(tracingLevel);
Logger.Initialize(tracingLevel, false);
bool isLogFileExpectedToExist = false;
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
TestLogger.Cleanup(Logger.LogFileFullPath);