From 187b6ecc14d3035939d406b8097017e4587ad4ae Mon Sep 17 00:00:00 2001 From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com> Date: Thu, 2 Mar 2023 09:39:54 -0800 Subject: [PATCH] Introduce AAD interactive auth mode (#1860) --- Packages.props | 2 + sqltoolsservice.sln | 22 ++- .../ConnectionProviderOptionsHelper.cs | 3 +- .../Connection/ConnectionService.cs | 3 +- .../DataSource/DataSourceFactory.cs | 3 +- .../DataSource/Kusto/KustoClient.cs | 7 +- .../Microsoft.Kusto.ServiceLayer.csproj | 1 + src/Microsoft.Kusto.ServiceLayer/Program.cs | 2 +- .../AccessToken.cs | 33 ++++ .../AuthenticationMethod.cs | 15 ++ .../AuthenticationParams.cs | 82 +++++++++ .../Authenticator.cs | 158 ++++++++++++++++++ .../Microsoft.SqlTools.Authentication.csproj | 18 ++ .../Sql/AuthenticationProvider.cs | 110 ++++++++++++ .../Utility/Utils.cs | 119 +++++++++++++ src/Microsoft.SqlTools.Credentials/Program.cs | 2 +- .../Microsoft.SqlTools.Hosting.csproj | 3 + .../Utility/CommandOptions.cs | 24 +++ .../ReliableSqlConnection.cs | 3 +- src/Microsoft.SqlTools.Migration/Program.cs | 2 +- .../Program.cs | 2 +- .../ConnectionProviderOptionsHelper.cs | 3 +- .../Connection/ConnectionService.cs | 129 ++++++++------ .../Contracts/SecurityTokenRequest.cs | 14 +- .../DacFx/DacFxOperation.cs | 5 +- .../HostLoader.cs | 9 +- .../Microsoft.SqlTools.ServiceLayer.csproj | 2 + .../ObjectExplorer/ObjectExplorerService.cs | 4 +- .../Program.cs | 13 +- .../SchemaCompare/SchemaCompareUtils.cs | 5 +- .../Scripting/ScriptingService.cs | 7 +- .../CompoundSqlToolsSettingsValues.cs | 27 ++- .../SqlContext/ISqlToolsSettingsValues.cs | 7 + .../SqlContext/SqlToolsSettingsValues.cs | 8 + .../TableDesigner/TableDesignerService.cs | 6 +- .../Utility/ServiceLayerCommandOptions.cs | 2 +- .../Microsoft.SqlTools.Shared.csproj | 8 + .../Utility/Constants.cs | 18 ++ .../Utility/Logger.cs | 84 ++++++---- .../DataSource/DataSourceFactoryTests.cs | 5 +- .../DataSource/KustoClientTests.cs | 17 +- ...t.SqlTools.ServiceLayer.Test.Common.csproj | 5 +- .../TestLogger.cs | 4 +- .../TestServiceProvider.cs | 2 +- .../Connection/ConnectionServiceTests.cs | 64 ++++++- ...oft.SqlTools.ServiceLayer.UnitTests.csproj | 1 + .../ServiceHost/LoggerTests.cs | 6 +- 47 files changed, 918 insertions(+), 151 deletions(-) create mode 100644 src/Microsoft.SqlTools.Authentication/AccessToken.cs create mode 100644 src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs create mode 100644 src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs create mode 100644 src/Microsoft.SqlTools.Authentication/Authenticator.cs create mode 100644 src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj create mode 100644 src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs create mode 100644 src/Microsoft.SqlTools.Authentication/Utility/Utils.cs create mode 100644 src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj create mode 100644 src/Microsoft.SqlTools.Shared/Utility/Constants.cs rename src/{Microsoft.SqlTools.Hosting => Microsoft.SqlTools.Shared}/Utility/Logger.cs (87%) diff --git a/Packages.props b/Packages.props index 5b389bf4..aeb1d6e7 100644 --- a/Packages.props +++ b/Packages.props @@ -15,6 +15,8 @@ + + diff --git a/sqltoolsservice.sln b/sqltoolsservice.sln index 0364baa9..47941438 100644 --- a/sqltoolsservice.sln +++ b/sqltoolsservice.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.29409.12 +# Visual Studio Version 17 +VisualStudioVersion = 17.5.33209.295 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{2BBD7364-054F-4693-97CD-1C395E3E84A9}" ProjectSection(SolutionItems) = preProject @@ -96,6 +96,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migratio EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.SqlTools.Migration.IntegrationTests", "test\Microsoft.SqlTools.Migration.IntegrationTests\Microsoft.SqlTools.Migration.IntegrationTests.csproj", "{5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Authentication", "src\Microsoft.SqlTools.Authentication\Microsoft.SqlTools.Authentication.csproj", "{2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.SqlTools.Shared", "src\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj", "{531EC0E0-F400-42C5-BCFA-6691313B5F3E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -229,6 +233,18 @@ Global {5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Integration|Any CPU.Build.0 = Debug|Any CPU {5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.ActiveCfg = Release|Any CPU {5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9}.Release|Any CPU.Build.0 = Release|Any CPU + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.ActiveCfg = Debug|Any CPU + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Integration|Any CPU.Build.0 = Debug|Any CPU + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC}.Release|Any CPU.Build.0 = Release|Any CPU + {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.ActiveCfg = Debug|Any CPU + {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Integration|Any CPU.Build.0 = Debug|Any CPU + {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {531EC0E0-F400-42C5-BCFA-6691313B5F3E}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -256,6 +272,8 @@ Global {07296730-DAB7-4B0B-9D09-ABD9A5025D68} = {2BBD7364-054F-4693-97CD-1C395E3E84A9} {22DB0C12-6848-4503-AD1C-DAD6A1D631AE} = {2BBD7364-054F-4693-97CD-1C395E3E84A9} {5C7F4DAC-F794-4C21-A031-DCAAFAF3C0A9} = {AB9CA2B8-6F70-431C-8A1D-67479D8A7BE4} + {2A32C3B6-3E9F-4A8E-BF98-59F9AEF6DAAC} = {2BBD7364-054F-4693-97CD-1C395E3E84A9} + {531EC0E0-F400-42C5-BCFA-6691313B5F3E} = {2BBD7364-054F-4693-97CD-1C395E3E84A9} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {B31CDF4B-2851-45E5-8C5F-BE97125D9DD8} diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs index cdd9b04a..de9e4a73 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs @@ -4,6 +4,7 @@ // using Microsoft.SqlTools.Hosting.Contracts; +using static Microsoft.SqlTools.Shared.Utility.Constants; namespace Microsoft.Kusto.ServiceLayer.Connection { @@ -50,7 +51,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection CategoryValues = new CategoryValue[] { new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" }, new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" }, - new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" } + new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA } }, IsIdentity = true, IsRequired = true, diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs index 54b5aba2..c7cc33c0 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs @@ -19,6 +19,7 @@ using System.Diagnostics; using Microsoft.Kusto.ServiceLayer.DataSource; using Microsoft.Kusto.ServiceLayer.DataSource.Metadata; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using static Microsoft.SqlTools.Shared.Utility.Constants; namespace Microsoft.Kusto.ServiceLayer.Connection { @@ -916,7 +917,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection return new ConnectionDetails { ApplicationName = builder.ApplicationNameForTracing, - AuthenticationType = "AzureMFA", + AuthenticationType = AzureMFA, DatabaseName = builder.InitialCatalog, ServerName = builder.DataSource, UserName = builder.UserID, diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs index 3a3e337c..1f140297 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs @@ -20,6 +20,7 @@ using Microsoft.Kusto.ServiceLayer.LanguageServices; using Microsoft.Kusto.ServiceLayer.Utility; using Microsoft.Kusto.ServiceLayer.Workspace.Contracts; using CompletionItem = Microsoft.Kusto.ServiceLayer.LanguageServices.Contracts.CompletionItem; +using static Microsoft.SqlTools.Shared.Utility.Constants; namespace Microsoft.Kusto.ServiceLayer.DataSource { @@ -63,7 +64,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource private DataSourceConnectionDetails MapKustoConnectionDetails(ConnectionDetails connectionDetails) { - if (connectionDetails.AuthenticationType == "dstsAuth" || connectionDetails.AuthenticationType == "AzureMFA") + if (connectionDetails.AuthenticationType == dstsAuth || connectionDetails.AuthenticationType == AzureMFA) { ValidationUtils.IsTrue(!string.IsNullOrWhiteSpace(connectionDetails.AccountToken), $"The Kusto User Token is not specified - set {nameof(connectionDetails.AccountToken)}"); diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs index 660937b5..ac671977 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/Kusto/KustoClient.cs @@ -21,6 +21,7 @@ using Kusto.Language.Editor; using Microsoft.Kusto.ServiceLayer.Connection; using Microsoft.Kusto.ServiceLayer.DataSource.Contracts; using Microsoft.Kusto.ServiceLayer.Utility; +using static Microsoft.SqlTools.Shared.Utility.Constants; namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto { @@ -74,7 +75,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto ServerName = ClusterName, DatabaseName = DatabaseName, UserToken = accountToken, - AuthenticationType = "AzureMFA" + AuthenticationType = AzureMFA }; Initialize(connectionDetails); @@ -96,8 +97,8 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource.Kusto switch (connectionDetails.AuthenticationType) { - case "AzureMFA": return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken); - case "dstsAuth": return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken); + case AzureMFA: return stringBuilder.WithAadUserTokenAuthentication(connectionDetails.UserToken); + case dstsAuth: return stringBuilder.WithDstsUserTokenAuthentication(connectionDetails.UserToken); default: return string.IsNullOrWhiteSpace(connectionDetails.UserName) && string.IsNullOrWhiteSpace(connectionDetails.Password) ? stringBuilder diff --git a/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj b/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj index e646918e..80501dcf 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj +++ b/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj @@ -51,6 +51,7 @@ + diff --git a/src/Microsoft.Kusto.ServiceLayer/Program.cs b/src/Microsoft.Kusto.ServiceLayer/Program.cs index 91cb0749..02573e65 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Program.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Program.cs @@ -41,7 +41,7 @@ namespace Microsoft.Kusto.ServiceLayer logFilePath = Logger.GenerateLogFilePath("kustoservice"); } - Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog); + Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "kustoservice", commandOptions.AutoFlushLog); // set up the host details and profile paths var hostDetails = new HostDetails(version: new Version(1, 0)); diff --git a/src/Microsoft.SqlTools.Authentication/AccessToken.cs b/src/Microsoft.SqlTools.Authentication/AccessToken.cs new file mode 100644 index 00000000..b698ce87 --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/AccessToken.cs @@ -0,0 +1,33 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.Authentication +{ + /// + /// Represents an access token data object. + /// + public class AccessToken + { + /// + /// OAuth 2.0 JWT encoded access token string + /// + public string Token { get; set; } + + /// + /// Expiry date of token + /// + public DateTimeOffset ExpiresOn { get; set; } + + /// + /// Default constructor for Access Token object + /// + /// Access token as string + /// Expiry date + public AccessToken(string token, DateTimeOffset expiresOn) { + this.Token = token; + this.ExpiresOn = expiresOn; + } + } +} diff --git a/src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs b/src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs new file mode 100644 index 00000000..606507dc --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/AuthenticationMethod.cs @@ -0,0 +1,15 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.Authentication +{ + /// + /// Supported Active Directory authentication modes + /// + public enum AuthenticationMethod + { + ActiveDirectoryInteractive + } +} diff --git a/src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs b/src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs new file mode 100644 index 00000000..cbfa5995 --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/AuthenticationParams.cs @@ -0,0 +1,82 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.Authentication +{ + /// + /// Parameters to be passed to to request an access token + /// + public class AuthenticationParams + { + /// + /// Authentication method to be used by . + /// + public AuthenticationMethod AuthenticationMethod { get; set; } + + /// + /// Authority URL, e.g. https://login.microsoftonline.com/ + /// + public string Authority { get; set; } + + /// + /// Audience for which access token should be acquired, e.g. common, organizations, consumers. + /// It can also be a tenant Id when authenticating multi-tenant application accounts. + /// + public string Audience { get; set; } + + /// + /// Resource URL, e.g. https://database.windows.net/ + /// + public string Resource { get; set; } + + /// + /// Array of scopes for which access token is requested. + /// + public string[] Scopes { get; set; } + + /// + /// Connection Id, that will be passed to Azure AD when requesting access token. + /// It can be used for tracking AAD request status if needed. + /// + public Guid ConnectionId { get; set; } + + /// + /// User name to be provided as userhint when acquiring access token. + /// + public string UserName { get; set; } + + /// + /// Default constructor + /// + /// Authentication Method to be used. + /// Authority URL + /// Audience + /// Resource for which token is requested. + /// Scopes for access token + /// User hint information + /// Connection Id for tracing AAD request + public AuthenticationParams(AuthenticationMethod authMethod, string authority, string audience, + string resource, string[] scopes, string userName, Guid connectionId) { + this.AuthenticationMethod = authMethod; + this.Authority = authority; + this.Audience = audience; + this.Resource = resource; + this.Scopes = scopes; + this.UserName = userName; + this.ConnectionId = connectionId; + } + + public string ToLogString(bool piiEnabled) + { + return $"\tAuthenticationMethod: {AuthenticationMethod}" + + $"\n\tAuthority: {Authority}" + + $"\n\tAudience: {Audience}" + + $"\n\tResource: {Resource}" + + $"\n\tScopes: {Scopes.ToString()}" + + (piiEnabled ? $"\n\tUsername: {UserName}" : "") + + $"\n\tConnection Id: {ConnectionId}"; + } + } +} diff --git a/src/Microsoft.SqlTools.Authentication/Authenticator.cs b/src/Microsoft.SqlTools.Authentication/Authenticator.cs new file mode 100644 index 00000000..d149a85f --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/Authenticator.cs @@ -0,0 +1,158 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Concurrent; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensions.Msal; +using Microsoft.SqlTools.Authentication.Utility; +using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger; + +namespace Microsoft.SqlTools.Authentication +{ + /// + /// Provides APIs to acquire access token using MSAL.NET v4 with provided . + /// + public class Authenticator + { + private string applicationClientId; + private string applicationName; + private string cacheFolderPath; + private string cacheFileName; + private MsalCacheHelper cacheHelper; + private static ConcurrentDictionary PublicClientAppMap + = new ConcurrentDictionary(); + + #region Public APIs + public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName) + { + this.applicationClientId = appClientId; + this.applicationName = appName; + this.cacheFolderPath = cacheFolderPath; + this.cacheFileName = cacheFileName; + + // Storage creation properties are used to enable file system caching with MSAL + var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath) + .WithUnprotectedFile().Build(); + + // This hooks up the cross-platform cache into MSAL + this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult(); + } + + /// + /// Acquires access token synchronously. + /// + /// Authentication parameters to be used for access token request. + /// Cancellation token. + /// Access Token with expiry date + public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken) => + GetTokenAsync(@params, cancellationToken).GetAwaiter().GetResult(); + + /// + /// Acquires access token asynchronously. + /// + /// Authentication parameters to be used for access token request. + /// Cancellation token. + /// Access Token with expiry date + /// + public async Task GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken) + { + SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Received @params: {@params.ToLogString(SqlToolsLogger.IsPiiEnabled)}"); + + IPublicClientApplication publicClientApplication = GetPublicClientAppInstance(@params.Authority, @params.Audience); + + AccessToken? accessToken; + if (@params.AuthenticationMethod == AuthenticationMethod.ActiveDirectoryInteractive) + { + // Find account + IEnumerator? accounts = (await publicClientApplication.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + IAccount? account = default; + + if (!string.IsNullOrEmpty(@params.UserName) && accounts.MoveNext()) + { + // Handle username format to extract email: "John Doe - johndoe@constoso.com" + string username = @params.UserName.Contains(" - ") ? @params.UserName.Split(" - ")[1] : @params.UserName; + if (!Utils.isValidEmail(username)) + { + SqlToolsLogger.Pii($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Unexpected username format, email not retreived: {@params.UserName}. " + + $"Accepted formats are: 'johndoe@org.com' or 'John Doe - johndoe@org.com'."); + throw new Exception($"Invalid email address format for user: [{username}] received for Azure Active Directory authentication."); + } + + do + { + IAccount? currentVal = accounts.Current; + if (string.Compare(username, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) + { + account = currentVal; + SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account found in MSAL Cache: {account.HomeAccountId}"); + break; + } + } + while (accounts.MoveNext()); + + if (null != account) + { + try + { + // Fetch token silently + var result = await publicClientApplication.AcquireTokenSilent(@params.Scopes, account) + .ExecuteAsync(cancellationToken: cancellationToken) + .ConfigureAwait(false); + accessToken = new AccessToken(result!.AccessToken, result!.ExpiresOn); + } + catch (Exception e) + { + SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Silent authentication failed for resource {@params.Resource} for ConnectionId {@params.ConnectionId}."); + SqlToolsLogger.Error(e); + throw; + } + } + else + { + SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Account not found in MSAL cache for user."); + throw new Exception($"User account '{username}' not found in MSAL cache, please add linked account or refresh account credentials."); + } + } + else + { + SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account not received."); + throw new Exception($"User account not received."); + } + } + else + { + SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Authentication Method ${@params.AuthenticationMethod} is not supported."); + throw new Exception($"Authentication Method ${@params.AuthenticationMethod} is not supported."); + } + + return accessToken; + } + + #endregion + + #region Private methods + private IPublicClientApplication GetPublicClientAppInstance(string authority, string audience) + { + string authorityUrl = authority + '/' + audience; + if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance)) + { + clientApplicationInstance = CreatePublicClientAppInstance(authority, audience); + this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache); + PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance); + } + return clientApplicationInstance; + } + + private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) => + PublicClientApplicationBuilder.Create(this.applicationClientId) + .WithAuthority(authority, audience) + .WithClientName(this.applicationName) + .WithLogging(Utils.MSALLogCallback) + .WithDefaultRedirectUri() + .Build(); + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj b/src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj new file mode 100644 index 00000000..517cd82b --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/Microsoft.SqlTools.Authentication.csproj @@ -0,0 +1,18 @@ + + + + enable + enable + + + + + + + + + + + + + diff --git a/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs new file mode 100644 index 00000000..74dbbcbe --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs @@ -0,0 +1,110 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.Data.SqlClient; +using Microsoft.SqlTools.Authentication.Utility; + +namespace Microsoft.SqlTools.Authentication.Sql +{ + /// + /// Provides an implementation of for SQL Tools to be able to perform Federated authentication + /// silently with Microsoft.Data.SqlClient integration only for "ActiveDirectory" authentication modes. + /// When registered, the SqlClient driver calls the API + /// with server-sent authority information to request access token when needed. + /// + public class AuthenticationProvider : SqlAuthenticationProvider + { + private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255"; + private const string AzureTokenFolder = "Azure Accounts"; + private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud"; + private const string s_defaultScopeSuffix = "/.default"; + + private Authenticator authenticator; + + /// + /// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated. + /// + /// Application Name that identifies user folder path location for reading/writing to shared cache. + /// Callback that handles AAD authentication when user interaction is needed. + public AuthenticationProvider(string applicationName) + { + if(string.IsNullOrEmpty(applicationName)) { + applicationName = nameof(SqlTools); + } + var cachePath = Path.Combine(Utils.BuildAppDirectoryPath(), applicationName, AzureTokenFolder); + this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName); + } + + /// + /// Acquires access token with provided + /// + /// Authentication parameters + /// Authentication token containing access token and expiry date. + public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters) + { + // Setup scope + string resource = parameters.Resource; + string scope; + + if (parameters.Resource.EndsWith(s_defaultScopeSuffix)) + { + scope = parameters.Resource; + resource = parameters.Resource.Substring(0, parameters.Resource.LastIndexOf('/') + 1); + } + else + { + scope = parameters.Resource + s_defaultScopeSuffix; + } + + string[] scopes = new string[] { scope }; + + CancellationTokenSource cts = new CancellationTokenSource(); + + // Use Connection timeout value to cancel token acquire request after certain period of time. + cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds + + /* We split audience from Authority URL here. Audience can be one of the following: + * The Azure AD authority audience enumeration + * The tenant ID, which can be: + * - A GUID (the ID of your Azure AD instance), for single-tenant applications + * - A domain name associated with your Azure AD instance (also for single-tenant applications) + * One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration: + * - `organizations` for a multitenant application + * - `consumers` to sign in users only with their personal accounts + * - `common` to sign in users with their work and school accounts or their personal Microsoft accounts + * + * MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID. + * If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.) + * More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration + **/ + + int seperatorIndex = parameters.Authority.LastIndexOf('/'); + string authority = parameters.Authority.Remove(seperatorIndex + 1); + string audience = parameters.Authority.Substring(seperatorIndex + 1); + string? userName = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; + + AuthenticationParams @params = new AuthenticationParams( + AuthenticationMethod.ActiveDirectoryInteractive, + authority, + audience, + resource, + scopes, + userName!, + parameters.ConnectionId); + + AccessToken? result = await authenticator.GetTokenAsync(@params, cts.Token).ConfigureAwait(false); + return new SqlAuthenticationToken(result!.Token, result!.ExpiresOn); + } + + /// + /// Whether or not provided is supported. + /// + /// SQL Authentication method + /// + public override bool IsSupported(SqlAuthenticationMethod authenticationMethod) + => authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive; + + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs new file mode 100644 index 00000000..bc3b6712 --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs @@ -0,0 +1,119 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Net.Mail; +using System.Runtime.InteropServices; +using Microsoft.Identity.Client; +using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger; + +namespace Microsoft.SqlTools.Authentication.Utility +{ + internal sealed class Utils + { + /// + /// Validates provided follows email format. + /// + /// Email address + /// Whether email is in correct format. + public static bool isValidEmail(string userEmail) + { + try + { + new MailAddress(userEmail); + return true; + } + catch (FormatException) + { + return false; + } + } + + /// + /// Builds directory path based on environment settings. + /// + /// Application directory path + /// When called on unsupported platform. + public static string BuildAppDirectoryPath() + { + var homedir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + + // Windows + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var appData = Environment.GetEnvironmentVariable("APPDATA"); + var userProfile = Environment.GetEnvironmentVariable("USERPROFILE"); + if (appData != null) + { + return appData; + } + else if (userProfile != null) + { + return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming"); + } + else + { + throw new Exception("Not able to find APPDATA or USERPROFILE"); + } + } + + // Mac + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return string.Join(homedir, "Library", "Application Support"); + } + + // Linux + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME"); + if (xdgConfigHome != null) + { + return xdgConfigHome; + } + else + { + return string.Join(homedir, ".config"); + } + } + else + { + throw new Exception("Platform not supported"); + } + } + + /// + /// Log callback handler used for MSAL Client applications. + /// + /// Log level + /// Log message + /// Whether message contains PII information. + public static void MSALLogCallback(LogLevel logLevel, string message, bool pii) + { + switch (logLevel) + { + case LogLevel.Error: + if (pii) SqlToolsLogger.Pii(message); + else SqlToolsLogger.Error(message); + break; + case LogLevel.Warning: + if (pii) SqlToolsLogger.Pii(message); + else SqlToolsLogger.Warning(message); + break; + case LogLevel.Info: + if (pii) SqlToolsLogger.Pii(message); + else SqlToolsLogger.Information(message); + break; + case LogLevel.Verbose: + if (pii) SqlToolsLogger.Pii(message); + else SqlToolsLogger.Verbose(message); + break; + case LogLevel.Always: + if (pii) SqlToolsLogger.Pii(message); + else SqlToolsLogger.Critical(message); + break; + } + } + } +} diff --git a/src/Microsoft.SqlTools.Credentials/Program.cs b/src/Microsoft.SqlTools.Credentials/Program.cs index b270df5c..5dc87a38 100644 --- a/src/Microsoft.SqlTools.Credentials/Program.cs +++ b/src/Microsoft.SqlTools.Credentials/Program.cs @@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Credentials logFilePath = Logger.GenerateLogFilePath("credentials"); } - Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog); + Logger.Initialize(tracingLevel: commandOptions.TracingLevel, piiEnabled: commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "credentials", commandOptions.AutoFlushLog); // set up the host details and profile paths var hostDetails = new HostDetails( diff --git a/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj b/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj index 5749dde3..4efcce20 100644 --- a/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj +++ b/src/Microsoft.SqlTools.Hosting/Microsoft.SqlTools.Hosting.csproj @@ -25,4 +25,7 @@ + + + \ No newline at end of file diff --git a/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs b/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs index 0095cc8c..b56c084b 100644 --- a/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs +++ b/src/Microsoft.SqlTools.Hosting/Utility/CommandOptions.cs @@ -28,6 +28,7 @@ namespace Microsoft.SqlTools.Utility ServiceName = serviceName; ErrorMessage = string.Empty; Locale = string.Empty; + ApplicationName = string.Empty; try { @@ -42,12 +43,18 @@ namespace Microsoft.SqlTools.Utility switch (argName) { + case "-application-name": + ApplicationName = args[++i]; + break; case "-autoflush-log": AutoFlushLog = true; break; case "-tracing-level": TracingLevel = args[++i]; break; + case "-pii-logging": + PiiLogging = true; + break; case "-log-file": LogFilePath = args[++i]; break; @@ -66,6 +73,9 @@ namespace Microsoft.SqlTools.Utility case "-parallel-message-processing": ParallelMessageProcessing = true; break; + case "-enable-sql-authentication-provider": + EnableSqlAuthenticationProvider = true; + break; case "-parent-pid": string nextArg = args[++i]; if (Int32.TryParse(nextArg, out int parsedInt)) @@ -95,6 +105,11 @@ namespace Microsoft.SqlTools.Utility } } + /// + /// Name of application that is sending command options + /// + public string ApplicationName { get; private set; } + /// /// Contains any error messages during execution /// @@ -137,6 +152,8 @@ namespace Microsoft.SqlTools.Utility public string TracingLevel { get; private set; } + public bool PiiLogging { get; private set; } + public string LogFilePath { get; private set; } public bool AutoFlushLog { get; private set; } = false; @@ -147,6 +164,13 @@ namespace Microsoft.SqlTools.Utility /// public bool ParallelMessageProcessing { get; private set; } = false; + /// + /// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used + /// when user chooses 'Azure MFA'. This setting enables MSAL.NET to acquire token with SqlClient integration. + /// Currently this option is disabled by default, it's planned to be enabled by default in future releases. + /// + public bool EnableSqlAuthenticationProvider { get; private set; } = false; + /// /// The ID of the process that started this service. This is used to check when the parent /// process exits so that the service process can exit at the same time. diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs index 5b2d3b26..8d207dad 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs +++ b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs @@ -70,7 +70,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection _connectionRetryPolicy.RetryOccurred += RetryConnectionCallback; _commandRetryPolicy.RetryOccurred += RetryCommandCallback; - if (azureAccountToken != null) + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(_underlyingConnection.ConnectionString); + if (builder.Authentication is SqlAuthenticationMethod.NotSpecified && azureAccountToken is not null) { _underlyingConnection.AccessToken = azureAccountToken; } diff --git a/src/Microsoft.SqlTools.Migration/Program.cs b/src/Microsoft.SqlTools.Migration/Program.cs index 10e1d73c..d262f8d2 100644 --- a/src/Microsoft.SqlTools.Migration/Program.cs +++ b/src/Microsoft.SqlTools.Migration/Program.cs @@ -38,7 +38,7 @@ namespace Microsoft.SqlTools.Migration logFilePath = Logger.GenerateLogFilePath(logFilePath); } - Logger.Initialize(SourceLevels.Verbose, logFilePath, "Migration", commandOptions.AutoFlushLog); + Logger.Initialize(SourceLevels.Verbose, piiEnabled: commandOptions.PiiLogging, logFilePath, "Migration", commandOptions.AutoFlushLog); Logger.Verbose("Starting SqlTools Migration Server..."); diff --git a/src/Microsoft.SqlTools.ResourceProvider/Program.cs b/src/Microsoft.SqlTools.ResourceProvider/Program.cs index 5609bcae..2716e65f 100644 --- a/src/Microsoft.SqlTools.ResourceProvider/Program.cs +++ b/src/Microsoft.SqlTools.ResourceProvider/Program.cs @@ -41,7 +41,7 @@ namespace Microsoft.SqlTools.ResourceProvider } // we need to switch to Information when preparing for public preview - Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog); + Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "resourceprovider", commandOptions.AutoFlushLog); Logger.Write(TraceEventType.Information, "Starting SqlTools Resource Provider"); // set up the host details and profile paths diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs index bdcd58e0..1a09551a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs @@ -6,6 +6,7 @@ #nullable disable using Microsoft.SqlTools.Hosting.Contracts; +using static Microsoft.SqlTools.Shared.Utility.Constants; namespace Microsoft.SqlTools.ServiceLayer.Connection { @@ -52,7 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection CategoryValues = new CategoryValue[] { new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" }, new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" }, - new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" } + new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = AzureMFA } }, IsIdentity = true, IsRequired = true, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 067feeff..6485ef65 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -23,7 +23,9 @@ using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.Utility; +using static Microsoft.SqlTools.Shared.Utility.Constants; using System.Diagnostics; +using Microsoft.SqlTools.Authentication.Sql; namespace Microsoft.SqlTools.ServiceLayer.Connection { @@ -143,17 +145,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { RequestSecurityTokenParams message = new RequestSecurityTokenParams() { - Authority = authority, Provider = "Azure", + Authority = authority, Resource = resource, - Scope = scope + Scopes = new string[] { scope } }; - RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true); + RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true).ConfigureAwait(false); return response.Token; } + /// + /// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used + /// when user chooses 'Azure MFA'. + /// + public bool EnableSqlAuthenticationProvider { get; set; } + /// /// Returns a connection queue for given type /// @@ -248,7 +256,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection if (this.TryFindConnection(ownerUri, out connInfo)) { // If not an azure connection, no need to refresh token - if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA") + if (connInfo.ConnectionDetails.AuthenticationType != AzureMFA) { return false; } @@ -594,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionInfo.IsSqlDb = serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDatabase; connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse); // Determines that access token is used for creating connection. - connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == "AzureMFA"; + connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == AzureMFA; connectionInfo.EngineEdition = (DatabaseEngineEdition)serverInfo.EngineEditionId; // Azure Data Studio supports SQL Server 2014 and later releases. response.IsSupportedVersion = serverInfo.IsCloud || serverInfo.ServerMajorVersion >= 12; @@ -1061,10 +1069,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return handler.HandleRequest(this.connectionFactory, info); } - public void InitializeService(IProtocolEndpoint serviceHost) + public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions) { this.ServiceHost = serviceHost; + if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider) + { + // Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication. + var provider = new AuthenticationProvider(commandOptions.ApplicationName); + SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider); + + this.EnableSqlAuthenticationProvider = true; + Logger.Information("Registering implementation of SQL Authentication provider for 'Active Directory Interactive' authentication mode."); + } + // Register request and event handlers with the Service Host serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest, true); serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest, true); @@ -1304,31 +1322,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionBuilder = new SqlConnectionStringBuilder { - ["Data Source"] = dataSource, - ["User Id"] = connectionDetails.UserName, - ["Password"] = connectionDetails.Password + DataSource = dataSource }; } // Check for any optional parameters if (!string.IsNullOrEmpty(connectionDetails.DatabaseName)) { - connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName; + connectionBuilder.InitialCatalog = connectionDetails.DatabaseName; } if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType)) { switch (connectionDetails.AuthenticationType) { - case "Integrated": + case Integrated: connectionBuilder.IntegratedSecurity = true; break; - case "SqlLogin": + case SqlLogin: + connectionBuilder.UserID = connectionDetails.UserName; + connectionBuilder.Password = connectionDetails.Password; + connectionBuilder.Authentication = SqlAuthenticationMethod.SqlPassword; break; - case "AzureMFA": - connectionBuilder.UserID = ""; - connectionBuilder.Password = ""; + case AzureMFA: + if (Instance.EnableSqlAuthenticationProvider) + { + connectionBuilder.UserID = connectionDetails.UserName; + connectionDetails.AuthenticationType = ActiveDirectoryInteractive; + connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive; + } + else + { + connectionBuilder.UserID = ""; + } break; - case "ActiveDirectoryPassword": + case ActiveDirectoryInteractive: + connectionBuilder.UserID = connectionDetails.UserName; + connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive; + break; + case ActiveDirectoryPassword: + connectionBuilder.UserID = connectionDetails.UserName; + connectionBuilder.Password = connectionDetails.Password; connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryPassword; break; default: @@ -1337,16 +1370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } if (!string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting)) { - switch (connectionDetails.ColumnEncryptionSetting.ToUpper()) + if (Enum.TryParse(connectionDetails.ColumnEncryptionSetting, true, out var value)) { - case "ENABLED": - connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Enabled; - break; - case "DISABLED": - connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled; - break; - default: - throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting)); + connectionBuilder.ColumnEncryptionSetting = value; + } + else + { + throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting)); } } if (!string.IsNullOrEmpty(connectionDetails.SecureEnclaves)) @@ -1365,36 +1395,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationProtocol)) { - if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED" + if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled || string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED") { throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination); } - switch (connectionDetails.EnclaveAttestationProtocol.ToUpper()) + if (Enum.TryParse(connectionDetails.EnclaveAttestationProtocol, true, out var value)) { - case "AAS": - connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.AAS; - break; - case "HGS": - connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.HGS; - break; - case "NONE": - connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.None; - break; - default: - throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol)); + connectionBuilder.AttestationProtocol = value; + } + else + { + throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol)); } } if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationUrl)) { - if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED" + if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled || string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED") { throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination); } - if(connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None) + if (connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None) { throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAttestationProtocolNoneWithUrl); } @@ -1456,19 +1480,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent)) { - ApplicationIntent intent; - switch (connectionDetails.ApplicationIntent) + if (Enum.TryParse(connectionDetails.ApplicationIntent, true, out ApplicationIntent value)) { - case "ReadOnly": - intent = ApplicationIntent.ReadOnly; - break; - case "ReadWrite": - intent = ApplicationIntent.ReadWrite; - break; - default: - throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent)); + connectionBuilder.ApplicationIntent = value; + } + else + { + throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent)); } - connectionBuilder.ApplicationIntent = intent; } if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage)) { @@ -1585,7 +1604,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ApplicationIntent = builder.ApplicationIntent.ToString(), ApplicationName = builder.ApplicationName, AttachDbFilename = builder.AttachDBFilename, - AuthenticationType = builder.IntegratedSecurity ? "Integrated" : "SqlLogin", + AuthenticationType = builder.IntegratedSecurity ? "Integrated" : + (builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive + ? "ActiveDirectoryInteractive" : "SqlLogin"), ConnectRetryCount = builder.ConnectRetryCount, ConnectRetryInterval = builder.ConnectRetryInterval, ConnectTimeout = builder.ConnectTimeout, @@ -1793,7 +1814,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection SqlConnection sqlConn = new SqlConnection(connectionString); // Fill in Azure authentication token if needed - if (connInfo.ConnectionDetails.AzureAccountToken != null) + if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA) { sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken; } @@ -1824,7 +1845,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName); ServerConnection serverConnection; - if (connInfo.ConnectionDetails.AzureAccountToken != null) + if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA) { serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken)); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs index 013ba709..960c27a2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/SecurityTokenRequest.cs @@ -11,25 +11,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { class RequestSecurityTokenParams { - /// - /// Gets or sets the address of the authority to issue token. - /// - public string Authority { get; set; } - /// /// Gets or sets the provider that indicates the type of linked account to query. /// public string Provider { get; set; } + /// + /// Gets or sets the authority URL from where token is requested. + /// + public string Authority { get; set; } + /// /// Gets or sets the identifier of the target resource that is the recipient of the requested token. /// public string Resource { get; set; } /// - /// Gets or sets the scope of the authentication request. + /// Gets or sets the scope array of the authentication request. /// - public string Scope { get; set; } + public string [] Scopes { get; set; } } class RequestSecurityTokenResponse diff --git a/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs index b66c6fa9..64f86757 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DacFx/DacFxOperation.cs @@ -8,6 +8,7 @@ using Microsoft.SqlServer.Dac; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.TaskServices; using Microsoft.SqlTools.ServiceLayer.Utility; +using static Microsoft.SqlTools.Shared.Utility.Constants; using Microsoft.SqlTools.Utility; using System; using System.Diagnostics; @@ -86,7 +87,9 @@ namespace Microsoft.SqlTools.ServiceLayer.DacFx try { // Pass in Azure authentication token if needed - this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null ? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken)) : new DacServices(this.ConnectionString); + this.DacServices = this.ConnInfo.ConnectionDetails.AzureAccountToken != null && this.ConnInfo.ConnectionDetails.AuthenticationType == AzureMFA + ? new DacServices(this.ConnectionString, new AccessTokenProvider(this.ConnInfo.ConnectionDetails.AzureAccountToken)) + : new DacServices(this.ConnectionString); Execute(); } catch (Exception e) diff --git a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs index 539b82fb..1bc46184 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs @@ -40,6 +40,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.SqlProjects; using Microsoft.SqlTools.ServiceLayer.TableDesigner; +using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; namespace Microsoft.SqlTools.ServiceLayer @@ -53,7 +54,7 @@ namespace Microsoft.SqlTools.ServiceLayer private static object lockObject = new object(); private static bool isLoaded; - internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, Stream? inputStream = null, Stream? outputStream = null) + internal static ServiceHost CreateAndStartServiceHost(SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions, Stream? inputStream = null, Stream? outputStream = null) { ServiceHost serviceHost = ServiceHost.Instance; lock (lockObject) @@ -63,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer // Grab the instance of the service host serviceHost.Initialize(inputStream, outputStream); - InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext); + InitializeRequestHandlersAndServices(serviceHost, sqlToolsContext, commandOptions); // Start the service only after all request handlers are setup. This is vital // as otherwise the Initialize event can be lost - it's processed and discarded before the handler @@ -75,7 +76,7 @@ namespace Microsoft.SqlTools.ServiceLayer return serviceHost; } - private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext) + private static void InitializeRequestHandlersAndServices(ServiceHost serviceHost, SqlToolsContext sqlToolsContext, ServiceLayerCommandOptions? commandOptions) { // Load extension provider, which currently finds all exports in current DLL. Can be changed to find based // on directory or assembly list quite easily in the future @@ -96,7 +97,7 @@ namespace Microsoft.SqlTools.ServiceLayer LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext); serviceProvider.RegisterSingleService(LanguageService.Instance); - ConnectionService.Instance.InitializeService(serviceHost); + ConnectionService.Instance.InitializeService(serviceHost, commandOptions); serviceProvider.RegisterSingleService(ConnectionService.Instance); CredentialService.Instance.InitializeService(serviceHost); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj index 156e5ccf..1149eb4e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj +++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj @@ -57,6 +57,8 @@ + + diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs index e2b008d4..7e7f1b22 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs @@ -400,7 +400,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer var builder = ConnectionService.CreateConnectionStringBuilder(session.ConnectionInfo.ConnectionDetails); builder.InitialCatalog = node.NodeValue; builder.ApplicationName = TableDesignerService.TableDesignerApplicationName; - var azureToken = session.ConnectionInfo.ConnectionDetails.AzureAccountToken; + // Set Access Token only when authentication mode is not specified. + var azureToken = builder.Authentication == SqlAuthenticationMethod.NotSpecified + ? session.ConnectionInfo.ConnectionDetails.AzureAccountToken : null; TableDesignerCacheManager.StartDatabaseModelInitialization(builder.ToString(), azureToken); } catch (Exception ex) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Program.cs b/src/Microsoft.SqlTools.ServiceLayer/Program.cs index f63bd63f..cfc118b2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Program.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Program.cs @@ -41,7 +41,16 @@ namespace Microsoft.SqlTools.ServiceLayer logFilePath = Logger.GenerateLogFilePath("sqltools"); } - Logger.Initialize(tracingLevel: commandOptions.TracingLevel, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog); + Logger.Initialize(tracingLevel: commandOptions.TracingLevel, commandOptions.PiiLogging, logFilePath: logFilePath, traceSource: "sqltools", commandOptions.AutoFlushLog); + + // Register PII Logging configuration change callback + Workspace.WorkspaceService.Instance.RegisterConfigChangeCallback((newSettings, oldSettings, context) => + { + Logger.IsPiiEnabled = newSettings?.MssqlTools?.PiiLogging ?? false; + Logger.Information(Logger.IsPiiEnabled ? "PII Logging enabled" : "PII Logging disabled"); + return Task.FromResult(true); + }); + // Only enable SQL Client logging when verbose or higher to avoid extra overhead when the // detailed logging it provides isn't needed if (Logger.TracingLevel.HasFlag(SourceLevels.Verbose)) @@ -53,7 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer var hostDetails = new HostDetails(version: new Version(1, 0)); SqlToolsContext sqlToolsContext = new SqlToolsContext(hostDetails); - ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext); + ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, commandOptions); serviceHost.MessageDispatcher.ParallelMessageProcessing = commandOptions.ParallelMessageProcessing; // If this service was started by another process, then it should shutdown when that parent process does. diff --git a/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs index b85c6e24..707fb8eb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SchemaCompare/SchemaCompareUtils.cs @@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.DacFx.Contracts; using Microsoft.SqlTools.ServiceLayer.SchemaCompare.Contracts; using Microsoft.SqlTools.ServiceLayer.Utility; +using static Microsoft.SqlTools.Shared.Utility.Constants; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare @@ -187,7 +188,9 @@ namespace Microsoft.SqlTools.ServiceLayer.SchemaCompare case SchemaCompareEndpointType.Database: { string connectionString = GetConnectionString(connInfo, endpointInfo.DatabaseName); - return connInfo.ConnectionDetails?.AzureAccountToken != null + + // Set Access Token only when authentication mode is not specified. + return connInfo.ConnectionDetails?.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA ? new SchemaCompareDatabaseEndpoint(connectionString, new AccessTokenProvider(connInfo.ConnectionDetails.AzureAccountToken)) : new SchemaCompareDatabaseEndpoint(connectionString); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs index 9e7d10c7..e0c99530 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs @@ -17,6 +17,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.ServiceLayer.Utility; +using static Microsoft.SqlTools.Shared.Utility.Constants; using System.Linq; namespace Microsoft.SqlTools.ServiceLayer.Scripting @@ -109,7 +110,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting if (connInfo != null) { parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); - accessToken = connInfo.ConnectionDetails.AzureAccountToken; + // Set Access Token only when authentication type is AzureMFA. + if (connInfo.ConnectionDetails.AuthenticationType == AzureMFA) + { + accessToken = connInfo.ConnectionDetails.AzureAccountToken; + } } else { diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs index d38cb574..fc70aad9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/CompoundSqlToolsSettingsValues.cs @@ -18,7 +18,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// group such as Intellisense is defined on a serialized setting it's used in the order of mssql, then sql, then /// falls back to a default value. /// - public class CompoundToolsSettingsValues: ISqlToolsSettingsValues + public class CompoundToolsSettingsValues : ISqlToolsSettingsValues { private List priorityList = new List(); private SqlToolsSettingsValues defaultValues; @@ -44,11 +44,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// Gets or sets the detailed IntelliSense settings /// public IntelliSenseSettings IntelliSense - { + { get { return GetSettingOrDefault((settings) => settings.IntelliSense); - } + } set { priorityList[0].IntelliSense = value; @@ -59,11 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// Gets or sets the query execution settings /// public QueryExecutionSettings QueryExecutionSettings - { + { get { return GetSettingOrDefault((settings) => settings.QueryExecutionSettings); - } + } set { priorityList[0].QueryExecutionSettings = value; @@ -74,11 +74,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// Gets or sets the formatter settings /// public FormatterSettings Format - { + { get { return GetSettingOrDefault((settings) => settings.Format); - } + } set { priorityList[0].Format = value; @@ -89,15 +89,24 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// Gets or sets the object explorer settings /// public ObjectExplorerSettings ObjectExplorer - { + { get { return GetSettingOrDefault((settings) => settings.ObjectExplorer); - } + } set { priorityList[0].ObjectExplorer = value; } } + + /// + /// Gets or sets PII Logging setting. + /// + public bool PiiLogging + { + get => GetSettingOrDefault((settings) => settings.PiiLogging); + set => priorityList[0].PiiLogging = value; + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs index 8cac95b1..c6629628 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/ISqlToolsSettingsValues.cs @@ -5,6 +5,8 @@ #nullable disable +using System; + namespace Microsoft.SqlTools.ServiceLayer.SqlContext { /// @@ -31,5 +33,10 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// Object Explorer specific settings /// ObjectExplorerSettings ObjectExplorer { get; set; } + + /// + /// PII Logging setting + /// + Boolean PiiLogging { get; set; } } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs index f1879f0e..8acb87c4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/SqlContext/SqlToolsSettingsValues.cs @@ -5,6 +5,7 @@ #nullable disable +using System; using Newtonsoft.Json; namespace Microsoft.SqlTools.ServiceLayer.SqlContext @@ -25,6 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext QueryExecutionSettings = new QueryExecutionSettings(); Format = new FormatterSettings(); TableDesigner = new TableDesignerSettings(); + PiiLogging = false; } } @@ -57,5 +59,11 @@ namespace Microsoft.SqlTools.ServiceLayer.SqlContext /// [JsonProperty("tableDesigner")] public TableDesignerSettings TableDesigner { get; set; } + + /// + /// Gets or sets the setting to enable PII Logging. + /// + [JsonProperty("piiLogging")] + public Boolean PiiLogging { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs b/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs index f6abde88..7f140d6d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/TableDesigner/TableDesignerService.cs @@ -1800,7 +1800,11 @@ namespace Microsoft.SqlTools.ServiceLayer.TableDesigner connectionStringBuilder.InitialCatalog = tableInfo.Database; connectionStringBuilder.ApplicationName = TableDesignerService.TableDesignerApplicationName; var connectionString = connectionStringBuilder.ToString(); - tableDesigner = new Dac.TableDesigner(connectionString, tableInfo.AccessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable); + + // Set Access Token only when authentication mode is not specified. + var accessToken = connectionStringBuilder.Authentication == SqlAuthenticationMethod.NotSpecified + ? tableInfo.AccessToken : null; + tableDesigner = new Dac.TableDesigner(connectionString, accessToken, tableInfo.Schema, tableInfo.Name, tableInfo.IsNewTable); } else { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs index 44b965e7..a08d1c8e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/ServiceLayerCommandOptions.cs @@ -12,7 +12,7 @@ using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Utility { - class ServiceLayerCommandOptions : CommandOptions + public class ServiceLayerCommandOptions : CommandOptions { internal const string ServiceLayerServiceName = "MicrosoftSqlToolsServiceLayer.exe"; diff --git a/src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj b/src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj new file mode 100644 index 00000000..084c4ff0 --- /dev/null +++ b/src/Microsoft.SqlTools.Shared/Microsoft.SqlTools.Shared.csproj @@ -0,0 +1,8 @@ + + + + enable + enable + + + diff --git a/src/Microsoft.SqlTools.Shared/Utility/Constants.cs b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs new file mode 100644 index 00000000..d3a9d4e6 --- /dev/null +++ b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs @@ -0,0 +1,18 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.Shared.Utility +{ + public class Constants + { + // Authentication Types + public const string Integrated = "Integrated"; + public const string SqlLogin = "SqlLogin"; + public const string AzureMFA = "AzureMFA"; + public const string dstsAuth = "dstsAuth"; + public const string ActiveDirectoryInteractive = "ActiveDirectoryInteractive"; + public const string ActiveDirectoryPassword = "ActiveDirectoryPassword"; + } +} diff --git a/src/Microsoft.SqlTools.Hosting/Utility/Logger.cs b/src/Microsoft.SqlTools.Shared/Utility/Logger.cs similarity index 87% rename from src/Microsoft.SqlTools.Hosting/Utility/Logger.cs rename to src/Microsoft.SqlTools.Shared/Utility/Logger.cs index 26496206..46a582ff 100644 --- a/src/Microsoft.SqlTools.Hosting/Utility/Logger.cs +++ b/src/Microsoft.SqlTools.Shared/Utility/Logger.cs @@ -3,11 +3,8 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System; using System.Diagnostics; using System.Globalization; -using System.IO; -using System.Threading; namespace Microsoft.SqlTools.Utility { @@ -29,41 +26,45 @@ namespace Microsoft.SqlTools.Utility /// public static class Logger { - internal const SourceLevels defaultTracingLevel = SourceLevels.Critical; - internal const string defaultTraceSource = "sqltools"; + public const SourceLevels defaultTracingLevel = SourceLevels.Critical; + public const string defaultTraceSource = "sqltools"; private static SourceLevels tracingLevel = defaultTracingLevel; - private static string logFileFullPath; + private static string? logFileFullPath; + + public static TraceSource? TraceSource { get; set; } - internal static TraceSource TraceSource { get; set; } public static string LogFileFullPath { - get => logFileFullPath; + get => logFileFullPath!; private set { - //If the log file path has a directory component then ensure that the directory exists. - if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value))) + if (value != null) { - Directory.CreateDirectory(Path.GetDirectoryName(value)); - } + //If the log file path has a directory component then ensure that the directory exists. + if (!string.IsNullOrEmpty(Path.GetDirectoryName(value)) && !Directory.Exists(Path.GetDirectoryName(value))) + { + Directory.CreateDirectory(Path.GetDirectoryName(value)!); + } - logFileFullPath = value; + logFileFullPath = value; + } ConfigureListener(); } } - private static SqlToolsTraceListener Listener { get; set; } + private static SqlToolsTraceListener? Listener { get; set; } private static void ConfigureLogFile(string logFilePrefix) => LogFileFullPath = GenerateLogFilePath(logFilePrefix); /// /// Calling this method will turn on inclusion CallStack in the log for all future traces /// - public static void StartCallStack() => Listener.TraceOutputOptions |= TraceOptions.Callstack; + public static void StartCallStack() => Listener!.TraceOutputOptions |= TraceOptions.Callstack; /// /// Calling this method will turn off inclusion of CallStack in the log for all future traces /// - public static void StopCallStack() => Listener.TraceOutputOptions &= ~TraceOptions.Callstack; + public static void StopCallStack() => Listener!.TraceOutputOptions &= ~TraceOptions.Callstack; /// /// Calls flush on defaultTracingLevel configured listeners. @@ -86,17 +87,19 @@ namespace Microsoft.SqlTools.Utility get => tracingLevel; set { - if(TraceSource != null) + if (TraceSource != null) { // configures the source level filter. This alone is not enough for tracing that is done via "Trace" class instead of "TraceSource" object TraceSource.Switch = new SourceSwitch(TraceSource.Name, value.ToString()); } // configure the listener level filter tracingLevel = value; - Listener.Filter = new EventTypeFilter(tracingLevel); + Listener!.Filter = new EventTypeFilter(tracingLevel); } } + public static bool IsPiiEnabled { get; set; } = false; + public static bool AutoFlush { get; set; } = false; /// @@ -115,11 +118,13 @@ namespace Microsoft.SqlTools.Utility /// public static void Initialize( SourceLevels tracingLevel = defaultTracingLevel, - string logFilePath = null, + bool piiEnabled = false, + string? logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false) { Logger.tracingLevel = tracingLevel; + Logger.IsPiiEnabled = piiEnabled; Logger.AutoFlush = autoFlush; TraceSource = new TraceSource(traceSource, Logger.tracingLevel); if (string.IsNullOrWhiteSpace(logFilePath)) @@ -146,11 +151,12 @@ namespace Microsoft.SqlTools.Utility /// /// Optional. Specifies whether the log is flushed after every message /// - public static void Initialize(string tracingLevel, string logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false) + public static void Initialize(string tracingLevel, bool piiEnabled, string? logFilePath = null, string traceSource = defaultTraceSource, bool autoFlush = false) { Initialize(Enum.TryParse(tracingLevel, out SourceLevels sourceTracingLevel) ? sourceTracingLevel : defaultTracingLevel + , piiEnabled , logFilePath , traceSource , autoFlush); @@ -169,7 +175,7 @@ namespace Microsoft.SqlTools.Utility throw new ArgumentOutOfRangeException(nameof(logFilePrefix), $"LogfilePath cannot be configured if argument {nameof(logFilePrefix)} has not been set"); } // Create the log directory - string logDir = Path.GetDirectoryName(logFilePrefix); + string? logDir = Path.GetDirectoryName(logFilePrefix); if (!string.IsNullOrWhiteSpace(logDir)) { if (!Directory.Exists(logDir)) @@ -199,7 +205,7 @@ namespace Microsoft.SqlTools.Utility } string fileName; - try + try { var now = DateTime.Now; fileName = string.Format(CultureInfo.InvariantCulture, @@ -239,6 +245,16 @@ namespace Microsoft.SqlTools.Utility /// The message text to be written. public static void Write(TraceEventType eventType, string logMessage) => Write(eventType, LogEvent.Default, logMessage); + /// + /// Writes a PII message to the log file with the Verbose event level when PII flag is enabled. + /// + /// The message text to be written. + public static void Pii(string logMessage) { + if (IsPiiEnabled) { + Write(TraceEventType.Verbose, logMessage); + } + } + /// /// Writes a message to the log file with the Verbose event level /// @@ -386,9 +402,9 @@ namespace Microsoft.SqlTools.Utility ); } #region forward actual write/close/flush/dispose calls to the underlying listener. - public override void Write(string message) => Listener.Write(message); + public override void Write(string? message) => Listener.Write(message); - public override void WriteLine(string message) => Listener.WriteLine(message); + public override void WriteLine(string? message) => Listener.WriteLine(message); /// /// Closes the so that it no longer @@ -430,41 +446,41 @@ namespace Microsoft.SqlTools.Utility } #endregion - public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id) + public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id) { TraceEvent(eventCache, source, eventType, id, String.Empty); } // All other TraceEvent methods come through this one. - public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string message) + public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? message) { if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, message, null, null, null)) { return; } - WriteHeader(eventCache, source, eventType, id); - WriteLine(message); - WriteFooter(eventCache); + WriteHeader(eventCache!, source, eventType, id); + WriteLine(message!); + WriteFooter(eventCache!); } - public override void TraceEvent(TraceEventCache eventCache, String source, TraceEventType eventType, int id, string format, params object[] args) + public override void TraceEvent(TraceEventCache? eventCache, String source, TraceEventType eventType, int id, string? format, params object?[]? args) { if (Filter != null && !Filter.ShouldTrace(eventCache, source, eventType, id, format, args, null, null)) { return; } - WriteHeader(eventCache, source, eventType, id); + WriteHeader(eventCache!, source, eventType, id); if (args != null) { - WriteLine(String.Format(CultureInfo.InvariantCulture, format, args)); + WriteLine(String.Format(CultureInfo.InvariantCulture, format!, args)); } else { - WriteLine(format); + WriteLine(format!); } - WriteFooter(eventCache); + WriteFooter(eventCache!); } private void WriteHeader(TraceEventCache eventCache, String source, TraceEventType eventType, int id) diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs index 07afb461..c9e7caf4 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs @@ -12,14 +12,15 @@ using Microsoft.Kusto.ServiceLayer.DataSource; using Microsoft.Kusto.ServiceLayer.DataSource.Intellisense; using Microsoft.Kusto.ServiceLayer.LanguageServices; using Microsoft.Kusto.ServiceLayer.Workspace.Contracts; +using static Microsoft.SqlTools.Shared.Utility.Constants; using NUnit.Framework; namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource { public class DataSourceFactoryTests { - [TestCase(typeof(ArgumentException), "ConnectionString", "", "AzureMFA")] - [TestCase(typeof(ArgumentException), "ConnectionString", "", "dstsAuth")] + [TestCase(typeof(ArgumentException), "ConnectionString", "", AzureMFA)] + [TestCase(typeof(ArgumentException), "ConnectionString", "", dstsAuth)] public void Create_Throws_Exceptions_For_InvalidAzureAccountToken(Type exceptionType, string connectionString, string azureAccountToken, string authType) { Program.ServiceName = "Kusto"; diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs index 1a5f7aa2..99b8b1e8 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/KustoClientTests.cs @@ -8,14 +8,15 @@ using System; using Microsoft.Kusto.ServiceLayer.DataSource.Contracts; using Microsoft.Kusto.ServiceLayer.DataSource.Kusto; +using static Microsoft.SqlTools.Shared.Utility.Constants; using NUnit.Framework; namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource { public class KustoClientTests { - [TestCase("dstsAuth")] - [TestCase("AzureMFA")] + [TestCase(dstsAuth)] + [TestCase(AzureMFA)] public void Constructor_Throws_ArgumentException_For_MissingToken(string authType) { var connectionDetails = new DataSourceConnectionDetails @@ -37,7 +38,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource UserToken = "UserToken", ServerName = clusterName, DatabaseName = "", - AuthenticationType = "AzureMFA" + AuthenticationType = AzureMFA }; var client = new KustoClient(connectionDetails, "ownerUri"); @@ -46,10 +47,10 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource Assert.AreEqual("NetDefaultDB", client.DatabaseName); } - [TestCase("dstsAuth")] - [TestCase("AzureMFA")] + [TestCase(dstsAuth)] + [TestCase(AzureMFA)] [TestCase("NoAuth")] - [TestCase("SqlLogin")] + [TestCase(SqlLogin)] public void Constructor_Creates_Client_With_Valid_AuthenticationType(string authenticationType) { string clusterName = "https://fake.url.com"; @@ -59,8 +60,8 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource ServerName = clusterName, DatabaseName = "FakeDatabaseName", AuthenticationType = authenticationType, - UserName = authenticationType == "SqlLogin" ? "username": null, - Password = authenticationType == "SqlLogin" ? "password": null + UserName = authenticationType == SqlLogin ? "username": null, + Password = authenticationType == SqlLogin ? "password": null }; var client = new KustoClient(connectionDetails, "ownerUri"); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj index e21705fa..e62da1a9 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj @@ -7,8 +7,8 @@ false - - + + @@ -25,5 +25,6 @@ + \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs index b4ffd905..6625fb8a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestLogger.cs @@ -32,6 +32,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common public SourceLevels TracingLevel { get; set; } = SourceLevels.Critical; public bool DoNotUseTraceSource { get; set; } = false; + public bool IsPiiEnabled { get; set; } = false; + public bool AutoFlush { get; set; } = false; private List pendingVerifications; @@ -43,7 +45,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common public string LogFileName { get => logFileName ?? Logger.LogFileFullPath; set => logFileName = value; } public void Initialize() => - Logger.Initialize(TracingLevel, LogFilePath, TraceSource, AutoFlush); // initialize the logger + Logger.Initialize(TracingLevel, IsPiiEnabled, LogFilePath, TraceSource, AutoFlush); // initialize the logger public string LogContents { get diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs index 6c2ba0b4..d01ecf9d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestServiceProvider.cs @@ -195,7 +195,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common // Initialize the ServiceHost, using a MemoryStream for the output stream so that we don't fill up the logs // with a bunch of outgoing messages (which aren't used for anything during tests) - ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, new MemoryStream()); + ServiceHost serviceHost = HostLoader.CreateAndStartServiceHost(sqlToolsContext, null, null, new MemoryStream()); // Set up our logger to write to Console for tests to help debug issues Logger.Initialize(autoFlush: true); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 10c87f29..c5669d60 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -12,6 +12,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using static Microsoft.SqlTools.Shared.Utility.Constants; using Moq; using Moq.Protected; using NUnit.Framework; @@ -479,6 +480,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection new object[] {null, "12345678"}, new object[] {"", "12345678"}, }; + /// /// Verify that when using integrated authentication, the username and/or password can be empty. /// @@ -497,13 +499,67 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection DatabaseName = "test", UserName = userName, Password = password, - AuthenticationType = "Integrated" + AuthenticationType = Integrated } }); Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "check that the connection was successful"); } + /// + /// Verify that username is required when using Active Directory Interactive authentication. + /// Both AzureMFA and ActiveDirectoryInteractive should work same way, when SqlAuthenticationProvider is enabled. + /// + [TestCase(null, AzureMFA)] + [TestCase(null, ActiveDirectoryInteractive)] + public async Task ConnectingWitNoUsernameFailsForAADInteractiveAuth(string userName, string authMode) + { + // This is an exception scenario test, therefore using ConnectionService instance instead of TestConnectionService directly. + ConnectionService.Instance.EnableSqlAuthenticationProvider = true; + // Connect + var connectionResult = await + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = "file:///my/test/file.sql", + Connection = new ConnectionDetails() + { + ServerName = "my-server", + DatabaseName = "test", + UserName = userName, + AuthenticationType = authMode + } + }); + + Assert.That(connectionResult.ConnectionId, Is.Null.Or.Empty, "Connection should not be successful"); + } + + /// + /// Verify that password is ignored when using Active Directory Interactive authentication. + /// + [TestCase("user", "anything", AzureMFA)] + [TestCase("user", "anything", ActiveDirectoryInteractive)] + public async Task ConnectingWitPasswordIsIgnoredForAADInteractiveAuth(string username, string password, string authMode) + { + // Connect + var connectionResult = await + TestObjects.GetTestConnectionService() + .Connect(new ConnectParams() + { + OwnerUri = "file:///my/test/file.sql", + Connection = new ConnectionDetails() + { + ServerName = "my-server", + DatabaseName = "test", + UserName = username, + Password = password, + AuthenticationType = authMode + } + }); + + Assert.That(connectionResult.ConnectionId, Is.Not.Null.Or.Empty, "Connection should be successful"); + } + /// /// Verify that when connecting with a null parameters object, an error is thrown. /// @@ -1773,16 +1829,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection details.AzureAccountToken = azureAccountToken; details.UserName = ""; details.Password = ""; - details.AuthenticationType = "AzureMFA"; + details.AuthenticationType = AzureMFA; - // If I open a connection using connection details that include an account token + // Open a connection using connection details that include an account token await connectionService.Connect(new ConnectParams { OwnerUri = "testURI", Connection = details }); - // Then the connection factory got called with details including an account token + // Validate that the connection factory gets called with details NOT including an account token mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny(), It.Is(accountToken => accountToken == azureAccountToken)), Times.Once()); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj index 0f947ee0..9719192d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj @@ -34,6 +34,7 @@ + diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs index 672e474d..98c2bd0f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ServiceHost/LoggerTests.cs @@ -109,7 +109,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost string tracingLevel = null; SourceLevels expectedTracingLevel = Logger.defaultTracingLevel; string expectedTraceSource = Logger.defaultTraceSource; - Logger.Initialize(tracingLevel); + Logger.Initialize(tracingLevel, false); bool isLogFileExpectedToExist = false; TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++); TestLogger.Cleanup(Logger.LogFileFullPath); @@ -120,7 +120,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost string tracingLevel = null; SourceLevels expectedTracingLevel = Logger.defaultTracingLevel; string expectedTraceSource = Logger.defaultTraceSource; - Logger.Initialize(tracingLevel); + Logger.Initialize(tracingLevel, false); bool isLogFileExpectedToExist = false; TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++); TestLogger.Cleanup(Logger.LogFileFullPath); @@ -131,7 +131,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ServiceHost string tracingLevel = "invalid"; SourceLevels expectedTracingLevel = Logger.defaultTracingLevel; string expectedTraceSource = Logger.defaultTraceSource; - Logger.Initialize(tracingLevel); + Logger.Initialize(tracingLevel, false); bool isLogFileExpectedToExist = false; TestLogger.VerifyInitialization(expectedTracingLevel, expectedTraceSource, Logger.LogFileFullPath, isLogFileExpectedToExist, testNo++); TestLogger.Cleanup(Logger.LogFileFullPath);