diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index 3091d78b..ccd85d20 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -24,7 +24,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public class Batch : IDisposable { #region Member Variables - /// /// For IDisposable implementation, whether or not this has been disposed /// @@ -219,7 +218,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #endregion #region Public Methods - + /// /// Executes this batch and captures any server messages that are returned. /// @@ -247,37 +246,41 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Register the message listener to *this instance* of the batch // Note: This is being done to associate messages with batches ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; - DbCommand command; + DbCommand dbCommand; if (sqlConn != null) { // Register the message listener to *this instance* of the batch // Note: This is being done to associate messages with batches sqlConn.GetUnderlyingConnection().InfoMessage += ServerMessageHandler; - command = sqlConn.GetUnderlyingConnection().CreateCommand(); + dbCommand = sqlConn.GetUnderlyingConnection().CreateCommand(); // Add a handler for when the command completes - SqlCommand sqlCommand = (SqlCommand)command; + SqlCommand sqlCommand = (SqlCommand)dbCommand; sqlCommand.StatementCompleted += StatementCompletedHandler; } else { - command = conn.CreateCommand(); + dbCommand = conn.CreateCommand(); } // Make sure we aren't using a ReliableCommad since we do not want automatic retry - Debug.Assert(!(command is ReliableSqlConnection.ReliableSqlCommand), + Debug.Assert(!(dbCommand is ReliableSqlConnection.ReliableSqlCommand), "ReliableSqlCommand command should not be used to execute queries"); // Create a command that we'll use for executing the query - using (command) + using (dbCommand) { - command.CommandText = BatchText; - command.CommandType = CommandType.Text; - command.CommandTimeout = 0; + // Make sure that we cancel the command if the cancellation token is cancelled + cancellationToken.Register(() => dbCommand?.Cancel()); + + // Setup the command for executing the batch + dbCommand.CommandText = BatchText; + dbCommand.CommandType = CommandType.Text; + dbCommand.CommandTimeout = 0; executionStartTime = DateTime.Now; // Execute the command to get back a reader - using (DbDataReader reader = await command.ExecuteReaderAsync(cancellationToken)) + using (DbDataReader reader = await dbCommand.ExecuteReaderAsync(cancellationToken)) { int resultSetOrdinal = 0; do