diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 7ae92a63..ee0f0730 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -613,6 +613,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } // Make sure a default connection exists + DbConnection connection; DbConnection defaultConnection; if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection)) { @@ -622,42 +623,99 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection if(IsDedicatedAdminConnection(connectionInfo.ConnectionDetails)) { // Since this is a dedicated connection only 1 is allowed at any time. Return the default connection for use in the requested action - return defaultConnection; + connection = defaultConnection; + } + else + { + // Try to get the DbConnection and create if it doesn't already exist + if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType) + { + connection = await TryOpenConnectionForConnectionType(ownerUri, connectionType, alwaysPersistSecurity, connectionInfo); + } } - // Try to get the DbConnection - DbConnection connection; - if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType) - { - // If the DbConnection does not exist and is not the default connection, create one. - // We can't create the default (initial) connection here because we won't have a ConnectionDetails - // if Connect() has not yet been called. - bool? originalPersistSecurityInfo = connectionInfo.ConnectionDetails.PersistSecurityInfo; - if (alwaysPersistSecurity) - { - connectionInfo.ConnectionDetails.PersistSecurityInfo = true; - } - ConnectParams connectParams = new ConnectParams - { - OwnerUri = ownerUri, - Connection = connectionInfo.ConnectionDetails, - Type = connectionType - }; - try - { - await Connect(connectParams); - } - finally - { - connectionInfo.ConnectionDetails.PersistSecurityInfo = originalPersistSecurityInfo; - } - - connectionInfo.TryGetConnection(connectionType, out connection); - } + VerifyConnectionOpen(connection); return connection; } + private async Task TryOpenConnectionForConnectionType(string ownerUri, string connectionType, + bool alwaysPersistSecurity, ConnectionInfo connectionInfo) + { + // If the DbConnection does not exist and is not the default connection, create one. + // We can't create the default (initial) connection here because we won't have a ConnectionDetails + // if Connect() has not yet been called. + bool? originalPersistSecurityInfo = connectionInfo.ConnectionDetails.PersistSecurityInfo; + if (alwaysPersistSecurity) + { + connectionInfo.ConnectionDetails.PersistSecurityInfo = true; + } + ConnectParams connectParams = new ConnectParams + { + OwnerUri = ownerUri, + Connection = connectionInfo.ConnectionDetails, + Type = connectionType + }; + try + { + await Connect(connectParams); + } + finally + { + connectionInfo.ConnectionDetails.PersistSecurityInfo = originalPersistSecurityInfo; + } + + DbConnection connection; + connectionInfo.TryGetConnection(connectionType, out connection); + return connection; + } + + private void VerifyConnectionOpen(DbConnection connection) + { + if (connection == null) + { + // Ignore this connection + return; + } + + if (connection.State != ConnectionState.Open) + { + // Note: this will fail and throw to the caller if something goes wrong. + // This seems the right thing to do but if this causes serviceability issues where stack trace + // is unexpected, might consider catching and allowing later code to fail. But given we want to get + // an opened connection for any action using this, it seems OK to handle in this manner + ClearPool(connection); + connection.Open(); + } + } + + /// + /// Clears the connection pool if this is a SqlConnection of some kind. + /// + private void ClearPool(DbConnection connection) + { + SqlConnection sqlConn; + if (TryGetAsSqlConnection(connection, out sqlConn)) + { + SqlConnection.ClearPool(sqlConn); + } + } + + private bool TryGetAsSqlConnection(DbConnection dbConn, out SqlConnection sqlConn) + { + ReliableSqlConnection reliableConn = dbConn as ReliableSqlConnection; + if (reliableConn != null) + { + sqlConn = reliableConn.GetUnderlyingConnection(); + } + else + { + sqlConn = dbConn as SqlConnection; + } + + return sqlConn != null; + } + /// /// Cancel a connection that is in the process of opening. /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs index abefea30..ab9745b2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs @@ -141,8 +141,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting else { ScriptingScriptOperation operation = new ScriptingScriptOperation(parameters); - operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e); - operation.ProgressNotification += (sender, e) => requestContext.SendEvent(ScriptingProgressNotificationEvent.Type, e); + operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait(); + operation.ProgressNotification += (sender, e) => requestContext.SendEvent(ScriptingProgressNotificationEvent.Type, e).Wait(); operation.CompleteNotification += (sender, e) => this.SendScriptingCompleteEvent(requestContext, ScriptingCompleteEvent.Type, e, operation, parameters.ScriptDestination); RunTask(requestContext, operation); @@ -267,11 +267,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting } // send script result to client - requestContext.SendResult(new ScriptingResult { Script = script }); + requestContext.SendResult(new ScriptingResult { Script = script }).Wait(); } catch (Exception e) { - requestContext.SendError(e); + requestContext.SendError(e).Wait(); } return null; @@ -283,7 +283,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting /// private void RunTask(RequestContext context, ScriptingOperation operation) { - Task.Run(() => + Task.Run(async () => { try { @@ -292,14 +292,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting } catch (Exception e) { - context.SendError(e); + await context.SendError(e); } finally { ScriptingOperation temp; this.ActiveOperations.TryRemove(operation.OperationId, out temp); } - }).ContinueWithOnFaulted(null); + }).ContinueWithOnFaulted(async t => await context.SendError(t.Exception)); } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver.Tests/SqlScriptPublishModelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.TestDriver.Tests/SqlScriptPublishModelTests.cs index bbed2e10..203ceabb 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver.Tests/SqlScriptPublishModelTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver.Tests/SqlScriptPublishModelTests.cs @@ -233,7 +233,7 @@ namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Tests var result = Task.Run(() => testService.Script(requestParams)); ScriptingProgressNotificationParams progressParams = await testService.Driver.WaitForEvent(ScriptingProgressNotificationEvent.Type, TimeSpan.FromSeconds(10)); - Task.Run(() => testService.CancelScript(progressParams.OperationId)); + Task.Run(() => testService.CancelScript(progressParams.OperationId).Wait()); ScriptingCompleteParams cancelEvent = await testService.Driver.WaitForEvent(ScriptingCompleteEvent.Type, TimeSpan.FromSeconds(10)); Assert.True(cancelEvent.Canceled); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 982f96c2..36bd0e0d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -1250,7 +1250,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection connInfo = service.OwnerToConnectionMap[connectionParameters.OwnerUri]; Assert.NotNull(defaultConn); Assert.Equal(connInfo.AllConnections.Count, 1); - + + // Verify that if the query connection was closed, it will be reopened on requesting the connection again + Assert.Equal(ConnectionState.Open, queryConn.State); + queryConn.Close(); + Assert.Equal(ConnectionState.Closed, queryConn.State); + queryConn = await service.GetOrOpenConnection(connectionParameters.OwnerUri, ConnectionType.Query); + Assert.Equal(ConnectionState.Open, queryConn.State); } [Fact] diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index 112e0a8c..5e00f1f4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -216,6 +216,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility public override void Close() { // No Op + this._state = ConnectionState.Closed; } public override void Open() @@ -225,6 +226,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { throw new Exception("Invalid credentials provided"); } + this._state = ConnectionState.Open; } public override string ConnectionString { get; set; }