diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 0bf5d869..02d5b120 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -75,7 +75,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { get { - return this.ownerToConnectionMap; + return this.ownerToConnectionMap; } } @@ -264,6 +264,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return validationResults; } + TrySetConnectionType(connectionParams); + + connectionParams.Connection.ApplicationName = GetApplicationNameWithFeature(connectionParams.Connection.ApplicationName, connectionParams.Type); // If there is no ConnectionInfo in the map, create a new ConnectionInfo, // but wait until later when we are connected to add it to the map. ConnectionInfo connectionInfo; @@ -305,16 +308,63 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Invoke callback notifications InvokeOnConnectionActivities(connectionInfo, connectionParams); - if(connectionParams.Type == ConnectionType.ObjectExplorer) + TryCloseConnectionTemporaryConnection(connectionParams, connectionInfo); + + return completeParams; + } + + private void TryCloseConnectionTemporaryConnection(ConnectParams connectionParams, ConnectionInfo connectionInfo) + { + try { - DbConnection connection; - if (connectionInfo.TryGetConnection(ConnectionType.ObjectExplorer, out connection)) + if (connectionParams.Type == ConnectionType.ObjectExplorer || connectionParams.Type == ConnectionType.Dashboard || connectionParams.Type == ConnectionType.ConnectionValidation) { - // OE doesn't need to keep the connection open - connection.Close(); + DbConnection connection; + string type = connectionParams.Type; + if (connectionInfo.TryGetConnection(type, out connection)) + { + // OE doesn't need to keep the connection open + connection.Close(); + } + } + } + catch (Exception ex) + { + Logger.Write(LogLevel.Normal, "Failed to close temporary connections. error: " + ex.Message); + } + } + + private static string GetApplicationNameWithFeature(string applicationName, string featureName) + { + string appNameWithFeature = applicationName; + + if (!string.IsNullOrWhiteSpace(applicationName) && !string.IsNullOrWhiteSpace(featureName)) + { + int index = applicationName.IndexOf('-'); + string appName = applicationName; + if (index > 0) + { + appName = applicationName.Substring(0, index - 1); + } + appNameWithFeature = $"{appName}-{featureName}"; + } + + return appNameWithFeature; + } + + private void TrySetConnectionType(ConnectParams connectionParams) + { + if (connectionParams != null && connectionParams.Type == ConnectionType.Default && !string.IsNullOrWhiteSpace(connectionParams.OwnerUri)) + { + if (connectionParams.OwnerUri.ToLowerInvariant().StartsWith("dashboard://")) + { + connectionParams.Type = ConnectionType.Dashboard; + } + else if (connectionParams.OwnerUri.ToLowerInvariant().StartsWith("connection://")) + { + connectionParams.Type = ConnectionType.ConnectionValidation; } } - return completeParams; } private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo) @@ -1159,7 +1209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams) { - if (connectParams.Type != ConnectionType.Default) + if (connectParams.Type != ConnectionType.Default && connectParams.Type != ConnectionType.ConnectionValidation) { return; } @@ -1219,7 +1269,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// Note: we need to audit all uses of this method to determine why we're /// bypassing normal ConnectionService connection management /// - internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo) + internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null) { try { @@ -1234,6 +1284,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connInfo.ConnectionDetails.PersistSecurityInfo = true; // turn off connection pool to avoid hold locks on server resources after calling SqlConnection Close method connInfo.ConnectionDetails.Pooling = false; + connInfo.ConnectionDetails.ApplicationName = GetApplicationNameWithFeature(connInfo.ConnectionDetails.ApplicationName, featureName); // generate connection string string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs index 09dec4af..9764846f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs @@ -17,5 +17,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public const string Query = "Query"; public const string Edit = "Edit"; public const string ObjectExplorer = "ObjectExplorer"; + public const string Dashboard = "Dashboard"; + public const string ConnectionValidation = "ConnectionValidation"; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/DatabaseLocksManager.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/DatabaseLocksManager.cs index 26c7dca3..e667488b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/DatabaseLocksManager.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/DatabaseLocksManager.cs @@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection private Dictionary databaseAccessEvents = new Dictionary(); private object databaseAccessLock = new object(); - public const int DefaultWaitToGetFullAccess = 60000; + public const int DefaultWaitToGetFullAccess = 10000; public int waitToGetFullAccess = DefaultWaitToGetFullAccess; private ManualResetEvent GetResetEvent(string serverName, string databaseName) @@ -53,6 +53,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public bool GainFullAccessToDatabase(string serverName, string databaseName) { /* + * TODO: add the lock so not two process can get full access at the same time ManualResetEvent resetEvent = GetResetEvent(serverName, databaseName); if (resetEvent.WaitOne(this.waitToGetFullAccess)) { @@ -71,7 +72,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection */ foreach (IConnectedBindingQueue item in ConnectionService.ConnectedQueues) { - item.CloseConnections(serverName, databaseName); + item.CloseConnections(serverName, databaseName, DefaultWaitToGetFullAccess); } return true; @@ -91,7 +92,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection */ foreach (IConnectedBindingQueue item in ConnectionService.ConnectedQueues) { - item.OpenConnections(serverName, databaseName); + item.OpenConnections(serverName, databaseName, DefaultWaitToGetFullAccess); } return true; diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs index 97d45e90..d32bae38 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/DisasterRecoveryService.cs @@ -144,7 +144,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery { using (DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(connInfo, databaseExists: true)) { - using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo)) + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo, "Backup")) { if (sqlConn != null && !connInfo.IsSqlDW && !connInfo.IsAzure) { @@ -307,7 +307,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery if (supported && connInfo != null) { DatabaseTaskHelper helper = AdminService.CreateDatabaseTaskHelper(connInfo, databaseExists: true); - SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo); + SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo, "Backup"); // Connection gets discounnected when backup is done BackupOperation backupOperation = CreateBackupOperation(helper.DataContainer, sqlConn, backupParams.BackupInfo); @@ -344,7 +344,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery if (connInfo != null) { - using (sqlConn = ConnectionService.OpenSqlConnection(connInfo)) + using (sqlConn = ConnectionService.OpenSqlConnection(connInfo, "DisasterRecovery")) { if (sqlConn != null && !connInfo.IsSqlDW && !connInfo.IsAzure) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs index 7da77df8..1a427251 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs @@ -19,9 +19,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { public interface IConnectedBindingQueue { - void CloseConnections(string serverName, string databaseName); - void OpenConnections(string serverName, string databaseName); - string AddConnectionContext(ConnectionInfo connInfo, bool overwrite = false); + void CloseConnections(string serverName, string databaseName, int millisecondsTimeout); + void OpenConnections(string serverName, string databaseName, int millisecondsTimeout); + string AddConnectionContext(ConnectionInfo connInfo, string featureName = null, bool overwrite = false); void Dispose(); QueueItem QueueBindingOperation( string key, @@ -91,28 +91,28 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } - public void CloseConnections(string serverName, string databaseName) + public void CloseConnections(string serverName, string databaseName, int millisecondsTimeout) { string connectionKey = GetConnectionContextKey(serverName, databaseName); var contexts = GetBindingContexts(connectionKey); foreach (var bindingContext in contexts) { - if (bindingContext.BindingLock.WaitOne(2000)) + if (bindingContext.BindingLock.WaitOne(millisecondsTimeout)) { bindingContext.ServerConnection.Disconnect(); } } } - public void OpenConnections(string serverName, string databaseName) + public void OpenConnections(string serverName, string databaseName, int millisecondsTimeout) { string connectionKey = GetConnectionContextKey(serverName, databaseName); var contexts = GetBindingContexts(connectionKey); foreach (var bindingContext in contexts) { - if (bindingContext.BindingLock.WaitOne(2000)) + if (bindingContext.BindingLock.WaitOne(millisecondsTimeout)) { - //bindingContext.ServerConnection.Connect(); + bindingContext.ServerConnection.Connect(); } } } @@ -122,7 +122,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// /// Connection info used to create binding context /// Overwrite existing context - public virtual string AddConnectionContext(ConnectionInfo connInfo, bool overwrite = false) + public virtual string AddConnectionContext(ConnectionInfo connInfo, string featureName = null, bool overwrite = false) { if (connInfo == null) { @@ -150,7 +150,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices try { bindingContext.BindingLock.Reset(); - SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo); + SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo, featureName); // populate the binding context to work with the SMO metadata provider bindingContext.ServerConnection = new ServerConnection(sqlConn); @@ -166,7 +166,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } bindingContext.BindingTimeout = ConnectedBindingQueue.DefaultBindingTimeout; - bindingContext.IsConnected = true; + bindingContext.IsConnected = true; } catch (Exception) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 78168812..bc0afe6b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -602,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { try { - this.BindingQueue.AddConnectionContext(connInfo, overwrite: true); + this.BindingQueue.AddConnectionContext(connInfo, featureName: "LanguageService", overwrite: true); } catch (Exception ex) { @@ -849,7 +849,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { try { - scriptInfo.ConnectionKey = this.BindingQueue.AddConnectionContext(info); + scriptInfo.ConnectionKey = this.BindingQueue.AddConnectionContext(info, "languageService"); scriptInfo.IsConnected = true; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs index 472065bf..0da6363b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs @@ -74,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata var metadata = new List(); if (connInfo != null) { - using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo)) + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo, "Metadata")) { ReadMetadata(sqlConn, metadata); } @@ -129,7 +129,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata ColumnMetadata[] metadata = null; if (connInfo != null) { - using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo)) + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connInfo, "Metadata")) { TableMetadata table = new SmoMetadataFactory().GetObjectMetadata( sqlConn, metadataParams.Schema, diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs index b6891198..3785ff96 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs @@ -392,7 +392,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer try { QueueItem queueItem = bindingQueue.QueueBindingOperation( - key: bindingQueue.AddConnectionContext(session.ConnectionInfo), + key: bindingQueue.AddConnectionContext(session.ConnectionInfo, "OE"), bindingTimeout: PrepopulateBindTimeout, waitForLockTimeout: PrepopulateBindTimeout, bindOperation: (bindingContext, cancelToken) => diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs index c53cdf17..d24aa9b0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs @@ -231,7 +231,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting private void RunSelectTask(ConnectionInfo connInfo, ScriptingParams parameters, RequestContext requestContext) { ConnectionServiceInstance.ConnectionQueue.QueueBindingOperation( - key: ConnectionServiceInstance.ConnectionQueue.AddConnectionContext(connInfo), + key: ConnectionServiceInstance.ConnectionQueue.AddConnectionContext(connInfo, "Scripting"), bindingTimeout: ScriptingOperationTimeout, bindOperation: (bindingContext, cancelToken) => { diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs index d3697aab..84153adc 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/DisasterRecovery/RestoreDatabaseServiceTests.cs @@ -188,7 +188,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.DisasterRecovery connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath, ConnectionType.ObjectExplorer); //Opening a connection to db to lock the db - connectionService.ConnectionQueue.AddConnectionContext(connectionResult.ConnectionInfo, true); + connectionService.ConnectionQueue.AddConnectionContext(connectionResult.ConnectionInfo, "", true); try { diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/DatabaseLocksManagerTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/DatabaseLocksManagerTests.cs index 2ecbd44e..8ef90b97 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/DatabaseLocksManagerTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/DatabaseLocksManagerTests.cs @@ -19,14 +19,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection public void GainFullAccessShouldDisconnectTheConnections() { var connectionLock = new Mock(); - connectionLock.Setup(x => x.CloseConnections(server1, database1)); + connectionLock.Setup(x => x.CloseConnections(server1, database1, DatabaseLocksManager.DefaultWaitToGetFullAccess)); using (DatabaseLocksManager databaseLocksManager = CreateManager()) { databaseLocksManager.ConnectionService.RegisterConnectedQueue("test", connectionLock.Object); databaseLocksManager.GainFullAccessToDatabase(server1, database1); - connectionLock.Verify(x => x.CloseConnections(server1, database1)); + connectionLock.Verify(x => x.CloseConnections(server1, database1, DatabaseLocksManager.DefaultWaitToGetFullAccess)); } } @@ -34,14 +34,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection public void ReleaseAccessShouldConnectTheConnections() { var connectionLock = new Mock(); - connectionLock.Setup(x => x.OpenConnections(server1, database1)); + connectionLock.Setup(x => x.OpenConnections(server1, database1, DatabaseLocksManager.DefaultWaitToGetFullAccess)); using (DatabaseLocksManager databaseLocksManager = CreateManager()) { databaseLocksManager.ConnectionService.RegisterConnectedQueue("test", connectionLock.Object); databaseLocksManager.ReleaseAccess(server1, database1); - connectionLock.Verify(x => x.OpenConnections(server1, database1)); + connectionLock.Verify(x => x.OpenConnections(server1, database1, DatabaseLocksManager.DefaultWaitToGetFullAccess)); } } @@ -74,10 +74,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection DatabaseLocksManager databaseLocksManager = new DatabaseLocksManager(2000); var connectionLock1 = new Mock(); var connectionLock2 = new Mock(); - connectionLock1.Setup(x => x.CloseConnections(It.IsAny(), It.IsAny())); - connectionLock2.Setup(x => x.OpenConnections(It.IsAny(), It.IsAny())); - connectionLock1.Setup(x => x.OpenConnections(It.IsAny(), It.IsAny())); - connectionLock2.Setup(x => x.CloseConnections(It.IsAny(), It.IsAny())); + connectionLock1.Setup(x => x.CloseConnections(It.IsAny(), It.IsAny(), It.IsAny())); + connectionLock2.Setup(x => x.OpenConnections(It.IsAny(), It.IsAny(), It.IsAny())); + connectionLock1.Setup(x => x.OpenConnections(It.IsAny(), It.IsAny(), It.IsAny())); + connectionLock2.Setup(x => x.CloseConnections(It.IsAny(), It.IsAny(), It.IsAny())); ConnectionService connectionService = new ConnectionService(); databaseLocksManager.ConnectionService = connectionService; diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/LanguageServiceTestBase.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/LanguageServiceTestBase.cs index 4642c727..2a27ef41 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/LanguageServiceTestBase.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/LanguageServer/LanguageServiceTestBase.cs @@ -78,7 +78,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.LanguageServer // setup binding queue mock bindingQueue = new Mock(); - bindingQueue.Setup(q => q.AddConnectionContext(It.IsAny(), It.IsAny())) + bindingQueue.Setup(q => q.AddConnectionContext(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(this.testConnectionKey); langService = new LanguageService();