From 187b6ecc14d3035939d406b8097017e4587ad4ae Mon Sep 17 00:00:00 2001
From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com>
Date: Thu, 2 Mar 2023 09:39:54 -0800
Subject: [PATCH] Introduce AAD interactive auth mode (#1860)
---
Packages.props | 2 +
sqltoolsservice.sln | 22 ++-
.../ConnectionProviderOptionsHelper.cs | 3 +-
.../Connection/ConnectionService.cs | 3 +-
.../DataSource/DataSourceFactory.cs | 3 +-
.../DataSource/Kusto/KustoClient.cs | 7 +-
.../Microsoft.Kusto.ServiceLayer.csproj | 1 +
src/Microsoft.Kusto.ServiceLayer/Program.cs | 2 +-
.../AccessToken.cs | 33 ++++
.../AuthenticationMethod.cs | 15 ++
.../AuthenticationParams.cs | 82 +++++++++
.../Authenticator.cs | 158 ++++++++++++++++++
.../Microsoft.SqlTools.Authentication.csproj | 18 ++
.../Sql/AuthenticationProvider.cs | 110 ++++++++++++
.../Utility/Utils.cs | 119 +++++++++++++
src/Microsoft.SqlTools.Credentials/Program.cs | 2 +-
.../Microsoft.SqlTools.Hosting.csproj | 3 +
.../Utility/CommandOptions.cs | 24 +++
.../ReliableSqlConnection.cs | 3 +-
src/Microsoft.SqlTools.Migration/Program.cs | 2 +-
.../Program.cs | 2 +-
.../ConnectionProviderOptionsHelper.cs | 3 +-
.../Connection/ConnectionService.cs | 129 ++++++++------
.../Contracts/SecurityTokenRequest.cs | 14 +-
.../DacFx/DacFxOperation.cs | 5 +-
.../HostLoader.cs | 9 +-
.../Microsoft.SqlTools.ServiceLayer.csproj | 2 +
.../ObjectExplorer/ObjectExplorerService.cs | 4 +-
.../Program.cs | 13 +-
.../SchemaCompare/SchemaCompareUtils.cs | 5 +-
.../Scripting/ScriptingService.cs | 7 +-
.../CompoundSqlToolsSettingsValues.cs | 27 ++-
.../SqlContext/ISqlToolsSettingsValues.cs | 7 +
.../SqlContext/SqlToolsSettingsValues.cs | 8 +
.../TableDesigner/TableDesignerService.cs | 6 +-
.../Utility/ServiceLayerCommandOptions.cs | 2 +-
.../Microsoft.SqlTools.Shared.csproj | 8 +
.../Utility/Constants.cs | 18 ++
.../Utility/Logger.cs | 84 ++++++----
.../DataSource/DataSourceFactoryTests.cs | 5 +-
.../DataSource/KustoClientTests.cs | 17 +-
...t.SqlTools.ServiceLayer.Test.Common.csproj | 5 +-
.../TestLogger.cs | 4 +-
.../TestServiceProvider.cs | 2 +-
.../Connection/ConnectionServiceTests.cs | 64 ++++++-
...oft.SqlTools.ServiceLayer.UnitTests.csproj | 1 +
.../ServiceHost/LoggerTests.cs | 6 +-
47 files changed, 918 insertions(+), 151 deletions(-)
create mode 100644 src/Microsoft.SqlTools.Authentication/AccessToken.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/Authenticator.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj
create mode 100644 src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
create mode 100644 src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj
create mode 100644 src/Microsoft.SqlTools.Shared/Utility/Constants.cs
rename src/{Microsoft.SqlTools.Hosting => Microsoft.SqlTools.Shared}/Utility/Logger.cs (87%)
diff --git a/Packages.props b/Packages.props
index 5b389bf4..aeb1d6e7 100644
--- a/Packages.props
+++ b/Packages.props
@@ -15,6 +15,8 @@
+
+
diff --git a/sqltoolsservice.sln b/sqltoolsservice.sln
index 0364baa9..47941438 100644
--- a/sqltoolsservice.sln
+++ b/sqltoolsservice.sln
@@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
-# Visual Studio Version 16
-VisualStudioVersion = 16.0.29409.12
+# Visual Studio Version 17
+VisualStudioVersion = 17.5.33209.295
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{2BBD7364-054F-4693-97CD-1C395E3E84A9}"
ProjectSection(SolutionItems) = preProject
@@ -96,6 +96,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migratio
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migration.IntegrationTests", "test\Microsoft.SqlTools.Migration.IntegrationTests\Microsoft.SqlTools.Migration.IntegrationTests.csproj", "{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Authentication", "src\Microsoft.SqlTools.Authentication\Microsoft.SqlTools.Authentication.csproj", "{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Shared", "src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj", "{531EC0E0-F400-42C5-BCFA-6691313B5F3E}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -229,6 +233,18 @@ Global
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Integration|Any CPU.Build.0 = Debug|Any CPU
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.Build.0 = Release|Any CPU
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.ActiveCfg = Debug|Any CPU
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.Build.0 = Debug|Any CPU
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.Build.0 = Release|Any CPU
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.ActiveCfg = Debug|Any CPU
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.Build.0 = Debug|Any CPU
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -256,6 +272,8 @@ Global
{07296730-DAB7-4B0B-9D09-ABD9A5025D68} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
{22DB0C12-6848-4503-AD1C-DAD6A1D631AE} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4}
+ {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
+ {531EC0E0-F400-42C5-BCFA-6691313B5F3E} = {2BBD7364-054F-4693-97CD-1C395E3E84A9}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {B31CDF4B-2851-45E5-8C5F-BE97125D9DD8}
diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs
index cdd9b04a..de9e4a73 100644
--- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs
+++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs
@@ -4,6 +4,7 @@
//
using Microsoft.SqlTools.Hosting.Contracts;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.Connection
{
@@ -50,7 +51,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection
CategoryValues = new CategoryValue[]
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
- new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
+ new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA }
},
IsIdentity = true,
IsRequired = true,
diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs
index 54b5aba2..c7cc33c0 100644
--- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs
@@ -19,6 +19,7 @@ using System.Diagnostics;
using Microsoft.Kusto.ServiceLayer.DataSource;
using Microsoft.Kusto.ServiceLayer.DataSource.Metadata;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.Connection
{
@@ -916,7 +917,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection
return new ConnectionDetails
{
ApplicationName = builder.ApplicationNameForTracing,
- AuthenticationType = "AzureMFA",
+ AuthenticationType = AzureMFA,
DatabaseName = builder.InitialCatalog,
ServerName = builder.DataSource,
UserName = builder.UserID,
diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs
index 3a3e337c..1f140297 100644
--- a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs
+++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs
@@ -20,6 +20,7 @@ using Microsoft.Kusto.ServiceLayer.LanguageServices;
using Microsoft.Kusto.ServiceLayer.Utility;
using Microsoft.Kusto.ServiceLayer.Workspace.Contracts;
using CompletionItem = Microsoft.Kusto.ServiceLayer.LanguageServices.Contracts.CompletionItem;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.DataSource
{
@@ -63,7 +64,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource
private DataSourceConnectionDetails MapKustoConnectionDetails(ConnectionDetails connectionDetails)
{
- if (connectionDetails.AuthenticationType == "dstsAuth" || connectionDetails.AuthenticationType == "AzureMFA")
+ if (connectionDetails.AuthenticationType == dstsAuth || connectionDetails.AuthenticationType == AzureMFA)
{
ValidationUtils.IsTrue(!string.IsNullOrWhiteSpace(connectionDetails.AccountToken),
$"The Kusto User Token is not specified - set {nameof(connectionDetails.AccountToken)}");
diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs
index 660937b5..ac671977 100644
--- a/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs
+++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs
@@ -21,6 +21,7 @@ using Kusto.Language.Editor;
using Microsoft.Kusto.ServiceLayer.Connection;
using Microsoft.Kusto.ServiceLayer.DataSource.Contracts;
using Microsoft.Kusto.ServiceLayer.Utility;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
{
@@ -74,7 +75,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
ServerName = ClusterName,
DatabaseName = DatabaseName,
UserToken = accountToken,
- AuthenticationType = "AzureMFA"
+ AuthenticationType = AzureMFA
};
Initialize(connectionDetails);
@@ -96,8 +97,8 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto
switch (connectionDetails.AuthenticationType)
{
- case "AzureMFA": return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken);
- case "dstsAuth": return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken);
+ case AzureMFA: return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken);
+ case dstsAuth: return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken);
default:
return string.IsNullOrWhiteSpace(connectionDetails.UserName) && string.IsNullOrWhiteSpace(connectionDetails.Password)
? stringBuilder
diff --git a/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj b/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj
index e646918e..80501dcf 100644
--- a/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj
+++ b/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj
@@ -51,6 +51,7 @@
+
diff --git a/src/Microsoft.Kusto.ServiceLayer/Program.cs b/src/Microsoft.Kusto.ServiceLayer/Program.cs
index 91cb0749..02573e65 100644
--- a/src/Microsoft.Kusto.ServiceLayer/Program.cs
+++ b/src/Microsoft.Kusto.ServiceLayer/Program.cs
@@ -41,7 +41,7 @@ namespace Microsoft.Kusto.ServiceLayer
logFilePath = Logger.GenerateLogFilePath("kustoservice");
}
- Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog);
+ Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog);
// set up the host details and profile paths
var hostDetails = new HostDetails(version: new Version(1, 0));
diff --git a/src/Microsoft.SqlTools.Authentication/AccessToken.cs b/src/Microsoft.SqlTools.Authentication/AccessToken.cs
new file mode 100644
index 00000000..b698ce87
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/AccessToken.cs
@@ -0,0 +1,33 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.Authentication
+{
+ ///
+ /// Represents an access token data object.
+ ///
+ public class AccessToken
+ {
+ ///
+ /// OAuth 2.0 JWT encoded access token string
+ ///
+ public string Token { get; set; }
+
+ ///
+ /// Expiry date of token
+ ///
+ public DateTimeOffset ExpiresOn { get; set; }
+
+ ///
+ /// Default constructor for Access Token object
+ ///
+ /// Access token as string
+ /// Expiry date
+ public AccessToken(string token, DateTimeOffset expiresOn) {
+ this.Token = token;
+ this.ExpiresOn = expiresOn;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs b/src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs
new file mode 100644
index 00000000..606507dc
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs
@@ -0,0 +1,15 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.Authentication
+{
+ ///
+ /// Supported Active Directory authentication modes
+ ///
+ public enum AuthenticationMethod
+ {
+ ActiveDirectoryInteractive
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs b/src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs
new file mode 100644
index 00000000..cbfa5995
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs
@@ -0,0 +1,82 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.Authentication
+{
+ ///
+ /// Parameters to be passed to to request an access token
+ ///
+ public class AuthenticationParams
+ {
+ ///
+ /// Authentication method to be used by .
+ ///
+ public AuthenticationMethod AuthenticationMethod { get; set; }
+
+ ///
+ /// Authority URL, e.g. https://login.microsoftonline.com/
+ ///
+ public string Authority { get; set; }
+
+ ///
+ /// Audience for which access token should be acquired, e.g. common, organizations, consumers.
+ /// It can also be a tenant Id when authenticating multi-tenant application accounts.
+ ///
+ public string Audience { get; set; }
+
+ ///
+ /// Resource URL, e.g. https://database.windows.net/
+ ///
+ public string Resource { get; set; }
+
+ ///
+ /// Array of scopes for which access token is requested.
+ ///
+ public string[] Scopes { get; set; }
+
+ ///
+ /// Connection Id, that will be passed to Azure AD when requesting access token.
+ /// It can be used for tracking AAD request status if needed.
+ ///
+ public Guid ConnectionId { get; set; }
+
+ ///
+ /// User name to be provided as userhint when acquiring access token.
+ ///
+ public string UserName { get; set; }
+
+ ///
+ /// Default constructor
+ ///
+ /// Authentication Method to be used.
+ /// Authority URL
+ /// Audience
+ /// Resource for which token is requested.
+ /// Scopes for access token
+ /// User hint information
+ /// Connection Id for tracing AAD request
+ public AuthenticationParams(AuthenticationMethod authMethod, string authority, string audience,
+ string resource, string[] scopes, string userName, Guid connectionId) {
+ this.AuthenticationMethod = authMethod;
+ this.Authority = authority;
+ this.Audience = audience;
+ this.Resource = resource;
+ this.Scopes = scopes;
+ this.UserName = userName;
+ this.ConnectionId = connectionId;
+ }
+
+ public string ToLogString(bool piiEnabled)
+ {
+ return $"\tAuthenticationMethod: {AuthenticationMethod}" +
+ $"\n\tAuthority: {Authority}" +
+ $"\n\tAudience: {Audience}" +
+ $"\n\tResource: {Resource}" +
+ $"\n\tScopes: {Scopes.ToString()}" +
+ (piiEnabled ? $"\n\tUsername: {UserName}" : "") +
+ $"\n\tConnection Id: {ConnectionId}";
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/Authenticator.cs b/src/Microsoft.SqlTools.Authentication/Authenticator.cs
new file mode 100644
index 00000000..d149a85f
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/Authenticator.cs
@@ -0,0 +1,158 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Collections.Concurrent;
+using Microsoft.Identity.Client;
+using Microsoft.Identity.Client.Extensions.Msal;
+using Microsoft.SqlTools.Authentication.Utility;
+using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
+
+namespace Microsoft.SqlTools.Authentication
+{
+ ///
+ /// Provides APIs to acquire access token using MSAL.NET v4 with provided .
+ ///
+ public class Authenticator
+ {
+ private string applicationClientId;
+ private string applicationName;
+ private string cacheFolderPath;
+ private string cacheFileName;
+ private MsalCacheHelper cacheHelper;
+ private static ConcurrentDictionary PublicClientAppMap
+ = new ConcurrentDictionary();
+
+ #region Public APIs
+ public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName)
+ {
+ this.applicationClientId = appClientId;
+ this.applicationName = appName;
+ this.cacheFolderPath = cacheFolderPath;
+ this.cacheFileName = cacheFileName;
+
+ // Storage creation properties are used to enable file system caching with MSAL
+ var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath)
+ .WithUnprotectedFile().Build();
+
+ // This hooks up the cross-platform cache into MSAL
+ this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult();
+ }
+
+ ///
+ /// Acquires access token synchronously.
+ ///
+ /// Authentication parameters to be used for access token request.
+ /// Cancellation token.
+ /// Access Token with expiry date
+ public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken) =>
+ GetTokenAsync(@params, cancellationToken).GetAwaiter().GetResult();
+
+ ///
+ /// Acquires access token asynchronously.
+ ///
+ /// Authentication parameters to be used for access token request.
+ /// Cancellation token.
+ /// Access Token with expiry date
+ ///
+ public async Task GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken)
+ {
+ SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Received @params: {@params.ToLogString(SqlToolsLogger.IsPiiEnabled)}");
+
+ IPublicClientApplication publicClientApplication = GetPublicClientAppInstance(@params.Authority, @params.Audience);
+
+ AccessToken? accessToken;
+ if (@params.AuthenticationMethod == AuthenticationMethod.ActiveDirectoryInteractive)
+ {
+ // Find account
+ IEnumerator? accounts = (await publicClientApplication.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
+ IAccount? account = default;
+
+ if (!string.IsNullOrEmpty(@params.UserName) && accounts.MoveNext())
+ {
+ // Handle username format to extract email: "John Doe - johndoe@constoso.com"
+ string username = @params.UserName.Contains(" - ") ? @params.UserName.Split(" - ")[1] : @params.UserName;
+ if (!Utils.isValidEmail(username))
+ {
+ SqlToolsLogger.Pii($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Unexpected username format, email not retreived: {@params.UserName}. " +
+ $"Accepted formats are: 'johndoe@org.com' or 'John Doe - johndoe@org.com'.");
+ throw new Exception($"Invalid email address format for user: [{username}] received for Azure Active Directory authentication.");
+ }
+
+ do
+ {
+ IAccount? currentVal = accounts.Current;
+ if (string.Compare(username, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
+ {
+ account = currentVal;
+ SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account found in MSAL Cache: {account.HomeAccountId}");
+ break;
+ }
+ }
+ while (accounts.MoveNext());
+
+ if (null != account)
+ {
+ try
+ {
+ // Fetch token silently
+ var result = await publicClientApplication.AcquireTokenSilent(@params.Scopes, account)
+ .ExecuteAsync(cancellationToken: cancellationToken)
+ .ConfigureAwait(false);
+ accessToken = new AccessToken(result!.AccessToken, result!.ExpiresOn);
+ }
+ catch (Exception e)
+ {
+ SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Silent authentication failed for resource {@params.Resource} for ConnectionId {@params.ConnectionId}.");
+ SqlToolsLogger.Error(e);
+ throw;
+ }
+ }
+ else
+ {
+ SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Account not found in MSAL cache for user.");
+ throw new Exception($"User account '{username}' not found in MSAL cache, please add linked account or refresh account credentials.");
+ }
+ }
+ else
+ {
+ SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account not received.");
+ throw new Exception($"User account not received.");
+ }
+ }
+ else
+ {
+ SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Authentication Method ${@params.AuthenticationMethod} is not supported.");
+ throw new Exception($"Authentication Method ${@params.AuthenticationMethod} is not supported.");
+ }
+
+ return accessToken;
+ }
+
+ #endregion
+
+ #region Private methods
+ private IPublicClientApplication GetPublicClientAppInstance(string authority, string audience)
+ {
+ string authorityUrl = authority + '/' + audience;
+ if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance))
+ {
+ clientApplicationInstance = CreatePublicClientAppInstance(authority, audience);
+ this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache);
+ PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance);
+ }
+ return clientApplicationInstance;
+ }
+
+ private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) =>
+ PublicClientApplicationBuilder.Create(this.applicationClientId)
+ .WithAuthority(authority, audience)
+ .WithClientName(this.applicationName)
+ .WithLogging(Utils.MSALLogCallback)
+ .WithDefaultRedirectUri()
+ .Build();
+
+ #endregion
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj b/src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj
new file mode 100644
index 00000000..517cd82b
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj
@@ -0,0 +1,18 @@
+
+
+
+ enable
+ enable
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs
new file mode 100644
index 00000000..74dbbcbe
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs
@@ -0,0 +1,110 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.Data.SqlClient;
+using Microsoft.SqlTools.Authentication.Utility;
+
+namespace Microsoft.SqlTools.Authentication.Sql
+{
+ ///
+ /// Provides an implementation of for SQL Tools to be able to perform Federated authentication
+ /// silently with Microsoft.Data.SqlClient integration only for "ActiveDirectory" authentication modes.
+ /// When registered, the SqlClient driver calls the API
+ /// with server-sent authority information to request access token when needed.
+ ///
+ public class AuthenticationProvider : SqlAuthenticationProvider
+ {
+ private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255";
+ private const string AzureTokenFolder = "Azure Accounts";
+ private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud";
+ private const string s_defaultScopeSuffix = "/.default";
+
+ private Authenticator authenticator;
+
+ ///
+ /// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated.
+ ///
+ /// Application Name that identifies user folder path location for reading/writing to shared cache.
+ /// Callback that handles AAD authentication when user interaction is needed.
+ public AuthenticationProvider(string applicationName)
+ {
+ if(string.IsNullOrEmpty(applicationName)) {
+ applicationName = nameof(SqlTools);
+ }
+ var cachePath = Path.Combine(Utils.BuildAppDirectoryPath(), applicationName, AzureTokenFolder);
+ this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName);
+ }
+
+ ///
+ /// Acquires access token with provided
+ ///
+ /// Authentication parameters
+ /// Authentication token containing access token and expiry date.
+ public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters)
+ {
+ // Setup scope
+ string resource = parameters.Resource;
+ string scope;
+
+ if (parameters.Resource.EndsWith(s_defaultScopeSuffix))
+ {
+ scope = parameters.Resource;
+ resource = parameters.Resource.Substring(0, parameters.Resource.LastIndexOf('/') + 1);
+ }
+ else
+ {
+ scope = parameters.Resource + s_defaultScopeSuffix;
+ }
+
+ string[] scopes = new string[] { scope };
+
+ CancellationTokenSource cts = new CancellationTokenSource();
+
+ // Use Connection timeout value to cancel token acquire request after certain period of time.
+ cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds
+
+ /* We split audience from Authority URL here. Audience can be one of the following:
+ * The Azure AD authority audience enumeration
+ * The tenant ID, which can be:
+ * - A GUID (the ID of your Azure AD instance), for single-tenant applications
+ * - A domain name associated with your Azure AD instance (also for single-tenant applications)
+ * One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
+ * - `organizations` for a multitenant application
+ * - `consumers` to sign in users only with their personal accounts
+ * - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
+ *
+ * MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
+ * If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
+ * More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
+ **/
+
+ int seperatorIndex = parameters.Authority.LastIndexOf('/');
+ string authority = parameters.Authority.Remove(seperatorIndex + 1);
+ string audience = parameters.Authority.Substring(seperatorIndex + 1);
+ string? userName = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;
+
+ AuthenticationParams @params = new AuthenticationParams(
+ AuthenticationMethod.ActiveDirectoryInteractive,
+ authority,
+ audience,
+ resource,
+ scopes,
+ userName!,
+ parameters.ConnectionId);
+
+ AccessToken? result = await authenticator.GetTokenAsync(@params, cts.Token).ConfigureAwait(false);
+ return new SqlAuthenticationToken(result!.Token, result!.ExpiresOn);
+ }
+
+ ///
+ /// Whether or not provided is supported.
+ ///
+ /// SQL Authentication method
+ ///
+ public override bool IsSupported(SqlAuthenticationMethod authenticationMethod)
+ => authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive;
+
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
new file mode 100644
index 00000000..bc3b6712
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
@@ -0,0 +1,119 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Net.Mail;
+using System.Runtime.InteropServices;
+using Microsoft.Identity.Client;
+using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
+
+namespace Microsoft.SqlTools.Authentication.Utility
+{
+ internal sealed class Utils
+ {
+ ///
+ /// Validates provided follows email format.
+ ///
+ /// Email address
+ /// Whether email is in correct format.
+ public static bool isValidEmail(string userEmail)
+ {
+ try
+ {
+ new MailAddress(userEmail);
+ return true;
+ }
+ catch (FormatException)
+ {
+ return false;
+ }
+ }
+
+ ///
+ /// Builds directory path based on environment settings.
+ ///
+ /// Application directory path
+ /// When called on unsupported platform.
+ public static string BuildAppDirectoryPath()
+ {
+ var homedir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
+
+ // Windows
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ var appData = Environment.GetEnvironmentVariable("APPDATA");
+ var userProfile = Environment.GetEnvironmentVariable("USERPROFILE");
+ if (appData != null)
+ {
+ return appData;
+ }
+ else if (userProfile != null)
+ {
+ return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming");
+ }
+ else
+ {
+ throw new Exception("Not able to find APPDATA or USERPROFILE");
+ }
+ }
+
+ // Mac
+ else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
+ {
+ return string.Join(homedir, "Library", "Application Support");
+ }
+
+ // Linux
+ else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
+ {
+ var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME");
+ if (xdgConfigHome != null)
+ {
+ return xdgConfigHome;
+ }
+ else
+ {
+ return string.Join(homedir, ".config");
+ }
+ }
+ else
+ {
+ throw new Exception("Platform not supported");
+ }
+ }
+
+ ///
+ /// Log callback handler used for MSAL Client applications.
+ ///
+ /// Log level
+ /// Log message
+ /// Whether message contains PII information.
+ public static void MSALLogCallback(LogLevel logLevel, string message, bool pii)
+ {
+ switch (logLevel)
+ {
+ case LogLevel.Error:
+ if (pii) SqlToolsLogger.Pii(message);
+ else SqlToolsLogger.Error(message);
+ break;
+ case LogLevel.Warning:
+ if (pii) SqlToolsLogger.Pii(message);
+ else SqlToolsLogger.Warning(message);
+ break;
+ case LogLevel.Info:
+ if (pii) SqlToolsLogger.Pii(message);
+ else SqlToolsLogger.Information(message);
+ break;
+ case LogLevel.Verbose:
+ if (pii) SqlToolsLogger.Pii(message);
+ else SqlToolsLogger.Verbose(message);
+ break;
+ case LogLevel.Always:
+ if (pii) SqlToolsLogger.Pii(message);
+ else SqlToolsLogger.Critical(message);
+ break;
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.Credentials/Program.cs b/src/Microsoft.SqlTools.Credentials/Program.cs
index b270df5c..5dc87a38 100644
--- a/src/Microsoft.SqlTools.Credentials/Program.cs
+++ b/src/Microsoft.SqlTools.Credentials/Program.cs
@@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Credentials
logFilePath = Logger.GenerateLogFilePath("credentials");
}
- Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog);
+ Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog);
// set up the host details and profile paths
var hostDetails = new HostDetails(
diff --git a/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj b/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj
index 5749dde3..4efcce20 100644
--- a/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj
+++ b/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj
@@ -25,4 +25,7 @@
+
+
+
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs b/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs
index 0095cc8c..b56c084b 100644
--- a/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs
+++ b/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs
@@ -28,6 +28,7 @@ namespace Microsoft.SqlTools.Utility
ServiceName = serviceName;
ErrorMessage = string.Empty;
Locale = string.Empty;
+ ApplicationName = string.Empty;
try
{
@@ -42,12 +43,18 @@ namespace Microsoft.SqlTools.Utility
switch (argName)
{
+ case "-application-name":
+ ApplicationName = args[++i];
+ break;
case "-autoflush-log":
AutoFlushLog = true;
break;
case "-tracing-level":
TracingLevel = args[++i];
break;
+ case "-pii-logging":
+ PiiLogging = true;
+ break;
case "-log-file":
LogFilePath = args[++i];
break;
@@ -66,6 +73,9 @@ namespace Microsoft.SqlTools.Utility
case "-parallel-message-processing":
ParallelMessageProcessing = true;
break;
+ case "-enable-sql-authentication-provider":
+ EnableSqlAuthenticationProvider = true;
+ break;
case "-parent-pid":
string nextArg = args[++i];
if (Int32.TryParse(nextArg, out int parsedInt))
@@ -95,6 +105,11 @@ namespace Microsoft.SqlTools.Utility
}
}
+ ///
+ /// Name of application that is sending command options
+ ///
+ public string ApplicationName { get; private set; }
+
///
/// Contains any error messages during execution
///
@@ -137,6 +152,8 @@ namespace Microsoft.SqlTools.Utility
public string TracingLevel { get; private set; }
+ public bool PiiLogging { get; private set; }
+
public string LogFilePath { get; private set; }
public bool AutoFlushLog { get; private set; } = false;
@@ -147,6 +164,13 @@ namespace Microsoft.SqlTools.Utility
///
public bool ParallelMessageProcessing { get; private set; } = false;
+ ///
+ /// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
+ /// when user chooses 'Azure MFA'. This setting enables MSAL.NET to acquire token with SqlClient integration.
+ /// Currently this option is disabled by default, it's planned to be enabled by default in future releases.
+ ///
+ public bool EnableSqlAuthenticationProvider { get; private set; } = false;
+
///
/// The ID of the process that started this service. This is used to check when the parent
/// process exits so that the service process can exit at the same time.
diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs
index 5b2d3b26..8d207dad 100644
--- a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs
+++ b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs
@@ -70,7 +70,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
- if (azureAccountToken != null)
+ SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(_underlyingConnection.ConnectionString);
+ if (builder.Authentication is SqlAuthenticationMethod.NotSpecified && azureAccountToken is not null)
{
_underlyingConnection.AccessToken = azureAccountToken;
}
diff --git a/src/Microsoft.SqlTools.Migration/Program.cs b/src/Microsoft.SqlTools.Migration/Program.cs
index 10e1d73c..d262f8d2 100644
--- a/src/Microsoft.SqlTools.Migration/Program.cs
+++ b/src/Microsoft.SqlTools.Migration/Program.cs
@@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Migration
logFilePath = Logger.GenerateLogFilePath(logFilePath);
}
- Logger.Initialize(SourceLevels.Verbose, logFilePath, "Migration", commandOptions.AutoFlushLog);
+ Logger.Initialize(SourceLevels.Verbose, piiEnabled: commandOptions.PiiLogging, logFilePath, "Migration", commandOptions.AutoFlushLog);
Logger.Verbose("Starting SqlTools Migration Server...");
diff --git a/src/Microsoft.SqlTools.ResourceProvider/Program.cs b/src/Microsoft.SqlTools.ResourceProvider/Program.cs
index 5609bcae..2716e65f 100644
--- a/src/Microsoft.SqlTools.ResourceProvider/Program.cs
+++ b/src/Microsoft.SqlTools.ResourceProvider/Program.cs
@@ -41,7 +41,7 @@ namespace Microsoft.SqlTools.ResourceProvider
}
// we need to switch to Information when preparing for public preview
- Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog);
+ Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog);
Logger.Write(TraceEventType.Information, "Starting SqlTools Resource Provider");
// set up the host details and profile paths
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs
index bdcd58e0..1a09551a 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs
@@ -6,6 +6,7 @@
#nullable disable
using Microsoft.SqlTools.Hosting.Contracts;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -52,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
CategoryValues = new CategoryValue[]
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
- new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
+ new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA }
},
IsIdentity = true,
IsRequired = true,
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index 067feeff..6485ef65 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -23,7 +23,9 @@ using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Diagnostics;
+using Microsoft.SqlTools.Authentication.Sql;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -143,17 +145,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
RequestSecurityTokenParams message = new RequestSecurityTokenParams()
{
- Authority = authority,
Provider = "Azure",
+ Authority = authority,
Resource = resource,
- Scope = scope
+ Scopes = new string[] { scope }
};
- RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true);
+ RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true).ConfigureAwait(false);
return response.Token;
}
+ ///
+ /// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
+ /// when user chooses 'Azure MFA'.
+ ///
+ public bool EnableSqlAuthenticationProvider { get; set; }
+
///
/// Returns a connection queue for given type
///
@@ -248,7 +256,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
if (this.TryFindConnection(ownerUri, out connInfo))
{
// If not an azure connection, no need to refresh token
- if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
+ if (connInfo.ConnectionDetails.AuthenticationType != AzureMFA)
{
return false;
}
@@ -594,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionInfo.IsSqlDb = serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDatabase;
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
// Determines that access token is used for creating connection.
- connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == "AzureMFA";
+ connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == AzureMFA;
connectionInfo.EngineEdition = (DatabaseEngineEdition)serverInfo.EngineEditionId;
// Azure Data Studio supports SQL Server 2014 and later releases.
response.IsSupportedVersion = serverInfo.IsCloud || serverInfo.ServerMajorVersion >= 12;
@@ -1061,10 +1069,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return handler.HandleRequest(this.connectionFactory, info);
}
- public void InitializeService(IProtocolEndpoint serviceHost)
+ public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
{
this.ServiceHost = serviceHost;
+ if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
+ {
+ // Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
+ var provider = new AuthenticationProvider(commandOptions.ApplicationName);
+ SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
+
+ this.EnableSqlAuthenticationProvider = true;
+ Logger.Information("Registering implementation of SQL Authentication provider for 'Active Directory Interactive' authentication mode.");
+ }
+
// Register request and event handlers with the Service Host
serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest, true);
serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest, true);
@@ -1304,31 +1322,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionBuilder = new SqlConnectionStringBuilder
{
- ["Data Source"] = dataSource,
- ["User Id"] = connectionDetails.UserName,
- ["Password"] = connectionDetails.Password
+ DataSource = dataSource
};
}
// Check for any optional parameters
if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
{
- connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
+ connectionBuilder.InitialCatalog = connectionDetails.DatabaseName;
}
if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
{
switch (connectionDetails.AuthenticationType)
{
- case "Integrated":
+ case Integrated:
connectionBuilder.IntegratedSecurity = true;
break;
- case "SqlLogin":
+ case SqlLogin:
+ connectionBuilder.UserID = connectionDetails.UserName;
+ connectionBuilder.Password = connectionDetails.Password;
+ connectionBuilder.Authentication = SqlAuthenticationMethod.SqlPassword;
break;
- case "AzureMFA":
- connectionBuilder.UserID = "";
- connectionBuilder.Password = "";
+ case AzureMFA:
+ if (Instance.EnableSqlAuthenticationProvider)
+ {
+ connectionBuilder.UserID = connectionDetails.UserName;
+ connectionDetails.AuthenticationType = ActiveDirectoryInteractive;
+ connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
+ }
+ else
+ {
+ connectionBuilder.UserID = "";
+ }
break;
- case "ActiveDirectoryPassword":
+ case ActiveDirectoryInteractive:
+ connectionBuilder.UserID = connectionDetails.UserName;
+ connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
+ break;
+ case ActiveDirectoryPassword:
+ connectionBuilder.UserID = connectionDetails.UserName;
+ connectionBuilder.Password = connectionDetails.Password;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryPassword;
break;
default:
@@ -1337,16 +1370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting))
{
- switch (connectionDetails.ColumnEncryptionSetting.ToUpper())
+ if (Enum.TryParse(connectionDetails.ColumnEncryptionSetting, true, out var value))
{
- case "ENABLED":
- connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Enabled;
- break;
- case "DISABLED":
- connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
- break;
- default:
- throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
+ connectionBuilder.ColumnEncryptionSetting = value;
+ }
+ else
+ {
+ throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
}
}
if (!string.IsNullOrEmpty(connectionDetails.SecureEnclaves))
@@ -1365,36 +1395,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationProtocol))
{
- if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
+ if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
}
- switch (connectionDetails.EnclaveAttestationProtocol.ToUpper())
+ if (Enum.TryParse(connectionDetails.EnclaveAttestationProtocol, true, out var value))
{
- case "AAS":
- connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.AAS;
- break;
- case "HGS":
- connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.HGS;
- break;
- case "NONE":
- connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.None;
- break;
- default:
- throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
+ connectionBuilder.AttestationProtocol = value;
+ }
+ else
+ {
+ throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
}
}
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationUrl))
{
- if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
+ if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
}
- if(connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
+ if (connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAttestationProtocolNoneWithUrl);
}
@@ -1456,19 +1480,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
{
- ApplicationIntent intent;
- switch (connectionDetails.ApplicationIntent)
+ if (Enum.TryParse(connectionDetails.ApplicationIntent, true, out ApplicationIntent value))
{
- case "ReadOnly":
- intent = ApplicationIntent.ReadOnly;
- break;
- case "ReadWrite":
- intent = ApplicationIntent.ReadWrite;
- break;
- default:
- throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
+ connectionBuilder.ApplicationIntent = value;
+ }
+ else
+ {
+ throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
}
- connectionBuilder.ApplicationIntent = intent;
}
if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
{
@@ -1585,7 +1604,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ApplicationIntent = builder.ApplicationIntent.ToString(),
ApplicationName = builder.ApplicationName,
AttachDbFilename = builder.AttachDBFilename,
- AuthenticationType = builder.IntegratedSecurity ? "Integrated" : "SqlLogin",
+ AuthenticationType = builder.IntegratedSecurity ? "Integrated" :
+ (builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
+ ? "ActiveDirectoryInteractive" : "SqlLogin"),
ConnectRetryCount = builder.ConnectRetryCount,
ConnectRetryInterval = builder.ConnectRetryInterval,
ConnectTimeout = builder.ConnectTimeout,
@@ -1793,7 +1814,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
SqlConnection sqlConn = new SqlConnection(connectionString);
// Fill in Azure authentication token if needed
- if (connInfo.ConnectionDetails.AzureAccountToken != null)
+ if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
@@ -1824,7 +1845,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
ServerConnection serverConnection;
- if (connInfo.ConnectionDetails.AzureAccountToken != null)
+ if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs
index 013ba709..960c27a2 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs
@@ -11,25 +11,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
class RequestSecurityTokenParams
{
- ///
- /// Gets or sets the address of the authority to issue token.
- ///
- public string Authority { get; set; }
-
///
/// Gets or sets the provider that indicates the type of linked account to query.
///
public string Provider { get; set; }
+ ///
+ /// Gets or sets the authority URL from where token is requested.
+ ///
+ public string Authority { get; set; }
+
///
/// Gets or sets the identifier of the target resource that is the recipient of the requested token.
///
public string Resource { get; set; }
///
- /// Gets or sets the scope of the authentication request.
+ /// Gets or sets the scope array of the authentication request.
///
- public string Scope { get; set; }
+ public string [] Scopes { get; set; }
}
class RequestSecurityTokenResponse
diff --git a/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs
index b66c6fa9..64f86757 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs
@@ -8,6 +8,7 @@ using Microsoft.SqlServer.Dac;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.TaskServices;
using Microsoft.SqlTools.ServiceLayer.Utility;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using Microsoft.SqlTools.Utility;
using System;
using System.Diagnostics;
@@ -86,7 +87,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DacFx
try
{
// Pass in Azure authentication token if needed
- this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null ? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken)) : new DacServices(this.ConnectionString);
+ this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null && this.ConnInfo.ConnectionDetails.AuthenticationType == AzureMFA
+ ? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken))
+ : new DacServices(this.ConnectionString);
Execute();
}
catch (Exception e)
diff --git a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
index 539b82fb..1bc46184 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
@@ -40,6 +40,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.SqlProjects;
using Microsoft.SqlTools.ServiceLayer.TableDesigner;
+using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace;
namespace Microsoft.SqlTools.ServiceLayer
@@ -53,7 +54,7 @@ namespace Microsoft.SqlTools.ServiceLayer
private static object lockObject = new object();
private static bool isLoaded;
- internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, Stream? inputStream = null, Stream? outputStream = null)
+ internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions, Stream? inputStream = null, Stream? outputStream = null)
{
ServiceHost serviceHost = ServiceHost.Instance;
lock (lockObject)
@@ -63,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer
// Grab the instance of the service host
serviceHost.Initialize(inputStream, outputStream);
- InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext);
+ InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext, commandOptions);
// Start the service only after all request handlers are setup. This is vital
// as otherwise the Initialize event can be lost - it's processed and discarded before the handler
@@ -75,7 +76,7 @@ namespace Microsoft.SqlTools.ServiceLayer
return serviceHost;
}
- private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext)
+ private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions)
{
// Load extension provider, which currently finds all exports in current DLL. Can be changed to find based
// on directory or assembly list quite easily in the future
@@ -96,7 +97,7 @@ namespace Microsoft.SqlTools.ServiceLayer
LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext);
serviceProvider.RegisterSingleService(LanguageService.Instance);
- ConnectionService.Instance.InitializeService(serviceHost);
+ ConnectionService.Instance.InitializeService(serviceHost, commandOptions);
serviceProvider.RegisterSingleService(ConnectionService.Instance);
CredentialService.Instance.InitializeService(serviceHost);
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj
index 156e5ccf..1149eb4e 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj
+++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj
@@ -57,6 +57,8 @@
+
+
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs
index e2b008d4..7e7f1b22 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs
@@ -400,7 +400,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer
var builder = ConnectionService.CreateConnectionStringBuilder(session.ConnectionInfo.ConnectionDetails);
builder.InitialCatalog = node.NodeValue;
builder.ApplicationName = TableDesignerService.TableDesignerApplicationName;
- var azureToken = session.ConnectionInfo.ConnectionDetails.AzureAccountToken;
+ // Set Access Token only when authentication mode is not specified.
+ var azureToken = builder.Authentication == SqlAuthenticationMethod.NotSpecified
+ ? session.ConnectionInfo.ConnectionDetails.AzureAccountToken : null;
TableDesignerCacheManager.StartDatabaseModelInitialization(builder.ToString(), azureToken);
}
catch (Exception ex)
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Program.cs b/src/Microsoft.SqlTools.ServiceLayer/Program.cs
index f63bd63f..cfc118b2 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Program.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Program.cs
@@ -41,7 +41,16 @@ namespace Microsoft.SqlTools.ServiceLayer
logFilePath = Logger.GenerateLogFilePath("sqltools");
}
- Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog);
+ Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog);
+
+ // Register PII Logging configuration change callback
+ Workspace.WorkspaceService.Instance.RegisterConfigChangeCallback((newSettings, oldSettings, context) =>
+ {
+ Logger.IsPiiEnabled = newSettings?.MssqlTools?.PiiLogging ?? false;
+ Logger.Information(Logger.IsPiiEnabled ? "PII Logging enabled" : "PII Logging disabled");
+ return Task.FromResult(true);
+ });
+
// Only enable SQL Client logging when verbose or higher to avoid extra overhead when the
// detailed logging it provides isn't needed
if (Logger.TracingLevel.HasFlag(SourceLevels.Verbose))
@@ -53,7 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer
var hostDetails = new HostDetails(version: new Version(1, 0));
SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails);
- ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext);
+ ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, commandOptions);
serviceHost.MessageDispatcher.ParallelMessageProcessing = commandOptions.ParallelMessageProcessing;
// If this service was started by another process, then it should shutdown when that parent process does.
diff --git a/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs
index b85c6e24..707fb8eb 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs
@@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.DacFx.Contracts;
using Microsoft.SqlTools.ServiceLayer.SchemaCompare.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare
@@ -187,7 +188,9 @@ namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare
case SchemaCompareEndpointType.Database:
{
string connectionString = GetConnectionString(connInfo, endpointInfo.DatabaseName);
- return connInfo.ConnectionDetails?.AzureAccountToken != null
+
+ // Set Access Token only when authentication mode is not specified.
+ return connInfo.ConnectionDetails?.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA
? new SchemaCompareDatabaseEndpoint(connectionString, new AccessTokenProvider(connInfo.ConnectionDetails.AzureAccountToken))
: new SchemaCompareDatabaseEndpoint(connectionString);
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs
index 9e7d10c7..e0c99530 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs
@@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using Microsoft.SqlTools.ServiceLayer.Utility;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Linq;
namespace Microsoft.SqlTools.ServiceLayer.Scripting
@@ -109,7 +110,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
if (connInfo != null)
{
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
- accessToken = connInfo.ConnectionDetails.AzureAccountToken;
+ // Set Access Token only when authentication type is AzureMFA.
+ if (connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
+ {
+ accessToken = connInfo.ConnectionDetails.AzureAccountToken;
+ }
}
else
{
diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs
index d38cb574..fc70aad9 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs
@@ -18,7 +18,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// group such as Intellisense is defined on a serialized setting it's used in the order of mssql, then sql, then
/// falls back to a default value.
///
- public class CompoundToolsSettingsValues: ISqlToolsSettingsValues
+ public class CompoundToolsSettingsValues : ISqlToolsSettingsValues
{
private List priorityList = new List();
private SqlToolsSettingsValues defaultValues;
@@ -44,11 +44,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the detailed IntelliSense settings
///
public IntelliSenseSettings IntelliSense
- {
+ {
get
{
return GetSettingOrDefault((settings) => settings.IntelliSense);
- }
+ }
set
{
priorityList[0].IntelliSense = value;
@@ -59,11 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the query execution settings
///
public QueryExecutionSettings QueryExecutionSettings
- {
+ {
get
{
return GetSettingOrDefault((settings) => settings.QueryExecutionSettings);
- }
+ }
set
{
priorityList[0].QueryExecutionSettings = value;
@@ -74,11 +74,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the formatter settings
///
public FormatterSettings Format
- {
+ {
get
{
return GetSettingOrDefault((settings) => settings.Format);
- }
+ }
set
{
priorityList[0].Format = value;
@@ -89,15 +89,24 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Gets or sets the object explorer settings
///
public ObjectExplorerSettings ObjectExplorer
- {
+ {
get
{
return GetSettingOrDefault((settings) => settings.ObjectExplorer);
- }
+ }
set
{
priorityList[0].ObjectExplorer = value;
}
}
+
+ ///
+ /// Gets or sets PII Logging setting.
+ ///
+ public bool PiiLogging
+ {
+ get => GetSettingOrDefault((settings) => settings.PiiLogging);
+ set => priorityList[0].PiiLogging = value;
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs
index 8cac95b1..c6629628 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs
@@ -5,6 +5,8 @@
#nullable disable
+using System;
+
namespace Microsoft.SqlTools.ServiceLayer.SqlContext
{
///
@@ -31,5 +33,10 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
/// Object Explorer specific settings
///
ObjectExplorerSettings ObjectExplorer { get; set; }
+
+ ///
+ /// PII Logging setting
+ ///
+ Boolean PiiLogging { get; set; }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs
index f1879f0e..8acb87c4 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs
@@ -5,6 +5,7 @@
#nullable disable
+using System;
using Newtonsoft.Json;
namespace Microsoft.SqlTools.ServiceLayer.SqlContext
@@ -25,6 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
QueryExecutionSettings = new QueryExecutionSettings();
Format = new FormatterSettings();
TableDesigner = new TableDesignerSettings();
+ PiiLogging = false;
}
}
@@ -57,5 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext
///
[JsonProperty("tableDesigner")]
public TableDesignerSettings TableDesigner { get; set; }
+
+ ///
+ /// Gets or sets the setting to enable PII Logging.
+ ///
+ [JsonProperty("piiLogging")]
+ public Boolean PiiLogging { get; set; }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs b/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs
index f6abde88..7f140d6d 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs
@@ -1800,7 +1800,11 @@ namespace Microsoft.SqlTools.ServiceLayer.TableDesigner
connectionStringBuilder.InitialCatalog = tableInfo.Database;
connectionStringBuilder.ApplicationName = TableDesignerService.TableDesignerApplicationName;
var connectionString = connectionStringBuilder.ToString();
- tableDesigner = new Dac.TableDesigner(connectionString, tableInfo.AccessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable);
+
+ // Set Access Token only when authentication mode is not specified.
+ var accessToken = connectionStringBuilder.Authentication == SqlAuthenticationMethod.NotSpecified
+ ? tableInfo.AccessToken : null;
+ tableDesigner = new Dac.TableDesigner(connectionString, accessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable);
}
else
{
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs
index 44b965e7..a08d1c8e 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs
@@ -12,7 +12,7 @@ using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Utility
{
- class ServiceLayerCommandOptions : CommandOptions
+ public class ServiceLayerCommandOptions : CommandOptions
{
internal const string ServiceLayerServiceName = "MicrosoftSqlToolsServiceLayer.exe";
diff --git a/src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj b/src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj
new file mode 100644
index 00000000..084c4ff0
--- /dev/null
+++ b/src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj
@@ -0,0 +1,8 @@
+
+
+
+ enable
+ enable
+
+
+
diff --git a/src/Microsoft.SqlTools.Shared/Utility/Constants.cs b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs
new file mode 100644
index 00000000..d3a9d4e6
--- /dev/null
+++ b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs
@@ -0,0 +1,18 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.Shared.Utility
+{
+ public class Constants
+ {
+ // Authentication Types
+ public const string Integrated = "Integrated";
+ public const string SqlLogin = "SqlLogin";
+ public const string AzureMFA = "AzureMFA";
+ public const string dstsAuth = "dstsAuth";
+ public const string ActiveDirectoryInteractive = "ActiveDirectoryInteractive";
+ public const string ActiveDirectoryPassword = "ActiveDirectoryPassword";
+ }
+}
diff --git a/src/Microsoft.SqlTools.Hosting/Utility/Logger.cs b/src/Microsoft.SqlTools.Shared/Utility/Logger.cs
similarity index 87%
rename from src/Microsoft.SqlTools.Hosting/Utility/Logger.cs
rename to src/Microsoft.SqlTools.Shared/Utility/Logger.cs
index 26496206..46a582ff 100644
--- a/src/Microsoft.SqlTools.Hosting/Utility/Logger.cs
+++ b/src/Microsoft.SqlTools.Shared/Utility/Logger.cs
@@ -3,11 +3,8 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using System;
using System.Diagnostics;
using System.Globalization;
-using System.IO;
-using System.Threading;
namespace Microsoft.SqlTools.Utility
{
@@ -29,41 +26,45 @@ namespace Microsoft.SqlTools.Utility
///
public static class Logger
{
- internal const SourceLevels defaultTracingLevel = SourceLevels.Critical;
- internal const string defaultTraceSource = "sqltools";
+ public const SourceLevels defaultTracingLevel = SourceLevels.Critical;
+ public const string defaultTraceSource = "sqltools";
private static SourceLevels tracingLevel = defaultTracingLevel;
- private static string logFileFullPath;
+ private static string? logFileFullPath;
+
+ public static TraceSource? TraceSource { get; set; }
- internal static TraceSource TraceSource { get; set; }
public static string LogFileFullPath
{
- get => logFileFullPath;
+ get => logFileFullPath!;
private set
{
- //If the log file path has a directory component then ensure that the directory exists.
- if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value)))
+ if (value != null)
{
- Directory.CreateDirectory(Path.GetDirectoryName(value));
- }
+ //If the log file path has a directory component then ensure that the directory exists.
+ if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value)))
+ {
+ Directory.CreateDirectory(Path.GetDirectoryName(value)!);
+ }
- logFileFullPath = value;
+ logFileFullPath = value;
+ }
ConfigureListener();
}
}
- private static SqlToolsTraceListener Listener { get; set; }
+ private static SqlToolsTraceListener? Listener { get; set; }
private static void ConfigureLogFile(string logFilePrefix) => LogFileFullPath = GenerateLogFilePath(logFilePrefix);
///
/// Calling this method will turn on inclusion CallStack in the log for all future traces
///
- public static void StartCallStack() => Listener.TraceOutputOptions |= TraceOptions.Callstack;
+ public static void StartCallStack() => Listener!.TraceOutputOptions |= TraceOptions.Callstack;
///
/// Calling this method will turn off inclusion of CallStack in the log for all future traces
///
- public static void StopCallStack() => Listener.TraceOutputOptions &= ~TraceOptions.Callstack;
+ public static void StopCallStack() => Listener!.TraceOutputOptions &= ~TraceOptions.Callstack;
///
/// Calls flush on defaultTracingLevel configured listeners.
@@ -86,17 +87,19 @@ namespace Microsoft.SqlTools.Utility
get => tracingLevel;
set
{
- if(TraceSource != null)
+ if (TraceSource != null)
{
// configures the source level filter. This alone is not enough for tracing that is done via "Trace" class instead of "TraceSource" object
TraceSource.Switch = new SourceSwitch(TraceSource.Name, value.ToString());
}
// configure the listener level filter
tracingLevel = value;
- Listener.Filter = new EventTypeFilter(tracingLevel);
+ Listener!.Filter = new EventTypeFilter(tracingLevel);
}
}
+ public static bool IsPiiEnabled { get; set; } = false;
+
public static bool AutoFlush { get; set; } = false;
///
@@ -115,11 +118,13 @@ namespace Microsoft.SqlTools.Utility
///
public static void Initialize(
SourceLevels tracingLevel = defaultTracingLevel,
- string logFilePath = null,
+ bool piiEnabled = false,
+ string? logFilePath = null,
string traceSource = defaultTraceSource,
bool autoFlush = false)
{
Logger.tracingLevel = tracingLevel;
+ Logger.IsPiiEnabled = piiEnabled;
Logger.AutoFlush = autoFlush;
TraceSource = new TraceSource(traceSource, Logger.tracingLevel);
if (string.IsNullOrWhiteSpace(logFilePath))
@@ -146,11 +151,12 @@ namespace Microsoft.SqlTools.Utility
///
/// Optional. Specifies whether the log is flushed after every message
///
- public static void Initialize(string tracingLevel, string logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false)
+ public static void Initialize(string tracingLevel, bool piiEnabled, string? logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false)
{
Initialize(Enum.TryParse(tracingLevel, out SourceLevels sourceTracingLevel)
? sourceTracingLevel
: defaultTracingLevel
+ , piiEnabled
, logFilePath
, traceSource
, autoFlush);
@@ -169,7 +175,7 @@ namespace Microsoft.SqlTools.Utility
throw new ArgumentOutOfRangeException(nameof(logFilePrefix), $"LogfilePath cannot be configured if argument {nameof(logFilePrefix)} has not been set");
}
// Create the log directory
- string logDir = Path.GetDirectoryName(logFilePrefix);
+ string? logDir = Path.GetDirectoryName(logFilePrefix);
if (!string.IsNullOrWhiteSpace(logDir))
{
if (!Directory.Exists(logDir))
@@ -199,7 +205,7 @@ namespace Microsoft.SqlTools.Utility
}
string fileName;
- try
+ try
{
var now = DateTime.Now;
fileName = string.Format(CultureInfo.InvariantCulture,
@@ -239,6 +245,16 @@ namespace Microsoft.SqlTools.Utility
/// The message text to be written.
public static void Write(TraceEventType eventType, string logMessage) => Write(eventType, LogEvent.Default, logMessage);
+ ///
+ /// Writes a PII message to the log file with the Verbose event level when PII flag is enabled.
+ ///
+ /// The message text to be written.
+ public static void Pii(string logMessage) {
+ if (IsPiiEnabled) {
+ Write(TraceEventType.Verbose, logMessage);
+ }
+ }
+
///
/// Writes a message to the log file with the Verbose event level
///
@@ -386,9 +402,9 @@ namespace Microsoft.SqlTools.Utility
);
}
#region forward actual write/close/flush/dispose calls to the underlying listener.
- public override void Write(string message) => Listener.Write(message);
+ public override void Write(string? message) => Listener.Write(message);
- public override void WriteLine(string message) => Listener.WriteLine(message);
+ public override void WriteLine(string? message) => Listener.WriteLine(message);
///
/// Closes the so that it no longer
@@ -430,41 +446,41 @@ namespace Microsoft.SqlTools.Utility
}
#endregion
- public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id)
+ public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id)
{
TraceEvent(eventCache, source, eventType, id, String.Empty);
}
// All other TraceEvent methods come through this one.
- public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string message)
+ public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? message)
{
if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, message, null, null, null))
{
return;
}
- WriteHeader(eventCache, source, eventType, id);
- WriteLine(message);
- WriteFooter(eventCache);
+ WriteHeader(eventCache!, source, eventType, id);
+ WriteLine(message!);
+ WriteFooter(eventCache!);
}
- public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string format, params object[] args)
+ public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? format, params object?[]? args)
{
if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, format, args, null, null))
{
return;
}
- WriteHeader(eventCache, source, eventType, id);
+ WriteHeader(eventCache!, source, eventType, id);
if (args != null)
{
- WriteLine(String.Format(CultureInfo.InvariantCulture, format, args));
+ WriteLine(String.Format(CultureInfo.InvariantCulture, format!, args));
}
else
{
- WriteLine(format);
+ WriteLine(format!);
}
- WriteFooter(eventCache);
+ WriteFooter(eventCache!);
}
private void WriteHeader(TraceEventCache eventCache, String source, TraceEventType eventType, int id)
diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs
index 07afb461..c9e7caf4 100644
--- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs
+++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs
@@ -12,14 +12,15 @@ using Microsoft.Kusto.ServiceLayer.DataSource;
using Microsoft.Kusto.ServiceLayer.DataSource.Intellisense;
using Microsoft.Kusto.ServiceLayer.LanguageServices;
using Microsoft.Kusto.ServiceLayer.Workspace.Contracts;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using NUnit.Framework;
namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
{
public class DataSourceFactoryTests
{
- [TestCase(typeof(ArgumentException), "ConnectionString", "", "AzureMFA")]
- [TestCase(typeof(ArgumentException), "ConnectionString", "", "dstsAuth")]
+ [TestCase(typeof(ArgumentException), "ConnectionString", "", AzureMFA)]
+ [TestCase(typeof(ArgumentException), "ConnectionString", "", dstsAuth)]
public void Create_Throws_Exceptions_For_InvalidAzureAccountToken(Type exceptionType, string connectionString, string azureAccountToken, string authType)
{
Program.ServiceName = "Kusto";
diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs
index 1a5f7aa2..99b8b1e8 100644
--- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs
+++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs
@@ -8,14 +8,15 @@
using System;
using Microsoft.Kusto.ServiceLayer.DataSource.Contracts;
using Microsoft.Kusto.ServiceLayer.DataSource.Kusto;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using NUnit.Framework;
namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
{
public class KustoClientTests
{
- [TestCase("dstsAuth")]
- [TestCase("AzureMFA")]
+ [TestCase(dstsAuth)]
+ [TestCase(AzureMFA)]
public void Constructor_Throws_ArgumentException_For_MissingToken(string authType)
{
var connectionDetails = new DataSourceConnectionDetails
@@ -37,7 +38,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
UserToken = "UserToken",
ServerName = clusterName,
DatabaseName = "",
- AuthenticationType = "AzureMFA"
+ AuthenticationType = AzureMFA
};
var client = new KustoClient(connectionDetails, "ownerUri");
@@ -46,10 +47,10 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
Assert.AreEqual("NetDefaultDB", client.DatabaseName);
}
- [TestCase("dstsAuth")]
- [TestCase("AzureMFA")]
+ [TestCase(dstsAuth)]
+ [TestCase(AzureMFA)]
[TestCase("NoAuth")]
- [TestCase("SqlLogin")]
+ [TestCase(SqlLogin)]
public void Constructor_Creates_Client_With_Valid_AuthenticationType(string authenticationType)
{
string clusterName = "https://fake.url.com";
@@ -59,8 +60,8 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource
ServerName = clusterName,
DatabaseName = "FakeDatabaseName",
AuthenticationType = authenticationType,
- UserName = authenticationType == "SqlLogin" ? "username": null,
- Password = authenticationType == "SqlLogin" ? "password": null
+ UserName = authenticationType == SqlLogin ? "username": null,
+ Password = authenticationType == SqlLogin ? "password": null
};
var client = new KustoClient(connectionDetails, "ownerUri");
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj
index e21705fa..e62da1a9 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj
@@ -7,8 +7,8 @@
false
-
-
+
+
@@ -25,5 +25,6 @@
+
\ No newline at end of file
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs
index b4ffd905..6625fb8a 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs
@@ -32,6 +32,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
public SourceLevels TracingLevel { get; set; } = SourceLevels.Critical;
public bool DoNotUseTraceSource { get; set; } = false;
+ public bool IsPiiEnabled { get; set; } = false;
+
public bool AutoFlush { get; set; } = false;
private List pendingVerifications;
@@ -43,7 +45,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
public string LogFileName { get => logFileName ?? Logger.LogFileFullPath; set => logFileName = value; }
public void Initialize() =>
- Logger.Initialize(TracingLevel, LogFilePath, TraceSource, AutoFlush); // initialize the logger
+ Logger.Initialize(TracingLevel, IsPiiEnabled, LogFilePath, TraceSource, AutoFlush); // initialize the logger
public string LogContents
{
get
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs
index 6c2ba0b4..d01ecf9d 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs
@@ -195,7 +195,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common
// Initialize the ServiceHost, using a MemoryStream for the output stream so that we don't fill up the logs
// with a bunch of outgoing messages (which aren't used for anything during tests)
- ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, new MemoryStream());
+ ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, null, new MemoryStream());
// Set up our logger to write to Console for tests to help debug issues
Logger.Initialize(autoFlush: true);
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
index 10c87f29..c5669d60 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
@@ -12,6 +12,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
+using static Microsoft.SqlTools.Shared.Utility.Constants;
using Moq;
using Moq.Protected;
using NUnit.Framework;
@@ -479,6 +480,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
new object[] {null, "12345678"},
new object[] {"", "12345678"},
};
+
///
/// Verify that when using integrated authentication, the username and/or password can be empty.
///
@@ -497,13 +499,67 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
DatabaseName = "test",
UserName = userName,
Password = password,
- AuthenticationType = "Integrated"
+ AuthenticationType = Integrated
}
});
Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful");
}
+ ///
+ /// Verify that username is required when using Active Directory Interactive authentication.
+ /// Both AzureMFA and ActiveDirectoryInteractive should work same way, when SqlAuthenticationProvider is enabled.
+ ///
+ [TestCase(null, AzureMFA)]
+ [TestCase(null, ActiveDirectoryInteractive)]
+ public async Task ConnectingWitNoUsernameFailsForAADInteractiveAuth(string userName, string authMode)
+ {
+ // This is an exception scenario test, therefore using ConnectionService instance instead of TestConnectionService directly.
+ ConnectionService.Instance.EnableSqlAuthenticationProvider = true;
+ // Connect
+ var connectionResult = await
+ TestObjects.GetTestConnectionService()
+ .Connect(new ConnectParams()
+ {
+ OwnerUri = "file:///my/test/file.sql",
+ Connection = new ConnectionDetails()
+ {
+ ServerName = "my-server",
+ DatabaseName = "test",
+ UserName = userName,
+ AuthenticationType = authMode
+ }
+ });
+
+ Assert.That(connectionResult.ConnectionId, Is.Null.Or.Empty, "Connection should not be successful");
+ }
+
+ ///
+ /// Verify that password is ignored when using Active Directory Interactive authentication.
+ ///
+ [TestCase("user", "anything", AzureMFA)]
+ [TestCase("user", "anything", ActiveDirectoryInteractive)]
+ public async Task ConnectingWitPasswordIsIgnoredForAADInteractiveAuth(string username, string password, string authMode)
+ {
+ // Connect
+ var connectionResult = await
+ TestObjects.GetTestConnectionService()
+ .Connect(new ConnectParams()
+ {
+ OwnerUri = "file:///my/test/file.sql",
+ Connection = new ConnectionDetails()
+ {
+ ServerName = "my-server",
+ DatabaseName = "test",
+ UserName = username,
+ Password = password,
+ AuthenticationType = authMode
+ }
+ });
+
+ Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "Connection should be successful");
+ }
+
///
/// Verify that when connecting with a null parameters object, an error is thrown.
///
@@ -1773,16 +1829,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
details.AzureAccountToken = azureAccountToken;
details.UserName = "";
details.Password = "";
- details.AuthenticationType = "AzureMFA";
+ details.AuthenticationType = AzureMFA;
- // If I open a connection using connection details that include an account token
+ // Open a connection using connection details that include an account token
await connectionService.Connect(new ConnectParams
{
OwnerUri = "testURI",
Connection = details
});
- // Then the connection factory got called with details including an account token
+ // Validate that the connection factory gets called with details NOT including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny(), It.Is(accountToken => accountToken == azureAccountToken)), Times.Once());
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj
index 0f947ee0..9719192d 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj
@@ -34,6 +34,7 @@
+
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs
index 672e474d..98c2bd0f 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs
@@ -109,7 +109,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
string tracingLevel = null;
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
string expectedTraceSource = Logger.defaultTraceSource;
- Logger.Initialize(tracingLevel);
+ Logger.Initialize(tracingLevel, false);
bool isLogFileExpectedToExist = false;
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
TestLogger.Cleanup(Logger.LogFileFullPath);
@@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
string tracingLevel = null;
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
string expectedTraceSource = Logger.defaultTraceSource;
- Logger.Initialize(tracingLevel);
+ Logger.Initialize(tracingLevel, false);
bool isLogFileExpectedToExist = false;
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
TestLogger.Cleanup(Logger.LogFileFullPath);
@@ -131,7 +131,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost
string tracingLevel = "invalid";
SourceLevels expectedTracingLevel = Logger.defaultTracingLevel;
string expectedTraceSource = Logger.defaultTraceSource;
- Logger.Initialize(tracingLevel);
+ Logger.Initialize(tracingLevel, false);
bool isLogFileExpectedToExist = false;
TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++);
TestLogger.Cleanup(Logger.LogFileFullPath);