diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs index 62e22fb4..c5e5888e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/UpdateManagement/RowUpdate.cs @@ -3,11 +3,13 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; +using System.Globalization; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -30,6 +32,14 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement private const string UpdateScript = "UPDATE {0} SET {1} {2}"; private const string UpdateScriptMemOptimized = "UPDATE {0} WITH (SNAPSHOT) SET {1} {2}"; private const string SelectStatement = "SELECT {0} FROM {1}"; + private string validateUpdateOnlyOneRow = "DECLARE @numberOfRows int = 0;" + Environment.NewLine + + "Select @numberOfRows = count(*) FROM {0} {1} " + Environment.NewLine + + "IF (@numberOfRows > 1) " + Environment.NewLine + + "Begin" + Environment.NewLine + + " DECLARE @error NVARCHAR(100) = N'The row value(s) updated do not make the row unique or they alter multiple rows(' + CAST(@numberOfRows as varchar(10)) + ' rows)';" + Environment.NewLine + + " RAISERROR (@error, 16, 1) " + Environment.NewLine + + "End" + Environment.NewLine + + "ELSE BEGIN" + Environment.NewLine; internal readonly ConcurrentDictionary cellUpdates; private readonly IList associatedRow; @@ -77,7 +87,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement public override DbCommand GetCommand(DbConnection connection) { Validate.IsNotNull(nameof(connection), connection); - + // Process the cells and columns List declareColumns = new List(); List inParameters = new List(); @@ -120,6 +130,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement string.Join(", ", outClauseColumns), tempTableName, whereClause.CommandText); + + + string validateScript = string.Format(CultureInfo.InvariantCulture, validateUpdateOnlyOneRow, + AssociatedObjectMetadata.EscapedMultipartName, + whereClause.CommandText); // Step 3) Build the select statement string selectStatement = string.Format(SelectStatement, string.Join(", ", selectColumns), tempTableName); @@ -127,8 +142,10 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement // Step 4) Put it all together into a results object StringBuilder query = new StringBuilder(); query.AppendLine(declareStatement); + query.AppendLine(validateScript); query.AppendLine(updateStatement); - query.Append(selectStatement); + query.AppendLine(selectStatement); + query.Append("END"); // Build the command DbCommand command = connection.CreateCommand(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs index 55c6a50e..09a6a8fc 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/EditData/RowUpdateTests.cs @@ -271,7 +271,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData // ... Validate the command's makeup // Break the query into parts string[] splitSql = cmd.CommandText.Split(Environment.NewLine); - Assert.Equal(3, splitSql.Length); + Assert.True(splitSql.Length >= 3); // Check the declare statement first Regex declareRegex = new Regex(@"^DECLARE @(.+) TABLE \((.+)\)$"); @@ -291,7 +291,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData ? @"^UPDATE (.+) WITH \(SNAPSHOT\) SET (.+) OUTPUT (.+) INTO @(.+) WHERE .+$" : @"^UPDATE (.+) SET (.+) OUTPUT (.+) INTO @(.+) WHERE .+$"; Regex updateRegex = new Regex(regex); - Match updateMatch = updateRegex.Match(splitSql[1]); + Match updateMatch = updateRegex.Match(splitSql[10]); Assert.True(updateMatch.Success); // Table name matches @@ -313,7 +313,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.EditData // Check the select statement last Regex selectRegex = new Regex(@"^SELECT (.+) FROM @(.+)$"); - Match selectMatch = selectRegex.Match(splitSql[2]); + Match selectMatch = selectRegex.Match(splitSql[11]); Assert.True(selectMatch.Success); // Correct number of columns in select statement