diff --git a/Packages.props b/Packages.props index baa6bf15..debbf14f 100644 --- a/Packages.props +++ b/Packages.props @@ -6,8 +6,8 @@ - - + + @@ -19,14 +19,14 @@ - - - - + + + + - + @@ -45,4 +45,4 @@ - \ No newline at end of file + diff --git a/bin/nuget/Microsoft.SqlServer.Migration.Assessment.1.0.20220816.40.nupkg b/bin/nuget/Microsoft.SqlServer.Migration.Assessment.1.0.20221014.16.nupkg similarity index 71% rename from bin/nuget/Microsoft.SqlServer.Migration.Assessment.1.0.20220816.40.nupkg rename to bin/nuget/Microsoft.SqlServer.Migration.Assessment.1.0.20221014.16.nupkg index 3adccc1c..12cce49b 100644 Binary files a/bin/nuget/Microsoft.SqlServer.Migration.Assessment.1.0.20220816.40.nupkg and b/bin/nuget/Microsoft.SqlServer.Migration.Assessment.1.0.20221014.16.nupkg differ diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs index ede01b0b..cdd9b04a 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs @@ -91,14 +91,6 @@ namespace Microsoft.Kusto.ServiceLayer.Connection GroupName = "Initialization" }, new ConnectionOption - { - Name = "asynchronousProcessing", - DisplayName = "Asynchronous processing enabled", - Description = "When true, enables usage of the Asynchronous functionality in the .Net Framework Data Provider", - ValueType = ConnectionOption.ValueTypeBoolean, - GroupName = "Initialization" - }, - new ConnectionOption { Name = "connectTimeout", DisplayName = "Connect timeout", diff --git a/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj b/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj index e917d052..694c68e5 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj +++ b/src/Microsoft.Kusto.ServiceLayer/Microsoft.Kusto.ServiceLayer.csproj @@ -31,8 +31,9 @@ - + + diff --git a/src/Microsoft.SqlTools.Hosting/Utility/GeneralRequestDetails.cs b/src/Microsoft.SqlTools.Hosting/Utility/GeneralRequestDetails.cs index 0611b1fc..4f39e6fe 100644 --- a/src/Microsoft.SqlTools.Hosting/Utility/GeneralRequestDetails.cs +++ b/src/Microsoft.SqlTools.Hosting/Utility/GeneralRequestDetails.cs @@ -86,7 +86,7 @@ namespace Microsoft.SqlTools.Utility enumValue = Enum.Parse(t, value); return true; } - catch(Exception) + catch (Exception) { enumValue = default(T); return false; diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/BatchParserSqlCmd.cs b/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/BatchParserSqlCmd.cs index c3cb2897..a07633bd 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/BatchParserSqlCmd.cs +++ b/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/BatchParserSqlCmd.cs @@ -155,12 +155,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode } } - if (fullFileName == null) - { - fullFileName = Path.GetFullPath(fileName); - } - - return fullFileName; + return fullFileName ?? Path.GetFullPath(fileName); } catch (ArgumentException) { diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/ScriptExecutionArgs.cs b/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/ScriptExecutionArgs.cs index 340017af..9e04a2d7 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/ScriptExecutionArgs.cs +++ b/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/ExecutionEngineCode/ScriptExecutionArgs.cs @@ -128,15 +128,7 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode internal Dictionary Variables { - get - { - if (cmdVariables == null) - { - cmdVariables = new Dictionary(StringComparer.CurrentCultureIgnoreCase); - } - - return cmdVariables; - } + get => cmdVariables ??= new Dictionary(StringComparer.CurrentCultureIgnoreCase); } #endregion } diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/Parser.cs b/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/Parser.cs index 6d2582a7..18b11e12 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/Parser.cs +++ b/src/Microsoft.SqlTools.ManagedBatchParser/BatchParser/Parser.cs @@ -608,18 +608,10 @@ namespace Microsoft.SqlTools.ServiceLayer.BatchParser } internal void RaiseError(ErrorCode errorCode, string message = null) - { - RaiseError(errorCode, LookaheadToken, message); - } + => RaiseError(errorCode, LookaheadToken, message); internal static void RaiseError(ErrorCode errorCode, Token token, string message = null) - { - if (message == null) - { - message = string.Format(CultureInfo.CurrentCulture, SR.BatchParser_IncorrectSyntax, token.Text); - } - throw new BatchParserException(errorCode, token, message); - } + => throw new BatchParserException(errorCode, token, message ?? string.Format(CultureInfo.CurrentCulture, SR.BatchParser_IncorrectSyntax, token.Text)); internal string ResolveVariables(Token inputToken, int offset, List variableRefs) { diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj b/src/Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj index 01dc72c3..0c87959c 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj +++ b/src/Microsoft.SqlTools.ManagedBatchParser/Microsoft.SqlTools.ManagedBatchParser.csproj @@ -1,7 +1,8 @@  - - netstandard2.0 + + net6.0;net472 + 9.0 disable false false diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableConnectionHelper.cs b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableConnectionHelper.cs index fd21b6b0..752a0fba 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableConnectionHelper.cs +++ b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableConnectionHelper.cs @@ -276,10 +276,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection Debug.Assert(conn.State == ConnectionState.Open, "connection passed to ExecuteNonQuery should be open."); cmd = conn.CreateCommand(); - if (initializeCommand == null) - { - initializeCommand = SetCommandTimeout; - } + + initializeCommand ??= SetCommandTimeout; initializeCommand(cmd); cmd.CommandText = commandText; @@ -331,10 +329,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection Debug.Assert(conn.State == ConnectionState.Open, "connection passed to ExecuteScalar should be open."); cmd = conn.CreateCommand(); - if (initializeCommand == null) - { - initializeCommand = SetCommandTimeout; - } + + initializeCommand ??= SetCommandTimeout; initializeCommand(cmd); cmd.CommandText = commandText; @@ -384,11 +380,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection { cmd = conn.CreateCommand(); - if (initializeCommand == null) - { - initializeCommand = SetCommandTimeout; - } - + initializeCommand ??= SetCommandTimeout; initializeCommand(cmd); cmd.CommandText = commandText; diff --git a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs index e6b81a47..5b2d3b26 100644 --- a/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs +++ b/src/Microsoft.SqlTools.ManagedBatchParser/ReliableConnection/ReliableSqlConnection.cs @@ -478,10 +478,7 @@ SET NUMERIC_ROUNDABORT OFF;"; { // Verify whether or not the connection is valid and is open. This code may be retried therefore // it is important to ensure that a connection is re-established should it have previously failed. - if (command.Connection == null) - { - command.Connection = this; - } + command.Connection ??= this; if (command.Connection.State != ConnectionState.Open) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs index c742a834..a1c4a043 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs @@ -91,14 +91,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection GroupName = "Initialization" }, new ConnectionOption - { - Name = "asynchronousProcessing", - DisplayName = "Asynchronous processing enabled", - Description = "When true, enables usage of the Asynchronous functionality in the .Net Framework Data Provider", - ValueType = ConnectionOption.ValueTypeBoolean, - GroupName = "Initialization" - }, - new ConnectionOption { Name = "connectTimeout", DisplayName = "Connect timeout", @@ -152,10 +144,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { Name = "encrypt", DisplayName = "Encrypt", - Description = - "When true, SQL Server uses SSL encryption for all data sent between the client and server if the servers has a certificate installed", + Description = "When set, SQL Server uses provided setting for SSL encryption for all data sent between the client and server.", + ValueType = ConnectionOption.ValueTypeCategory, GroupName = "Security", - ValueType = ConnectionOption.ValueTypeBoolean + CategoryValues = new CategoryValue[] { + new CategoryValue { DisplayName = "Optional", Name = "Optional" }, + new CategoryValue { DisplayName = "Mandatory", Name = "Mandatory" }, + new CategoryValue { DisplayName = "Strict", Name = "Strict" } + } }, new ConnectionOption { @@ -174,6 +170,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ValueType = ConnectionOption.ValueTypeBoolean }, new ConnectionOption + { + Name = "hostNameInCertificate", + DisplayName = "HostNameInCertificate", + Description = "Specifies host name in certificate to be used for certificate validation, when encryption is enabled.", + GroupName = "Security", + ValueType = ConnectionOption.ValueTypeString, + }, + new ConnectionOption { Name = "attachedDBFileName", DisplayName = "Attached DB file name", diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 528b6eca..0792e0b8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -410,7 +410,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection if (response?.ErrorNumber == 40613) { counter++; - if(counter != MaxServerlessReconnectTries) { + if (counter != MaxServerlessReconnectTries) + { Logger.Information($"Database for connection {connectionInfo.OwnerUri} is paused, retrying connection. Attempt #{counter}"); } } @@ -1299,14 +1300,26 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionBuilder.EnclaveAttestationUrl = connectionDetails.EnclaveAttestationUrl; } - if (connectionDetails.Encrypt.HasValue) + + if (!string.IsNullOrEmpty(connectionDetails.Encrypt)) { - connectionBuilder.Encrypt = connectionDetails.Encrypt.Value; + connectionBuilder.Encrypt = connectionDetails.Encrypt.ToLowerInvariant() switch + { + "optional" or "false" or "no" => SqlConnectionEncryptOption.Optional, + "mandatory" or "true" or "yes" => SqlConnectionEncryptOption.Mandatory, + "strict" => SqlConnectionEncryptOption.Strict, + _ => throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEncryptOption(connectionDetails.Encrypt)) + }; } + if (connectionDetails.TrustServerCertificate.HasValue) { connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value; } + if (!string.IsNullOrEmpty(connectionDetails.HostNameInCertificate)) + { + connectionBuilder.HostNameInCertificate = connectionDetails.HostNameInCertificate; + } if (connectionDetails.PersistSecurityInfo.HasValue) { connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value; @@ -1471,8 +1484,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ColumnEncryptionSetting = builder.ColumnEncryptionSetting.ToString(), EnclaveAttestationProtocol = builder.AttestationProtocol == SqlConnectionAttestationProtocol.NotSpecified ? null : builder.AttestationProtocol.ToString(), EnclaveAttestationUrl = builder.EnclaveAttestationUrl, - Encrypt = builder.Encrypt, + Encrypt = builder.Encrypt.ToString(), FailoverPartner = builder.FailoverPartner, + HostNameInCertificate = builder.HostNameInCertificate, LoadBalanceTimeout = builder.LoadBalanceTimeout, MaxPoolSize = builder.MaxPoolSize, MinPoolSize = builder.MinPoolSize, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index fac37cef..7694d16a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -147,13 +147,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts } /// - /// Gets or sets a Boolean value that indicates whether SQL Server uses SSL encryption for all data sent between the client and server if the server has a certificate installed. + /// Gets or sets a value that indicates encryption mode that SQL Server should use to perform SSL encryption for all the data sent between the client and server. Supported values are: Optional, Mandatory, Strict. /// - public bool? Encrypt + public string Encrypt { get { - return GetOptionValue("encrypt"); + return GetOptionValue("encrypt"); } set @@ -178,6 +178,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts } } + /// + /// Gets or sets a value that indicates the host name in the certificate to be used for certificate validation when encryption is enabled. + /// + public string HostNameInCertificate + { + get + { + return GetOptionValue("hostNameInCertificate"); + } + + set + { + SetOptionValue("hostNameInCertificate", value); + } + } + /// /// Gets or sets a Boolean value that indicates if security-sensitive information, such as the password, is not returned as part of the connection if the connection is open or has ever been in an open state. /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs index 33161fc0..0fd7ae83 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs @@ -27,6 +27,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts EnclaveAttestationUrl = details.EnclaveAttestationUrl, Encrypt = details.Encrypt, TrustServerCertificate = details.TrustServerCertificate, + HostNameInCertificate = details.HostNameInCertificate, PersistSecurityInfo = details.PersistSecurityInfo, ConnectTimeout = details.ConnectTimeout, ConnectRetryCount = details.ConnectRetryCount, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs index 13b1af8e..14aff68d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.cs @@ -9641,6 +9641,11 @@ namespace Microsoft.SqlTools.ServiceLayer return Keys.GetString(Keys.ConnectionServiceConnStringInvalidColumnEncryptionSetting, columnEncryptionSetting); } + public static string ConnectionServiceConnStringInvalidEncryptOption(string encrypt) + { + return Keys.GetString(Keys.ConnectionServiceConnStringInvalidEncryptOption, encrypt); + } + public static string ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(string enclaveAttestationProtocol) { return Keys.GetString(Keys.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol, enclaveAttestationProtocol); @@ -10062,6 +10067,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string ConnectionServiceConnStringInvalidColumnEncryptionSetting = "ConnectionServiceConnStringInvalidColumnEncryptionSetting"; + public const string ConnectionServiceConnStringInvalidEncryptOption = "ConnectionServiceConnStringInvalidEncryptOption"; + + public const string ConnectionServiceConnStringInvalidEnclaveAttestationProtocol = "ConnectionServiceConnStringInvalidEnclaveAttestationProtocol"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx index c2dbeef8..8669e7ed 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.resx @@ -144,6 +144,11 @@ Invalid value '{0}' for ComlumEncryption. Valid values are 'Enabled' and 'Disabled'. . Parameters: 0 - columnEncryptionSetting (string) + + + Invalid value '{0}' for Encrypt. Valid values are 'Optional', 'Mandatory', 'Strict', 'True', 'False', 'Yes' and 'No'. + . + Parameters: 0 - encrypt (string) Invalid value '{0}' for EnclaveAttestationProtocol. Valid values are 'AAS' and 'HGS'. diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings index 84a8abd3..7a52cb53 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.strings @@ -35,6 +35,8 @@ ConnectionServiceConnStringInvalidAuthType(string authType) = Invalid value '{0} ConnectionServiceConnStringInvalidColumnEncryptionSetting(string columnEncryptionSetting) = Invalid value '{0}' for ComlumEncryption. Valid values are 'Enabled' and 'Disabled'. +ConnectionServiceConnStringInvalidEncryptOption(string encrypt) = Invalid value '{0}' for Encrypt. Valid values are 'Optional', 'Mandatory', 'Strict', 'True', 'False', 'Yes' and 'No'. + ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(string enclaveAttestationProtocol) = Invalid value '{0}' for EnclaveAttestationProtocol. Valid values are 'AAS' and 'HGS'. ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination = The Attestation Protocol and Enclave Attestation URL requires Always Encrypted to be set to Enabled. diff --git a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf index 9ef56950..3de1029f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf +++ b/src/Microsoft.SqlTools.ServiceLayer/Localization/sr.xlf @@ -2034,6 +2034,12 @@ Invalid value '{0}' for ComlumEncryption. Valid values are 'Enabled' and 'Disabled'. . Parameters: 0 - columnEncryptionSetting (string) + + + Invalid value '{0}' for Encrypt. Valid values are 'Optional', 'Mandatory', 'Strict', 'True', 'False', 'Yes' and 'No'. + Invalid value '{0}' for Encrypt. Valid values are 'Optional', 'Mandatory', 'Strict', 'True', 'False', 'Yes' and 'No'. + . + Parameters: 0 - encrypt (string) Invalid value '{0}' for EnclaveAttestationProtocol. Valid values are 'AAS' and 'HGS'. diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj index d5081c62..291a05d9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj +++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj @@ -49,11 +49,12 @@ - + + - + diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionProviderOptionsHelperTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionProviderOptionsHelperTests.cs index d3219753..c52d06cd 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionProviderOptionsHelperTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionProviderOptionsHelperTests.cs @@ -11,10 +11,10 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.Connection public class ConnectionProviderOptionsHelperTests { [Test] - public void BuildConnectionProviderOptions_Returns_31_Options() + public void BuildConnectionProviderOptions_Returns_30_Options() { var providerOptions = ConnectionProviderOptionsHelper.BuildConnectionProviderOptions(); - Assert.AreEqual(31, providerOptions.Options.Length); + Assert.AreEqual(30, providerOptions.Options.Length); } } } \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs index e5a069bf..72653ac1 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs @@ -242,8 +242,19 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection csb.Password = connectParams.Connection.Password; } csb.ConnectTimeout = connectParams.Connection.ConnectTimeout.HasValue ? connectParams.Connection.ConnectTimeout.Value: 30; - csb.Encrypt = connectParams.Connection.Encrypt.HasValue ? connectParams.Connection.Encrypt.Value : false; + + csb.Encrypt = connectParams.Connection.Encrypt?.ToLowerInvariant() switch { + "optional" or "false" or "no" => SqlConnectionEncryptOption.Optional, + "mandatory" or "true" or "yes" => SqlConnectionEncryptOption.Mandatory, + "strict" => SqlConnectionEncryptOption.Strict, + _ => default + }; + csb.TrustServerCertificate = connectParams.Connection.TrustServerCertificate.HasValue ? connectParams.Connection.TrustServerCertificate.Value : false; + if (!string.IsNullOrEmpty(connectParams.Connection.HostNameInCertificate)) + { + csb.HostNameInCertificate = connectParams.Connection.HostNameInCertificate; + } return csb; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DacFx/DacFxServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DacFx/DacFxServiceTests.cs index d3791d02..cf6255a4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DacFx/DacFxServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DacFx/DacFxServiceTests.cs @@ -852,7 +852,7 @@ Streaming query statement contains a reference to missing output stream 'Missing Assert.That(options.ObjectTypesDictionary, Is.Not.Null, "Object types dictionary is empty"); // Verify that the objects dictionary has all the item from Enum - Assert.That(options.ObjectTypesDictionary.Count, Is.EqualTo(Enum.GetNames(typeof(ObjectType)).Length), @"ObjectTypesDictionary is missing these objectTypes: {0}", + Assert.That(options.ObjectTypesDictionary.Count, Is.EqualTo(Enum.GetNames(typeof(ObjectType)).Length), @"ObjectTypesDictionary is missing these objectTypes: {0}", string.Join(", ", Enum.GetNames(typeof(ObjectType)).Except(options.ObjectTypesDictionary.Keys))); // Verify the options in the objects dictionary exists in the ObjectType Enum @@ -956,158 +956,158 @@ Streaming query statement contains a reference to missing output stream 'Missing } } } -} -[TestFixture] -public class TSqlModelRequestTests -{ - private string TSqlModelTestFolder = string.Empty; - - private DacFxService service = new DacFxService(); - - [SetUp] - public void Create() + [TestFixture] + public class TSqlModelRequestTests { - TSqlModelTestFolder = Path.Combine("..", "..", "..", "DacFx", "TSqlModels", Guid.NewGuid().ToString()); - Directory.CreateDirectory(TSqlModelTestFolder); - } + private string TSqlModelTestFolder = string.Empty; - [TearDown] - public void CleanUp() - { - Directory.Delete(TSqlModelTestFolder, true); - } + private DacFxService service = new DacFxService(); - /// - /// Verify the generate Tsql model operation - /// - [Test] - public void GenerateTSqlModelFromSqlFiles() - { - string sqlTable1DefinitionPath = Path.Join(TSqlModelTestFolder, "table1.sql"); - string sqlTable2DefinitionPath = Path.Join(TSqlModelTestFolder, "table2.sql"); - const string table1 = @"CREATE TABLE [dbo].[table1] + [SetUp] + public void Create() + { + TSqlModelTestFolder = Path.Combine("..", "..", "..", "DacFx", "TSqlModels", Guid.NewGuid().ToString()); + Directory.CreateDirectory(TSqlModelTestFolder); + } + + [TearDown] + public void CleanUp() + { + Directory.Delete(TSqlModelTestFolder, true); + } + + /// + /// Verify the generate Tsql model operation + /// + [Test] + public void GenerateTSqlModelFromSqlFiles() + { + string sqlTable1DefinitionPath = Path.Join(TSqlModelTestFolder, "table1.sql"); + string sqlTable2DefinitionPath = Path.Join(TSqlModelTestFolder, "table2.sql"); + const string table1 = @"CREATE TABLE [dbo].[table1] ( [ID] INT NOT NULL PRIMARY KEY, )"; - const string table2 = @"CREATE TABLE [dbo].[table2] + const string table2 = @"CREATE TABLE [dbo].[table2] ( [ID] INT NOT NULL PRIMARY KEY, )"; - // create sql file - File.WriteAllText(sqlTable1DefinitionPath, table1); - File.WriteAllText(sqlTable2DefinitionPath, table2); + // create sql file + File.WriteAllText(sqlTable1DefinitionPath, table1); + File.WriteAllText(sqlTable2DefinitionPath, table2); - var generateTSqlScriptParams = new GenerateTSqlModelParams + var generateTSqlScriptParams = new GenerateTSqlModelParams + { + ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), + ModelTargetVersion = "Sql160", + FilePaths = new[] { sqlTable1DefinitionPath, sqlTable2DefinitionPath } + }; + + GenerateTSqlModelOperation op = new GenerateTSqlModelOperation(generateTSqlScriptParams); + var model = op.GenerateTSqlModel(); + var objects = model.GetObjects(DacQueryScopes.UserDefined, ModelSchema.Table).ToList(); + + Assert.That(model.Version.ToString(), Is.EqualTo(generateTSqlScriptParams.ModelTargetVersion), $"Model version is not equal to {generateTSqlScriptParams.ModelTargetVersion}"); + Assert.That(objects, Is.Not.Empty, "Model is empty"); + + var tableNames = objects.Select(o => o.Name.ToString()).ToList(); + + Assert.That(tableNames.Count, Is.EqualTo(2), "Model was not populated correctly"); + CollectionAssert.AreEquivalent(tableNames, new[] { "[dbo].[table1]", "[dbo].[table2]" }, "Table names do not match"); + } + + /// + /// Verify the generate Tsql model operation, creates an empty model when files are empty + /// + [Test] + public void GenerateEmptyTSqlModel() { - ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), - ModelTargetVersion = "Sql160", - FilePaths = new[] { sqlTable1DefinitionPath, sqlTable2DefinitionPath } - }; + var generateTSqlScriptParams = new GenerateTSqlModelParams + { + ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), + ModelTargetVersion = "Sql160", + FilePaths = new string[] { } + }; - GenerateTSqlModelOperation op = new GenerateTSqlModelOperation(generateTSqlScriptParams); - var model = op.GenerateTSqlModel(); - var objects = model.GetObjects(DacQueryScopes.UserDefined, ModelSchema.Table).ToList(); + GenerateTSqlModelOperation op = new GenerateTSqlModelOperation(generateTSqlScriptParams); + var model = op.GenerateTSqlModel(); - Assert.That(model.Version.ToString(), Is.EqualTo(generateTSqlScriptParams.ModelTargetVersion), $"Model version is not equal to {generateTSqlScriptParams.ModelTargetVersion}"); - Assert.That(objects, Is.Not.Empty, "Model is empty"); + Assert.That(model.GetObjects(DacQueryScopes.UserDefined, ModelSchema.Table).ToList().Count, Is.EqualTo(0), "Model is not empty"); + Assert.That(model.Version.ToString(), Is.EqualTo(generateTSqlScriptParams.ModelTargetVersion), $"Model version is not equal to {generateTSqlScriptParams.ModelTargetVersion}"); + } - var tableNames = objects.Select(o => o.Name.ToString()).ToList(); - - Assert.That(tableNames.Count, Is.EqualTo(2), "Model was not populated correctly"); - CollectionAssert.AreEquivalent(tableNames, new[] { "[dbo].[table1]", "[dbo].[table2]" }, "Table names do not match"); - } - - /// - /// Verify the generate Tsql model operation, creates an empty model when files are empty - /// - [Test] - public void GenerateEmptyTSqlModel() - { - var generateTSqlScriptParams = new GenerateTSqlModelParams + /// + /// Verify the generate TSql Model handle + /// + [Test] + public async Task VerifyGenerateTSqlModelHandle() { - ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), - ModelTargetVersion = "Sql160", - FilePaths = new string[] { } - }; + var generateTSqlScriptParams = new GenerateTSqlModelParams + { + ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), + ModelTargetVersion = "Sql160", + FilePaths = new string[] { } + }; - GenerateTSqlModelOperation op = new GenerateTSqlModelOperation(generateTSqlScriptParams); - var model = op.GenerateTSqlModel(); + var requestContext = new Mock>(); + requestContext.Setup((RequestContext x) => x.SendResult(It.Is((result) => result == true))).Returns(Task.FromResult(new object())); - Assert.That(model.GetObjects(DacQueryScopes.UserDefined, ModelSchema.Table).ToList().Count, Is.EqualTo(0), "Model is not empty"); - Assert.That(model.Version.ToString(), Is.EqualTo(generateTSqlScriptParams.ModelTargetVersion), $"Model version is not equal to {generateTSqlScriptParams.ModelTargetVersion}"); - } + await service.HandleGenerateTSqlModelRequest(generateTSqlScriptParams, requestContext.Object); + Assert.That(service.projectModels.Value, Contains.Key(generateTSqlScriptParams.ProjectUri), "Model was not stored under project uri"); + } - /// - /// Verify the generate TSql Model handle - /// - [Test] - public async Task VerifyGenerateTSqlModelHandle() - { - var generateTSqlScriptParams = new GenerateTSqlModelParams + /// + /// Verify the get objects TSql Model handle + /// + [Test] + public async Task VerifyGetObjectsFromTSqlModelHandle() { - ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), - ModelTargetVersion = "Sql160", - FilePaths = new string[] { } - }; - - var requestContext = new Mock>(); - requestContext.Setup((RequestContext x) => x.SendResult(It.Is((result) => result == true))).Returns(Task.FromResult(new object())); - - await service.HandleGenerateTSqlModelRequest(generateTSqlScriptParams, requestContext.Object); - Assert.That(service.projectModels.Value, Contains.Key(generateTSqlScriptParams.ProjectUri), "Model was not stored under project uri"); - } - - /// - /// Verify the get objects TSql Model handle - /// - [Test] - public async Task VerifyGetObjectsFromTSqlModelHandle() - { - string sqlTable1DefinitionPath = Path.Join(TSqlModelTestFolder, "table1.sql"); - string sqlTable2DefinitionPath = Path.Join(TSqlModelTestFolder, "table2.sql"); - string view1DefinitionPath = Path.Join(TSqlModelTestFolder, "view1.sql"); - const string table1 = @"CREATE TABLE [dbo].[table1] + string sqlTable1DefinitionPath = Path.Join(TSqlModelTestFolder, "table1.sql"); + string sqlTable2DefinitionPath = Path.Join(TSqlModelTestFolder, "table2.sql"); + string view1DefinitionPath = Path.Join(TSqlModelTestFolder, "view1.sql"); + const string table1 = @"CREATE TABLE [dbo].[table1] ( [ID] INT NOT NULL PRIMARY KEY, )"; - const string table2 = @"CREATE TABLE [dbo].[table2] + const string table2 = @"CREATE TABLE [dbo].[table2] ( [ID] INT NOT NULL PRIMARY KEY, )"; - const string view1 = "CREATE VIEW [dbo].[view1] AS SELECT dbo.table1.* FROM dbo.table1"; - // create sql file - File.WriteAllText(sqlTable1DefinitionPath, table1); - File.WriteAllText(sqlTable2DefinitionPath, table2); - File.WriteAllText(view1DefinitionPath, view1); + const string view1 = "CREATE VIEW [dbo].[view1] AS SELECT dbo.table1.* FROM dbo.table1"; + // create sql file + File.WriteAllText(sqlTable1DefinitionPath, table1); + File.WriteAllText(sqlTable2DefinitionPath, table2); + File.WriteAllText(view1DefinitionPath, view1); - var generateTSqlScriptParams = new GenerateTSqlModelParams - { - ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), - ModelTargetVersion = "Sql160", - FilePaths = new[] { sqlTable1DefinitionPath, sqlTable2DefinitionPath } - }; - - GenerateTSqlModelOperation op = new GenerateTSqlModelOperation(generateTSqlScriptParams); - var model = op.GenerateTSqlModel(); + var generateTSqlScriptParams = new GenerateTSqlModelParams + { + ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), + ModelTargetVersion = "Sql160", + FilePaths = new[] { sqlTable1DefinitionPath, sqlTable2DefinitionPath } + }; - service.projectModels.Value.TryAdd(generateTSqlScriptParams.ProjectUri, model); - - var getObjectsParams = new GetObjectsFromTSqlModelParams - { - ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), - ObjectTypes = new[] { "Table" } - }; + GenerateTSqlModelOperation op = new GenerateTSqlModelOperation(generateTSqlScriptParams); + var model = op.GenerateTSqlModel(); - var requestContext = new Mock>(); - var actualResponse = new List(); - requestContext.Setup(x => x.SendResult(It.IsAny())) - .Callback(actual => actualResponse = actual.ToList()) - .Returns(Task.CompletedTask); - await service.HandleGetObjectsFromTSqlModelRequest(getObjectsParams, requestContext.Object); + service.projectModels.Value.TryAdd(generateTSqlScriptParams.ProjectUri, model); - Assert.IsNotNull(actualResponse); - Assert.AreEqual(actualResponse.Count, 2); - CollectionAssert.AreEquivalent(actualResponse.Select(o => o.Name), new[] { "[dbo].[table1]", "[dbo].[table2]" }, "Table names do not match"); + var getObjectsParams = new GetObjectsFromTSqlModelParams + { + ProjectUri = Path.Join(TSqlModelTestFolder, "test.sqlproj"), + ObjectTypes = new[] { "Table" } + }; + + var requestContext = new Mock>(); + var actualResponse = new List(); + requestContext.Setup(x => x.SendResult(It.IsAny())) + .Callback(actual => actualResponse = actual.ToList()) + .Returns(Task.CompletedTask); + await service.HandleGetObjectsFromTSqlModelRequest(getObjectsParams, requestContext.Object); + + Assert.IsNotNull(actualResponse); + Assert.AreEqual(actualResponse.Count, 2); + CollectionAssert.AreEquivalent(actualResponse.Select(o => o.Name), new[] { "[dbo].[table1]", "[dbo].[table2]" }, "Table names do not match"); + } } -} +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Utility/LiveConnectionHelper.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Utility/LiveConnectionHelper.cs index 46dc73f4..7d8ce12f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Utility/LiveConnectionHelper.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Utility/LiveConnectionHelper.cs @@ -46,11 +46,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility } public static TestConnectionResult InitLiveConnectionInfo(string databaseName = null, string ownerUri = null) - { - var task = InitLiveConnectionInfoAsync(databaseName, ownerUri, ServiceLayer.Connection.ConnectionType.Default); - task.Wait(); - return task.Result; - } + => InitLiveConnectionInfoAsync(databaseName, ownerUri, ServiceLayer.Connection.ConnectionType.Default).ConfigureAwait(false).GetAwaiter().GetResult(); public static async Task InitLiveConnectionInfoAsync(string databaseName = "master", string ownerUri = null, string connectionType = ServiceLayer.Connection.ConnectionType.Default, TestServerType serverType = TestServerType.OnPrem) diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/ConnectionSetting.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/ConnectionSetting.cs index 81fd1edb..bbde4832 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/ConnectionSetting.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/ConnectionSetting.cs @@ -16,9 +16,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common public class ConnectionSetting { [JsonProperty("mssql.connections")] - public List Connections { get; set; } + public List? Connections { get; set; } - public InstanceInfo GetConnectionProfile(string profileName, string serverName) + public InstanceInfo? GetConnectionProfile(string profileName, string serverName) { if (!string.IsNullOrEmpty(profileName) && Connections != null) { @@ -28,7 +28,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common return byProfileName; } } - return Connections.FirstOrDefault(x => x.ServerName == serverName); + return Connections?.FirstOrDefault(x => x.ServerName == serverName); } } @@ -47,23 +47,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common public string ServerName { get; set; } [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] - public string Database { get; set; } + public string? Database { get; set; } [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] - public string User { get; set; } + public string? User { get; set; } [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] - public string Password { get; set; } + public string? Password { get; set; } [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] - public string ProfileName { get; set; } + public string? ProfileName { get; set; } + + [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] + public string? Encrypt { get; set; } + + [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] + public string? HostNameInCertificate { get; set; } public TestServerType ServerType { get; set; } public AuthenticationType AuthenticationType { get; set; } [JsonProperty(NullValueHandling = NullValueHandling.Ignore)] - public string RemoteSharePath { get; set; } + public string? RemoteSharePath { get; set; } public int ConnectTimeout { get; set; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/SqlTestDb.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/SqlTestDb.cs index a0cdcf46..f85764e1 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/SqlTestDb.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/SqlTestDb.cs @@ -35,6 +35,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common { DataSource = connectParams.Connection.ServerName, InitialCatalog = connectParams.Connection.DatabaseName, + TrustServerCertificate = true }; if (connectParams.Connection.AuthenticationType == "Integrated") @@ -45,6 +46,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common { builder.UserID = connectParams.Connection.UserName; builder.Password = connectParams.Connection.Password; + builder.PersistSecurityInfo = true; + } + + if (!string.IsNullOrEmpty(connectParams.Connection.Encrypt)) + { + builder.Encrypt = connectParams.Connection.Encrypt switch + { + "optional" or "false" or "no" => SqlConnectionEncryptOption.Optional, + "mandatory" or "true" or "yes" => SqlConnectionEncryptOption.Mandatory, + "strict" => SqlConnectionEncryptOption.Strict, + _ => SqlConnectionEncryptOption.Optional + }; + } + + if (!string.IsNullOrEmpty(connectParams.Connection.HostNameInCertificate)) + { + builder.HostNameInCertificate = connectParams.Connection.HostNameInCertificate; } return builder.ToString(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestConnectionProfileService.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestConnectionProfileService.cs index 2db5727f..a470cf41 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestConnectionProfileService.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/TestConnectionProfileService.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.Globalization; +using Microsoft.Data.SqlClient; using Microsoft.SqlTools.Credentials.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using NUnit.Framework; @@ -37,12 +38,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common } } - public static InstanceInfo SqlAzure + public static InstanceInfo? SqlAzure { get { return GetInstance(SqlAzureInstanceKey); } } - public static InstanceInfo SqlOnPrem + public static InstanceInfo? SqlOnPrem { get { return GetInstance(SqlOnPremInstanceKey); } } @@ -50,30 +51,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common /// /// Returns the SQL connection info for given version key /// - public static InstanceInfo GetInstance(string key) + public static InstanceInfo? GetInstance(string key) { - InstanceInfo instanceInfo; - connectionProfilesCache.TryGetValue(key, out instanceInfo); + connectionProfilesCache.TryGetValue(key, out InstanceInfo? instanceInfo); Assert.True(instanceInfo != null, string.Format(CultureInfo.InvariantCulture, "Cannot find any instance for version key: {0}", key)); return instanceInfo; } - public ConnectParams GetConnectionParameters(string key = SqlOnPremInstanceKey, string databaseName = null) + public ConnectParams? GetConnectionParameters(string key = SqlOnPremInstanceKey, string databaseName = null) { - InstanceInfo instanceInfo = GetInstance(key); - if (instanceInfo != null) - { - ConnectParams connectParam = CreateConnectParams(instanceInfo, key, databaseName); - - return connectParam; - } - return null; + InstanceInfo? instanceInfo = GetInstance(key); + return instanceInfo != null ? CreateConnectParams(instanceInfo, key, databaseName) : null; } /// /// Returns database connection parameters for given server type /// - public ConnectParams GetConnectionParameters(TestServerType serverType = TestServerType.OnPrem, string databaseName = null) + public ConnectParams? GetConnectionParameters(TestServerType serverType = TestServerType.OnPrem, string databaseName = null) { string key = ConvertServerTypeToVersionKey(serverType); return GetConnectionParameters(key, databaseName); @@ -94,11 +88,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common Console.WriteLine("DBTestInstance not configured. Run 'dotnet run Microsoft.SqlTools.ServiceLayer.TestEnvConfig' from the command line to configure"); } - if (testServers != null && settings != null) + if (testServers != null) { foreach (var serverIdentity in testServers) { - var instance = settings != null ? settings.GetConnectionProfile(serverIdentity.ProfileName, serverIdentity.ServerName) : null; + var instance = settings?.GetConnectionProfile(serverIdentity.ProfileName, serverIdentity.ServerName); if (instance?.ServerType == TestServerType.None) { instance.ServerType = serverIdentity.ServerType; @@ -106,7 +100,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common } } } - if (settings != null) + if (settings?.Connections != null) { foreach (var instance in settings.Connections) { @@ -162,12 +156,38 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common ConnectParams connectParams = new ConnectParams(); connectParams.Connection = new ConnectionDetails(); connectParams.Connection.ServerName = connectionProfile.ServerName; - connectParams.Connection.DatabaseName = connectionProfile.Database; - connectParams.Connection.DatabaseDisplayName = connectionProfile.Database; - connectParams.Connection.UserName = connectionProfile.User; - connectParams.Connection.Password = connectionProfile.Password; - connectParams.Connection.MaxPoolSize = 200; + + if (!string.IsNullOrEmpty(connectionProfile.Database)) + { + connectParams.Connection.DatabaseName = connectionProfile.Database; + connectParams.Connection.DatabaseDisplayName = connectionProfile.Database; + } + if (!string.IsNullOrEmpty(connectionProfile.User)) + { + connectParams.Connection.UserName = connectionProfile.User; + } + if (!string.IsNullOrEmpty(connectionProfile.Password)) + { + connectParams.Connection.Password = connectionProfile.Password; + } + connectParams.Connection.AuthenticationType = connectionProfile.AuthenticationType.ToString(); + connectParams.Connection.MaxPoolSize = 200; + + if (!string.IsNullOrEmpty(connectionProfile.Encrypt)) + { + connectParams.Connection.Encrypt = connectionProfile.Encrypt; + } + else + { + connectParams.Connection.Encrypt = SqlConnectionEncryptOption.Optional.ToString(); + } + + if (!string.IsNullOrEmpty(connectionProfile.HostNameInCertificate)) + { + connectParams.Connection.HostNameInCertificate = connectionProfile.HostNameInCertificate; + } + if (!string.IsNullOrEmpty(databaseName)) { connectParams.Connection.DatabaseName = databaseName; @@ -176,9 +196,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common if (key == SqlAzureInstanceKey || key == SqlAzureInstanceKey) { connectParams.Connection.ConnectTimeout = 30; - connectParams.Connection.Encrypt = true; + connectParams.Connection.Encrypt = SqlConnectionEncryptOption.Mandatory.ToString(); connectParams.Connection.TrustServerCertificate = false; } + return connectParams; } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs index c66ea43a..ba68753c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionDetailsTests.cs @@ -25,6 +25,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection var expectedForStrings = default(string); var expectedForInt = default(int?); var expectedForBoolean = default(bool?); + var expectedEncryptOption = default(string?); Assert.AreEqual(details.ApplicationIntent, expectedForStrings); Assert.AreEqual(details.ApplicationName, expectedForStrings); @@ -48,13 +49,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.AreEqual(details.ColumnEncryptionSetting, expectedForStrings); Assert.AreEqual(details.EnclaveAttestationUrl, expectedForStrings); Assert.AreEqual(details.EnclaveAttestationProtocol, expectedForStrings); - Assert.AreEqual(details.Encrypt, expectedForBoolean); + Assert.AreEqual(details.Encrypt, expectedEncryptOption); Assert.AreEqual(details.MultipleActiveResultSets, expectedForBoolean); Assert.AreEqual(details.MultiSubnetFailover, expectedForBoolean); Assert.AreEqual(details.PersistSecurityInfo, expectedForBoolean); Assert.AreEqual(details.Pooling, expectedForBoolean); Assert.AreEqual(details.Replication, expectedForBoolean); Assert.AreEqual(details.TrustServerCertificate, expectedForBoolean); + Assert.AreEqual(details.HostNameInCertificate, expectedForStrings); Assert.AreEqual(details.Port, expectedForInt); } @@ -88,13 +90,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection details.ColumnEncryptionSetting = expectedForStrings + index++; details.EnclaveAttestationProtocol = expectedForStrings + index++; details.EnclaveAttestationUrl = expectedForStrings + index++; - details.Encrypt = (index++ % 2 == 0); + details.Encrypt = expectedForStrings + index++; details.MultipleActiveResultSets = (index++ % 2 == 0); details.MultiSubnetFailover = (index++ % 2 == 0); details.PersistSecurityInfo = (index++ % 2 == 0); details.Pooling = (index++ % 2 == 0); details.Replication = (index++ % 2 == 0); details.TrustServerCertificate = (index++ % 2 == 0); + details.HostNameInCertificate = expectedForStrings + index++; details.Port = expectedForInt + index++; index = 0; @@ -120,13 +123,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.AreEqual(details.ColumnEncryptionSetting, expectedForStrings + index++); Assert.AreEqual(details.EnclaveAttestationProtocol, expectedForStrings + index++); Assert.AreEqual(details.EnclaveAttestationUrl, expectedForStrings + index++); - Assert.AreEqual(details.Encrypt, (index++ % 2 == 0)); + Assert.AreEqual(details.Encrypt, expectedForStrings + index++); Assert.AreEqual(details.MultipleActiveResultSets, (index++ % 2 == 0)); Assert.AreEqual(details.MultiSubnetFailover, (index++ % 2 == 0)); Assert.AreEqual(details.PersistSecurityInfo, (index++ % 2 == 0)); Assert.AreEqual(details.Pooling, (index++ % 2 == 0)); Assert.AreEqual(details.Replication, (index++ % 2 == 0)); Assert.AreEqual(details.TrustServerCertificate, (index++ % 2 == 0)); + Assert.AreEqual(details.HostNameInCertificate, expectedForStrings + index++); Assert.AreEqual(details.Port, (expectedForInt + index++)); } @@ -161,16 +165,17 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection details.ColumnEncryptionSetting = expectedForStrings + index++; details.EnclaveAttestationProtocol = expectedForStrings + index++; details.EnclaveAttestationUrl = expectedForStrings + index++; - details.Encrypt = (index++ % 2 == 0); + details.Encrypt = expectedForStrings + index++; details.MultipleActiveResultSets = (index++ % 2 == 0); details.MultiSubnetFailover = (index++ % 2 == 0); details.PersistSecurityInfo = (index++ % 2 == 0); details.Pooling = (index++ % 2 == 0); details.Replication = (index++ % 2 == 0); details.TrustServerCertificate = (index++ % 2 == 0); + details.HostNameInCertificate = expectedForStrings + index++; details.Port = expectedForInt + index++; - if(optionMetadata.Options.Count() != details.Options.Count) + if (optionMetadata.Options.Count() != details.Options.Count) { var optionsNotInMetadata = details.Options.Where(o => !optionMetadata.Options.Any(m => m.Name == o.Key)); var optionNames = optionsNotInMetadata.Any() ? optionsNotInMetadata.Select(s => s.Key).Aggregate((i, j) => i + "," + j) : null; @@ -180,7 +185,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { var metadata = optionMetadata.Options.FirstOrDefault(x => x.Name == option.Key); Assert.NotNull(metadata); - if(metadata.ValueType == ConnectionOption.ValueTypeString) + if (metadata.ValueType == ConnectionOption.ValueTypeString) { Assert.True(option.Value is string); } @@ -200,7 +205,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection public void SettingConnectiomTimeoutToLongShouldStillReturnInt() { ConnectionDetails details = new ConnectionDetails(); - + long timeout = 30; int? expectedValue = 30; details.Options["connectTimeout"] = timeout; @@ -226,44 +231,21 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection } [Test] - public void SettingEncryptToStringShouldStillReturnBoolean() + public void SettingEncrypShouldReturnExpectedEncryptOption() { ConnectionDetails details = new ConnectionDetails(); + details.Options["Encrypt"] = true.ToString(); + Assert.That(details.Encrypt, Is.EqualTo(true.ToString()), "Encrypt should be Mandatory."); - string encrypt = "True"; - bool? expectedValue = true; - details.Options["encrypt"] = encrypt; - - Assert.AreEqual(details.Encrypt, expectedValue); + details.Options["Encrypt"] = "Strict"; + Assert.That(details.Encrypt, Is.EqualTo("Strict"), "Encrypt should be Strict."); } [Test] - public void SettingEncryptToLowecaseStringShouldStillReturnBoolean() + public void EncryptShouldReturnMandatoryIfNotSet() { ConnectionDetails details = new ConnectionDetails(); - - string encrypt = "true"; - bool? expectedValue = true; - details.Options["encrypt"] = encrypt; - - Assert.AreEqual(details.Encrypt, expectedValue); - } - - [Test] - public void EncryptShouldReturnNullIfNotSet() - { - ConnectionDetails details = new ConnectionDetails(); - bool? expectedValue = null; - Assert.AreEqual(details.Encrypt, expectedValue); - } - - [Test] - public void EncryptShouldReturnNullIfSetToNull() - { - ConnectionDetails details = new ConnectionDetails(); - details.Options["encrypt"] = null; - int? expectedValue = null; - Assert.AreEqual(details.ConnectTimeout, expectedValue); + Assert.That(details.Encrypt, Is.Null, "Encrypt should be null when set to null"); } [Test] @@ -273,11 +255,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection long timeout = long.MaxValue; int? expectedValue = null; + string? expectedEncryptValue = "Mandatory"; details.Options["connectTimeout"] = timeout; - details.Options["encrypt"] = true; + details.Options["encrypt"] = expectedEncryptValue; - Assert.AreEqual(details.ConnectTimeout, expectedValue); - Assert.AreEqual(true, details.Encrypt); + Assert.That(details.ConnectTimeout, Is.EqualTo(expectedValue), "Connect Timeout not as expected"); + Assert.That(details.Encrypt, Is.EqualTo("Mandatory"), "Encrypt should be mandatory."); } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index b3bc8026..d073c36c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -273,7 +273,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection /// provided as a parameter. /// [Test] - public async Task CanConnectWithEmptyDatabaseName([Values(null, "")]string databaseName) + public async Task CanConnectWithEmptyDatabaseName([Values(null, "")] string databaseName) { // Connect var connectionDetails = TestObjects.GetTestConnectionDetails(); @@ -294,7 +294,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection /// provided as a parameter. /// [Test] - public async Task ConnectToDefaultDatabaseRespondsWithActualDbName([Values("master", "nonMasterDb")]string expectedDbName) + public async Task ConnectToDefaultDatabaseRespondsWithActualDbName([Values("master", "nonMasterDb")] string expectedDbName) { // Given connecting with empty database name will return the expected DB name var connectionMock = new Mock { CallBase = true }; @@ -440,11 +440,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection new object[] {"Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456"}, new object[] {"Integrated", "", "my-server", "test", "sa", "123456"}, new object[] {"Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456"} - }; + }; /// /// Verify that when connecting with invalid parameters, an error is thrown. /// - [Test, TestCaseSource(nameof(invalidParameters))] + [Test, TestCaseSource(nameof(invalidParameters))] public async Task ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password) { // Connect with invalid parameters @@ -521,8 +521,9 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { new object[] {"AuthenticationType", "Integrated", "Integrated Security" }, new object[] {"AuthenticationType", "SqlLogin", ""}, - new object[] {"Encrypt", true, "Encrypt"}, - new object[] {"Encrypt", false, "Encrypt"}, + new object[] {"Encrypt", "Mandatory", "Encrypt"}, + new object[] {"Encrypt", "Optional", "Encrypt"}, + new object[] {"Encrypt", "Strict", "Encrypt"}, new object[] {"ColumnEncryptionSetting", "Enabled", "Column Encryption Setting=Enabled"}, new object[] {"ColumnEncryptionSetting", "Disabled", "Column Encryption Setting=Disabled"}, new object[] {"ColumnEncryptionSetting", "enabled", "Column Encryption Setting=Enabled"}, @@ -533,6 +534,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection new object[] {"ColumnEncryptionSetting", "DiSaBlEd", "Column Encryption Setting=Disabled"}, new object[] {"TrustServerCertificate", true, "Trust Server Certificate"}, new object[] {"TrustServerCertificate", false, "Trust Server Certificate"}, + new object[] {"HostNameInCertificate", "hostname", "Host Name In Certificate"}, new object[] {"PersistSecurityInfo", true, "Persist Security Info"}, new object[] {"PersistSecurityInfo", false, "Persist Security Info"}, new object[] {"ConnectTimeout", 15, "Connect Timeout"}, @@ -603,7 +605,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection .SetValue(details, propertyValue); // Test that a connection string can be created without exceptions - string connectionString = ConnectionService.BuildConnectionString(details); + string connectionString = ConnectionService.BuildConnectionString(details); Assert.That(connectionString, Contains.Substring(connectionStringMarker), "Verify that the parameter is in the connection string"); } @@ -613,6 +615,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection new object[] {"ColumnEncryptionSetting", "NotAValidColumnEncryptionSetting" }, new object[] {"EnclaveAttestationProtocol", "NotAValidEnclaveAttestationProtocol" }, }; + /// /// Build connection string with an invalid property type /// @@ -625,7 +628,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.Throws(() => ConnectionService.BuildConnectionString(details)); } - private static readonly Tuple[][] optionCombos = + private static readonly Tuple[][] optionCombos = { new [] { @@ -647,6 +650,29 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection } }; + private static readonly object[] EncryptionCombinations = + { + new object[] { SqlConnectionEncryptOption.Optional, SqlConnectionEncryptOption.Optional }, + new object[] { SqlConnectionEncryptOption.Mandatory, SqlConnectionEncryptOption.Mandatory }, + new object[] { SqlConnectionEncryptOption.Strict, SqlConnectionEncryptOption.Strict }, + }; + + /// + /// Verify that Strict Encryption parameters can be built into a connection string for connecting. + /// + [Test, TestCaseSource(nameof(EncryptionCombinations))] + public void ConnectingWithStrictEncryptionBuildsConnectionString(SqlConnectionEncryptOption encryptValue, SqlConnectionEncryptOption expected) + { + // Create a test connection details object and set the property to a specific value + ConnectionDetails details = TestObjects.GetTestConnectionDetails(); + details.Encrypt = encryptValue.ToString(); + + // Test that a connection string can be created without exceptions + string connectionString = ConnectionService.BuildConnectionString(details); + + Assert.That(connectionString, Contains.Substring("Encrypt=" + expected.ToString()), "Encrypt not as expected."); + } + /// /// Build connection string with an invalid property combinations /// @@ -1137,10 +1163,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection /// Test that the connection summary comparer creates a hash code correctly /// [Test] - public void TestConnectionSummaryComparerHashCode([Values]bool objectNull, - [Values(null, "server")]string serverName, - [Values(null, "test")]string databaseName, - [Values(null, "sa")]string userName) + public void TestConnectionSummaryComparerHashCode([Values] bool objectNull, + [Values(null, "server")] string serverName, + [Values(null, "test")] string databaseName, + [Values(null, "sa")] string userName) { // Given a connection summary and comparer object ConnectionSummary summary = null; @@ -1341,13 +1367,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection } [Test] - public async Task GetOrOpenNullOwnerUri([Values(null, "")]string ownerUri) + public async Task GetOrOpenNullOwnerUri([Values(null, "")] string ownerUri) { // If: I have a connection service and I ask for a connection with an invalid ownerUri // Then: An exception should be thrown var service = TestObjects.GetTestConnectionService(); - Assert.ThrowsAsync( - () => service.GetOrOpenConnection(ownerUri, ConnectionType.Default)); + Assert.ThrowsAsync( + () => service.GetOrOpenConnection(ownerUri, ConnectionType.Default)); } [Test] @@ -1643,19 +1669,43 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // If we make a connection to a live database ConnectionService service = ConnectionService.Instance; - var connectionString = "Server=tcp:{servername},1433;Initial Catalog={databasename};Persist Security Info=False;User ID={your_username};Password={your_password};MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;"; + var connectionString = "Server=tcp:{servername},1433;Initial Catalog={databasename};Persist Security Info=False;User ID={your_username};Password={your_password};MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;HostNameInCertificate={servername}"; var details = service.ParseConnectionString(connectionString); + Assert.That(details.ServerName, Is.EqualTo("tcp:{servername},1433"), "Unexpected server name"); + Assert.That(details.DatabaseName, Is.EqualTo("{databasename}"), "Unexpected database name"); + Assert.That(details.UserName, Is.EqualTo("{your_username}"), "Unexpected username"); + Assert.That(details.Password, Is.EqualTo("{your_password}"), "Unexpected password"); + Assert.That(details.PersistSecurityInfo, Is.False, "Unexpected Persist Security Info"); + Assert.That(details.MultipleActiveResultSets, Is.False, "Unexpected Multiple Active Result Sets value"); + Assert.That(details.Encrypt, Is.EqualTo("True"), "Unexpected Encrypt value"); + Assert.That(details.TrustServerCertificate, Is.False, "Unexpected database name value"); + Assert.That(details.HostNameInCertificate, Is.EqualTo("{servername}"), "Unexpected Host Name in Certificate value"); + Assert.That(details.ConnectTimeout, Is.EqualTo(30), "Unexpected Connect Timeout value"); + } - Assert.AreEqual("tcp:{servername},1433", details.ServerName); - Assert.AreEqual("{databasename}", details.DatabaseName); - Assert.AreEqual("{your_username}", details.UserName); - Assert.AreEqual("{your_password}", details.Password); - Assert.AreEqual(false, details.PersistSecurityInfo); - Assert.AreEqual(false, details.MultipleActiveResultSets); - Assert.AreEqual(true, details.Encrypt); - Assert.AreEqual(false, details.TrustServerCertificate); - Assert.AreEqual(30, details.ConnectTimeout); + /// + /// Test ParseConnectionString + /// + [Test] + public void ParseConnectionStringTest_StrictEncryption() + { + // If we make a connection to a live database + ConnectionService service = ConnectionService.Instance; + + var connectionString = "Server=tcp:{servername},1433;Initial Catalog={databasename};Persist Security Info=False;User ID={your_username};Password={your_password};MultipleActiveResultSets=False;Encrypt=Strict;TrustServerCertificate=False;Connection Timeout=30;HostNameInCertificate={servername}"; + + var details = service.ParseConnectionString(connectionString); + Assert.That(details.ServerName, Is.EqualTo("tcp:{servername},1433"), "Unexpected server name"); + Assert.That(details.DatabaseName, Is.EqualTo("{databasename}"), "Unexpected database name"); + Assert.That(details.UserName, Is.EqualTo("{your_username}"), "Unexpected username"); + Assert.That(details.Password, Is.EqualTo("{your_password}"), "Unexpected password"); + Assert.That(details.PersistSecurityInfo, Is.False, "Unexpected Persist Security Info"); + Assert.That(details.MultipleActiveResultSets, Is.False, "Unexpected Multiple Active Result Sets value"); + Assert.That(details.Encrypt, Is.EqualTo(SqlConnectionEncryptOption.Strict.ToString()), "Unexpected Encrypt value"); + Assert.That(details.TrustServerCertificate, Is.False, "Unexpected database name value"); + Assert.That(details.HostNameInCertificate, Is.EqualTo("{servername}"), "Unexpected Host Name in Certificate value"); + Assert.That(details.ConnectTimeout, Is.EqualTo(30), "Unexpected Connect Timeout value"); } [Test] diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Fakes/FakeDataFactory.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Fakes/FakeDataFactory.cs index 56fb8500..562331e9 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Fakes/FakeDataFactory.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Fakes/FakeDataFactory.cs @@ -169,7 +169,6 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ResourceProvider.Fakes // connectionStringBuilder.ConnectTimeout = 123; // connectionStringBuilder.Encrypt = true; // connectionStringBuilder.ApplicationIntent = ApplicationIntent.ReadWrite; - // connectionStringBuilder.AsynchronousProcessing = true; // connectionStringBuilder.MaxPoolSize = 45; // connectionStringBuilder.MinPoolSize = 3; // connectionStringBuilder.PacketSize = 600;