diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index c81362dd..b79af0b2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Data.SqlClient; using System.Linq; using System.Threading; @@ -134,16 +135,26 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution try { + DbCommand command = null; + // Register the message listener to *this instance* of the batch // Note: This is being done to associate messages with batches ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; if (sqlConn != null) { sqlConn.GetUnderlyingConnection().InfoMessage += StoreDbMessage; + command = sqlConn.GetUnderlyingConnection().CreateCommand(); + } + else + { + command = conn.CreateCommand(); } + // Make sure we aren't using a ReliableCommad since we do not want automatic retry + Debug.Assert(!(command is ReliableSqlConnection.ReliableSqlCommand), "ReliableSqlCommand command should not be used to execute queries"); + // Create a command that we'll use for executing the query - using (DbCommand command = conn.CreateCommand()) + using (command) { command.CommandText = BatchText; command.CommandType = CommandType.Text; @@ -190,10 +201,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution finally { // Remove the message event handler from the connection - SqlConnection sqlConn = conn as SqlConnection; + ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; if (sqlConn != null) { - sqlConn.InfoMessage -= StoreDbMessage; + sqlConn.GetUnderlyingConnection().InfoMessage -= StoreDbMessage; } // Mark that we have executed diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 2c1c2ce2..cf2df73c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -195,10 +195,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage; } - // We need these to execute synchronously, otherwise the user will be very unhappy - foreach (Batch b in Batches) + try { - await b.Execute(conn, cancellationSource.Token); + // We need these to execute synchronously, otherwise the user will be very unhappy + foreach (Batch b in Batches) + { + await b.Execute(conn, cancellationSource.Token); + } + } + finally + { + if (sqlConn != null) + { + // Subscribe to database informational messages + sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage; + } } // TODO: Close connection after eliminating using statement for above TODO diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 4aed41d8..777cbaa3 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -39,6 +39,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public const string NoOpQuery = "-- No ops here, just us chickens."; + public const string UdtQuery = "SELECT hierarchyid::Parse('/')"; + public const string OwnerUri = "testFile"; public const int StandardRows = 5; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index b44b3618..edfd6bad 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -3,6 +3,8 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +//#define USE_LIVE_CONNECTION + using System; using System.Data.Common; using System.Linq; @@ -17,6 +19,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; using Moq; using Xunit; @@ -667,7 +670,35 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Assert.NotEmpty(complete.BatchSummaries[0].Messages); } - #endregion +#if USE_LIVE_CONNECTION + [Fact] + public void QueryUdtShouldNotRetry() + { + // If: + // ... I create a query with a udt column in the result set + ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo(); + Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + + // If: + // ... I then execute the query + DateTime startTime = DateTime.Now; + query.Execute().Wait(); + + // Then: + // ... The query should complete within 2 seconds since retry logic should not kick in + Assert.True(DateTime.Now.Subtract(startTime) < TimeSpan.FromSeconds(2), "Query completed slower than expected, did retry logic execute?"); + + // Then: + // ... There should be an error on the batch + Assert.True(query.HasExecuted); + Assert.NotEmpty(query.BatchSummaries); + Assert.Equal(1, query.BatchSummaries.Length); + Assert.True(query.BatchSummaries[0].HasError); + Assert.NotEmpty(query.BatchSummaries[0].Messages); + } +#endif + +#endregion private void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) { diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs index 82ffffd0..3474c4df 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs @@ -35,6 +35,17 @@ namespace Microsoft.SqlTools.Test.Utility #endif } + /// + /// Creates a test connection info instance. + /// + public static ConnectionInfo GetTestConnectionInfo() + { + return new ConnectionInfo( + GetTestSqlConnectionFactory(), + "file://some/file.sql", + GetTestConnectionDetails()); + } + public static ConnectParams GetTestConnectionParams() { return new ConnectParams()