diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs index 7164a21e..2b074348 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs @@ -89,9 +89,9 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement WhereClause where = GetWhereClause(true); string commandText = GetCommandText(where.CommandText); string verifyText = GetVerifyText(where.CommandText); - if (!CheckForDuplicateDeleteRows(where, verifyText, connection)) + if (HasDuplicateRows(where, verifyText, connection)) { - throw new EditDataDeleteException("This action will delete more than one row!"); + throw new EditDataDeleteException("Cannot delete: Action will delete more than one row"); } DbCommand command = connection.CreateCommand(); @@ -103,34 +103,23 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement /// /// Runs a query using the where clause to determine if duplicates are found (causes issues when deleting). - /// If no duplicates are found, the check passes, else it returns false; + /// If duplicates are found, the check returns true, else it returns false; /// - private bool CheckForDuplicateDeleteRows(WhereClause where, string input, DbConnection connection) + private bool HasDuplicateRows(WhereClause where, string input, DbConnection connection) { using (DbCommand command = connection.CreateCommand()) { command.CommandText = input; command.Parameters.AddRange(where.Parameters.ToArray()); - using (DbDataReader reader = command.ExecuteReader()) + try { - try - { - while (reader.Read()) - { - //If the count of the row is - if (reader.GetInt32(0) != 1) - { - return false; - } - } - } - catch (Exception ex) - { - Logger.Write(TraceEventType.Error, ex.ToString()); - } + return (Convert.ToInt32(command.ExecuteScalar())) > 1; + } + finally + { + command.Parameters.Clear(); } } - return true; } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs index 75f53673..36143312 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs @@ -83,7 +83,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData RowDelete rd = new RowDelete(0, rs, data.TableMetadata); // ... Mock db connection for building the command - var mockConn = new TestSqlConnection(null); + var mockConn = new TestEditDataSqlConnection(null); // If: I attempt to get a command for the edit DbCommand cmd = rd.GetCommand(mockConn); @@ -235,7 +235,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData Assert.AreEqual(expectedKeys, whereComponents.Length); // ... Mock db connection for building the command - var mockConn = new TestSqlConnection(new[] { testResultSet }); + var mockConn = new TestEditDataSqlConnection(new[] { testResultSet }); // If: I attempt to get a command for a simulated delete of a row with duplicates. // Then: The Command will throw an exception as it detects there are diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index f4630cd4..3997983a 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -190,6 +190,63 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility private List listParams = new List(); } + /// + /// Test mock class for IDbCommand (Modified for Edit Data) + /// + public class TestEditDataSqlCommand : TestSqlCommand + { + internal TestEditDataSqlCommand(TestResultSet[] data) : base(data) + { + + } + + /// + /// Function to check for duplicates in Data's rows. + /// Returns the max duplicate count for a unique row value (int). + /// + public override object ExecuteScalar() + { + if (Data != null) + { + //Get row data and set up row count map. + object[] rowData = Data[0].Rows[0]; + Dictionary rowCountMap = new Dictionary(); + + //Go through each row value. + foreach (object rowValue in rowData) + { + if (rowCountMap.ContainsKey(rowValue)) + { + // Add to existing count + rowCountMap[rowValue] += 1; + } + else + { + // New unique value found, add to map with 1 count. + rowCountMap.Add(rowValue, 1); + } + } + + // Find the greatest number of duplicates among unique values + // in the map and return it. + int maxCount = 0; + foreach (var rowCount in rowCountMap) + { + if (rowCount.Value > maxCount) + { + maxCount = rowCount.Value; + } + } + return maxCount; + } + else + { + // Return 0 if Data is not provided. + return 0; + } + } + } + /// /// Test mock class for SqlConnection wrapper /// @@ -270,6 +327,27 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility } } + /// + /// Test mock class for SqlConnection wrapper (Modified for Edit Data) + /// + public class TestEditDataSqlConnection : TestSqlConnection + { + public TestEditDataSqlConnection() + { + + } + + public TestEditDataSqlConnection(TestResultSet[] data) : base(data) + { + + } + + protected override DbCommand CreateDbCommand() + { + return new TestEditDataSqlCommand(Data); + } + } + /// /// Test mock class for SqlConnection factory ///