From 8a898837dcdfa6da42eac88616782bd3069cf8c2 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com> Date: Mon, 10 Jul 2023 15:02:52 -0700 Subject: [PATCH] Fix app name and connection string defaults to match extension defaults (#2132) --- .../Hosting/Protocol/Constants.cs | 1 - .../Connection/ConnectionService.cs | 93 +++++++++++-------- .../Connection/ConnectionServiceTests.cs | 10 +- 3 files changed, 57 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Constants.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Constants.cs index f2b7aabb..44d186fd 100644 --- a/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Constants.cs +++ b/src/Microsoft.SqlTools.Hosting/Hosting/Protocol/Constants.cs @@ -13,7 +13,6 @@ namespace Microsoft.SqlTools.Hosting.Protocol public const string ContentLengthFormatString = "Content-Length: {0}\r\n\r\n"; public static readonly JsonSerializerSettings JsonSerializerSettings; - public static readonly string DefaultApplicationName = "azdata"; public static readonly string SqlLoginAuthenticationType = "SqlLogin"; // Feature names diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 7e7f049d..f7f5d058 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -63,6 +63,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// public static ConnectionService Instance => instance.Value; + private static readonly SqlConnectionStringBuilder defaultBuilder = new SqlConnectionStringBuilder(); + /// /// IV and Key as received from Encryption Key Notification event. /// @@ -165,6 +167,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return response.Token; } + /// + /// Default Application name as received in service startup + /// + public static string ApplicationName { get; set; } /// /// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used @@ -482,9 +488,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Connection Service will not set custom application name if connection pooling is enabled on service. if (!EnableConnectionPooling && !string.IsNullOrWhiteSpace(applicationName) && !string.IsNullOrWhiteSpace(featureName) && !applicationName.EndsWith(featureName)) { - int azdataStartIndex = applicationName.IndexOf(Constants.DefaultApplicationName); - string originalAppName = azdataStartIndex != -1 - ? applicationName.Substring(0, azdataStartIndex + Constants.DefaultApplicationName.Length) + int appNameStartIndex = applicationName.IndexOf(ApplicationName); + string originalAppName = appNameStartIndex != -1 + ? applicationName.Substring(0, appNameStartIndex + ApplicationName.Length) : applicationName; // Reset to default if azdata not found. appNameWithFeature = $"{originalAppName}-{featureName}"; } @@ -1096,6 +1102,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection if (commandOptions != null) { + ApplicationName = commandOptions.ApplicationName switch + { + "azuredatastudio" => "azdata", + "code" => "vscode-mssql", + _ => "sqltools" // fallback + }; + if (commandOptions.EnableSqlAuthenticationProvider) { // Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication. @@ -1636,7 +1649,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // default connection string application name to always be included unless set to false if (connStringBuilder.ApplicationName == null && (!connStringParams.IncludeApplicationName.HasValue || connStringParams.IncludeApplicationName.Value == true)) { - connStringBuilder.ApplicationName = Constants.DefaultApplicationName; + connStringBuilder.ApplicationName = ApplicationName; } connectionString = connStringBuilder.ConnectionString; @@ -1665,41 +1678,43 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public ConnectionDetails ParseConnectionString(string connectionString) { SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connectionString); + + // Set defaults as per MSSQL connection property defaults, not SqlClient's Connection string buider defaults ConnectionDetails details = new ConnectionDetails() { - ApplicationIntent = builder.ApplicationIntent.ToString(), - ApplicationName = builder.ApplicationName, - AttachDbFilename = builder.AttachDBFilename, + ApplicationIntent = defaultBuilder.ApplicationIntent != builder.ApplicationIntent ? builder.ApplicationIntent.ToString() : null, + ApplicationName = defaultBuilder.ApplicationName != builder.ApplicationName ? builder.ApplicationName : ApplicationName, + AttachDbFilename = defaultBuilder.AttachDBFilename != builder.AttachDBFilename ? builder.AttachDBFilename.ToString() : null, AuthenticationType = builder.IntegratedSecurity ? "Integrated" : (builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive - ? "ActiveDirectoryInteractive" : "SqlLogin"), - ConnectRetryCount = builder.ConnectRetryCount, - ConnectRetryInterval = builder.ConnectRetryInterval, - ConnectTimeout = builder.ConnectTimeout, - CommandTimeout = builder.CommandTimeout, - CurrentLanguage = builder.CurrentLanguage, - DatabaseName = builder.InitialCatalog, - ColumnEncryptionSetting = builder.ColumnEncryptionSetting.ToString(), - EnclaveAttestationProtocol = builder.AttestationProtocol == SqlConnectionAttestationProtocol.NotSpecified ? null : builder.AttestationProtocol.ToString(), - EnclaveAttestationUrl = builder.EnclaveAttestationUrl, - Encrypt = builder.Encrypt.ToString(), - FailoverPartner = builder.FailoverPartner, - HostNameInCertificate = builder.HostNameInCertificate, - LoadBalanceTimeout = builder.LoadBalanceTimeout, - MaxPoolSize = builder.MaxPoolSize, - MinPoolSize = builder.MinPoolSize, - MultipleActiveResultSets = builder.MultipleActiveResultSets, - MultiSubnetFailover = builder.MultiSubnetFailover, - PacketSize = builder.PacketSize, - Password = !builder.IntegratedSecurity ? builder.Password : string.Empty, - PersistSecurityInfo = builder.PersistSecurityInfo, - Pooling = builder.Pooling, - Replication = builder.Replication, - ServerName = builder.DataSource, - TrustServerCertificate = builder.TrustServerCertificate, - TypeSystemVersion = builder.TypeSystemVersion, - UserName = builder.UserID, - WorkstationId = builder.WorkstationID, + ? "AzureMFA" : "SqlLogin"), + ConnectRetryCount = defaultBuilder.ConnectRetryCount != builder.ConnectRetryCount ? builder.ConnectRetryCount : 1, + ConnectRetryInterval = defaultBuilder.ConnectRetryInterval != builder.ConnectRetryInterval ? builder.ConnectRetryInterval : 10, + ConnectTimeout = defaultBuilder.ConnectTimeout != builder.ConnectTimeout ? builder.ConnectTimeout : 30, + CommandTimeout = defaultBuilder.CommandTimeout != builder.CommandTimeout ? builder.CommandTimeout : 30, + CurrentLanguage = defaultBuilder.CurrentLanguage != builder.CurrentLanguage ? builder.CurrentLanguage : null, + DatabaseName = defaultBuilder.InitialCatalog != builder.InitialCatalog ? builder.InitialCatalog : null, + ColumnEncryptionSetting = defaultBuilder.ColumnEncryptionSetting != builder.ColumnEncryptionSetting ? builder.ColumnEncryptionSetting.ToString() : null, + EnclaveAttestationProtocol = defaultBuilder.AttestationProtocol != builder.AttestationProtocol ? builder.AttestationProtocol.ToString() : null, + EnclaveAttestationUrl = defaultBuilder.EnclaveAttestationUrl != builder.EnclaveAttestationUrl ? builder.EnclaveAttestationUrl : null, + Encrypt = defaultBuilder.Encrypt != builder.Encrypt ? builder.Encrypt.ToString() : Boolean.TrueString.ToLower(CultureInfo.InvariantCulture), + FailoverPartner = defaultBuilder.FailoverPartner != builder.FailoverPartner ? builder.FailoverPartner : null, + HostNameInCertificate = defaultBuilder.HostNameInCertificate != builder.HostNameInCertificate ? builder.HostNameInCertificate : null, + LoadBalanceTimeout = defaultBuilder.LoadBalanceTimeout != builder.LoadBalanceTimeout ? builder.LoadBalanceTimeout : null, + MaxPoolSize = defaultBuilder.MaxPoolSize != builder.MaxPoolSize ? builder.MaxPoolSize : null, + MinPoolSize = defaultBuilder.MinPoolSize != builder.MinPoolSize ? builder.MinPoolSize : null, + MultipleActiveResultSets = defaultBuilder.MultipleActiveResultSets != builder.MultipleActiveResultSets ? builder.MultipleActiveResultSets : null, + MultiSubnetFailover = defaultBuilder.MultiSubnetFailover != builder.MultiSubnetFailover ? builder.MultiSubnetFailover : null, + PacketSize = defaultBuilder.PacketSize != builder.PacketSize ? builder.PacketSize : null, + Password = !builder.IntegratedSecurity ? builder.Password : null, + PersistSecurityInfo = defaultBuilder.PersistSecurityInfo != builder.PersistSecurityInfo ? builder.PersistSecurityInfo : null, + Pooling = defaultBuilder.Pooling != builder.Pooling ? builder.Pooling : null, + Replication = defaultBuilder.Replication != builder.Replication ? builder.Replication : null, + ServerName = defaultBuilder.DataSource != builder.DataSource ? builder.DataSource : null, + TrustServerCertificate = defaultBuilder.TrustServerCertificate != builder.TrustServerCertificate ? builder.TrustServerCertificate : false, + TypeSystemVersion = defaultBuilder.TypeSystemVersion != builder.TypeSystemVersion ? builder.TypeSystemVersion : null, + UserName = defaultBuilder.UserID != builder.UserID ? builder.UserID : null, + WorkstationId = defaultBuilder.WorkstationID != builder.WorkstationID ? builder.WorkstationID : null }; return details; @@ -1853,8 +1868,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection try { // capture original values - int? originalTimeout = connInfo.ConnectionDetails.ConnectTimeout; - int? originalCommandTimeout = connInfo.ConnectionDetails.CommandTimeout; bool? originalPersistSecurityInfo = connInfo.ConnectionDetails.PersistSecurityInfo; bool? originalPooling = connInfo.ConnectionDetails.Pooling; @@ -1862,8 +1875,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection bool shouldForceDisablePooling = !EnableConnectionPooling && featureName != Constants.LanguageServiceFeature; // increase the connection and command timeout to at least 30 seconds and and build connection string - connInfo.ConnectionDetails.ConnectTimeout = Math.Max(30, originalTimeout ?? 0); - connInfo.ConnectionDetails.CommandTimeout = Math.Max(30, originalCommandTimeout ?? 0); + connInfo.ConnectionDetails.ConnectTimeout = Math.Max(30, connInfo.ConnectionDetails.ConnectTimeout ?? 0); + connInfo.ConnectionDetails.CommandTimeout = Math.Max(30, connInfo.ConnectionDetails.CommandTimeout ?? 0); // enable PersistSecurityInfo to handle issues in SMO where the connection context is lost in reconnections connInfo.ConnectionDetails.PersistSecurityInfo = true; @@ -1878,8 +1891,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails, shouldForceDisablePooling); // restore original values - connInfo.ConnectionDetails.ConnectTimeout = originalTimeout; - connInfo.ConnectionDetails.CommandTimeout = originalCommandTimeout; connInfo.ConnectionDetails.PersistSecurityInfo = originalPersistSecurityInfo; connInfo.ConnectionDetails.Pooling = originalPooling; diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 1b76c0b1..89af4576 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -1807,9 +1807,9 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.That(details.DatabaseName, Is.EqualTo("{databasename}"), "Unexpected database name"); Assert.That(details.UserName, Is.EqualTo("{your_username}"), "Unexpected username"); Assert.That(details.Password, Is.EqualTo("{your_password}"), "Unexpected password"); - Assert.That(details.PersistSecurityInfo, Is.False, "Unexpected Persist Security Info"); - Assert.That(details.MultipleActiveResultSets, Is.False, "Unexpected Multiple Active Result Sets value"); - Assert.That(details.Encrypt, Is.EqualTo("True"), "Unexpected Encrypt value"); + Assert.That(details.PersistSecurityInfo, Is.Null, "Unexpected Persist Security Info"); + Assert.That(details.MultipleActiveResultSets, Is.Null, "Unexpected Multiple Active Result Sets value"); + Assert.That(details.Encrypt, Is.EqualTo("true"), "Unexpected Encrypt value"); Assert.That(details.TrustServerCertificate, Is.False, "Unexpected database name value"); Assert.That(details.HostNameInCertificate, Is.EqualTo("{servername}"), "Unexpected Host Name in Certificate value"); Assert.That(details.ConnectTimeout, Is.EqualTo(30), "Unexpected Connect Timeout value"); @@ -1832,8 +1832,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.That(details.DatabaseName, Is.EqualTo("{databasename}"), "Unexpected database name"); Assert.That(details.UserName, Is.EqualTo("{your_username}"), "Unexpected username"); Assert.That(details.Password, Is.EqualTo("{your_password}"), "Unexpected password"); - Assert.That(details.PersistSecurityInfo, Is.False, "Unexpected Persist Security Info"); - Assert.That(details.MultipleActiveResultSets, Is.False, "Unexpected Multiple Active Result Sets value"); + Assert.That(details.PersistSecurityInfo, Is.Null, "Unexpected Persist Security Info"); + Assert.That(details.MultipleActiveResultSets, Is.Null, "Unexpected Multiple Active Result Sets value"); Assert.That(details.Encrypt, Is.EqualTo(SqlConnectionEncryptOption.Strict.ToString()), "Unexpected Encrypt value"); Assert.That(details.TrustServerCertificate, Is.False, "Unexpected database name value"); Assert.That(details.HostNameInCertificate, Is.EqualTo("{servername}"), "Unexpected Host Name in Certificate value");