diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs
index 91206df8..1f1d7fc3 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/EditSession.cs
@@ -516,12 +516,21 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
editOperations.Sort();
foreach (var editOperation in editOperations)
{
- // Get the command from the edit operation and execute it
- using (DbCommand editCommand = editOperation.GetCommand(connection))
- using (DbDataReader reader = await editCommand.ExecuteReaderAsync())
+ try
{
- // Apply the changes of the command to the result set
- await editOperation.ApplyChanges(reader);
+ // Get the command from the edit operation and execute it
+ using (DbCommand editCommand = editOperation.GetCommand(connection))
+ using (DbDataReader reader = await editCommand.ExecuteReaderAsync())
+ {
+ // Apply the changes of the command to the result set
+ await editOperation.ApplyChanges(reader);
+ }
+ }
+ catch (EditDataDeleteException)
+ {
+ //clear EditCache to allow for deletion of other rows.
+ EditCache.TryRemove(editOperation.RowId, out RowEditBase xe);
+ throw;
}
// If we succeeded in applying the changes, then remove this from the cache
diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs
index 3bbed902..7164a21e 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowDelete.cs
@@ -5,6 +5,7 @@
using System;
using System.Data.Common;
+using System.Diagnostics;
using System.Globalization;
using System.Linq;
using System.Threading.Tasks;
@@ -15,6 +16,26 @@ using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
+ ///
+ /// An error indicating that a delete action will delete multiple rows.
+ ///
+ public class EditDataDeleteException : Exception
+ {
+ public EditDataDeleteException()
+ {
+ }
+
+ public EditDataDeleteException(string message)
+ : base(message)
+ {
+ }
+
+ public EditDataDeleteException(string message, Exception inner)
+ : base(message, inner)
+ {
+ }
+ }
+
///
/// Represents a row that should be deleted. This will generate a DELETE statement
///
@@ -22,6 +43,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
private const string DeleteStatement = "DELETE FROM {0} {1}";
private const string DeleteMemoryOptimizedStatement = "DELETE FROM {0} WITH(SNAPSHOT) {1}";
+ private const string VerifyStatement = "SELECT COUNT (*) FROM ";
///
/// Constructs a new RowDelete object
@@ -66,6 +88,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
// Return a SqlCommand with formatted with the parameters from the where clause
WhereClause where = GetWhereClause(true);
string commandText = GetCommandText(where.CommandText);
+ string verifyText = GetVerifyText(where.CommandText);
+ if (!CheckForDuplicateDeleteRows(where, verifyText, connection))
+ {
+ throw new EditDataDeleteException("This action will delete more than one row!");
+ }
DbCommand command = connection.CreateCommand();
command.CommandText = commandText;
@@ -74,6 +101,38 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
return command;
}
+ ///
+ /// Runs a query using the where clause to determine if duplicates are found (causes issues when deleting).
+ /// If no duplicates are found, the check passes, else it returns false;
+ ///
+ private bool CheckForDuplicateDeleteRows(WhereClause where, string input, DbConnection connection)
+ {
+ using (DbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = input;
+ command.Parameters.AddRange(where.Parameters.ToArray());
+ using (DbDataReader reader = command.ExecuteReader())
+ {
+ try
+ {
+ while (reader.Read())
+ {
+ //If the count of the row is
+ if (reader.GetInt32(0) != 1)
+ {
+ return false;
+ }
+ }
+ }
+ catch (Exception ex)
+ {
+ Logger.Write(TraceEventType.Error, ex.ToString());
+ }
+ }
+ }
+ return true;
+ }
+
///
/// Generates a edit row that represents a row pending deletion. All the original cells are
/// intact but the state is dirty.
@@ -101,6 +160,15 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
return GetCommandText(GetWhereClause(false).CommandText);
}
+ ///
+ /// Generates a WHERE statement to verify the row delete is unique.
+ ///
+ /// String of the WHERE statement
+ public string GetVerifyScript()
+ {
+ return GetVerifyText(GetWhereClause(false).CommandText);
+ }
+
///
/// This method should not be called. A cell cannot be reverted on a row that is pending
/// deletion.
@@ -131,6 +199,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
return RowId.CompareTo(rowEdit.RowId) * -1;
}
+ private string GetVerifyText(string whereText)
+ {
+ return $"{VerifyStatement}{AssociatedObjectMetadata.EscapedMultipartName} {whereText}";
+ }
+
private string GetCommandText(string whereText)
{
string formatString = AssociatedObjectMetadata.IsMemoryOptimized
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs
index 8784e9e4..5da6dcbd 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowDeleteTests.cs
@@ -193,6 +193,56 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData
Assert.Throws(() => rd.RevertCell(0));
}
+ [Fact]
+ public async Task GetVerifyQuery()
+ {
+ // Setup: Create a row update and set the first row cell to have values
+ // ... other than "1" for testing purposes (simulated select query result).
+ Common.TestDbColumnsWithTableMetadata data = new Common.TestDbColumnsWithTableMetadata(false, false, 0, 0);
+ var rs = await Common.GetResultSet(data.DbColumns, false);
+ RowUpdate ru = new RowUpdate(0, rs, data.TableMetadata);
+ object[][] rows =
+ {
+ new object[] {"2", "0", "0"},
+ };
+ var testResultSet = new TestResultSet(data.DbColumns, rows);
+ var newRowReader = new TestDbDataReader(new[] { testResultSet }, false);
+ await ru.ApplyChanges(newRowReader);
+
+ // ... Create a row delete.
+ RowDelete rd = new RowDelete(0, rs, data.TableMetadata);
+ int expectedKeys = 3;
+
+ // If: I generate a verify command
+ String verifyCommand = rd.GetVerifyScript();
+
+ // Then:
+ // ... The command should not be null
+ Assert.NotNull(verifyCommand);
+
+ // ... It should be formatted into an where script
+ string regexTest = @"SELECT COUNT \(\*\) FROM (.+) WHERE (.+)";
+ Regex r = new Regex(regexTest);
+ var m = r.Match(verifyCommand);
+ Assert.True(m.Success);
+
+ // ... There should be a table
+ string tbl = m.Groups[1].Value;
+ Assert.Equal(data.TableMetadata.EscapedMultipartName, tbl);
+
+ // ... There should be as many where components as there are keys
+ string[] whereComponents = m.Groups[2].Value.Split(new[] { "AND" }, StringSplitOptions.None);
+ Assert.Equal(expectedKeys, whereComponents.Length);
+
+ // ... Mock db connection for building the command
+ var mockConn = new TestSqlConnection(new[] { testResultSet });
+
+ // If: I attempt to get a command for a simulated delete of a row with duplicates.
+ // Then: The Command will throw an exception as it detects there are
+ // ... 2 or more rows with the same value in the simulated query results data.
+ Assert.Throws(() => rd.GetCommand(mockConn));
+ }
+
private async Task GetStandardRowDelete()
{
Common.TestDbColumnsWithTableMetadata data = new Common.TestDbColumnsWithTableMetadata(false, false, 0, 0);
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs
index 1cc989cd..c517f198 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs
@@ -219,7 +219,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
public override int GetInt32(int ordinal)
{
- throw new NotImplementedException();
+ string allChars = ((string) RowEnumerator.Current[ordinal]);
+ int x = 0;
+ if(allChars.Length != 1 || !Int32.TryParse(allChars.ToString(), out x) )
+ {
+ throw new InvalidCastException();
+ }
+ return x;
}
public override short GetInt16(int ordinal)