From b7cffa3194c0cc6552c47c25fca9c719766f7c2e Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Wed, 10 Jan 2018 14:02:07 -0800 Subject: [PATCH] Edit Data: Fix to work with tables with triggers (#576) * Moving logic for adding default values to new rows * Fixing implementation of script generation to handle default values all around * Unit tests! * WIP * Reworking row create script/command generation to work more cleanly and work on triggered tables * Addressing some bugs with the create row implementation * Implementing the trigger table fix for row updates Some small improvements to the create/update tests. --- .../EditData/UpdateManagement/RowCreate.cs | 230 +++++++-------- .../EditData/UpdateManagement/RowUpdate.cs | 122 ++++---- .../SqlScriptFormatters/ToSqlScript.cs | 41 ++- .../EditData/Common.cs | 1 + .../EditData/RowCreateTests.cs | 65 ++++- .../EditData/RowUpdateTests.cs | 264 +++++++++++------- .../UtilityTests/ToSqlScriptTests.cs | 97 +++---- 7 files changed, 483 insertions(+), 337 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs index a3e0511d..7b6c6484 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowCreate.cs @@ -24,11 +24,12 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement /// public sealed class RowCreate : RowEditBase { - private const string InsertScriptStart = "INSERT INTO {0}"; - private const string InsertScriptColumns = "({0})"; - private const string InsertScriptOut = " OUTPUT {0}"; - private const string InsertScriptDefault = " DEFAULT VALUES"; - private const string InsertScriptValues = " VALUES ({0})"; + private const string DeclareStatement = "DECLARE {0} TABLE ({1})"; + private const string InsertOutputDefaultStatement = "INSERT INTO {0} OUTPUT {1} INTO {2} DEFAULT VALUES"; + private const string InsertOutputValuesStatement = "INSERT INTO {0}({1}) OUTPUT {2} INTO {3} VALUES ({4})"; + private const string InsertScriptDefaultStatement = "INSERT INTO {0} DEFAULT VALUES"; + private const string InsertScriptValuesStatement = "INSERT INTO {0}({1}) VALUES ({2})"; + private const string SelectStatement = "SELECT {0} FROM {1}"; internal readonly CellUpdate[] newCells; @@ -88,13 +89,72 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement { Validate.IsNotNull(nameof(connection), connection); - // Build the script and generate a command - ScriptBuildResult result = BuildInsertScript(forCommand: true); + // Process the cells and columns + List declareColumns = new List(); + List inColumnNames = new List(); + List outClauseColumnNames = new List(); + List inValues = new List(); + List inParameters = new List(); + List selectColumns = new List(); + for(int i = 0; i < AssociatedObjectMetadata.Columns.Length; i++) + { + DbColumnWrapper column = AssociatedResultSet.Columns[i]; + EditColumnMetadata metadata = AssociatedObjectMetadata.Columns[i]; + CellUpdate cell = newCells[i]; + + // Add the output columns regardless of whether the column is read only + outClauseColumnNames.Add($"inserted.{metadata.EscapedName}"); + declareColumns.Add($"{metadata.EscapedName} {ToSqlScript.FormatColumnType(column, useSemanticEquivalent: true)}"); + selectColumns.Add(metadata.EscapedName); + + // Continue if we're not inserting a value for this column + if (!IsCellValueProvided(column, cell, DefaultValues[i])) + { + continue; + } + + // Add the input column + inColumnNames.Add(metadata.EscapedName); + + // Add the input values as parameters + string paramName = $"@Value{RowId}_{i}"; + inValues.Add(paramName); + inParameters.Add(new SqlParameter(paramName, column.SqlDbType) {Value = cell.Value}); + } + // Put everything together into a single query + // Step 1) Build a temp table for inserting output values into + string tempTableName = $"@Insert{RowId}Output"; + string declareStatement = string.Format(DeclareStatement, tempTableName, string.Join(", ", declareColumns)); + + // Step 2) Build the insert statement + string joinedOutClauseNames = string.Join(", ", outClauseColumnNames); + string insertStatement = inValues.Count > 0 + ? string.Format(InsertOutputValuesStatement, + AssociatedObjectMetadata.EscapedMultipartName, + string.Join(", ", inColumnNames), + joinedOutClauseNames, + tempTableName, + string.Join(", ", inValues)) + : string.Format(InsertOutputDefaultStatement, + AssociatedObjectMetadata.EscapedMultipartName, + joinedOutClauseNames, + tempTableName); + + // Step 3) Build the select statement + string selectStatement = string.Format(SelectStatement, string.Join(", ", selectColumns), tempTableName); + + // Step 4) Put it all together into a results object + StringBuilder query = new StringBuilder(); + query.AppendLine(declareStatement); + query.AppendLine(insertStatement); + query.Append(selectStatement); + + // Build the command DbCommand command = connection.CreateCommand(); - command.CommandText = result.ScriptText; + command.CommandText = query.ToString(); command.CommandType = CommandType.Text; - command.Parameters.AddRange(result.ScriptParameters); + command.Parameters.AddRange(inParameters.ToArray()); return command; } @@ -123,7 +183,32 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement /// INSERT INTO statement public override string GetScript() { - return BuildInsertScript(forCommand: false).ScriptText; + // Process the cells and columns + List inColumns = new List(); + List inValues = new List(); + for (int i = 0; i < AssociatedObjectMetadata.Columns.Length; i++) + { + DbColumnWrapper column = AssociatedResultSet.Columns[i]; + CellUpdate cell = newCells[i]; + + // Continue if we're not inserting a value for this column + if (!IsCellValueProvided(column, cell, DefaultValues[i])) + { + continue; + } + + // Column is provided + inColumns.Add(AssociatedObjectMetadata.Columns[i].EscapedName); + inValues.Add(ToSqlScript.FormatValue(cell.AsDbCellValue, column)); + } + + // Build the insert statement + return inValues.Count > 0 + ? string.Format(InsertScriptValuesStatement, + AssociatedObjectMetadata.EscapedMultipartName, + string.Join(", ", inColumns), + string.Join(", ", inValues)) + : string.Format(InsertScriptDefaultStatement, AssociatedObjectMetadata.EscapedMultipartName); } /// @@ -173,111 +258,40 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement #endregion /// - /// Generates an INSERT script that will insert this row + /// Verifies the column and cell, ensuring a column that needs a value has one. /// - /// - /// If true the script will be generated with an OUTPUT clause for returning all - /// values in the inserted row (including computed values). The script will also generate - /// parameters for inserting the values. - /// If false the script will not have an OUTPUT clause and will have the values - /// directly inserted into the script (with proper escaping, of course). - /// - /// A script build result object with the script text and any parameters + /// Column that will be inserted into + /// Current cell value for this row + /// Default value for the column in this row /// - /// Thrown if there are columns that are not readonly, do not have default values, and were - /// not assigned values. + /// Thrown if the column needs a value but it is not provided /// - private ScriptBuildResult BuildInsertScript(bool forCommand) - { - // Process all the columns in this table - List inValues = new List(); - List inColumns = new List(); - List outColumns = new List(); - List sqlParameters = new List(); - for (int i = 0; i < AssociatedObjectMetadata.Columns.Length; i++) + /// + /// true If the column has a value provided + /// false If the column does not have a value provided (column is read-only, has default, etc) + /// + private static bool IsCellValueProvided(DbColumnWrapper column, CellUpdate cell, string defaultCell) + { + // Skip columns that cannot be updated + if (!column.IsUpdatable) { - DbColumnWrapper column = AssociatedResultSet.Columns[i]; - CellUpdate cell = newCells[i]; + return false; + } - // Add an out column if we're doing this for a command - if (forCommand) + // Make sure a value was provided for the cell + if (cell == null) + { + // If the column is not nullable and there is not default defined, then fail + if (!column.AllowDBNull.HasTrue() && defaultCell == null) { - outColumns.Add($"inserted.{ToSqlScript.FormatIdentifier(column.ColumnName)}"); + throw new InvalidOperationException(SR.EditDataCreateScriptMissingValue(column.ColumnName)); } - - // Skip columns that cannot be updated - if (!column.IsUpdatable) - { - continue; - } - - // Make sure a value was provided for the cell - if (cell == null) - { - // If the column is not nullable and there is no default defined, then fail - if (!column.AllowDBNull.HasTrue() && DefaultValues[i] == null) - { - throw new InvalidOperationException(SR.EditDataCreateScriptMissingValue(column.ColumnName)); - } - // There is a default value (or omitting the value is fine), so trust the db will apply it correctly - continue; - } - - // Add the input values - if (forCommand) - { - // Since this script is for command use, add parameter for the input value to the list - string paramName = $"@Value{RowId}_{i}"; - inValues.Add(paramName); - - SqlParameter param = new SqlParameter(paramName, cell.Column.SqlDbType) {Value = cell.Value}; - sqlParameters.Add(param); - } - else - { - // This script isn't for command use, add the value, formatted for insertion - inValues.Add(ToSqlScript.FormatValue(cell.Value, column)); - } - - // Add the column to the in columns - inColumns.Add(ToSqlScript.FormatIdentifier(column.ColumnName)); - } - - // Begin the script (ie, INSERT INTO blah) - StringBuilder queryBuilder = new StringBuilder(); - queryBuilder.AppendFormat(InsertScriptStart, AssociatedObjectMetadata.EscapedMultipartName); - - // Add the input columns (if there are any) - if (inColumns.Count > 0) - { - string joinedInColumns = string.Join(", ", inColumns); - queryBuilder.AppendFormat(InsertScriptColumns, joinedInColumns); - } - - // Add the output columns (this will be empty if we are not building for command) - if (outColumns.Count > 0) - { - string joinedOutColumns = string.Join(", ", outColumns); - queryBuilder.AppendFormat(InsertScriptOut, joinedOutColumns); - } - - // Add the input values (if there any) or use the default values - if (inValues.Count > 0) - { - string joinedInValues = string.Join(", ", inValues); - queryBuilder.AppendFormat(InsertScriptValues, joinedInValues); - } - else - { - queryBuilder.AppendFormat(InsertScriptDefault); + // There is a default value (or omitting the value is fine), so trust the db will apply it correctly + return false; } - return new ScriptBuildResult - { - ScriptText = queryBuilder.ToString(), - ScriptParameters = sqlParameters.ToArray() - }; + return true; } private EditCell GetEditCell(CellUpdate cell, int index) @@ -301,11 +315,5 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement } return new EditCell(dbCell, isDirty: true); } - - private class ScriptBuildResult - { - public string ScriptText { get; set; } - public SqlParameter[] ScriptParameters { get; set; } - } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs index 529b040c..615e266d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs @@ -9,6 +9,7 @@ using System.Data; using System.Data.Common; using System.Data.SqlClient; using System.Linq; +using System.Text; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution; @@ -23,11 +24,12 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement /// public sealed class RowUpdate : RowEditBase { - private const string UpdateScriptStart = @"UPDATE {0}"; - private const string UpdateScriptStartMemOptimized = @"UPDATE {0} WITH (SNAPSHOT)"; - - private const string UpdateScript = @"{0} SET {1} {2}"; - private const string UpdateScriptOutput = @"{0} SET {1} OUTPUT {2} {3}"; + private const string DeclareStatement = "DECLARE {0} TABLE ({1})"; + private const string UpdateOutput = "UPDATE {0} SET {1} OUTPUT {2} INTO {3} {4}"; + private const string UpdateOutputMemOptimized = "UPDATE {0} WITH (SNAPSHOT) SET {1} OUTPUT {2} INTO {3} {4}"; + private const string UpdateScript = "UPDATE {0} SET {1} {2}"; + private const string UpdateScriptMemOptimized = "UPDATE {0} WITH (SNAPSHOT) SET {1} {2}"; + private const string SelectStatement = "SELECT {0} FROM {1}"; internal readonly ConcurrentDictionary cellUpdates; private readonly IList associatedRow; @@ -75,40 +77,66 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement public override DbCommand GetCommand(DbConnection connection) { Validate.IsNotNull(nameof(connection), connection); - DbCommand command = connection.CreateCommand(); - - // Build the "SET" portion of the statement + + // Process the cells and columns + List declareColumns = new List(); + List inParameters = new List(); List setComponents = new List(); - foreach (var updateElement in cellUpdates) + List outClauseColumns = new List(); + List selectColumns = new List(); + for (int i = 0; i < AssociatedObjectMetadata.Columns.Length; i++) { - string formattedColumnName = ToSqlScript.FormatIdentifier(updateElement.Value.Column.ColumnName); - string paramName = $"@Value{RowId}_{updateElement.Key}"; - setComponents.Add($"{formattedColumnName} = {paramName}"); - SqlParameter parameter = new SqlParameter(paramName, updateElement.Value.Column.SqlDbType) + EditColumnMetadata metadata = AssociatedObjectMetadata.Columns[i]; + + // Add the output columns regardless of whether the column is read only + declareColumns.Add($"{metadata.EscapedName} {ToSqlScript.FormatColumnType(metadata.DbColumn, useSemanticEquivalent: true)}"); + outClauseColumns.Add($"inserted.{metadata.EscapedName}"); + selectColumns.Add(metadata.EscapedName); + + // If we have a new value for the column, proccess it now + CellUpdate cellUpdate; + if (cellUpdates.TryGetValue(i, out cellUpdate)) { - Value = updateElement.Value.Value - }; - command.Parameters.Add(parameter); + string paramName = $"@Value{RowId}_{i}"; + setComponents.Add($"{metadata.EscapedName} = {paramName}"); + inParameters.Add(new SqlParameter(paramName, AssociatedResultSet.Columns[i].SqlDbType) {Value = cellUpdate.Value}); + } } - string setComponentsJoined = string.Join(", ", setComponents); - - // Build the "OUTPUT" portion of the statement - var outColumns = from c in AssociatedResultSet.Columns - let formatted = ToSqlScript.FormatIdentifier(c.ColumnName) - select $"inserted.{formatted}"; - string outColumnsJoined = string.Join(", ", outColumns); - - // Get the where clause - WhereClause where = GetWhereClause(true); - command.Parameters.AddRange(where.Parameters.ToArray()); - - // Get the start of the statement - string statementStart = GetStatementStart(); - - // Put the whole #! together - command.CommandText = string.Format(UpdateScriptOutput, statementStart, setComponentsJoined, - outColumnsJoined, where.CommandText); + + // Put everything together into a single query + // Step 1) Build a temp table for inserting output values into + string tempTableName = $"@Update{RowId}Output"; + string declareStatement = string.Format(DeclareStatement, tempTableName, string.Join(", ", declareColumns)); + + // Step 2) Build the update statement + WhereClause whereClause = GetWhereClause(true); + + string updateStatementFormat = AssociatedObjectMetadata.IsMemoryOptimized + ? UpdateOutputMemOptimized + : UpdateOutput; + string updateStatement = string.Format(updateStatementFormat, + AssociatedObjectMetadata.EscapedMultipartName, + string.Join(", ", setComponents), + string.Join(", ", outClauseColumns), + tempTableName, + whereClause.CommandText); + + // Step 3) Build the select statement + string selectStatement = string.Format(SelectStatement, string.Join(", ", selectColumns), tempTableName); + + // Step 4) Put it all together into a results object + StringBuilder query = new StringBuilder(); + query.AppendLine(declareStatement); + query.AppendLine(updateStatement); + query.Append(selectStatement); + + // Build the command + DbCommand command = connection.CreateCommand(); + command.CommandText = query.ToString(); command.CommandType = CommandType.Text; + command.Parameters.AddRange(inParameters.ToArray()); + command.Parameters.AddRange(whereClause.Parameters.ToArray()); + return command; } @@ -153,15 +181,18 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement return $"{formattedColumnName} = {formattedValue}"; }); string setClause = string.Join(", ", setComponents); - - // Get the where clause + + // Put everything together into a single query string whereClause = GetWhereClause(false).CommandText; + string updateStatementFormat = AssociatedObjectMetadata.IsMemoryOptimized + ? UpdateScriptMemOptimized + : UpdateScript; - // Get the start of the statement - string statementStart = GetStatementStart(); - - // Put the whole #! together - return string.Format(UpdateScript, statementStart, setClause, whereClause); + return string.Format(updateStatementFormat, + AssociatedObjectMetadata.EscapedMultipartName, + setClause, + whereClause + ); } /// @@ -226,14 +257,5 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement } #endregion - - private string GetStatementStart() - { - string formatString = AssociatedObjectMetadata.IsMemoryOptimized - ? UpdateScriptStartMemOptimized - : UpdateScriptStart; - - return string.Format(formatString, AssociatedObjectMetadata.EscapedMultipartName); - } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatters/ToSqlScript.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatters/ToSqlScript.cs index 0698ceac..38c4e4ff 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatters/ToSqlScript.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/SqlScriptFormatters/ToSqlScript.cs @@ -1,4 +1,4 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // @@ -75,23 +75,36 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters /// /// /// - public static string FormatColumnType(DbColumn column) + public static string FormatColumnType(DbColumn column, bool useSemanticEquivalent = false) { string typeName = column.DataTypeName.ToUpperInvariant(); // TODO: This doesn't support UDTs at all. // TODO: It's unclear if this will work on a case-sensitive db collation + // Strip any unecessary info from the front certain types + if (typeName.EndsWith("HIERARCHYID") || typeName.EndsWith("GEOGRAPHY") || typeName.EndsWith("GEOMETRY")) + { + string[] typeNameComponents = typeName.Split("."); + typeName = typeNameComponents[typeNameComponents.Length - 1]; + } + + // Replace timestamp columns with semantic equivalent if requested + if (useSemanticEquivalent && typeName == "TIMESTAMP") + { + typeName = "VARBINARY(8)"; + } + // If the type supports length parameters, the add those - switch (column.DataTypeName.ToLowerInvariant()) + switch (typeName) { // Types with length - case "char": - case "nchar": - case "varchar": - case "nvarchar": - case "binary": - case "varbinary": + case "CHAR": + case "NCHAR": + case "VARCHAR": + case "NVARCHAR": + case "BINARY": + case "VARBINARY": if (!column.ColumnSize.HasValue) { throw new InvalidOperationException(SR.SqlScriptFormatterLengthTypeMissingSize); @@ -105,8 +118,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters break; // Types with precision and scale - case "numeric": - case "decimal": + case "NUMERIC": + case "DECIMAL": if (!column.NumericPrecision.HasValue || !column.NumericScale.HasValue) { throw new InvalidOperationException(SR.SqlScriptFormatterDecimalMissingPrecision); @@ -115,9 +128,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters break; // Types with scale only - case "datetime2": - case "datetimeoffset": - case "time": + case "DATETIME2": + case "DATETIMEOFFSET": + case "TIME": if (!column.NumericScale.HasValue) { throw new InvalidOperationException(SR.SqlScriptFormatterScalarTypeMissingScale); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/Common.cs index c5a8b30d..aacadad8 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/Common.cs @@ -8,6 +8,7 @@ using System.Data.Common; using System.Linq; using System.Threading; using System.Threading.Tasks; +using Castle.Components.DictionaryAdapter; using Microsoft.SqlTools.ServiceLayer.EditData; using Microsoft.SqlTools.ServiceLayer.EditData.Contracts; using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement; diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowCreateTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowCreateTests.cs index 668626f4..a9185a3c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowCreateTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowCreateTests.cs @@ -306,45 +306,84 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData private static void ValidateCommandAgainstRegex(string sql, RegexExpectedOutput expectedOutput) { + // Break the query into parts + string[] splitSql = sql.Split(Environment.NewLine); + Assert.Equal(3, splitSql.Length); + + // Check the declare statement first + Regex declareRegex = new Regex(@"^DECLARE @(.+) TABLE \((.+)\)$"); + Match declareMatch = declareRegex.Match(splitSql[0]); + Assert.True(declareMatch.Success); + + // Declared table name matches + Assert.True(declareMatch.Groups[1].Value.StartsWith("Insert")); + Assert.True(declareMatch.Groups[1].Value.EndsWith("Output")); + + // Correct number of columns in declared table + string[] declareCols = declareMatch.Groups[2].Value.Split(", "); + Assert.Equal(expectedOutput.ExpectedOutColumns, declareCols.Length); + + // Check the insert statement in the middle if (expectedOutput.ExpectedInColumns == 0 || expectedOutput.ExpectedInValues == 0) { // If expected output was null make sure we match the default values reges - Regex r = new Regex(@"INSERT INTO (.+) OUTPUT (.+) DEFAULT VALUES"); - Match m = r.Match(sql); - Assert.True(m.Success); + Regex insertRegex = new Regex(@"^INSERT INTO (.+) OUTPUT (.+) INTO @(.+) DEFAULT VALUES$"); + Match insertMatch = insertRegex.Match(splitSql[1]); + Assert.True(insertMatch.Success); // Table name matches - Assert.Equal(Common.TableName, m.Groups[1].Value); + Assert.Equal(Common.TableName, insertMatch.Groups[1].Value); // Output columns match - string[] outCols = m.Groups[2].Value.Split(", "); + string[] outCols = insertMatch.Groups[2].Value.Split(", "); Assert.Equal(expectedOutput.ExpectedOutColumns, outCols.Length); Assert.All(outCols, col => Assert.StartsWith("inserted.", col)); + + // Output table name matches + Assert.StartsWith("Insert", insertMatch.Groups[3].Value); + Assert.EndsWith("Output", insertMatch.Groups[3].Value); } else { // Do the whole validation - Regex r = new Regex(@"INSERT INTO (.+)\((.+)\) OUTPUT (.+) VALUES \((.+)\)"); - Match m = r.Match(sql); - Assert.True(m.Success); + Regex insertRegex = new Regex(@"^INSERT INTO (.+)\((.+)\) OUTPUT (.+) INTO @(.+) VALUES \((.+)\)$"); + Match insertMatch = insertRegex.Match(splitSql[1]); + Assert.True(insertMatch.Success); // Table name matches - Assert.Equal(Common.TableName, m.Groups[1].Value); + Assert.Equal(Common.TableName, insertMatch.Groups[1].Value); // Output columns match - string[] outCols = m.Groups[3].Value.Split(", "); + string[] outCols = insertMatch.Groups[3].Value.Split(", "); Assert.Equal(expectedOutput.ExpectedOutColumns, outCols.Length); Assert.All(outCols, col => Assert.StartsWith("inserted.", col)); // In columns match - string[] inCols = m.Groups[2].Value.Split(", "); + string[] inCols = insertMatch.Groups[2].Value.Split(", "); Assert.Equal(expectedOutput.ExpectedInColumns, inCols.Length); + // Output table name matches + Assert.StartsWith("Insert", insertMatch.Groups[4].Value); + Assert.EndsWith("Output", insertMatch.Groups[4].Value); + // In values match - string[] inVals = m.Groups[4].Value.Split(", "); + string[] inVals = insertMatch.Groups[5].Value.Split(", "); Assert.Equal(expectedOutput.ExpectedInValues, inVals.Length); - Assert.All(inVals, val => Assert.Matches(@"@.+\d+", val)); + Assert.All(inVals, val => Assert.Matches(@"@.+\d+_\d+", val)); } + + // Check the select statement last + Regex selectRegex = new Regex(@"^SELECT (.+) FROM @(.+)$"); + Match selectMatch = selectRegex.Match(splitSql[2]); + Assert.True(selectMatch.Success); + + // Correct number of columns in declared table + string[] selectCols = selectMatch.Groups[1].Value.Split(", "); + Assert.Equal(expectedOutput.ExpectedOutColumns, selectCols.Length); + + // Declared table name matches + Assert.True(selectMatch.Groups[2].Value.StartsWith("Insert")); + Assert.True(selectMatch.Groups[2].Value.EndsWith("Output")); } #endregion diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs index 71e81992..55c6a50e 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs @@ -38,84 +38,21 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData Assert.Equal(data.TableMetadata, rc.AssociatedObjectMetadata); } - [Fact] - public async Task SetCell() + #region SetCell Tests + + [Theory] + [InlineData(-1)] // Negative + [InlineData(3)] // At edge of acceptable values + [InlineData(100)] // Way too large value + public async Task SetCellOutOfRange(int columnId) { - // Setup: Create a row update + // Setup: Generate a row create RowUpdate ru = await GetStandardRowUpdate(); - // If: I set a cell that can be updated - EditUpdateCellResult eucr = ru.SetCell(0, "col1"); - - // Then: - // ... A edit cell was returned - Assert.NotNull(eucr); - Assert.NotNull(eucr.Cell); - - // ... The new value we provided should be returned - Assert.Equal("col1", eucr.Cell.DisplayValue); - Assert.False(eucr.Cell.IsNull); - - // ... The row is still dirty - Assert.True(eucr.IsRowDirty); - - // ... The cell should be dirty - Assert.True(eucr.Cell.IsDirty); - - // ... There should be a cell update in the cell list - Assert.Contains(0, ru.cellUpdates.Keys); - Assert.NotNull(ru.cellUpdates[0]); + // If: I attempt to set a cell on a column that is out of range, I should get an exception + Assert.Throws(() => ru.SetCell(columnId, string.Empty)); } - - [Fact] - public void SetCellHasCorrections() - { - // Setup: - // ... Generate a result set with a single binary column - DbColumn[] cols = - { - new TestDbColumn - { - DataType = typeof(byte[]), - DataTypeName = "binary" - } - }; - object[][] rows = { new object[]{new byte[] {0x00}}}; - var testResultSet = new TestResultSet(cols, rows); - var testReader = new TestDbDataReader(new[] { testResultSet }, false); - var rs = new ResultSet(0, 0, MemoryFileSystem.GetFileStreamFactory()); - rs.ReadResultToEnd(testReader, CancellationToken.None).Wait(); - - // ... Generate the metadata - var etm = Common.GetCustomEditTableMetadata(cols); - - // ... Create the row update - RowUpdate ru = new RowUpdate(0, rs, etm); - - // If: I set a cell in the newly created row to something that will be corrected - EditUpdateCellResult eucr = ru.SetCell(0, "1000"); - - // Then: - // ... A edit cell was returned - Assert.NotNull(eucr); - Assert.NotNull(eucr.Cell); - - // ... The value we used won't be returned - Assert.NotEmpty(eucr.Cell.DisplayValue); - Assert.NotEqual("1000", eucr.Cell.DisplayValue); - Assert.False(eucr.Cell.IsNull); - - // ... The cell should be dirty - Assert.True(eucr.Cell.IsDirty); - - // ... The row is still dirty - Assert.True(eucr.IsRowDirty); - - // ... There should be a cell update in the cell list - Assert.Contains(0, ru.cellUpdates.Keys); - Assert.NotNull(ru.cellUpdates[0]); - } - + [Fact] public async Task SetCellImplicitRevertTest() { @@ -189,6 +126,86 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData // TODO: Make sure that the script and command things will return null } + + [Fact] + public void SetCellHasCorrections() + { + // Setup: + // ... Generate a result set with a single binary column + DbColumn[] cols = + { + new TestDbColumn + { + DataType = typeof(byte[]), + DataTypeName = "binary" + } + }; + object[][] rows = { new object[]{new byte[] {0x00}}}; + var testResultSet = new TestResultSet(cols, rows); + var testReader = new TestDbDataReader(new[] { testResultSet }, false); + var rs = new ResultSet(0, 0, MemoryFileSystem.GetFileStreamFactory()); + rs.ReadResultToEnd(testReader, CancellationToken.None).Wait(); + + // ... Generate the metadata + var etm = Common.GetCustomEditTableMetadata(cols); + + // ... Create the row update + RowUpdate ru = new RowUpdate(0, rs, etm); + + // If: I set a cell in the newly created row to something that will be corrected + EditUpdateCellResult eucr = ru.SetCell(0, "1000"); + + // Then: + // ... A edit cell was returned + Assert.NotNull(eucr); + Assert.NotNull(eucr.Cell); + + // ... The value we used won't be returned + Assert.NotEmpty(eucr.Cell.DisplayValue); + Assert.NotEqual("1000", eucr.Cell.DisplayValue); + Assert.False(eucr.Cell.IsNull); + + // ... The cell should be dirty + Assert.True(eucr.Cell.IsDirty); + + // ... The row is still dirty + Assert.True(eucr.IsRowDirty); + + // ... There should be a cell update in the cell list + Assert.Contains(0, ru.cellUpdates.Keys); + Assert.NotNull(ru.cellUpdates[0]); + } + + [Fact] + public async Task SetCell() + { + // Setup: Create a row update + RowUpdate ru = await GetStandardRowUpdate(); + + // If: I set a cell that can be updated + EditUpdateCellResult eucr = ru.SetCell(0, "col1"); + + // Then: + // ... A edit cell was returned + Assert.NotNull(eucr); + Assert.NotNull(eucr.Cell); + + // ... The new value we provided should be returned + Assert.Equal("col1", eucr.Cell.DisplayValue); + Assert.False(eucr.Cell.IsNull); + + // ... The row is still dirty + Assert.True(eucr.IsRowDirty); + + // ... The cell should be dirty + Assert.True(eucr.Cell.IsDirty); + + // ... There should be a cell update in the cell list + Assert.Contains(0, ru.cellUpdates.Keys); + Assert.NotNull(ru.cellUpdates[0]); + } + + #endregion [Theory] [InlineData(true)] @@ -224,6 +241,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData Assert.Equal(3, updateSplit.Length); Assert.All(updateSplit, s => Assert.Equal(2, s.Split('=').Length)); } + + #region GetCommand Tests [Theory] [InlineData(true, true)] @@ -249,37 +268,66 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData // ... The command should not be null Assert.NotNull(cmd); + // ... Validate the command's makeup + // Break the query into parts + string[] splitSql = cmd.CommandText.Split(Environment.NewLine); + Assert.Equal(3, splitSql.Length); + + // Check the declare statement first + Regex declareRegex = new Regex(@"^DECLARE @(.+) TABLE \((.+)\)$"); + Match declareMatch = declareRegex.Match(splitSql[0]); + Assert.True(declareMatch.Success); + + // Declared table name matches + Assert.True(declareMatch.Groups[1].Value.StartsWith("Update")); + Assert.True(declareMatch.Groups[1].Value.EndsWith("Output")); + + // Correct number of columns in declared table + string[] declareCols = declareMatch.Groups[2].Value.Split(", "); + Assert.Equal(rs.Columns.Length, declareCols.Length); + + // Check the update statement in the middle + string regex = isMemoryOptimized + ? @"^UPDATE (.+) WITH \(SNAPSHOT\) SET (.+) OUTPUT (.+) INTO @(.+) WHERE .+$" + : @"^UPDATE (.+) SET (.+) OUTPUT (.+) INTO @(.+) WHERE .+$"; + Regex updateRegex = new Regex(regex); + Match updateMatch = updateRegex.Match(splitSql[1]); + Assert.True(updateMatch.Success); + + // Table name matches + Assert.Equal(Common.TableName, updateMatch.Groups[1].Value); + + // Output columns match + string[] outCols = updateMatch.Groups[3].Value.Split(", "); + Assert.Equal(rs.Columns.Length, outCols.Length); + Assert.All(outCols, col => Assert.StartsWith("inserted.", col)); + + // Set columns match + string[] setCols = updateMatch.Groups[2].Value.Split(", "); + Assert.Equal(3, setCols.Length); + Assert.All(setCols, s => Assert.Matches(@".+ = @Value\d+_\d+", s)); + + // Output table name matches + Assert.StartsWith("Update", updateMatch.Groups[4].Value); + Assert.EndsWith("Output", updateMatch.Groups[4].Value); + + // Check the select statement last + Regex selectRegex = new Regex(@"^SELECT (.+) FROM @(.+)$"); + Match selectMatch = selectRegex.Match(splitSql[2]); + Assert.True(selectMatch.Success); + + // Correct number of columns in select statement + string[] selectCols = selectMatch.Groups[1].Value.Split(", "); + Assert.Equal(rs.Columns.Length, selectCols.Length); + + // Select table name matches + Assert.StartsWith("Update", selectMatch.Groups[2].Value); + Assert.EndsWith("Output", selectMatch.Groups[2].Value); + // ... There should be an appropriate number of parameters in it // (1 or 3 keys, 3 value parameters) int expectedKeys = includeIdentity ? 1 : 3; Assert.Equal(expectedKeys + 3, cmd.Parameters.Count); - - // ... It should be formatted into an update script with output - string regexFormat = isMemoryOptimized - ? @"UPDATE (.+) WITH \(SNAPSHOT\) SET (.+) OUTPUT (.+) WHERE (.+)" - : @"UPDATE (.+) SET (.+) OUTPUT(.+) WHERE (.+)"; - Regex r = new Regex(regexFormat); - var m = r.Match(cmd.CommandText); - Assert.True(m.Success); - - // ... There should be a table - string tbl = m.Groups[1].Value; - Assert.Equal(data.TableMetadata.EscapedMultipartName, tbl); - - // ... There should be 3 parameters for input - string[] inCols = m.Groups[2].Value.Split(','); - Assert.Equal(3, inCols.Length); - Assert.All(inCols, s => Assert.Matches(@"\[.+\] = @Value\d+", s)); - - // ... There should be 3 OR 4 columns for output - string[] outCols = m.Groups[3].Value.Split(','); - Assert.Equal(includeIdentity ? 4 : 3, outCols.Length); - Assert.All(outCols, s => Assert.StartsWith("inserted.", s.Trim())); - - // ... There should be 1 OR 3 columns for where components - string[] whereComponents = m.Groups[4].Value.Split(new[] {"AND"}, StringSplitOptions.None); - Assert.Equal(expectedKeys, whereComponents.Length); - Assert.All(whereComponents, s => Assert.Matches(@"\(.+ = @Param\d+\)", s)); } [Fact] @@ -292,7 +340,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData // Then: It should throw an exception Assert.Throws(() => ru.GetCommand(null)); } + + #endregion + #region GetEditRow Tests + [Fact] public async Task GetEditRow() { @@ -344,6 +396,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData Assert.Throws(() => ru.GetEditRow(null)); } + #endregion + + #region ApplyChanges Tests + [Theory] [InlineData(true)] [InlineData(false)] @@ -382,6 +438,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData await Assert.ThrowsAsync(() => ru.ApplyChanges(null)); } + #endregion + + #region RevertCell Tests + [Theory] [InlineData(-1)] // Negative [InlineData(3)] // At edge of acceptable values @@ -485,6 +545,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData // ... The cell should no longer be set Assert.DoesNotContain(0, ru.cellUpdates.Keys); } + + #endregion private async Task GetStandardRowUpdate() { diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/UtilityTests/ToSqlScriptTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/UtilityTests/ToSqlScriptTests.cs index 4c18f515..bf011230 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/UtilityTests/ToSqlScriptTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/UtilityTests/ToSqlScriptTests.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Data.Common; using System.Text.RegularExpressions; @@ -312,60 +312,61 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.UtilityTests { get { - yield return new object[] {new FormatterTestDbColumn("biGint"), "BIGINT"}; - yield return new object[] {new FormatterTestDbColumn("biT"), "BIT"}; - yield return new object[] {new FormatterTestDbColumn("deCimal", precision: 18, scale: 0), "DECIMAL(18, 0)"}; - yield return new object[] {new FormatterTestDbColumn("deCimal", precision: 22, scale: 2), "DECIMAL(22, 2)"}; - yield return new object[] {new FormatterTestDbColumn("inT"), "INT"}; - yield return new object[] {new FormatterTestDbColumn("moNey"), "MONEY"}; - yield return new object[] {new FormatterTestDbColumn("nuMeric", precision: 18, scale: 0), "NUMERIC(18, 0)"}; - yield return new object[] {new FormatterTestDbColumn("nuMeric", precision: 22, scale: 2), "NUMERIC(22, 2)"}; - yield return new object[] {new FormatterTestDbColumn("smAllint"), "SMALLINT"}; - yield return new object[] {new FormatterTestDbColumn("smAllmoney"), "SMALLMONEY"}; - yield return new object[] {new FormatterTestDbColumn("tiNyint"), "TINYINT"}; - yield return new object[] {new FormatterTestDbColumn("biNary", size: 255), "BINARY(255)"}; - yield return new object[] {new FormatterTestDbColumn("biNary", size: 10), "BINARY(10)"}; - yield return new object[] {new FormatterTestDbColumn("vaRbinary", size: 255), "VARBINARY(255)"}; - yield return new object[] {new FormatterTestDbColumn("vaRbinary", size: 10), "VARBINARY(10)"}; - yield return new object[] {new FormatterTestDbColumn("vaRbinary", size: int.MaxValue), "VARBINARY(MAX)"}; - yield return new object[] {new FormatterTestDbColumn("imAge"), "IMAGE"}; - yield return new object[] {new FormatterTestDbColumn("smAlldatetime"), "SMALLDATETIME"}; - yield return new object[] {new FormatterTestDbColumn("daTetime"), "DATETIME"}; - yield return new object[] {new FormatterTestDbColumn("daTetime2", scale: 7), "DATETIME2(7)"}; - yield return new object[] {new FormatterTestDbColumn("daTetime2", scale: 0), "DATETIME2(0)"}; - yield return new object[] {new FormatterTestDbColumn("daTetimeoffset", scale: 7), "DATETIMEOFFSET(7)"}; - yield return new object[] {new FormatterTestDbColumn("daTetimeoffset", scale: 0), "DATETIMEOFFSET(0)"}; - yield return new object[] {new FormatterTestDbColumn("tiMe", scale: 7), "TIME(7)"}; - yield return new object[] {new FormatterTestDbColumn("flOat"), "FLOAT"}; - yield return new object[] {new FormatterTestDbColumn("reAl"), "REAL"}; - yield return new object[] {new FormatterTestDbColumn("chAr", size: 1), "CHAR(1)"}; - yield return new object[] {new FormatterTestDbColumn("chAr", size: 255), "CHAR(255)"}; - yield return new object[] {new FormatterTestDbColumn("ncHar", size: 1), "NCHAR(1)"}; - yield return new object[] {new FormatterTestDbColumn("ncHar", size: 255), "NCHAR(255)"}; - yield return new object[] {new FormatterTestDbColumn("vaRchar", size: 1), "VARCHAR(1)"}; - yield return new object[] {new FormatterTestDbColumn("vaRchar", size: 255), "VARCHAR(255)"}; - yield return new object[] {new FormatterTestDbColumn("vaRchar", size: int.MaxValue), "VARCHAR(MAX)"}; - yield return new object[] {new FormatterTestDbColumn("nvArchar", size: 1), "NVARCHAR(1)"}; - yield return new object[] {new FormatterTestDbColumn("nvArchar", size: 255), "NVARCHAR(255)"}; - yield return new object[] {new FormatterTestDbColumn("nvArchar", size: int.MaxValue), "NVARCHAR(MAX)"}; - yield return new object[] {new FormatterTestDbColumn("teXt"), "TEXT"}; - yield return new object[] {new FormatterTestDbColumn("nteXt"), "NTEXT"}; - yield return new object[] {new FormatterTestDbColumn("unIqueidentifier"), "UNIQUEIDENTIFIER"}; - yield return new object[] {new FormatterTestDbColumn("sqL_variant"), "SQL_VARIANT"}; - yield return new object[] {new FormatterTestDbColumn("somEthing.sys.hierarchyid"), "SOMETHING.SYS.HIERARCHYID"}; - yield return new object[] {new FormatterTestDbColumn("geOgraphy"), "GEOGRAPHY"}; - yield return new object[] {new FormatterTestDbColumn("geOmetry"), "GEOMETRY"}; - yield return new object[] {new FormatterTestDbColumn("sySname"), "SYSNAME"}; - yield return new object[] {new FormatterTestDbColumn("tiMestamp"), "TIMESTAMP"}; + yield return new object[] {false, new FormatterTestDbColumn("biGint"), "BIGINT"}; + yield return new object[] {false, new FormatterTestDbColumn("biT"), "BIT"}; + yield return new object[] {false, new FormatterTestDbColumn("deCimal", precision: 18, scale: 0), "DECIMAL(18, 0)"}; + yield return new object[] {false, new FormatterTestDbColumn("deCimal", precision: 22, scale: 2), "DECIMAL(22, 2)"}; + yield return new object[] {false, new FormatterTestDbColumn("inT"), "INT"}; + yield return new object[] {false, new FormatterTestDbColumn("moNey"), "MONEY"}; + yield return new object[] {false, new FormatterTestDbColumn("nuMeric", precision: 18, scale: 0), "NUMERIC(18, 0)"}; + yield return new object[] {false, new FormatterTestDbColumn("nuMeric", precision: 22, scale: 2), "NUMERIC(22, 2)"}; + yield return new object[] {false, new FormatterTestDbColumn("smAllint"), "SMALLINT"}; + yield return new object[] {false, new FormatterTestDbColumn("smAllmoney"), "SMALLMONEY"}; + yield return new object[] {false, new FormatterTestDbColumn("tiNyint"), "TINYINT"}; + yield return new object[] {false, new FormatterTestDbColumn("biNary", size: 255), "BINARY(255)"}; + yield return new object[] {false, new FormatterTestDbColumn("biNary", size: 10), "BINARY(10)"}; + yield return new object[] {false, new FormatterTestDbColumn("vaRbinary", size: 255), "VARBINARY(255)"}; + yield return new object[] {false, new FormatterTestDbColumn("vaRbinary", size: 10), "VARBINARY(10)"}; + yield return new object[] {false, new FormatterTestDbColumn("vaRbinary", size: int.MaxValue), "VARBINARY(MAX)"}; + yield return new object[] {false, new FormatterTestDbColumn("imAge"), "IMAGE"}; + yield return new object[] {false, new FormatterTestDbColumn("smAlldatetime"), "SMALLDATETIME"}; + yield return new object[] {false, new FormatterTestDbColumn("daTetime"), "DATETIME"}; + yield return new object[] {false, new FormatterTestDbColumn("daTetime2", scale: 7), "DATETIME2(7)"}; + yield return new object[] {false, new FormatterTestDbColumn("daTetime2", scale: 0), "DATETIME2(0)"}; + yield return new object[] {false, new FormatterTestDbColumn("daTetimeoffset", scale: 7), "DATETIMEOFFSET(7)"}; + yield return new object[] {false, new FormatterTestDbColumn("daTetimeoffset", scale: 0), "DATETIMEOFFSET(0)"}; + yield return new object[] {false, new FormatterTestDbColumn("tiMe", scale: 7), "TIME(7)"}; + yield return new object[] {false, new FormatterTestDbColumn("flOat"), "FLOAT"}; + yield return new object[] {false, new FormatterTestDbColumn("reAl"), "REAL"}; + yield return new object[] {false, new FormatterTestDbColumn("chAr", size: 1), "CHAR(1)"}; + yield return new object[] {false, new FormatterTestDbColumn("chAr", size: 255), "CHAR(255)"}; + yield return new object[] {false, new FormatterTestDbColumn("ncHar", size: 1), "NCHAR(1)"}; + yield return new object[] {false, new FormatterTestDbColumn("ncHar", size: 255), "NCHAR(255)"}; + yield return new object[] {false, new FormatterTestDbColumn("vaRchar", size: 1), "VARCHAR(1)"}; + yield return new object[] {false, new FormatterTestDbColumn("vaRchar", size: 255), "VARCHAR(255)"}; + yield return new object[] {false, new FormatterTestDbColumn("vaRchar", size: int.MaxValue), "VARCHAR(MAX)"}; + yield return new object[] {false, new FormatterTestDbColumn("nvArchar", size: 1), "NVARCHAR(1)"}; + yield return new object[] {false, new FormatterTestDbColumn("nvArchar", size: 255), "NVARCHAR(255)"}; + yield return new object[] {false, new FormatterTestDbColumn("nvArchar", size: int.MaxValue), "NVARCHAR(MAX)"}; + yield return new object[] {false, new FormatterTestDbColumn("teXt"), "TEXT"}; + yield return new object[] {false, new FormatterTestDbColumn("nteXt"), "NTEXT"}; + yield return new object[] {false, new FormatterTestDbColumn("unIqueidentifier"), "UNIQUEIDENTIFIER"}; + yield return new object[] {false, new FormatterTestDbColumn("sqL_variant"), "SQL_VARIANT"}; + yield return new object[] {false, new FormatterTestDbColumn("somEthing.sys.hierarchyid"), "HIERARCHYID"}; + yield return new object[] {false, new FormatterTestDbColumn("table.geOgraphy"), "GEOGRAPHY"}; + yield return new object[] {false, new FormatterTestDbColumn("table.geOmetry"), "GEOMETRY"}; + yield return new object[] {false, new FormatterTestDbColumn("sySname"), "SYSNAME"}; + yield return new object[] {false, new FormatterTestDbColumn("tiMestamp"), "TIMESTAMP"}; + yield return new object[] {true, new FormatterTestDbColumn("tiMestamp"), "VARBINARY(8)"}; } } [Theory] [MemberData(nameof(FormatColumnTypeData))] - public void FormatColumnType(DbColumn input, string expectedOutput) + public void FormatColumnType(bool useSemanticEquivalent, DbColumn input, string expectedOutput) { // If: I supply the input columns - string output = ToSqlScript.FormatColumnType(input); + string output = ToSqlScript.FormatColumnType(input, useSemanticEquivalent); // Then: The output should match the expected output Assert.Equal(expectedOutput, output);