edit/commit Command (#262)

The main goal of this feature is to enable a command that will
1) Generate a parameterized command for each edit that is in the session
2) Execute that command against the server
3) Update the cached results of the table/view that's being edited with the committed changes (including computed/identity columns)

There's some secret sauce in here where I cheated around worrying about gaps in the updated results. This was accomplished by implementing an IComparable for row edit objects that ensures deletes are the *last* actions to occur and that they occur from the bottom of the list up (highest row ID to lowest). Thus, all other actions that are dependent on the row ID are performed first, then the largest row ID is deleted, then next largest, etc. Nevertheless, by the end of a commit the associated ResultSet is still the source of truth. It is expected that the results grid will need updating once changes are committed.

Also worth noting, although this pull request supports a "many edits, one commit" approach, it will work just fine for a "one edit, one commit" approach.

* WIP

* Adding basic commit support. Deletions work!

* Nailing down the commit logic, insert commits work!

* Updates work!

* Fixing bug in DbColumnWrapper IsReadOnly setting

* Comments

* ResultSet unit tests, fixing issue with seeking in mock writers

* Unit tests for RowCreate commands

* Unit tests for RowDelete

* RowUpdate unit tests

* Session and edit base tests

* Fixing broken unit tests

* Moving constants to constants file

* Addressing code review feedback

* Fixes from merge issues, string consts

* Removing ad-hoc code

* fixing as per @abist requests

* Fixing a couple more issues
This commit is contained in:
Benjamin Russell
2017-03-03 15:47:47 -08:00
committed by GitHub
parent f00136cffb
commit 52ac038ebe
44 changed files with 2546 additions and 2464 deletions

View File

@@ -0,0 +1,30 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.EditData.Contracts
{
/// <summary>
/// Parameters for a request to commit pending edit operations
/// </summary>
public class EditCommitParams : SessionOperationParams
{
}
/// <summary>
/// Parameters to return upon successful completion of commiting pending edit operations
/// </summary>
public class EditCommitResult
{
}
public class EditCommitRequest
{
public static readonly
RequestType<EditCommitParams, EditCommitResult> Type =
RequestType<EditCommitParams, EditCommitResult>.Create("edit/commit");
}
}

View File

@@ -54,8 +54,8 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
private readonly QueryExecutionService queryExecutionService;
private readonly Lazy<ConcurrentDictionary<string, Session>> editSessions = new Lazy<ConcurrentDictionary<string, Session>>(
() => new ConcurrentDictionary<string, Session>());
private readonly Lazy<ConcurrentDictionary<string, EditSession>> editSessions = new Lazy<ConcurrentDictionary<string, EditSession>>(
() => new ConcurrentDictionary<string, EditSession>());
#endregion
@@ -64,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
/// <summary>
/// Dictionary mapping OwnerURIs to active sessions
/// </summary>
internal ConcurrentDictionary<string, Session> ActiveSessions => editSessions.Value;
internal ConcurrentDictionary<string, EditSession> ActiveSessions => editSessions.Value;
#endregion
@@ -86,14 +86,14 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
#region Request Handlers
internal async Task HandleSessionRequest<TResult>(SessionOperationParams sessionParams,
RequestContext<TResult> requestContext, Func<Session, TResult> sessionOperation)
RequestContext<TResult> requestContext, Func<EditSession, TResult> sessionOperation)
{
try
{
Session session = GetActiveSessionOrThrow(sessionParams.OwnerUri);
EditSession editSession = GetActiveSessionOrThrow(sessionParams.OwnerUri);
// Get the result from execution of the session operation
TResult result = sessionOperation(session);
// Get the result from execution of the editSession operation
TResult result = sessionOperation(editSession);
await requestContext.SendResult(result);
}
catch (Exception e)
@@ -135,9 +135,9 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
// Sanity check the owner URI
Validate.IsNotNullOrWhitespaceString(nameof(disposeParams.OwnerUri), disposeParams.OwnerUri);
// Attempt to remove the session
Session session;
if (!ActiveSessions.TryRemove(disposeParams.OwnerUri, out session))
// Attempt to remove the editSession
EditSession editSession;
if (!ActiveSessions.TryRemove(disposeParams.OwnerUri, out editSession))
{
await requestContext.SendError(SR.EditDataSessionNotFound);
return;
@@ -219,6 +219,31 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
session => session.UpdateCell(updateParams.RowId, updateParams.ColumnId, updateParams.NewValue));
}
internal async Task HandleCommitRequest(EditCommitParams commitParams,
RequestContext<EditCommitResult> requestContext)
{
// Setup a callback for if the edits have been successfully written to the db
Func<Task> successHandler = () => requestContext.SendResult(new EditCommitResult());
// Setup a callback for if the edits failed to be written to db
Func<Exception, Task> failureHandler = e => requestContext.SendError(e.Message);
try
{
// Get the editSession
EditSession editSession = GetActiveSessionOrThrow(commitParams.OwnerUri);
// Get a connection for doing the committing
DbConnection conn = await connectionService.GetOrOpenConnection(commitParams.OwnerUri,
ConnectionType.Edit);
editSession.CommitEdits(conn, successHandler, failureHandler);
}
catch (Exception e)
{
await failureHandler(e);
}
}
#endregion
#region Private Helpers
@@ -229,19 +254,19 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
/// <exception cref="Exception">If the edit session doesn't exist</exception>
/// <param name="ownerUri">Owner URI for the edit session</param>
/// <returns>The edit session that corresponds to the owner URI</returns>
private Session GetActiveSessionOrThrow(string ownerUri)
private EditSession GetActiveSessionOrThrow(string ownerUri)
{
// Sanity check the owner URI is provided
Validate.IsNotNullOrWhitespaceString(nameof(ownerUri), ownerUri);
// Attempt to get the session, throw if unable
Session session;
if (!ActiveSessions.TryGetValue(ownerUri, out session))
// Attempt to get the editSession, throw if unable
EditSession editSession;
if (!ActiveSessions.TryGetValue(ownerUri, out editSession))
{
throw new Exception(SR.EditDataSessionNotFound);
}
return session;
return editSession;
}
private async Task QueryCompleteCallback(Query query, EditInitializeParams initParams,
@@ -254,19 +279,19 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
try
{
// Validate the query for a session
ResultSet resultSet = Session.ValidateQueryForSession(query);
// Validate the query for a editSession
ResultSet resultSet = EditSession.ValidateQueryForSession(query);
// Get a connection we'll use for SMO metadata lookup (and committing, later on)
DbConnection conn = await connectionService.GetOrOpenConnection(initParams.OwnerUri, ConnectionType.Edit);
var metadata = metadataFactory.GetObjectMetadata(conn, resultSet.Columns,
initParams.ObjectName, initParams.ObjectType);
// Create the session and add it to the sessions list
Session session = new Session(resultSet, metadata);
if (!ActiveSessions.TryAdd(initParams.OwnerUri, session))
// Create the editSession and add it to the sessions list
EditSession editSession = new EditSession(resultSet, metadata);
if (!ActiveSessions.TryAdd(initParams.OwnerUri, editSession))
{
throw new InvalidOperationException("Failed to create edit session, session already exists.");
throw new InvalidOperationException("Failed to create edit editSession, editSession already exists.");
}
readyParams.Success = true;
}

View File

@@ -5,8 +5,10 @@
using System;
using System.Collections.Concurrent;
using System.Data.Common;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.EditData.Contracts;
using Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
@@ -18,22 +20,18 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
/// Represents an edit "session" bound to the results of a query, containing a cache of edits
/// that are pending. Provides logic for performing edit operations.
/// </summary>
public class Session
public class EditSession
{
#region Member Variables
private readonly ResultSet associatedResultSet;
private readonly IEditTableMetadata objectMetadata;
#endregion
/// <summary>
/// Constructs a new edit session bound to the result set and metadat object provided
/// </summary>
/// <param name="resultSet">The result set of the table to be edited</param>
/// <param name="objMetadata">Metadata provider for the table to be edited</param>
public Session(ResultSet resultSet, IEditTableMetadata objMetadata)
public EditSession(ResultSet resultSet, IEditTableMetadata objMetadata)
{
Validate.IsNotNull(nameof(resultSet), resultSet);
Validate.IsNotNull(nameof(objMetadata), objMetadata);
@@ -47,6 +45,12 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
#region Properties
/// <summary>
/// The task that is running to commit the changes to the db
/// Internal for unit test purposes.
/// </summary>
internal Task CommitTask { get; set; }
/// <summary>
/// The internal ID for the next row in the table. Internal for unit testing purposes only.
/// </summary>
@@ -55,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
/// <summary>
/// The cache of pending updates. Internal for unit test purposes only
/// </summary>
internal ConcurrentDictionary<long, RowEditBase> EditCache { get;}
internal ConcurrentDictionary<long, RowEditBase> EditCache { get; }
#endregion
@@ -109,6 +113,29 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
return newRowId;
}
/// <summary>
/// Commits the edits in the cache to the database and then to the associated result set of
/// this edit session. This is launched asynchronously.
/// </summary>
/// <param name="connection">The connection to use for executing the query</param>
/// <param name="successHandler">Callback to perform when the commit process has finished</param>
/// <param name="errorHandler">Callback to perform if the commit process has failed at some point</param>
public void CommitEdits(DbConnection connection, Func<Task> successHandler, Func<Exception, Task> errorHandler)
{
Validate.IsNotNull(nameof(connection), connection);
Validate.IsNotNull(nameof(successHandler), successHandler);
Validate.IsNotNull(nameof(errorHandler), errorHandler);
// Make sure that there isn't a commit task in progress
if (CommitTask != null && !CommitTask.IsCompleted)
{
throw new InvalidOperationException(SR.EditDataCommitInProgress);
}
// Start up the commit process
CommitTask = CommitEditsInternal(connection, successHandler, errorHandler);
}
/// <summary>
/// Creates a delete row update and adds it to the update cache
/// </summary>
@@ -149,6 +176,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
}
}
/// <summary>
/// Generates a single script file with all the pending edits scripted.
/// </summary>
/// <param name="outputPath">The path to output the script to</param>
/// <returns></returns>
public string ScriptEdits(string outputPath)
{
// Validate the output path
@@ -203,7 +235,10 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
// Attempt to get the row that is being edited, create a new update object if one
// doesn't exist
RowEditBase editRow = EditCache.GetOrAdd(rowId, new RowUpdate(rowId, associatedResultSet, objectMetadata));
// NOTE: This *must* be done as a lambda. RowUpdate creation requires that the row
// exist in the result set. We only want a new RowUpdate to be created if the edit
// doesn't already exist in the cache
RowEditBase editRow = EditCache.GetOrAdd(rowId, key => new RowUpdate(rowId, associatedResultSet, objectMetadata));
// Pass the call to the row update
return editRow.SetCell(columnId, newValue);
@@ -211,5 +246,36 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
#endregion
private async Task CommitEditsInternal(DbConnection connection, Func<Task> successHandler, Func<Exception, Task> errorHandler)
{
try
{
// @TODO: Add support for transactional commits
// Trust the RowEdit to sort itself appropriately
var editOperations = EditCache.Values.ToList();
editOperations.Sort();
foreach (var editOperation in editOperations)
{
// Get the command from the edit operation and execute it
using (DbCommand editCommand = editOperation.GetCommand(connection))
using (DbDataReader reader = await editCommand.ExecuteReaderAsync())
{
// Apply the changes of the command to the result set
await editOperation.ApplyChanges(reader);
}
// If we succeeded in applying the changes, then remove this from the cache
// @TODO: Prevent edit sessions from being modified while a commit is in progress
RowEditBase re;
EditCache.TryRemove(editOperation.RowId, out re);
}
await successHandler();
}
catch (Exception e)
{
await errorHandler(e);
}
}
}
}

View File

@@ -4,9 +4,9 @@
//
using System;
using System.Data.Common;
using System.Globalization;
using System.Text.RegularExpressions;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
@@ -26,7 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// </summary>
/// <param name="column">Column the cell will be under</param>
/// <param name="valueAsString">The string from the client to convert to an object</param>
public CellUpdate(DbColumn column, string valueAsString)
public CellUpdate(DbColumnWrapper column, string valueAsString)
{
Validate.IsNotNull(nameof(column), column);
Validate.IsNotNull(nameof(valueAsString), valueAsString);
@@ -89,7 +89,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// <summary>
/// The column that the cell will be placed in
/// </summary>
public DbColumn Column { get; }
public DbColumnWrapper Column { get; }
/// <summary>
/// The object representation of the cell provided by the client

View File

@@ -5,10 +5,16 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.EditData.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
@@ -17,7 +23,9 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// </summary>
public sealed class RowCreate : RowEditBase
{
private const string InsertStatement = "INSERT INTO {0}({1}) VALUES ({2})";
private const string InsertStart = "INSERT INTO {0}({1})";
private const string InsertCompleteScript = "{0} VALUES ({1})";
private const string InsertCompleteOutput = "{0} OUTPUT {1} VALUES ({2})";
private readonly CellUpdate[] newCells;
@@ -34,42 +42,121 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
}
/// <summary>
/// Generates the INSERT INTO statement that will apply the row creation
/// Sort ID for a RowCreate object. Setting to 1 ensures that these are the first changes
/// to be committed
/// </summary>
/// <returns>INSERT INTO statement</returns>
public override string GetScript()
{
List<string> columnNames = new List<string>();
List<string> columnValues = new List<string>();
protected override int SortId => 1;
// Build the column list and value list
#region Public Methods
/// <summary>
/// Applies the changes to the associated result set after successfully executing the
/// change on the database
/// </summary>
/// <param name="dataReader">
/// Reader returned from the execution of the command to insert a new row. Should contain
/// a single row that represents the newly added row.
/// </param>
public override Task ApplyChanges(DbDataReader dataReader)
{
Validate.IsNotNull(nameof(dataReader), dataReader);
return AssociatedResultSet.AddRow(dataReader);
}
/// <summary>
/// Generates a command that can be executed to insert a new row -- and return the newly
/// inserted row.
/// </summary>
/// <param name="connection">The connection the command should be associated with</param>
/// <returns>Command to insert the new row</returns>
public override DbCommand GetCommand(DbConnection connection)
{
Validate.IsNotNull(nameof(connection), connection);
// Process all the columns. Add the column to the output columns, add updateable
// columns to the input parameters
List<string> outColumns = new List<string>();
List<string> inColumns = new List<string>();
DbCommand command = connection.CreateCommand();
for (int i = 0; i < AssociatedResultSet.Columns.Length; i++)
{
DbColumnWrapper column = AssociatedResultSet.Columns[i];
CellUpdate cell = newCells[i];
// If the column is not updatable, then skip it
// Add the column to the output
outColumns.Add($"inserted.{SqlScriptFormatter.FormatIdentifier(column.ColumnName)}");
// Skip columns that cannot be updated
if (!column.IsUpdatable)
{
continue;
}
// If the cell doesn't have a value, but is updatable, don't try to create the script
// If we're missing a cell, then we cannot continue
if (cell == null)
{
throw new InvalidOperationException(SR.EditDataCreateScriptMissingValue);
}
// Add the column and the data to their respective lists
columnNames.Add(SqlScriptFormatter.FormatIdentifier(column.ColumnName));
columnValues.Add(SqlScriptFormatter.FormatValue(cell.Value, column));
// Create a parameter for the value and add it to the command
// Add the parameterization to the list and add it to the command
string paramName = $"@Value{RowId}{i}";
inColumns.Add(paramName);
SqlParameter param = new SqlParameter(paramName, cell.Column.SqlDbType)
{
Value = cell.Value
};
command.Parameters.Add(param);
}
string joinedInColumns = string.Join(", ", inColumns);
string joinedOutColumns = string.Join(", ", outColumns);
// Get the start clause
string start = GetTableClause();
// Put together the components of the statement
string joinedColumnNames = string.Join(", ", columnNames);
string joinedColumnValues = string.Join(", ", columnValues);
return string.Format(InsertStatement, AssociatedObjectMetadata.EscapedMultipartName, joinedColumnNames,
joinedColumnValues);
// Put the whole #! together
command.CommandText = string.Format(InsertCompleteOutput, start, joinedOutColumns, joinedInColumns);
command.CommandType = CommandType.Text;
return command;
}
/// <summary>
/// Generates the INSERT INTO statement that will apply the row creation
/// </summary>
/// <returns>INSERT INTO statement</returns>
public override string GetScript()
{
// Process all the cells, and generate the values
List<string> values = new List<string>();
for (int i = 0; i < AssociatedResultSet.Columns.Length; i++)
{
DbColumnWrapper column = AssociatedResultSet.Columns[i];
CellUpdate cell = newCells[i];
// Skip columns that cannot be updated
if (!column.IsUpdatable)
{
continue;
}
// If we're missing a cell, then we cannot continue
if (cell == null)
{
throw new InvalidOperationException(SR.EditDataCreateScriptMissingValue);
}
// Format the value and add it to the list
values.Add(SqlScriptFormatter.FormatValue(cell.Value, column));
}
string joinedValues = string.Join(", ", values);
// Get the start clause
string start = GetTableClause();
// Put the whole #! together
return string.Format(InsertCompleteScript, start, joinedValues);
}
/// <summary>
@@ -99,5 +186,19 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
};
return eucr;
}
#endregion
private string GetTableClause()
{
// Get all the columns that will be provided
var inColumns = from c in AssociatedResultSet.Columns
where c.IsUpdatable
select SqlScriptFormatter.FormatIdentifier(c.ColumnName);
// Package it into a single INSERT statement starter
string inColumnsJoined = string.Join(", ", inColumns);
return string.Format(InsertStart, AssociatedObjectMetadata.EscapedMultipartName, inColumnsJoined);
}
}
}

View File

@@ -4,9 +4,12 @@
//
using System;
using System.Data.Common;
using System.Globalization;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.EditData.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
@@ -29,15 +32,53 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
}
/// <summary>
/// Sort ID for a RowDelete object. Setting to 2 ensures that these are the LAST changes
/// to be committed
/// </summary>
protected override int SortId => 2;
/// <summary>
/// Applies the changes to the associated result set after successfully executing the
/// change on the database
/// </summary>
/// <param name="dataReader">
/// Reader returned from the execution of the command to insert a new row. Should NOT
/// contain any rows.
/// </param>
public override Task ApplyChanges(DbDataReader dataReader)
{
// Take the result set and remove the row from it
AssociatedResultSet.RemoveRow(RowId);
return Task.FromResult(0);
}
/// <summary>
/// Generates a command for deleting the selected row
/// </summary>
/// <returns></returns>
public override DbCommand GetCommand(DbConnection connection)
{
Validate.IsNotNull(nameof(connection), connection);
// Return a SqlCommand with formatted with the parameters from the where clause
WhereClause where = GetWhereClause(true);
string commandText = GetCommandText(where.CommandText);
DbCommand command = connection.CreateCommand();
command.CommandText = commandText;
command.Parameters.AddRange(where.Parameters.ToArray());
return command;
}
/// <summary>
/// Generates a DELETE statement to delete this row
/// </summary>
/// <returns>String of the DELETE statement</returns>
public override string GetScript()
{
string formatString = AssociatedObjectMetadata.IsMemoryOptimized ? DeleteMemoryOptimizedStatement : DeleteStatement;
return string.Format(CultureInfo.InvariantCulture, formatString,
AssociatedObjectMetadata.EscapedMultipartName, GetWhereClause(false).CommandText);
return GetCommandText(GetWhereClause(false).CommandText);
}
/// <summary>
@@ -51,5 +92,23 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
throw new InvalidOperationException(SR.EditDataDeleteSetCell);
}
protected override int CompareToSameType(RowEditBase rowEdit)
{
// We want to sort by row ID *IN REVERSE* to make sure we delete from the bottom first.
// If we delete from the top first, it will change IDs, making all subsequent deletes
// off by one or more!
return RowId.CompareTo(rowEdit.RowId) * -1;
}
private string GetCommandText(string whereText)
{
string formatString = AssociatedObjectMetadata.IsMemoryOptimized
? DeleteMemoryOptimizedStatement
: DeleteStatement;
return string.Format(CultureInfo.InvariantCulture, formatString,
AssociatedObjectMetadata.EscapedMultipartName, whereText);
}
}
}

View File

@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.EditData.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
@@ -18,9 +19,10 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// <summary>
/// Base class for row edit operations. Provides basic information and helper functionality
/// that all RowEdit implementations can use. Defines functionality that must be implemented
/// in all child classes.
/// in all child classes. Implements a custom IComparable to enable sorting by type of the edit
/// and then by an overrideable
/// </summary>
public abstract class RowEditBase
public abstract class RowEditBase : IComparable<RowEditBase>
{
/// <summary>
/// Internal parameterless constructor, required for mocking
@@ -58,8 +60,31 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// </summary>
public IEditTableMetadata AssociatedObjectMetadata { get; }
/// <summary>
/// Sort ID for a row edit. Ensures that when a collection of RowEditBase objects are
/// sorted, the appropriate types are sorted to the top.
/// </summary>
protected abstract int SortId { get; }
#endregion
#region Abstract Methods
/// <summary>
/// Applies the changes to the associated result set
/// </summary>
/// <param name="dataReader">
/// Data reader from execution of the command to commit the change to the db
/// </param>
public abstract Task ApplyChanges(DbDataReader dataReader);
/// <summary>
/// Gets a command that will commit the change to the db
/// </summary>
/// <param name="connection">The connection to associate the command to</param>
/// <returns>Command to commit the change to the db</returns>
public abstract DbCommand GetCommand(DbConnection connection);
/// <summary>
/// Converts the row edit into a SQL statement
/// </summary>
@@ -74,6 +99,10 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// <returns>The value of the cell after applying validation logic</returns>
public abstract EditUpdateCellResult SetCell(int columnId, string newValue);
#endregion
#region Protected Helper Methods
/// <summary>
/// Performs validation of column ID and if column can be updated.
/// </summary>
@@ -146,7 +175,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
// we execute multiple row edits at once.
string paramName = $"@Param{RowId}{col.Ordinal}";
cellDataClause = $"= {paramName}";
output.Parameters.Add(new SqlParameter(paramName, col.DbColumn.SqlDbType));
SqlParameter parameter = new SqlParameter(paramName, col.DbColumn.SqlDbType)
{
Value = cellData.RawObject
};
output.Parameters.Add(parameter);
}
else
{
@@ -163,6 +196,66 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
return output;
}
#endregion
#region IComparable Implementation
/// <summary>
/// Compares a row edit against another row edit. If they are the same type, then we
/// compare using an overrideable "same type" comparer. If they are different types, they
/// are sorted by their sort indexes.
///
/// In general, RowCreate and RowUpdates are sorted to the top. RowDeletes are sorted last.
/// If there are ties, default behavior is to sort by row ID ascending.
/// </summary>
/// <param name="other">The other row edit to compare against</param>
/// <returns>
/// A positive value if this edit should go first, a negative value if the other edit
/// should go first. 0 is returned if there is a tie.
/// </returns>
public int CompareTo(RowEditBase other)
{
// If the other is null, this one will come out on top
if (other == null)
{
return 1;
}
// If types are the same, use the type's tiebreaking sorter
if (GetType() == other.GetType())
{
return CompareToSameType(other);
}
// If the type's sort index is the same, use our tiebreaking sorter
// If they are different, use that as the comparison
int sortIdComparison = SortId.CompareTo(other.SortId);
return sortIdComparison == 0
? CompareByRowId(other)
: sortIdComparison;
}
/// <summary>
/// Default behavior for sorting if the two compared row edits are the same type. Sorts
/// by row ID ascending.
/// </summary>
/// <param name="rowEdit">The other row edit to compare against</param>
protected virtual int CompareToSameType(RowEditBase rowEdit)
{
return CompareByRowId(rowEdit);
}
/// <summary>
/// Compares two row edits by their row ID ascending.
/// </summary>
/// <param name="rowEdit">The other row edit to compare against</param>
private int CompareByRowId(RowEditBase rowEdit)
{
return RowId.CompareTo(rowEdit.RowId);
}
#endregion
/// <summary>
/// Represents a WHERE clause that can be used for identifying a row in a table.
/// </summary>

View File

@@ -5,12 +5,16 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.EditData.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
{
@@ -19,8 +23,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
/// </summary>
public sealed class RowUpdate : RowEditBase
{
private const string UpdateStatement = "UPDATE {0} SET {1} {2}";
private const string UpdateStatementMemoryOptimized = "UPDATE {0} WITH (SNAPSHOT) SET {1} {2}";
private const string UpdateScriptStart = @"UPDATE {0}";
private const string UpdateScriptStartMemOptimized = @"UPDATE {0} WITH (SNAPSHOT)";
private const string UpdateScript = @"{0} SET {1} {2}";
private const string UpdateScriptOutput = @"{0} SET {1} OUTPUT {2} {3}";
private readonly Dictionary<int, CellUpdate> cellUpdates;
private readonly IList<DbCellValue> associatedRow;
@@ -38,6 +45,73 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
associatedRow = associatedResultSet.GetRow(rowId);
}
/// <summary>
/// Sort order property. Sorts to same position as RowCreate
/// </summary>
protected override int SortId => 1;
#region Public Methods
/// <summary>
/// Applies the changes to the associated result set after successfully executing the
/// change on the database
/// </summary>
/// <param name="dataReader">
/// Reader returned from the execution of the command to update a row. Should contain
/// a single row that represents all the values of the row.
/// </param>
public override Task ApplyChanges(DbDataReader dataReader)
{
Validate.IsNotNull(nameof(dataReader), dataReader);
return AssociatedResultSet.UpdateRow(RowId, dataReader);
}
/// <summary>
/// Generates a command that can be executed to update a row -- and return the contents of
/// the updated row.
/// </summary>
/// <param name="connection">The connection the command should be associated with</param>
/// <returns>Command to update the row</returns>
public override DbCommand GetCommand(DbConnection connection)
{
Validate.IsNotNull(nameof(connection), connection);
DbCommand command = connection.CreateCommand();
// Build the "SET" portion of the statement
List<string> setComponents = new List<string>();
foreach (var updateElement in cellUpdates)
{
string formattedColumnName = SqlScriptFormatter.FormatIdentifier(updateElement.Value.Column.ColumnName);
string paramName = $"@Value{RowId}{updateElement.Key}";
setComponents.Add($"{formattedColumnName} = {paramName}");
SqlParameter parameter = new SqlParameter(paramName, updateElement.Value.Column.SqlDbType)
{
Value = updateElement.Value.Value
};
command.Parameters.Add(parameter);
}
string setComponentsJoined = string.Join(", ", setComponents);
// Build the "OUTPUT" portion of the statement
var outColumns = from c in AssociatedResultSet.Columns
let formatted = SqlScriptFormatter.FormatIdentifier(c.ColumnName)
select $"inserted.{formatted}";
string outColumnsJoined = string.Join(", ", outColumns);
// Get the where clause
WhereClause where = GetWhereClause(true);
command.Parameters.AddRange(where.Parameters.ToArray());
// Get the start of the statement
string statementStart = GetStatementStart();
// Put the whole #! together
command.CommandText = string.Format(UpdateScriptOutput, statementStart, setComponentsJoined,
outColumnsJoined, where.CommandText);
command.CommandType = CommandType.Text;
return command;
}
/// <summary>
/// Constructs an update statement to change the associated row.
/// </summary>
@@ -45,7 +119,7 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
public override string GetScript()
{
// Build the "SET" portion of the statement
IEnumerable<string> setComponents = cellUpdates.Values.Select(cellUpdate =>
var setComponents = cellUpdates.Values.Select(cellUpdate =>
{
string formattedColumnName = SqlScriptFormatter.FormatIdentifier(cellUpdate.Column.ColumnName);
string formattedValue = SqlScriptFormatter.FormatValue(cellUpdate.Value, cellUpdate.Column);
@@ -56,10 +130,11 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
// Get the where clause
string whereClause = GetWhereClause(false).CommandText;
// Put it all together
string formatString = AssociatedObjectMetadata.IsMemoryOptimized ? UpdateStatementMemoryOptimized : UpdateStatement;
return string.Format(CultureInfo.InvariantCulture, formatString,
AssociatedObjectMetadata.EscapedMultipartName, setClause, whereClause);
// Get the start of the statement
string statementStart = GetStatementStart();
// Put the whole #! together
return string.Format(UpdateScript, statementStart, setClause, whereClause);
}
/// <summary>
@@ -106,5 +181,16 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData.UpdateManagement
IsRevert = false // If we're in this branch, it is not a revert
};
}
#endregion
private string GetStatementStart()
{
string formatString = AssociatedObjectMetadata.IsMemoryOptimized
? UpdateScriptStartMemOptimized
: UpdateScriptStart;
return string.Format(formatString, AssociatedObjectMetadata.EscapedMultipartName);
}
}
}