// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using Microsoft.Data.SqlClient; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.SmoMetadataProvider; using Microsoft.SqlServer.Management.SqlParser.Binder; using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; using Microsoft.SqlTools.CoreServices.Connection; using Microsoft.SqlTools.DataProtocol.Contracts.Connection; using System.Threading; using Microsoft.SqlTools.CoreServices.Workspace; using Microsoft.SqlTools.CoreServices.SqlContext; namespace Microsoft.SqlTools.CoreServices.LanguageServices { public interface IConnectedBindingQueue { void CloseConnections(string serverName, string databaseName, int millisecondsTimeout); void OpenConnections(string serverName, string databaseName, int millisecondsTimeout); string AddConnectionContext(ConnectionInfo connInfo, string featureName = null, bool overwrite = false); void Dispose(); QueueItem QueueBindingOperation( string key, Func bindOperation, Func timeoutOperation = null, Func errorHandler = null, int? bindingTimeout = null, int? waitForLockTimeout = null); } public class SqlConnectionOpener { /// /// Virtual method used to support mocking and testing /// public virtual SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName) { return ConnectionServiceCore.OpenSqlConnection(connInfo, featureName); } } /// /// ConnectedBindingQueue class for processing online binding requests /// public class ConnectedBindingQueue : BindingQueue, IConnectedBindingQueue { internal const int DefaultBindingTimeout = 500; internal const int DefaultMinimumConnectionTimeout = 30; /// /// flag determing if the connection queue requires online metadata objects /// it's much cheaper to not construct these objects if not needed /// private bool needsMetadata; private SqlConnectionOpener connectionOpener; /// /// Gets the current settings /// internal SqlToolsSettings CurrentSettings { get { return SettingsService.Instance.CurrentSettings; } } public ConnectedBindingQueue() : this(true) { } public ConnectedBindingQueue(bool needsMetadata) { this.needsMetadata = needsMetadata; this.connectionOpener = new SqlConnectionOpener(); } // For testing purposes only internal void SetConnectionOpener(SqlConnectionOpener opener) { this.connectionOpener = opener; } /// /// Generate a unique key based on the ConnectionInfo object /// /// private string GetConnectionContextKey(ConnectionInfo connInfo) { ConnectionDetails details = connInfo.ConnectionDetails; string key = string.Format("{0}_{1}_{2}_{3}", details.ServerName ?? "NULL", details.DatabaseName ?? "NULL", details.UserName ?? "NULL", details.AuthenticationType ?? "NULL" ); if (!string.IsNullOrEmpty(details.DatabaseDisplayName)) { key += "_" + details.DatabaseDisplayName; } if (!string.IsNullOrEmpty(details.GroupId)) { key += "_" + details.GroupId; } return key; } /// /// Generate a unique key based on the ConnectionInfo object /// /// private string GetConnectionContextKey(string serverName, string databaseName) { return string.Format("{0}_{1}", serverName ?? "NULL", databaseName ?? "NULL"); } public void CloseConnections(string serverName, string databaseName, int millisecondsTimeout) { string connectionKey = GetConnectionContextKey(serverName, databaseName); var contexts = GetBindingContexts(connectionKey); foreach (var bindingContext in contexts) { if (bindingContext.BindingLock.WaitOne(millisecondsTimeout)) { bindingContext.ServerConnection.Disconnect(); } } } public void OpenConnections(string serverName, string databaseName, int millisecondsTimeout) { string connectionKey = GetConnectionContextKey(serverName, databaseName); var contexts = GetBindingContexts(connectionKey); foreach (var bindingContext in contexts) { if (bindingContext.BindingLock.WaitOne(millisecondsTimeout)) { try { bindingContext.ServerConnection.Connect(); } catch { //TODO: remove the binding context? } } } } public void RemoveBindingContext(ConnectionInfo connInfo) { string connectionKey = GetConnectionContextKey(connInfo); if (BindingContextExists(connectionKey)) { RemoveBindingContext(connectionKey); } } /// /// Use a ConnectionInfo item to create a connected binding context /// /// Connection info used to create binding context /// Overwrite existing context public virtual string AddConnectionContext(ConnectionInfo connInfo, string featureName = null, bool overwrite = false) { if (connInfo == null) { return string.Empty; } // lookup the current binding context string connectionKey = GetConnectionContextKey(connInfo); if (BindingContextExists(connectionKey)) { if (overwrite) { RemoveBindingContext(connectionKey); } else { // no need to populate the context again since the context already exists return connectionKey; } } IBindingContext bindingContext = this.GetOrCreateBindingContext(connectionKey); if (bindingContext.BindingLock.WaitOne()) { try { bindingContext.BindingLock.Reset(); SqlConnection sqlConn = connectionOpener.OpenSqlConnection(connInfo, featureName); // populate the binding context to work with the SMO metadata provider bindingContext.ServerConnection = new ServerConnection(sqlConn); if (this.needsMetadata) { bindingContext.SmoMetadataProvider = SmoMetadataProvider.CreateConnectedProvider(bindingContext.ServerConnection); bindingContext.MetadataDisplayInfoProvider = new MetadataDisplayInfoProvider(); bindingContext.MetadataDisplayInfoProvider.BuiltInCasing = this.CurrentSettings.SqlTools.IntelliSense.LowerCaseSuggestions.Value ? CasingStyle.Lowercase : CasingStyle.Uppercase; bindingContext.Binder = BinderProvider.CreateBinder(bindingContext.SmoMetadataProvider); } bindingContext.BindingTimeout = ConnectedBindingQueue.DefaultBindingTimeout; bindingContext.IsConnected = true; } catch (Exception) { bindingContext.IsConnected = false; } finally { bindingContext.BindingLock.Set(); } } return connectionKey; } } }