mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-14 01:25:40 -05:00
Introduce AAD interactive auth mode (#1860)
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
<PackageReference Update="Microsoft.Extensions.DependencyModel" Version="3.1.4" />
|
||||
<PackageReference Update="Microsoft.Extensions.FileSystemGlobbing" Version="7.0.0" />
|
||||
|
||||
<PackageReference Update="Microsoft.Identity.Client" Version="4.49.1" />
|
||||
<PackageReference Update="Microsoft.Identity.Client.Extensions.Msal" Version="2.25.3" />
|
||||
<PackageReference Update="Microsoft.Azure.Management.ResourceManager" Version="3.7.1-preview" />
|
||||
<PackageReference Update="Microsoft.Azure.Management.Sql" Version="1.41.0-preview" />
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio Version 16
|
||||
VisualStudioVersion = 16.0.29409.12
|
||||
# Visual Studio Version 17
|
||||
VisualStudioVersion = 17.5.33209.295
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{2BBD7364-054F-4693-97CD-1C395E3E84A9}"
|
||||
ProjectSection(SolutionItems) = preProject
|
||||
@@ -96,6 +96,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migratio
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migration.IntegrationTests", "test\Microsoft.SqlTools.Migration.IntegrationTests\Microsoft.SqlTools.Migration.IntegrationTests.csproj", "{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Authentication", "src\Microsoft.SqlTools.Authentication\Microsoft.SqlTools.Authentication.csproj", "{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}"
|
||||
EndProject
|
||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Shared", "src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj", "{531EC0E0-F400-42C5-BCFA-6691313B5F3E}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
@@ -229,6 +233,18 @@ Global
|
||||
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Integration|Any CPU.Build.0 = Debug|Any CPU
|
||||
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.Build.0 = Debug|Any CPU
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.Build.0 = Debug|Any CPU
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
@@ -256,6 +272,8 @@ Global
|
||||
{07296730-DAB7-4B0B-9D09-ABD9A5025D68} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
|
||||
{22DB0C12-6848-4503-AD1C-DAD6A1D631AE} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
|
||||
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4}
|
||||
{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
|
||||
{531EC0E0-F400-42C5-BCFA-6691313B5F3E} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
|
||||
EndGlobalSection
|
||||
GlobalSection(ExtensibilityGlobals) = postSolution
|
||||
SolutionGuid = {B31CDF4B-2851-45E5-8C5F-BE97125D9DD8}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//
|
||||
|
||||
using Microsoft.SqlTools.Hosting.Contracts;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
|
||||
namespace Microsoft.Kusto.ServiceLayer.Connection
|
||||
{
|
||||
@@ -50,7 +51,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection
|
||||
CategoryValues = new CategoryValue[]
|
||||
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
|
||||
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
|
||||
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
|
||||
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA }
|
||||
},
|
||||
IsIdentity = true,
|
||||
IsRequired = true,
|
||||
|
||||
@@ -19,6 +19,7 @@ using System.Diagnostics;
|
||||
using Microsoft.Kusto.ServiceLayer.DataSource;
|
||||
using Microsoft.Kusto.ServiceLayer.DataSource.Metadata;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
|
||||
namespace Microsoft.Kusto.ServiceLayer.Connection
|
||||
{
|
||||
@@ -916,7 +917,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection
|
||||
return new ConnectionDetails
|
||||
{
|
||||
ApplicationName = builder.ApplicationNameForTracing,
|
||||
AuthenticationType = "AzureMFA",
|
||||
AuthenticationType = AzureMFA,
|
||||
DatabaseName = builder.InitialCatalog,
|
||||
ServerName = builder.DataSource,
|
||||
UserName = builder.UserID,
|
||||
|
||||
@@ -20,6 +20,7 @@ using Microsoft.Kusto.ServiceLayer.LanguageServices;
|
||||
using Microsoft.Kusto.ServiceLayer.Utility;
|
||||
using Microsoft.Kusto.ServiceLayer.Workspace.Contracts;
|
||||
using CompletionItem = Microsoft.Kusto.ServiceLayer.LanguageServices.Contracts.CompletionItem;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
|
||||
namespace Microsoft.Kusto.ServiceLayer.DataSource
|
||||
{
|
||||
@@ -63,7 +64,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource
|
||||
|
||||
private DataSourceConnectionDetails MapKustoConnectionDetails(ConnectionDetails connectionDetails)
|
||||
{
|
||||
if (connectionDetails.AuthenticationType == "dstsAuth" || connectionDetails.AuthenticationType == "AzureMFA")
|
||||
if (connectionDetails.AuthenticationType == dstsAuth || connectionDetails.AuthenticationType == AzureMFA)
|
||||
{
|
||||
ValidationUtils.IsTrue<ArgumentException>(!string.IsNullOrWhiteSpace(connectionDetails.AccountToken),
|
||||
$"The Kusto User Token is not specified - set {nameof(connectionDetails.AccountToken)}");
|
||||
|
||||
@@ -21,6 +21,7 @@ using Kusto.Language.Editor;
|
||||
using Microsoft.Kusto.ServiceLayer.Connection;
|
||||
using Microsoft.Kusto.ServiceLayer.DataSource.Contracts;
|
||||
using Microsoft.Kusto.ServiceLayer.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
|
||||
namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
|
||||
{
|
||||
@@ -74,7 +75,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
|
||||
ServerName = ClusterName,
|
||||
DatabaseName = DatabaseName,
|
||||
UserToken = accountToken,
|
||||
AuthenticationType = "AzureMFA"
|
||||
AuthenticationType = AzureMFA
|
||||
};
|
||||
|
||||
Initialize(connectionDetails);
|
||||
@@ -96,8 +97,8 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
|
||||
|
||||
switch (connectionDetails.AuthenticationType)
|
||||
{
|
||||
case "AzureMFA": return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken);
|
||||
case "dstsAuth": return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken);
|
||||
case AzureMFA: return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken);
|
||||
case dstsAuth: return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken);
|
||||
default:
|
||||
return string.IsNullOrWhiteSpace(connectionDetails.UserName) && string.IsNullOrWhiteSpace(connectionDetails.Password)
|
||||
? stringBuilder
|
||||
|
||||
@@ -51,6 +51,7 @@
|
||||
<ProjectReference Include="../Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj" />
|
||||
<ProjectReference Include="../Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj" />
|
||||
<ProjectReference Include="../Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<EmbeddedResource Include="ObjectExplorer\DataSourceModel\TreeNodeDefinition.xml" />
|
||||
|
||||
@@ -41,7 +41,7 @@ namespace Microsoft.Kusto.ServiceLayer
|
||||
logFilePath = Logger.GenerateLogFilePath("kustoservice");
|
||||
}
|
||||
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog);
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog);
|
||||
|
||||
// set up the host details and profile paths
|
||||
var hostDetails = new HostDetails(version: new Version(1, 0));
|
||||
|
||||
33
src/Microsoft.SqlTools.Authentication/AccessToken.cs
Normal file
33
src/Microsoft.SqlTools.Authentication/AccessToken.cs
Normal file
@@ -0,0 +1,33 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
namespace Microsoft.SqlTools.Authentication
|
||||
{
|
||||
/// <summary>
|
||||
/// Represents an access token data object.
|
||||
/// </summary>
|
||||
public class AccessToken
|
||||
{
|
||||
/// <summary>
|
||||
/// OAuth 2.0 JWT encoded access token string
|
||||
/// </summary>
|
||||
public string Token { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Expiry date of token
|
||||
/// </summary>
|
||||
public DateTimeOffset ExpiresOn { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Default constructor for Access Token object
|
||||
/// </summary>
|
||||
/// <param name="token">Access token as string</param>
|
||||
/// <param name="expiresOn">Expiry date</param>
|
||||
public AccessToken(string token, DateTimeOffset expiresOn) {
|
||||
this.Token = token;
|
||||
this.ExpiresOn = expiresOn;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
namespace Microsoft.SqlTools.Authentication
|
||||
{
|
||||
/// <summary>
|
||||
/// Supported Active Directory authentication modes
|
||||
/// </summary>
|
||||
public enum AuthenticationMethod
|
||||
{
|
||||
ActiveDirectoryInteractive
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
namespace Microsoft.SqlTools.Authentication
|
||||
{
|
||||
/// <summary>
|
||||
/// Parameters to be passed to <see cref="Authenticator"/> to request an access token
|
||||
/// </summary>
|
||||
public class AuthenticationParams
|
||||
{
|
||||
/// <summary>
|
||||
/// Authentication method to be used by <see cref="Authenticator"/>.
|
||||
/// </summary>
|
||||
public AuthenticationMethod AuthenticationMethod { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Authority URL, e.g. https://login.microsoftonline.com/
|
||||
/// </summary>
|
||||
public string Authority { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Audience for which access token should be acquired, e.g. common, organizations, consumers.
|
||||
/// It can also be a tenant Id when authenticating multi-tenant application accounts.
|
||||
/// </summary>
|
||||
public string Audience { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Resource URL, e.g. https://database.windows.net/
|
||||
/// </summary>
|
||||
public string Resource { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Array of scopes for which access token is requested.
|
||||
/// </summary>
|
||||
public string[] Scopes { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// <see cref="Guid"/> Connection Id, that will be passed to Azure AD when requesting access token.
|
||||
/// It can be used for tracking AAD request status if needed.
|
||||
/// </summary>
|
||||
public Guid ConnectionId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// User name to be provided as userhint when acquiring access token.
|
||||
/// </summary>
|
||||
public string UserName { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Default constructor
|
||||
/// </summary>
|
||||
/// <param name="authMethod">Authentication Method to be used.</param>
|
||||
/// <param name="authority">Authority URL</param>
|
||||
/// <param name="audience">Audience</param>
|
||||
/// <param name="resource">Resource for which token is requested.</param>
|
||||
/// <param name="scopes">Scopes for access token</param>
|
||||
/// <param name="userName">User hint information</param>
|
||||
/// <param name="connectionId">Connection Id for tracing AAD request</param>
|
||||
public AuthenticationParams(AuthenticationMethod authMethod, string authority, string audience,
|
||||
string resource, string[] scopes, string userName, Guid connectionId) {
|
||||
this.AuthenticationMethod = authMethod;
|
||||
this.Authority = authority;
|
||||
this.Audience = audience;
|
||||
this.Resource = resource;
|
||||
this.Scopes = scopes;
|
||||
this.UserName = userName;
|
||||
this.ConnectionId = connectionId;
|
||||
}
|
||||
|
||||
public string ToLogString(bool piiEnabled)
|
||||
{
|
||||
return $"\tAuthenticationMethod: {AuthenticationMethod}" +
|
||||
$"\n\tAuthority: {Authority}" +
|
||||
$"\n\tAudience: {Audience}" +
|
||||
$"\n\tResource: {Resource}" +
|
||||
$"\n\tScopes: {Scopes.ToString()}" +
|
||||
(piiEnabled ? $"\n\tUsername: {UserName}" : "") +
|
||||
$"\n\tConnection Id: {ConnectionId}";
|
||||
}
|
||||
}
|
||||
}
|
||||
158
src/Microsoft.SqlTools.Authentication/Authenticator.cs
Normal file
158
src/Microsoft.SqlTools.Authentication/Authenticator.cs
Normal file
@@ -0,0 +1,158 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System.Collections.Concurrent;
|
||||
using Microsoft.Identity.Client;
|
||||
using Microsoft.Identity.Client.Extensions.Msal;
|
||||
using Microsoft.SqlTools.Authentication.Utility;
|
||||
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
|
||||
|
||||
namespace Microsoft.SqlTools.Authentication
|
||||
{
|
||||
/// <summary>
|
||||
/// Provides APIs to acquire access token using MSAL.NET v4 with provided <see cref="AuthenticationParams"/>.
|
||||
/// </summary>
|
||||
public class Authenticator
|
||||
{
|
||||
private string applicationClientId;
|
||||
private string applicationName;
|
||||
private string cacheFolderPath;
|
||||
private string cacheFileName;
|
||||
private MsalCacheHelper cacheHelper;
|
||||
private static ConcurrentDictionary<string, IPublicClientApplication> PublicClientAppMap
|
||||
= new ConcurrentDictionary<string, IPublicClientApplication>();
|
||||
|
||||
#region Public APIs
|
||||
public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName)
|
||||
{
|
||||
this.applicationClientId = appClientId;
|
||||
this.applicationName = appName;
|
||||
this.cacheFolderPath = cacheFolderPath;
|
||||
this.cacheFileName = cacheFileName;
|
||||
|
||||
// Storage creation properties are used to enable file system caching with MSAL
|
||||
var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath)
|
||||
.WithUnprotectedFile().Build();
|
||||
|
||||
// This hooks up the cross-platform cache into MSAL
|
||||
this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Acquires access token synchronously.
|
||||
/// </summary>
|
||||
/// <param name="params">Authentication parameters to be used for access token request.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Access Token with expiry date</returns>
|
||||
public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken) =>
|
||||
GetTokenAsync(@params, cancellationToken).GetAwaiter().GetResult();
|
||||
|
||||
/// <summary>
|
||||
/// Acquires access token asynchronously.
|
||||
/// </summary>
|
||||
/// <param name="params">Authentication parameters to be used for access token request.</param>
|
||||
/// <param name="cancellationToken">Cancellation token.</param>
|
||||
/// <returns>Access Token with expiry date</returns>
|
||||
/// <exception cref="Exception"></exception>
|
||||
public async Task<AccessToken?> GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken)
|
||||
{
|
||||
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Received @params: {@params.ToLogString(SqlToolsLogger.IsPiiEnabled)}");
|
||||
|
||||
IPublicClientApplication publicClientApplication = GetPublicClientAppInstance(@params.Authority, @params.Audience);
|
||||
|
||||
AccessToken? accessToken;
|
||||
if (@params.AuthenticationMethod == AuthenticationMethod.ActiveDirectoryInteractive)
|
||||
{
|
||||
// Find account
|
||||
IEnumerator<IAccount>? accounts = (await publicClientApplication.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
|
||||
IAccount? account = default;
|
||||
|
||||
if (!string.IsNullOrEmpty(@params.UserName) && accounts.MoveNext())
|
||||
{
|
||||
// Handle username format to extract email: "John Doe - johndoe@constoso.com"
|
||||
string username = @params.UserName.Contains(" - ") ? @params.UserName.Split(" - ")[1] : @params.UserName;
|
||||
if (!Utils.isValidEmail(username))
|
||||
{
|
||||
SqlToolsLogger.Pii($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Unexpected username format, email not retreived: {@params.UserName}. " +
|
||||
$"Accepted formats are: 'johndoe@org.com' or 'John Doe - johndoe@org.com'.");
|
||||
throw new Exception($"Invalid email address format for user: [{username}] received for Azure Active Directory authentication.");
|
||||
}
|
||||
|
||||
do
|
||||
{
|
||||
IAccount? currentVal = accounts.Current;
|
||||
if (string.Compare(username, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
|
||||
{
|
||||
account = currentVal;
|
||||
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account found in MSAL Cache: {account.HomeAccountId}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
while (accounts.MoveNext());
|
||||
|
||||
if (null != account)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Fetch token silently
|
||||
var result = await publicClientApplication.AcquireTokenSilent(@params.Scopes, account)
|
||||
.ExecuteAsync(cancellationToken: cancellationToken)
|
||||
.ConfigureAwait(false);
|
||||
accessToken = new AccessToken(result!.AccessToken, result!.ExpiresOn);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Silent authentication failed for resource {@params.Resource} for ConnectionId {@params.ConnectionId}.");
|
||||
SqlToolsLogger.Error(e);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Account not found in MSAL cache for user.");
|
||||
throw new Exception($"User account '{username}' not found in MSAL cache, please add linked account or refresh account credentials.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account not received.");
|
||||
throw new Exception($"User account not received.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Authentication Method ${@params.AuthenticationMethod} is not supported.");
|
||||
throw new Exception($"Authentication Method ${@params.AuthenticationMethod} is not supported.");
|
||||
}
|
||||
|
||||
return accessToken;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region Private methods
|
||||
private IPublicClientApplication GetPublicClientAppInstance(string authority, string audience)
|
||||
{
|
||||
string authorityUrl = authority + '/' + audience;
|
||||
if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance))
|
||||
{
|
||||
clientApplicationInstance = CreatePublicClientAppInstance(authority, audience);
|
||||
this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache);
|
||||
PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance);
|
||||
}
|
||||
return clientApplicationInstance;
|
||||
}
|
||||
|
||||
private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) =>
|
||||
PublicClientApplicationBuilder.Create(this.applicationClientId)
|
||||
.WithAuthority(authority, audience)
|
||||
.WithClientName(this.applicationName)
|
||||
.WithLogging(Utils.MSALLogCallback)
|
||||
.WithDefaultRedirectUri()
|
||||
.Build();
|
||||
|
||||
#endregion
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Data.SqlClient" />
|
||||
<PackageReference Include="Microsoft.Identity.Client" />
|
||||
<PackageReference Include="Microsoft.Identity.Client.Extensions.Msal" />
|
||||
<PackageReference Include="System.Configuration.ConfigurationManager" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
@@ -0,0 +1,110 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using Microsoft.Data.SqlClient;
|
||||
using Microsoft.SqlTools.Authentication.Utility;
|
||||
|
||||
namespace Microsoft.SqlTools.Authentication.Sql
|
||||
{
|
||||
/// <summary>
|
||||
/// Provides an implementation of <see cref="SqlAuthenticationProvider"/> for SQL Tools to be able to perform Federated authentication
|
||||
/// silently with Microsoft.Data.SqlClient integration only for "ActiveDirectory" authentication modes.
|
||||
/// When registered, the SqlClient driver calls the <see cref="AcquireTokenAsync(SqlAuthenticationParameters)"/> API
|
||||
/// with server-sent authority information to request access token when needed.
|
||||
/// </summary>
|
||||
public class AuthenticationProvider : SqlAuthenticationProvider
|
||||
{
|
||||
private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255";
|
||||
private const string AzureTokenFolder = "Azure Accounts";
|
||||
private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud";
|
||||
private const string s_defaultScopeSuffix = "/.default";
|
||||
|
||||
private Authenticator authenticator;
|
||||
|
||||
/// <summary>
|
||||
/// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated.
|
||||
/// </summary>
|
||||
/// <param name="applicationName">Application Name that identifies user folder path location for reading/writing to shared cache.</param>
|
||||
/// <param name="authCallback">Callback that handles AAD authentication when user interaction is needed.</param>
|
||||
public AuthenticationProvider(string applicationName)
|
||||
{
|
||||
if(string.IsNullOrEmpty(applicationName)) {
|
||||
applicationName = nameof(SqlTools);
|
||||
}
|
||||
var cachePath = Path.Combine(Utils.BuildAppDirectoryPath(), applicationName, AzureTokenFolder);
|
||||
this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Acquires access token with provided <paramref name="parameters"/>
|
||||
/// </summary>
|
||||
/// <param name="parameters">Authentication parameters</param>
|
||||
/// <returns>Authentication token containing access token and expiry date.</returns>
|
||||
public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters)
|
||||
{
|
||||
// Setup scope
|
||||
string resource = parameters.Resource;
|
||||
string scope;
|
||||
|
||||
if (parameters.Resource.EndsWith(s_defaultScopeSuffix))
|
||||
{
|
||||
scope = parameters.Resource;
|
||||
resource = parameters.Resource.Substring(0, parameters.Resource.LastIndexOf('/') + 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
scope = parameters.Resource + s_defaultScopeSuffix;
|
||||
}
|
||||
|
||||
string[] scopes = new string[] { scope };
|
||||
|
||||
CancellationTokenSource cts = new CancellationTokenSource();
|
||||
|
||||
// Use Connection timeout value to cancel token acquire request after certain period of time.
|
||||
cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds
|
||||
|
||||
/* We split audience from Authority URL here. Audience can be one of the following:
|
||||
* The Azure AD authority audience enumeration
|
||||
* The tenant ID, which can be:
|
||||
* - A GUID (the ID of your Azure AD instance), for single-tenant applications
|
||||
* - A domain name associated with your Azure AD instance (also for single-tenant applications)
|
||||
* One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
|
||||
* - `organizations` for a multitenant application
|
||||
* - `consumers` to sign in users only with their personal accounts
|
||||
* - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
|
||||
*
|
||||
* MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
|
||||
* If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
|
||||
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
|
||||
**/
|
||||
|
||||
int seperatorIndex = parameters.Authority.LastIndexOf('/');
|
||||
string authority = parameters.Authority.Remove(seperatorIndex + 1);
|
||||
string audience = parameters.Authority.Substring(seperatorIndex + 1);
|
||||
string? userName = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;
|
||||
|
||||
AuthenticationParams @params = new AuthenticationParams(
|
||||
AuthenticationMethod.ActiveDirectoryInteractive,
|
||||
authority,
|
||||
audience,
|
||||
resource,
|
||||
scopes,
|
||||
userName!,
|
||||
parameters.ConnectionId);
|
||||
|
||||
AccessToken? result = await authenticator.GetTokenAsync(@params, cts.Token).ConfigureAwait(false);
|
||||
return new SqlAuthenticationToken(result!.Token, result!.ExpiresOn);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Whether or not provided <paramref name="authenticationMethod"/> is supported.
|
||||
/// </summary>
|
||||
/// <param name="authenticationMethod">SQL Authentication method</param>
|
||||
/// <returns></returns>
|
||||
public override bool IsSupported(SqlAuthenticationMethod authenticationMethod)
|
||||
=> authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive;
|
||||
|
||||
}
|
||||
}
|
||||
119
src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
Normal file
119
src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
Normal file
@@ -0,0 +1,119 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System.Net.Mail;
|
||||
using System.Runtime.InteropServices;
|
||||
using Microsoft.Identity.Client;
|
||||
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
|
||||
|
||||
namespace Microsoft.SqlTools.Authentication.Utility
|
||||
{
|
||||
internal sealed class Utils
|
||||
{
|
||||
/// <summary>
|
||||
/// Validates provided <paramref name="userEmail"/> follows email format.
|
||||
/// </summary>
|
||||
/// <param name="useremail">Email address</param>
|
||||
/// <returns>Whether email is in correct format.</returns>
|
||||
public static bool isValidEmail(string userEmail)
|
||||
{
|
||||
try
|
||||
{
|
||||
new MailAddress(userEmail);
|
||||
return true;
|
||||
}
|
||||
catch (FormatException)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Builds directory path based on environment settings.
|
||||
/// </summary>
|
||||
/// <returns>Application directory path</returns>
|
||||
/// <exception cref="Exception">When called on unsupported platform.</exception>
|
||||
public static string BuildAppDirectoryPath()
|
||||
{
|
||||
var homedir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
|
||||
|
||||
// Windows
|
||||
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
|
||||
{
|
||||
var appData = Environment.GetEnvironmentVariable("APPDATA");
|
||||
var userProfile = Environment.GetEnvironmentVariable("USERPROFILE");
|
||||
if (appData != null)
|
||||
{
|
||||
return appData;
|
||||
}
|
||||
else if (userProfile != null)
|
||||
{
|
||||
return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception("Not able to find APPDATA or USERPROFILE");
|
||||
}
|
||||
}
|
||||
|
||||
// Mac
|
||||
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
|
||||
{
|
||||
return string.Join(homedir, "Library", "Application Support");
|
||||
}
|
||||
|
||||
// Linux
|
||||
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
|
||||
{
|
||||
var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME");
|
||||
if (xdgConfigHome != null)
|
||||
{
|
||||
return xdgConfigHome;
|
||||
}
|
||||
else
|
||||
{
|
||||
return string.Join(homedir, ".config");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception("Platform not supported");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Log callback handler used for MSAL Client applications.
|
||||
/// </summary>
|
||||
/// <param name="logLevel">Log level</param>
|
||||
/// <param name="message">Log message</param>
|
||||
/// <param name="pii">Whether message contains PII information.</param>
|
||||
public static void MSALLogCallback(LogLevel logLevel, string message, bool pii)
|
||||
{
|
||||
switch (logLevel)
|
||||
{
|
||||
case LogLevel.Error:
|
||||
if (pii) SqlToolsLogger.Pii(message);
|
||||
else SqlToolsLogger.Error(message);
|
||||
break;
|
||||
case LogLevel.Warning:
|
||||
if (pii) SqlToolsLogger.Pii(message);
|
||||
else SqlToolsLogger.Warning(message);
|
||||
break;
|
||||
case LogLevel.Info:
|
||||
if (pii) SqlToolsLogger.Pii(message);
|
||||
else SqlToolsLogger.Information(message);
|
||||
break;
|
||||
case LogLevel.Verbose:
|
||||
if (pii) SqlToolsLogger.Pii(message);
|
||||
else SqlToolsLogger.Verbose(message);
|
||||
break;
|
||||
case LogLevel.Always:
|
||||
if (pii) SqlToolsLogger.Pii(message);
|
||||
else SqlToolsLogger.Critical(message);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Credentials
|
||||
logFilePath = Logger.GenerateLogFilePath("credentials");
|
||||
}
|
||||
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog);
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog);
|
||||
|
||||
// set up the host details and profile paths
|
||||
var hostDetails = new HostDetails(
|
||||
|
||||
@@ -25,4 +25,7 @@
|
||||
<EmbeddedResource Include="Localization\*.resx" />
|
||||
<None Include="Localization\sr.strings" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
@@ -28,6 +28,7 @@ namespace Microsoft.SqlTools.Utility
|
||||
ServiceName = serviceName;
|
||||
ErrorMessage = string.Empty;
|
||||
Locale = string.Empty;
|
||||
ApplicationName = string.Empty;
|
||||
|
||||
try
|
||||
{
|
||||
@@ -42,12 +43,18 @@ namespace Microsoft.SqlTools.Utility
|
||||
|
||||
switch (argName)
|
||||
{
|
||||
case "-application-name":
|
||||
ApplicationName = args[++i];
|
||||
break;
|
||||
case "-autoflush-log":
|
||||
AutoFlushLog = true;
|
||||
break;
|
||||
case "-tracing-level":
|
||||
TracingLevel = args[++i];
|
||||
break;
|
||||
case "-pii-logging":
|
||||
PiiLogging = true;
|
||||
break;
|
||||
case "-log-file":
|
||||
LogFilePath = args[++i];
|
||||
break;
|
||||
@@ -66,6 +73,9 @@ namespace Microsoft.SqlTools.Utility
|
||||
case "-parallel-message-processing":
|
||||
ParallelMessageProcessing = true;
|
||||
break;
|
||||
case "-enable-sql-authentication-provider":
|
||||
EnableSqlAuthenticationProvider = true;
|
||||
break;
|
||||
case "-parent-pid":
|
||||
string nextArg = args[++i];
|
||||
if (Int32.TryParse(nextArg, out int parsedInt))
|
||||
@@ -95,6 +105,11 @@ namespace Microsoft.SqlTools.Utility
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Name of application that is sending command options
|
||||
/// </summary>
|
||||
public string ApplicationName { get; private set; }
|
||||
|
||||
/// <summary>
|
||||
/// Contains any error messages during execution
|
||||
/// </summary>
|
||||
@@ -137,6 +152,8 @@ namespace Microsoft.SqlTools.Utility
|
||||
|
||||
public string TracingLevel { get; private set; }
|
||||
|
||||
public bool PiiLogging { get; private set; }
|
||||
|
||||
public string LogFilePath { get; private set; }
|
||||
|
||||
public bool AutoFlushLog { get; private set; } = false;
|
||||
@@ -147,6 +164,13 @@ namespace Microsoft.SqlTools.Utility
|
||||
/// </summary>
|
||||
public bool ParallelMessageProcessing { get; private set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
|
||||
/// when user chooses 'Azure MFA'. This setting enables MSAL.NET to acquire token with SqlClient integration.
|
||||
/// Currently this option is disabled by default, it's planned to be enabled by default in future releases.
|
||||
/// </summary>
|
||||
public bool EnableSqlAuthenticationProvider { get; private set; } = false;
|
||||
|
||||
/// <summary>
|
||||
/// The ID of the process that started this service. This is used to check when the parent
|
||||
/// process exits so that the service process can exit at the same time.
|
||||
|
||||
@@ -70,7 +70,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
||||
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
|
||||
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
|
||||
|
||||
if (azureAccountToken != null)
|
||||
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(_underlyingConnection.ConnectionString);
|
||||
if (builder.Authentication is SqlAuthenticationMethod.NotSpecified && azureAccountToken is not null)
|
||||
{
|
||||
_underlyingConnection.AccessToken = azureAccountToken;
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Migration
|
||||
logFilePath = Logger.GenerateLogFilePath(logFilePath);
|
||||
}
|
||||
|
||||
Logger.Initialize(SourceLevels.Verbose, logFilePath, "Migration", commandOptions.AutoFlushLog);
|
||||
Logger.Initialize(SourceLevels.Verbose, piiEnabled: commandOptions.PiiLogging, logFilePath, "Migration", commandOptions.AutoFlushLog);
|
||||
|
||||
Logger.Verbose("Starting SqlTools Migration Server...");
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ namespace Microsoft.SqlTools.ResourceProvider
|
||||
}
|
||||
|
||||
// we need to switch to Information when preparing for public preview
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog);
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog);
|
||||
Logger.Write(TraceEventType.Information, "Starting SqlTools Resource Provider");
|
||||
|
||||
// set up the host details and profile paths
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#nullable disable
|
||||
|
||||
using Microsoft.SqlTools.Hosting.Contracts;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
@@ -52,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
CategoryValues = new CategoryValue[]
|
||||
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
|
||||
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
|
||||
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
|
||||
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA }
|
||||
},
|
||||
IsIdentity = true,
|
||||
IsRequired = true,
|
||||
|
||||
@@ -23,7 +23,9 @@ using Microsoft.SqlTools.ServiceLayer.LanguageServices;
|
||||
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using System.Diagnostics;
|
||||
using Microsoft.SqlTools.Authentication.Sql;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
@@ -143,17 +145,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
RequestSecurityTokenParams message = new RequestSecurityTokenParams()
|
||||
{
|
||||
Authority = authority,
|
||||
Provider = "Azure",
|
||||
Authority = authority,
|
||||
Resource = resource,
|
||||
Scope = scope
|
||||
Scopes = new string[] { scope }
|
||||
};
|
||||
|
||||
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true);
|
||||
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true).ConfigureAwait(false);
|
||||
|
||||
return response.Token;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
|
||||
/// when user chooses 'Azure MFA'.
|
||||
/// </summary>
|
||||
public bool EnableSqlAuthenticationProvider { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Returns a connection queue for given type
|
||||
/// </summary>
|
||||
@@ -248,7 +256,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
if (this.TryFindConnection(ownerUri, out connInfo))
|
||||
{
|
||||
// If not an azure connection, no need to refresh token
|
||||
if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
|
||||
if (connInfo.ConnectionDetails.AuthenticationType != AzureMFA)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -594,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
connectionInfo.IsSqlDb = serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDatabase;
|
||||
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
|
||||
// Determines that access token is used for creating connection.
|
||||
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == "AzureMFA";
|
||||
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == AzureMFA;
|
||||
connectionInfo.EngineEdition = (DatabaseEngineEdition)serverInfo.EngineEditionId;
|
||||
// Azure Data Studio supports SQL Server 2014 and later releases.
|
||||
response.IsSupportedVersion = serverInfo.IsCloud || serverInfo.ServerMajorVersion >= 12;
|
||||
@@ -1061,10 +1069,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
return handler.HandleRequest(this.connectionFactory, info);
|
||||
}
|
||||
|
||||
public void InitializeService(IProtocolEndpoint serviceHost)
|
||||
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
|
||||
{
|
||||
this.ServiceHost = serviceHost;
|
||||
|
||||
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
|
||||
{
|
||||
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
|
||||
var provider = new AuthenticationProvider(commandOptions.ApplicationName);
|
||||
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
|
||||
|
||||
this.EnableSqlAuthenticationProvider = true;
|
||||
Logger.Information("Registering implementation of SQL Authentication provider for 'Active Directory Interactive' authentication mode.");
|
||||
}
|
||||
|
||||
// Register request and event handlers with the Service Host
|
||||
serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest, true);
|
||||
serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest, true);
|
||||
@@ -1304,31 +1322,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
|
||||
connectionBuilder = new SqlConnectionStringBuilder
|
||||
{
|
||||
["Data Source"] = dataSource,
|
||||
["User Id"] = connectionDetails.UserName,
|
||||
["Password"] = connectionDetails.Password
|
||||
DataSource = dataSource
|
||||
};
|
||||
}
|
||||
|
||||
// Check for any optional parameters
|
||||
if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
|
||||
{
|
||||
connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
|
||||
connectionBuilder.InitialCatalog = connectionDetails.DatabaseName;
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
|
||||
{
|
||||
switch (connectionDetails.AuthenticationType)
|
||||
{
|
||||
case "Integrated":
|
||||
case Integrated:
|
||||
connectionBuilder.IntegratedSecurity = true;
|
||||
break;
|
||||
case "SqlLogin":
|
||||
case SqlLogin:
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionBuilder.Password = connectionDetails.Password;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.SqlPassword;
|
||||
break;
|
||||
case "AzureMFA":
|
||||
connectionBuilder.UserID = "";
|
||||
connectionBuilder.Password = "";
|
||||
case AzureMFA:
|
||||
if (Instance.EnableSqlAuthenticationProvider)
|
||||
{
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionDetails.AuthenticationType = ActiveDirectoryInteractive;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
|
||||
}
|
||||
else
|
||||
{
|
||||
connectionBuilder.UserID = "";
|
||||
}
|
||||
break;
|
||||
case "ActiveDirectoryPassword":
|
||||
case ActiveDirectoryInteractive:
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
|
||||
break;
|
||||
case ActiveDirectoryPassword:
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionBuilder.Password = connectionDetails.Password;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryPassword;
|
||||
break;
|
||||
default:
|
||||
@@ -1337,16 +1370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting))
|
||||
{
|
||||
switch (connectionDetails.ColumnEncryptionSetting.ToUpper())
|
||||
if (Enum.TryParse<SqlConnectionColumnEncryptionSetting>(connectionDetails.ColumnEncryptionSetting, true, out var value))
|
||||
{
|
||||
case "ENABLED":
|
||||
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Enabled;
|
||||
break;
|
||||
case "DISABLED":
|
||||
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
|
||||
connectionBuilder.ColumnEncryptionSetting = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
|
||||
}
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.SecureEnclaves))
|
||||
@@ -1365,36 +1395,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationProtocol))
|
||||
{
|
||||
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
|
||||
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|
||||
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
|
||||
}
|
||||
|
||||
switch (connectionDetails.EnclaveAttestationProtocol.ToUpper())
|
||||
if (Enum.TryParse<SqlConnectionAttestationProtocol>(connectionDetails.EnclaveAttestationProtocol, true, out var value))
|
||||
{
|
||||
case "AAS":
|
||||
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.AAS;
|
||||
break;
|
||||
case "HGS":
|
||||
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.HGS;
|
||||
break;
|
||||
case "NONE":
|
||||
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.None;
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
|
||||
connectionBuilder.AttestationProtocol = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
|
||||
}
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationUrl))
|
||||
{
|
||||
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
|
||||
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|
||||
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
|
||||
}
|
||||
|
||||
if(connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
|
||||
if (connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAttestationProtocolNoneWithUrl);
|
||||
}
|
||||
@@ -1456,19 +1480,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
|
||||
{
|
||||
ApplicationIntent intent;
|
||||
switch (connectionDetails.ApplicationIntent)
|
||||
if (Enum.TryParse<ApplicationIntent>(connectionDetails.ApplicationIntent, true, out ApplicationIntent value))
|
||||
{
|
||||
case "ReadOnly":
|
||||
intent = ApplicationIntent.ReadOnly;
|
||||
break;
|
||||
case "ReadWrite":
|
||||
intent = ApplicationIntent.ReadWrite;
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
|
||||
connectionBuilder.ApplicationIntent = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
|
||||
}
|
||||
connectionBuilder.ApplicationIntent = intent;
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
|
||||
{
|
||||
@@ -1585,7 +1604,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
ApplicationIntent = builder.ApplicationIntent.ToString(),
|
||||
ApplicationName = builder.ApplicationName,
|
||||
AttachDbFilename = builder.AttachDBFilename,
|
||||
AuthenticationType = builder.IntegratedSecurity ? "Integrated" : "SqlLogin",
|
||||
AuthenticationType = builder.IntegratedSecurity ? "Integrated" :
|
||||
(builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
|
||||
? "ActiveDirectoryInteractive" : "SqlLogin"),
|
||||
ConnectRetryCount = builder.ConnectRetryCount,
|
||||
ConnectRetryInterval = builder.ConnectRetryInterval,
|
||||
ConnectTimeout = builder.ConnectTimeout,
|
||||
@@ -1793,7 +1814,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
SqlConnection sqlConn = new SqlConnection(connectionString);
|
||||
|
||||
// Fill in Azure authentication token if needed
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
|
||||
{
|
||||
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||
}
|
||||
@@ -1824,7 +1845,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
|
||||
ServerConnection serverConnection;
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
|
||||
{
|
||||
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
|
||||
}
|
||||
|
||||
@@ -11,25 +11,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
||||
{
|
||||
class RequestSecurityTokenParams
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the address of the authority to issue token.
|
||||
/// </summary>
|
||||
public string Authority { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the provider that indicates the type of linked account to query.
|
||||
/// </summary>
|
||||
public string Provider { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the authority URL from where token is requested.
|
||||
/// </summary>
|
||||
public string Authority { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the identifier of the target resource that is the recipient of the requested token.
|
||||
/// </summary>
|
||||
public string Resource { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the scope of the authentication request.
|
||||
/// Gets or sets the scope array of the authentication request.
|
||||
/// </summary>
|
||||
public string Scope { get; set; }
|
||||
public string [] Scopes { get; set; }
|
||||
}
|
||||
|
||||
class RequestSecurityTokenResponse
|
||||
|
||||
@@ -8,6 +8,7 @@ using Microsoft.SqlServer.Dac;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.TaskServices;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
@@ -86,7 +87,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DacFx
|
||||
try
|
||||
{
|
||||
// Pass in Azure authentication token if needed
|
||||
this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null ? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken)) : new DacServices(this.ConnectionString);
|
||||
this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null && this.ConnInfo.ConnectionDetails.AuthenticationType == AzureMFA
|
||||
? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken))
|
||||
: new DacServices(this.ConnectionString);
|
||||
Execute();
|
||||
}
|
||||
catch (Exception e)
|
||||
|
||||
@@ -40,6 +40,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment;
|
||||
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||
using Microsoft.SqlTools.ServiceLayer.SqlProjects;
|
||||
using Microsoft.SqlTools.ServiceLayer.TableDesigner;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
using Microsoft.SqlTools.ServiceLayer.Workspace;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer
|
||||
@@ -53,7 +54,7 @@ namespace Microsoft.SqlTools.ServiceLayer
|
||||
private static object lockObject = new object();
|
||||
private static bool isLoaded;
|
||||
|
||||
internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, Stream? inputStream = null, Stream? outputStream = null)
|
||||
internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions, Stream? inputStream = null, Stream? outputStream = null)
|
||||
{
|
||||
ServiceHost serviceHost = ServiceHost.Instance;
|
||||
lock (lockObject)
|
||||
@@ -63,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer
|
||||
// Grab the instance of the service host
|
||||
serviceHost.Initialize(inputStream, outputStream);
|
||||
|
||||
InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext);
|
||||
InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext, commandOptions);
|
||||
|
||||
// Start the service only after all request handlers are setup. This is vital
|
||||
// as otherwise the Initialize event can be lost - it's processed and discarded before the handler
|
||||
@@ -75,7 +76,7 @@ namespace Microsoft.SqlTools.ServiceLayer
|
||||
return serviceHost;
|
||||
}
|
||||
|
||||
private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext)
|
||||
private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions)
|
||||
{
|
||||
// Load extension provider, which currently finds all exports in current DLL. Can be changed to find based
|
||||
// on directory or assembly list quite easily in the future
|
||||
@@ -96,7 +97,7 @@ namespace Microsoft.SqlTools.ServiceLayer
|
||||
LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext);
|
||||
serviceProvider.RegisterSingleService(LanguageService.Instance);
|
||||
|
||||
ConnectionService.Instance.InitializeService(serviceHost);
|
||||
ConnectionService.Instance.InitializeService(serviceHost, commandOptions);
|
||||
serviceProvider.RegisterSingleService(ConnectionService.Instance);
|
||||
|
||||
CredentialService.Instance.InitializeService(serviceHost);
|
||||
|
||||
@@ -57,6 +57,8 @@
|
||||
<ProjectReference Include="../Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj" />
|
||||
<ProjectReference Include="../Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj" />
|
||||
<ProjectReference Include="../Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.SqlTools.Authentication\Microsoft.SqlTools.Authentication.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Content Include="..\..\Notice.txt">
|
||||
|
||||
@@ -400,7 +400,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
|
||||
var builder = ConnectionService.CreateConnectionStringBuilder(session.ConnectionInfo.ConnectionDetails);
|
||||
builder.InitialCatalog = node.NodeValue;
|
||||
builder.ApplicationName = TableDesignerService.TableDesignerApplicationName;
|
||||
var azureToken = session.ConnectionInfo.ConnectionDetails.AzureAccountToken;
|
||||
// Set Access Token only when authentication mode is not specified.
|
||||
var azureToken = builder.Authentication == SqlAuthenticationMethod.NotSpecified
|
||||
? session.ConnectionInfo.ConnectionDetails.AzureAccountToken : null;
|
||||
TableDesignerCacheManager.StartDatabaseModelInitialization(builder.ToString(), azureToken);
|
||||
}
|
||||
catch (Exception ex)
|
||||
|
||||
@@ -41,7 +41,16 @@ namespace Microsoft.SqlTools.ServiceLayer
|
||||
logFilePath = Logger.GenerateLogFilePath("sqltools");
|
||||
}
|
||||
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog);
|
||||
Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog);
|
||||
|
||||
// Register PII Logging configuration change callback
|
||||
Workspace.WorkspaceService<SqlToolsSettings>.Instance.RegisterConfigChangeCallback((newSettings, oldSettings, context) =>
|
||||
{
|
||||
Logger.IsPiiEnabled = newSettings?.MssqlTools?.PiiLogging ?? false;
|
||||
Logger.Information(Logger.IsPiiEnabled ? "PII Logging enabled" : "PII Logging disabled");
|
||||
return Task.FromResult(true);
|
||||
});
|
||||
|
||||
// Only enable SQL Client logging when verbose or higher to avoid extra overhead when the
|
||||
// detailed logging it provides isn't needed
|
||||
if (Logger.TracingLevel.HasFlag(SourceLevels.Verbose))
|
||||
@@ -53,7 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer
|
||||
var hostDetails = new HostDetails(version: new Version(1, 0));
|
||||
|
||||
SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails);
|
||||
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext);
|
||||
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, commandOptions);
|
||||
serviceHost.MessageDispatcher.ParallelMessageProcessing = commandOptions.ParallelMessageProcessing;
|
||||
|
||||
// If this service was started by another process, then it should shutdown when that parent process does.
|
||||
|
||||
@@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.DacFx.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.SchemaCompare.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare
|
||||
@@ -187,7 +188,9 @@ namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare
|
||||
case SchemaCompareEndpointType.Database:
|
||||
{
|
||||
string connectionString = GetConnectionString(connInfo, endpointInfo.DatabaseName);
|
||||
return connInfo.ConnectionDetails?.AzureAccountToken != null
|
||||
|
||||
// Set Access Token only when authentication mode is not specified.
|
||||
return connInfo.ConnectionDetails?.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA
|
||||
? new SchemaCompareDatabaseEndpoint(connectionString, new AccessTokenProvider(connInfo.ConnectionDetails.AzureAccountToken))
|
||||
: new SchemaCompareDatabaseEndpoint(connectionString);
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting;
|
||||
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using System.Linq;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
@@ -109,7 +110,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
||||
if (connInfo != null)
|
||||
{
|
||||
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
|
||||
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||
// Set Access Token only when authentication type is AzureMFA.
|
||||
if (connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
|
||||
{
|
||||
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// group such as Intellisense is defined on a serialized setting it's used in the order of mssql, then sql, then
|
||||
/// falls back to a default value.
|
||||
/// </summary>
|
||||
public class CompoundToolsSettingsValues: ISqlToolsSettingsValues
|
||||
public class CompoundToolsSettingsValues : ISqlToolsSettingsValues
|
||||
{
|
||||
private List<ISqlToolsSettingsValues> priorityList = new List<ISqlToolsSettingsValues>();
|
||||
private SqlToolsSettingsValues defaultValues;
|
||||
@@ -44,11 +44,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// Gets or sets the detailed IntelliSense settings
|
||||
/// </summary>
|
||||
public IntelliSenseSettings IntelliSense
|
||||
{
|
||||
{
|
||||
get
|
||||
{
|
||||
return GetSettingOrDefault((settings) => settings.IntelliSense);
|
||||
}
|
||||
}
|
||||
set
|
||||
{
|
||||
priorityList[0].IntelliSense = value;
|
||||
@@ -59,11 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// Gets or sets the query execution settings
|
||||
/// </summary>
|
||||
public QueryExecutionSettings QueryExecutionSettings
|
||||
{
|
||||
{
|
||||
get
|
||||
{
|
||||
return GetSettingOrDefault((settings) => settings.QueryExecutionSettings);
|
||||
}
|
||||
}
|
||||
set
|
||||
{
|
||||
priorityList[0].QueryExecutionSettings = value;
|
||||
@@ -74,11 +74,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// Gets or sets the formatter settings
|
||||
/// </summary>
|
||||
public FormatterSettings Format
|
||||
{
|
||||
{
|
||||
get
|
||||
{
|
||||
return GetSettingOrDefault((settings) => settings.Format);
|
||||
}
|
||||
}
|
||||
set
|
||||
{
|
||||
priorityList[0].Format = value;
|
||||
@@ -89,15 +89,24 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// Gets or sets the object explorer settings
|
||||
/// </summary>
|
||||
public ObjectExplorerSettings ObjectExplorer
|
||||
{
|
||||
{
|
||||
get
|
||||
{
|
||||
return GetSettingOrDefault((settings) => settings.ObjectExplorer);
|
||||
}
|
||||
}
|
||||
set
|
||||
{
|
||||
priorityList[0].ObjectExplorer = value;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets PII Logging setting.
|
||||
/// </summary>
|
||||
public bool PiiLogging
|
||||
{
|
||||
get => GetSettingOrDefault((settings) => settings.PiiLogging);
|
||||
set => priorityList[0].PiiLogging = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#nullable disable
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
{
|
||||
/// <summary>
|
||||
@@ -31,5 +33,10 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// Object Explorer specific settings
|
||||
/// </summary>
|
||||
ObjectExplorerSettings ObjectExplorer { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// PII Logging setting
|
||||
/// </summary>
|
||||
Boolean PiiLogging { get; set; }
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#nullable disable
|
||||
|
||||
using System;
|
||||
using Newtonsoft.Json;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
@@ -25,6 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
QueryExecutionSettings = new QueryExecutionSettings();
|
||||
Format = new FormatterSettings();
|
||||
TableDesigner = new TableDesignerSettings();
|
||||
PiiLogging = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,5 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
|
||||
/// </summary>
|
||||
[JsonProperty("tableDesigner")]
|
||||
public TableDesignerSettings TableDesigner { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the setting to enable PII Logging.
|
||||
/// </summary>
|
||||
[JsonProperty("piiLogging")]
|
||||
public Boolean PiiLogging { get; set; }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1800,7 +1800,11 @@ namespace Microsoft.SqlTools.ServiceLayer.TableDesigner
|
||||
connectionStringBuilder.InitialCatalog = tableInfo.Database;
|
||||
connectionStringBuilder.ApplicationName = TableDesignerService.TableDesignerApplicationName;
|
||||
var connectionString = connectionStringBuilder.ToString();
|
||||
tableDesigner = new Dac.TableDesigner(connectionString, tableInfo.AccessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable);
|
||||
|
||||
// Set Access Token only when authentication mode is not specified.
|
||||
var accessToken = connectionStringBuilder.Authentication == SqlAuthenticationMethod.NotSpecified
|
||||
? tableInfo.AccessToken : null;
|
||||
tableDesigner = new Dac.TableDesigner(connectionString, accessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@ using Microsoft.SqlTools.Utility;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Utility
|
||||
{
|
||||
class ServiceLayerCommandOptions : CommandOptions
|
||||
public class ServiceLayerCommandOptions : CommandOptions
|
||||
{
|
||||
internal const string ServiceLayerServiceName = "MicrosoftSqlToolsServiceLayer.exe";
|
||||
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<ImplicitUsings>enable</ImplicitUsings>
|
||||
<Nullable>enable</Nullable>
|
||||
</PropertyGroup>
|
||||
|
||||
</Project>
|
||||
18
src/Microsoft.SqlTools.Shared/Utility/Constants.cs
Normal file
18
src/Microsoft.SqlTools.Shared/Utility/Constants.cs
Normal file
@@ -0,0 +1,18 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
namespace Microsoft.SqlTools.Shared.Utility
|
||||
{
|
||||
public class Constants
|
||||
{
|
||||
// Authentication Types
|
||||
public const string Integrated = "Integrated";
|
||||
public const string SqlLogin = "SqlLogin";
|
||||
public const string AzureMFA = "AzureMFA";
|
||||
public const string dstsAuth = "dstsAuth";
|
||||
public const string ActiveDirectoryInteractive = "ActiveDirectoryInteractive";
|
||||
public const string ActiveDirectoryPassword = "ActiveDirectoryPassword";
|
||||
}
|
||||
}
|
||||
@@ -3,11 +3,8 @@
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Globalization;
|
||||
using System.IO;
|
||||
using System.Threading;
|
||||
|
||||
namespace Microsoft.SqlTools.Utility
|
||||
{
|
||||
@@ -29,41 +26,45 @@ namespace Microsoft.SqlTools.Utility
|
||||
/// </summary>
|
||||
public static class Logger
|
||||
{
|
||||
internal const SourceLevels defaultTracingLevel = SourceLevels.Critical;
|
||||
internal const string defaultTraceSource = "sqltools";
|
||||
public const SourceLevels defaultTracingLevel = SourceLevels.Critical;
|
||||
public const string defaultTraceSource = "sqltools";
|
||||
private static SourceLevels tracingLevel = defaultTracingLevel;
|
||||
private static string logFileFullPath;
|
||||
private static string? logFileFullPath;
|
||||
|
||||
public static TraceSource? TraceSource { get; set; }
|
||||
|
||||
internal static TraceSource TraceSource { get; set; }
|
||||
public static string LogFileFullPath
|
||||
{
|
||||
get => logFileFullPath;
|
||||
get => logFileFullPath!;
|
||||
private set
|
||||
{
|
||||
//If the log file path has a directory component then ensure that the directory exists.
|
||||
if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value)))
|
||||
if (value != null)
|
||||
{
|
||||
Directory.CreateDirectory(Path.GetDirectoryName(value));
|
||||
}
|
||||
//If the log file path has a directory component then ensure that the directory exists.
|
||||
if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value)))
|
||||
{
|
||||
Directory.CreateDirectory(Path.GetDirectoryName(value)!);
|
||||
}
|
||||
|
||||
logFileFullPath = value;
|
||||
logFileFullPath = value;
|
||||
}
|
||||
ConfigureListener();
|
||||
}
|
||||
}
|
||||
|
||||
private static SqlToolsTraceListener Listener { get; set; }
|
||||
private static SqlToolsTraceListener? Listener { get; set; }
|
||||
|
||||
private static void ConfigureLogFile(string logFilePrefix) => LogFileFullPath = GenerateLogFilePath(logFilePrefix);
|
||||
|
||||
/// <summary>
|
||||
/// Calling this method will turn on inclusion CallStack in the log for all future traces
|
||||
/// </summary>
|
||||
public static void StartCallStack() => Listener.TraceOutputOptions |= TraceOptions.Callstack;
|
||||
public static void StartCallStack() => Listener!.TraceOutputOptions |= TraceOptions.Callstack;
|
||||
|
||||
/// <summary>
|
||||
/// Calling this method will turn off inclusion of CallStack in the log for all future traces
|
||||
/// </summary>
|
||||
public static void StopCallStack() => Listener.TraceOutputOptions &= ~TraceOptions.Callstack;
|
||||
public static void StopCallStack() => Listener!.TraceOutputOptions &= ~TraceOptions.Callstack;
|
||||
|
||||
/// <summary>
|
||||
/// Calls flush on defaultTracingLevel configured listeners.
|
||||
@@ -86,17 +87,19 @@ namespace Microsoft.SqlTools.Utility
|
||||
get => tracingLevel;
|
||||
set
|
||||
{
|
||||
if(TraceSource != null)
|
||||
if (TraceSource != null)
|
||||
{
|
||||
// configures the source level filter. This alone is not enough for tracing that is done via "Trace" class instead of "TraceSource" object
|
||||
TraceSource.Switch = new SourceSwitch(TraceSource.Name, value.ToString());
|
||||
}
|
||||
// configure the listener level filter
|
||||
tracingLevel = value;
|
||||
Listener.Filter = new EventTypeFilter(tracingLevel);
|
||||
Listener!.Filter = new EventTypeFilter(tracingLevel);
|
||||
}
|
||||
}
|
||||
|
||||
public static bool IsPiiEnabled { get; set; } = false;
|
||||
|
||||
public static bool AutoFlush { get; set; } = false;
|
||||
|
||||
/// <summary>
|
||||
@@ -115,11 +118,13 @@ namespace Microsoft.SqlTools.Utility
|
||||
/// </param>
|
||||
public static void Initialize(
|
||||
SourceLevels tracingLevel = defaultTracingLevel,
|
||||
string logFilePath = null,
|
||||
bool piiEnabled = false,
|
||||
string? logFilePath = null,
|
||||
string traceSource = defaultTraceSource,
|
||||
bool autoFlush = false)
|
||||
{
|
||||
Logger.tracingLevel = tracingLevel;
|
||||
Logger.IsPiiEnabled = piiEnabled;
|
||||
Logger.AutoFlush = autoFlush;
|
||||
TraceSource = new TraceSource(traceSource, Logger.tracingLevel);
|
||||
if (string.IsNullOrWhiteSpace(logFilePath))
|
||||
@@ -146,11 +151,12 @@ namespace Microsoft.SqlTools.Utility
|
||||
/// <param name="autoFlush">
|
||||
/// Optional. Specifies whether the log is flushed after every message
|
||||
/// </param>
|
||||
public static void Initialize(string tracingLevel, string logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false)
|
||||
public static void Initialize(string tracingLevel, bool piiEnabled, string? logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false)
|
||||
{
|
||||
Initialize(Enum.TryParse<SourceLevels>(tracingLevel, out SourceLevels sourceTracingLevel)
|
||||
? sourceTracingLevel
|
||||
: defaultTracingLevel
|
||||
, piiEnabled
|
||||
, logFilePath
|
||||
, traceSource
|
||||
, autoFlush);
|
||||
@@ -169,7 +175,7 @@ namespace Microsoft.SqlTools.Utility
|
||||
throw new ArgumentOutOfRangeException(nameof(logFilePrefix), $"LogfilePath cannot be configured if argument {nameof(logFilePrefix)} has not been set");
|
||||
}
|
||||
// Create the log directory
|
||||
string logDir = Path.GetDirectoryName(logFilePrefix);
|
||||
string? logDir = Path.GetDirectoryName(logFilePrefix);
|
||||
if (!string.IsNullOrWhiteSpace(logDir))
|
||||
{
|
||||
if (!Directory.Exists(logDir))
|
||||
@@ -199,7 +205,7 @@ namespace Microsoft.SqlTools.Utility
|
||||
}
|
||||
|
||||
string fileName;
|
||||
try
|
||||
try
|
||||
{
|
||||
var now = DateTime.Now;
|
||||
fileName = string.Format(CultureInfo.InvariantCulture,
|
||||
@@ -239,6 +245,16 @@ namespace Microsoft.SqlTools.Utility
|
||||
/// <param name="logMessage">The message text to be written.</param>
|
||||
public static void Write(TraceEventType eventType, string logMessage) => Write(eventType, LogEvent.Default, logMessage);
|
||||
|
||||
/// <summary>
|
||||
/// Writes a PII message to the log file with the Verbose event level when PII flag is enabled.
|
||||
/// </summary>
|
||||
/// <param name="logMessage">The message text to be written.</param>
|
||||
public static void Pii(string logMessage) {
|
||||
if (IsPiiEnabled) {
|
||||
Write(TraceEventType.Verbose, logMessage);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Writes a message to the log file with the Verbose event level
|
||||
/// </summary>
|
||||
@@ -386,9 +402,9 @@ namespace Microsoft.SqlTools.Utility
|
||||
);
|
||||
}
|
||||
#region forward actual write/close/flush/dispose calls to the underlying listener.
|
||||
public override void Write(string message) => Listener.Write(message);
|
||||
public override void Write(string? message) => Listener.Write(message);
|
||||
|
||||
public override void WriteLine(string message) => Listener.WriteLine(message);
|
||||
public override void WriteLine(string? message) => Listener.WriteLine(message);
|
||||
|
||||
/// <Summary>
|
||||
/// Closes the <see cref="System.Diagnostics.TextWriterTraceListener.Writer"> so that it no longer
|
||||
@@ -430,41 +446,41 @@ namespace Microsoft.SqlTools.Utility
|
||||
}
|
||||
#endregion
|
||||
|
||||
public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id)
|
||||
public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id)
|
||||
{
|
||||
TraceEvent(eventCache, source, eventType, id, String.Empty);
|
||||
}
|
||||
|
||||
// All other TraceEvent methods come through this one.
|
||||
public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string message)
|
||||
public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? message)
|
||||
{
|
||||
if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, message, null, null, null))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
WriteHeader(eventCache, source, eventType, id);
|
||||
WriteLine(message);
|
||||
WriteFooter(eventCache);
|
||||
WriteHeader(eventCache!, source, eventType, id);
|
||||
WriteLine(message!);
|
||||
WriteFooter(eventCache!);
|
||||
}
|
||||
|
||||
public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string format, params object[] args)
|
||||
public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? format, params object?[]? args)
|
||||
{
|
||||
if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, format, args, null, null))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
WriteHeader(eventCache, source, eventType, id);
|
||||
WriteHeader(eventCache!, source, eventType, id);
|
||||
if (args != null)
|
||||
{
|
||||
WriteLine(String.Format(CultureInfo.InvariantCulture, format, args));
|
||||
WriteLine(String.Format(CultureInfo.InvariantCulture, format!, args));
|
||||
}
|
||||
else
|
||||
{
|
||||
WriteLine(format);
|
||||
WriteLine(format!);
|
||||
}
|
||||
WriteFooter(eventCache);
|
||||
WriteFooter(eventCache!);
|
||||
}
|
||||
|
||||
private void WriteHeader(TraceEventCache eventCache, String source, TraceEventType eventType, int id)
|
||||
@@ -12,14 +12,15 @@ using Microsoft.Kusto.ServiceLayer.DataSource;
|
||||
using Microsoft.Kusto.ServiceLayer.DataSource.Intellisense;
|
||||
using Microsoft.Kusto.ServiceLayer.LanguageServices;
|
||||
using Microsoft.Kusto.ServiceLayer.Workspace.Contracts;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using NUnit.Framework;
|
||||
|
||||
namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
|
||||
{
|
||||
public class DataSourceFactoryTests
|
||||
{
|
||||
[TestCase(typeof(ArgumentException), "ConnectionString", "", "AzureMFA")]
|
||||
[TestCase(typeof(ArgumentException), "ConnectionString", "", "dstsAuth")]
|
||||
[TestCase(typeof(ArgumentException), "ConnectionString", "", AzureMFA)]
|
||||
[TestCase(typeof(ArgumentException), "ConnectionString", "", dstsAuth)]
|
||||
public void Create_Throws_Exceptions_For_InvalidAzureAccountToken(Type exceptionType, string connectionString, string azureAccountToken, string authType)
|
||||
{
|
||||
Program.ServiceName = "Kusto";
|
||||
|
||||
@@ -8,14 +8,15 @@
|
||||
using System;
|
||||
using Microsoft.Kusto.ServiceLayer.DataSource.Contracts;
|
||||
using Microsoft.Kusto.ServiceLayer.DataSource.Kusto;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using NUnit.Framework;
|
||||
|
||||
namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
|
||||
{
|
||||
public class KustoClientTests
|
||||
{
|
||||
[TestCase("dstsAuth")]
|
||||
[TestCase("AzureMFA")]
|
||||
[TestCase(dstsAuth)]
|
||||
[TestCase(AzureMFA)]
|
||||
public void Constructor_Throws_ArgumentException_For_MissingToken(string authType)
|
||||
{
|
||||
var connectionDetails = new DataSourceConnectionDetails
|
||||
@@ -37,7 +38,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
|
||||
UserToken = "UserToken",
|
||||
ServerName = clusterName,
|
||||
DatabaseName = "",
|
||||
AuthenticationType = "AzureMFA"
|
||||
AuthenticationType = AzureMFA
|
||||
};
|
||||
|
||||
var client = new KustoClient(connectionDetails, "ownerUri");
|
||||
@@ -46,10 +47,10 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
|
||||
Assert.AreEqual("NetDefaultDB", client.DatabaseName);
|
||||
}
|
||||
|
||||
[TestCase("dstsAuth")]
|
||||
[TestCase("AzureMFA")]
|
||||
[TestCase(dstsAuth)]
|
||||
[TestCase(AzureMFA)]
|
||||
[TestCase("NoAuth")]
|
||||
[TestCase("SqlLogin")]
|
||||
[TestCase(SqlLogin)]
|
||||
public void Constructor_Creates_Client_With_Valid_AuthenticationType(string authenticationType)
|
||||
{
|
||||
string clusterName = "https://fake.url.com";
|
||||
@@ -59,8 +60,8 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
|
||||
ServerName = clusterName,
|
||||
DatabaseName = "FakeDatabaseName",
|
||||
AuthenticationType = authenticationType,
|
||||
UserName = authenticationType == "SqlLogin" ? "username": null,
|
||||
Password = authenticationType == "SqlLogin" ? "password": null
|
||||
UserName = authenticationType == SqlLogin ? "username": null,
|
||||
Password = authenticationType == SqlLogin ? "password": null
|
||||
};
|
||||
|
||||
var client = new KustoClient(connectionDetails, "ownerUri");
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
<IsPackable>false</IsPackable>
|
||||
</PropertyGroup>
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Moq" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" />
|
||||
<PackageReference Include="Moq" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" />
|
||||
<PackageReference Include="nunit" />
|
||||
<PackageReference Include="nunit3testadapter" />
|
||||
<PackageReference Include="nunit.console" />
|
||||
@@ -25,5 +25,6 @@
|
||||
<ProjectReference Include="../../src/Microsoft.SqlTools.Credentials/Microsoft.SqlTools.Credentials.csproj" />
|
||||
<ProjectReference Include="../../src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj" />
|
||||
<ProjectReference Include="../../test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
@@ -32,6 +32,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
|
||||
public SourceLevels TracingLevel { get; set; } = SourceLevels.Critical;
|
||||
public bool DoNotUseTraceSource { get; set; } = false;
|
||||
|
||||
public bool IsPiiEnabled { get; set; } = false;
|
||||
|
||||
public bool AutoFlush { get; set; } = false;
|
||||
|
||||
private List<Action> pendingVerifications;
|
||||
@@ -43,7 +45,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
|
||||
|
||||
public string LogFileName { get => logFileName ?? Logger.LogFileFullPath; set => logFileName = value; }
|
||||
public void Initialize() =>
|
||||
Logger.Initialize(TracingLevel, LogFilePath, TraceSource, AutoFlush); // initialize the logger
|
||||
Logger.Initialize(TracingLevel, IsPiiEnabled, LogFilePath, TraceSource, AutoFlush); // initialize the logger
|
||||
public string LogContents
|
||||
{
|
||||
get
|
||||
|
||||
@@ -195,7 +195,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
|
||||
|
||||
// Initialize the ServiceHost, using a MemoryStream for the output stream so that we don't fill up the logs
|
||||
// with a bunch of outgoing messages (which aren't used for anything during tests)
|
||||
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, new MemoryStream());
|
||||
ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, null, new MemoryStream());
|
||||
|
||||
// Set up our logger to write to Console for tests to help debug issues
|
||||
Logger.Initialize(autoFlush: true);
|
||||
|
||||
@@ -12,6 +12,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Test.Common;
|
||||
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using Moq;
|
||||
using Moq.Protected;
|
||||
using NUnit.Framework;
|
||||
@@ -479,6 +480,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
new object[] {null, "12345678"},
|
||||
new object[] {"", "12345678"},
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Verify that when using integrated authentication, the username and/or password can be empty.
|
||||
/// </summary>
|
||||
@@ -497,13 +499,67 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
DatabaseName = "test",
|
||||
UserName = userName,
|
||||
Password = password,
|
||||
AuthenticationType = "Integrated"
|
||||
AuthenticationType = Integrated
|
||||
}
|
||||
});
|
||||
|
||||
Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that username is required when using Active Directory Interactive authentication.
|
||||
/// Both AzureMFA and ActiveDirectoryInteractive should work same way, when SqlAuthenticationProvider is enabled.
|
||||
/// </summary>
|
||||
[TestCase(null, AzureMFA)]
|
||||
[TestCase(null, ActiveDirectoryInteractive)]
|
||||
public async Task ConnectingWitNoUsernameFailsForAADInteractiveAuth(string userName, string authMode)
|
||||
{
|
||||
// This is an exception scenario test, therefore using ConnectionService instance instead of TestConnectionService directly.
|
||||
ConnectionService.Instance.EnableSqlAuthenticationProvider = true;
|
||||
// Connect
|
||||
var connectionResult = await
|
||||
TestObjects.GetTestConnectionService()
|
||||
.Connect(new ConnectParams()
|
||||
{
|
||||
OwnerUri = "file:///my/test/file.sql",
|
||||
Connection = new ConnectionDetails()
|
||||
{
|
||||
ServerName = "my-server",
|
||||
DatabaseName = "test",
|
||||
UserName = userName,
|
||||
AuthenticationType = authMode
|
||||
}
|
||||
});
|
||||
|
||||
Assert.That(connectionResult.ConnectionId, Is.Null.Or.Empty, "Connection should not be successful");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that password is ignored when using Active Directory Interactive authentication.
|
||||
/// </summary>
|
||||
[TestCase("user", "anything", AzureMFA)]
|
||||
[TestCase("user", "anything", ActiveDirectoryInteractive)]
|
||||
public async Task ConnectingWitPasswordIsIgnoredForAADInteractiveAuth(string username, string password, string authMode)
|
||||
{
|
||||
// Connect
|
||||
var connectionResult = await
|
||||
TestObjects.GetTestConnectionService()
|
||||
.Connect(new ConnectParams()
|
||||
{
|
||||
OwnerUri = "file:///my/test/file.sql",
|
||||
Connection = new ConnectionDetails()
|
||||
{
|
||||
ServerName = "my-server",
|
||||
DatabaseName = "test",
|
||||
UserName = username,
|
||||
Password = password,
|
||||
AuthenticationType = authMode
|
||||
}
|
||||
});
|
||||
|
||||
Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "Connection should be successful");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Verify that when connecting with a null parameters object, an error is thrown.
|
||||
/// </summary>
|
||||
@@ -1773,16 +1829,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
details.AzureAccountToken = azureAccountToken;
|
||||
details.UserName = "";
|
||||
details.Password = "";
|
||||
details.AuthenticationType = "AzureMFA";
|
||||
details.AuthenticationType = AzureMFA;
|
||||
|
||||
// If I open a connection using connection details that include an account token
|
||||
// Open a connection using connection details that include an account token
|
||||
await connectionService.Connect(new ConnectParams
|
||||
{
|
||||
OwnerUri = "testURI",
|
||||
Connection = details
|
||||
});
|
||||
|
||||
// Then the connection factory got called with details including an account token
|
||||
// Validate that the connection factory gets called with details NOT including an account token
|
||||
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
<ProjectReference Include="../../src/Microsoft.SqlTools.ResourceProvider.Core/Microsoft.SqlTools.ResourceProvider.Core.csproj" />
|
||||
<ProjectReference Include="../../src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Microsoft.SqlTools.ResourceProvider.DefaultImpl.csproj" />
|
||||
<ProjectReference Include="../../test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Service Include="{82a7f48d-3b50-4b1e-b82e-3ada8210c358}" />
|
||||
|
||||
@@ -109,7 +109,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
|
||||
string tracingLevel = null;
|
||||
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
|
||||
string expectedTraceSource = Logger.defaultTraceSource;
|
||||
Logger.Initialize(tracingLevel);
|
||||
Logger.Initialize(tracingLevel, false);
|
||||
bool isLogFileExpectedToExist = false;
|
||||
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
|
||||
TestLogger.Cleanup(Logger.LogFileFullPath);
|
||||
@@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
|
||||
string tracingLevel = null;
|
||||
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
|
||||
string expectedTraceSource = Logger.defaultTraceSource;
|
||||
Logger.Initialize(tracingLevel);
|
||||
Logger.Initialize(tracingLevel, false);
|
||||
bool isLogFileExpectedToExist = false;
|
||||
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
|
||||
TestLogger.Cleanup(Logger.LogFileFullPath);
|
||||
@@ -131,7 +131,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
|
||||
string tracingLevel = "invalid";
|
||||
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
|
||||
string expectedTraceSource = Logger.defaultTraceSource;
|
||||
Logger.Initialize(tracingLevel);
|
||||
Logger.Initialize(tracingLevel, false);
|
||||
bool isLogFileExpectedToExist = false;
|
||||
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
|
||||
TestLogger.Cleanup(Logger.LogFileFullPath);
|
||||
|
||||
Reference in New Issue
Block a user