diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs index 0610fff4..ff56875f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs @@ -12,16 +12,89 @@ using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection { + /// /// This class caches server information for subsequent use /// - internal static class CachedServerInfo + internal class CachedServerInfo { + /// + /// Singleton service instance + /// + private static readonly Lazy instance + = new Lazy(() => new CachedServerInfo()); + + /// + /// Gets the singleton instance + /// + public static CachedServerInfo Instance + { + get + { + return instance.Value; + } + } + public enum CacheVariable { IsSqlDw, IsAzure } + #region CacheKey implementation + internal class CacheKey : IEquatable + { + private string dataSource; + private string dbName; + + public CacheKey(SqlConnectionStringBuilder builder) + { + Validate.IsNotNull(nameof(builder), builder); + dataSource = builder.DataSource; + dbName = GetDatabaseName(builder); + } + + internal static string GetDatabaseName(SqlConnectionStringBuilder builder) + { + string dbName = string.Empty; + if (!string.IsNullOrEmpty((builder.InitialCatalog))) + { + dbName = builder.InitialCatalog; + } + else if (!string.IsNullOrEmpty((builder.AttachDBFilename))) + { + dbName = builder.AttachDBFilename; + } + return dbName; + } + + public override bool Equals(object obj) + { + if (obj == null) { return false; } + + CacheKey keyObj = obj as CacheKey; + if (keyObj == null) { return false; } + else { return Equals(keyObj); } + } + + public override int GetHashCode() + { + unchecked // Overflow is fine, just wrap + { + int hash = 17; + hash = (hash * 23) + (dataSource != null ? dataSource.GetHashCode() : 0); + hash = (hash * 23) + (dbName != null ? dbName.GetHashCode() : 0); + return hash; + } + } + + public bool Equals(CacheKey other) + { + return string.Equals(dataSource, other.dataSource, StringComparison.OrdinalIgnoreCase) + && string.Equals(dbName, other.dbName, StringComparison.OrdinalIgnoreCase); + } + } + #endregion + private struct CachedInfo { public bool IsAzure; @@ -29,38 +102,43 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public bool IsSqlDw; } - private static ConcurrentDictionary _cache; - private static object _cacheLock; private const int _maxCacheSize = 1024; private const int _deleteBatchSize = 512; private const int MinimalQueryTimeoutSecondsForAzure = 300; - static CachedServerInfo() + private ConcurrentDictionary _cache; + private object _cacheLock; + + /// + /// Internal constructor for testing purposes. For all code use, please use the + /// default instance. + /// + internal CachedServerInfo() { - _cache = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + _cache = new ConcurrentDictionary(); _cacheLock = new object(); } - public static int GetQueryTimeoutSeconds(IDbConnection connection) + public int GetQueryTimeoutSeconds(IDbConnection connection) { - string dataSource = SafeGetDataSourceFromConnection(connection); - return GetQueryTimeoutSeconds(dataSource); + SqlConnectionStringBuilder connStringBuilder = SafeGetConnectionStringFromConnection(connection); + return GetQueryTimeoutSeconds(connStringBuilder); } - public static int GetQueryTimeoutSeconds(string dataSource) + public int GetQueryTimeoutSeconds(SqlConnectionStringBuilder builder) { //keep existing behavior and return the default ambient settings //if the provided data source is null or whitespace, or the original //setting is already 0 which means no limit. int originalValue = AmbientSettings.QueryTimeoutSeconds; - if (string.IsNullOrWhiteSpace(dataSource) + if (builder == null || string.IsNullOrWhiteSpace(builder.DataSource) || (originalValue == 0)) { return originalValue; } CachedInfo info; - bool hasFound = _cache.TryGetValue(dataSource, out info); + bool hasFound = TryGetCacheValue(builder, out info); if (hasFound && info.IsAzure && originalValue < MinimalQueryTimeoutSecondsForAzure) @@ -73,55 +151,43 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection } } - public static void AddOrUpdateIsAzure(IDbConnection connection, bool isAzure) + public void AddOrUpdateIsAzure(IDbConnection connection, bool isAzure) { AddOrUpdateCache(connection, isAzure, CacheVariable.IsAzure); } - public static void AddOrUpdateIsSqlDw(IDbConnection connection, bool isSqlDw) + public void AddOrUpdateIsSqlDw(IDbConnection connection, bool isSqlDw) { AddOrUpdateCache(connection, isSqlDw, CacheVariable.IsSqlDw); } - private static void AddOrUpdateCache(IDbConnection connection, bool newState, CacheVariable cacheVar) + private void AddOrUpdateCache(IDbConnection connection, bool newState, CacheVariable cacheVar) { Validate.IsNotNull(nameof(connection), connection); SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString); - AddOrUpdateCache(builder.DataSource, newState, cacheVar); + AddOrUpdateCache(builder, newState, cacheVar); } - internal static void AddOrUpdateCache(string dataSource, bool newState, CacheVariable cacheVar) + internal void AddOrUpdateCache(SqlConnectionStringBuilder builder, bool newState, CacheVariable cacheVar) { - Validate.IsNotNullOrWhitespaceString(nameof(dataSource), dataSource); + Validate.IsNotNull(nameof(builder), builder); + Validate.IsNotNullOrWhitespaceString(nameof(builder) + ".DataSource", builder.DataSource); CachedInfo info; - bool hasFound = _cache.TryGetValue(dataSource, out info); + bool hasFound = TryGetCacheValue(builder, out info); if ((cacheVar == CacheVariable.IsSqlDw && hasFound && info.IsSqlDw == newState) || (cacheVar == CacheVariable.IsAzure && hasFound && info.IsAzure == newState)) { + // No change needed return; } else { lock (_cacheLock) { - if (! _cache.ContainsKey(dataSource)) - { - //delete a batch of old elements when we try to add a new one and - //the capacity limitation is hit - if (_cache.Keys.Count > _maxCacheSize - 1) - { - var keysToDelete = _cache - .OrderBy(x => x.Value.LastUpdate) - .Take(_deleteBatchSize) - .Select(pair => pair.Key); - - foreach (string key in keysToDelete) - { - _cache.TryRemove(key, out info); - } - } - } + // Clean older keys, update info, and add this back into the cache + CacheKey key = new CacheKey(builder); + CleanupCache(key); if (cacheVar == CacheVariable.IsSqlDw) { @@ -132,24 +198,47 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection info.IsAzure = newState; } info.LastUpdate = DateTime.UtcNow; - _cache.AddOrUpdate(dataSource, info, (key, oldValue) => info); + _cache.AddOrUpdate(key, info, (k, oldValue) => info); } } } - public static bool TryGetIsSqlDw(IDbConnection connection, out bool isSqlDw) + private void CleanupCache(CacheKey newKey) + { + if (!_cache.ContainsKey(newKey)) + { + //delete a batch of old elements when we try to add a new one and + //the capacity limitation is hit + if (_cache.Keys.Count > _maxCacheSize - 1) + { + var keysToDelete = _cache + .OrderBy(x => x.Value.LastUpdate) + .Take(_deleteBatchSize) + .Select(pair => pair.Key); + + foreach (CacheKey key in keysToDelete) + { + CachedInfo info; + _cache.TryRemove(key, out info); + } + } + } + } + + public bool TryGetIsSqlDw(IDbConnection connection, out bool isSqlDw) { Validate.IsNotNull(nameof(connection), connection); SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString); - return TryGetIsSqlDw(builder.DataSource, out isSqlDw); + return TryGetIsSqlDw(builder, out isSqlDw); } - public static bool TryGetIsSqlDw(string dataSource, out bool isSqlDw) + public bool TryGetIsSqlDw(SqlConnectionStringBuilder builder, out bool isSqlDw) { - Validate.IsNotNullOrWhitespaceString(nameof(dataSource), dataSource); + Validate.IsNotNull(nameof(builder), builder); + Validate.IsNotNullOrWhitespaceString(nameof(builder) + ".DataSource", builder.DataSource); CachedInfo info; - bool hasFound = _cache.TryGetValue(dataSource, out info); + bool hasFound = TryGetCacheValue(builder, out info); if(hasFound) { @@ -161,7 +250,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return false; } - private static string SafeGetDataSourceFromConnection(IDbConnection connection) + private static SqlConnectionStringBuilder SafeGetConnectionStringFromConnection(IDbConnection connection) { if (connection == null) { @@ -171,7 +260,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection try { SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString); - return builder.DataSource; + return builder; } catch { @@ -179,5 +268,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return null; } } + + private bool TryGetCacheValue(SqlConnectionStringBuilder builder, out CachedInfo value) + { + CacheKey key = new CacheKey(builder); + return _cache.TryGetValue(key, out value); + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs index 4a5a559e..643a2918 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs @@ -407,7 +407,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public static void SetCommandTimeout(IDbCommand cmd) { Validate.IsNotNull(nameof(cmd), cmd); - cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(cmd.Connection); + cmd.CommandTimeout = CachedServerInfo.Instance.GetQueryTimeoutSeconds(cmd.Connection); } @@ -773,7 +773,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection try { - CachedServerInfo.AddOrUpdateIsAzure(connection, serverInfo.IsCloud); + CachedServerInfo.Instance.AddOrUpdateIsAzure(connection, serverInfo.IsCloud); } catch (Exception ex) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs index 37ae5cd6..45773a82 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs @@ -109,10 +109,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection //This is assuming that it is highly unlikely for a connection to change between instances. //Hence any subsequent calls to this method will just return the cached value and not //verify again if this is a SQL DW database connection or not. - if (!CachedServerInfo.TryGetIsSqlDw(conn, out _isSqlDwDatabase)) + if (!CachedServerInfo.Instance.TryGetIsSqlDw(conn, out _isSqlDwDatabase)) { _isSqlDwDatabase = ReliableConnectionHelper.IsSqlDwDatabase(conn); - CachedServerInfo.AddOrUpdateIsSqlDw(conn, _isSqlDwDatabase);; + CachedServerInfo.Instance.AddOrUpdateIsSqlDw(conn, _isSqlDwDatabase);; } return _isSqlDwDatabase; @@ -137,7 +137,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection { cmd.CommandText = string.Format(CultureInfo.InvariantCulture, setLockTimeout, AmbientSettings.LockTimeoutMilliSeconds); cmd.CommandType = CommandType.Text; - cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(conn); + cmd.CommandTimeout = CachedServerInfo.Instance.GetQueryTimeoutSeconds(conn); cmd.ExecuteNonQuery(); } } @@ -157,7 +157,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection // Configure the connection with proper ANSI settings and lock timeout using (IDbCommand cmd = conn.CreateCommand()) { - cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(conn); + cmd.CommandTimeout = CachedServerInfo.Instance.GetQueryTimeoutSeconds(conn); if (!isSqlDw) { cmd.CommandText = @"SET ANSI_NULLS, ANSI_PADDING, ANSI_WARNINGS, ARITHABORT, CONCAT_NULL_YIELDS_NULL, QUOTED_IDENTIFIER ON; diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs index 2016b97e..d016eaab 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs @@ -9,18 +9,13 @@ using System.Data; using System.Data.Common; using System.Data.SqlClient; using System.Threading; -using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; -using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.Test.Common; -using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; using Microsoft.SqlTools.ServiceLayer.Test.Utility; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Microsoft.SqlTools.Test.Utility; using Xunit; -using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.ReliableConnectionHelper; using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.RetryPolicy; using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.RetryPolicy.TimeBasedRetryPolicy; using static Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection.SqlSchemaModelErrorCodes; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs index 74276634..b47833f9 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs @@ -6,6 +6,7 @@ using System; using Xunit; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using System.Data.SqlClient; namespace Microsoft.SqlTools.ServiceLayer.Test.Connection { @@ -14,18 +15,70 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// public class CachedServerInfoTests { + private CachedServerInfo cache; + + public CachedServerInfoTests() + { + cache = new CachedServerInfo(); + } + + [Fact] + public void CacheMatchesNullDbNameToEmptyString() + { + // Set sqlDw result into cache + string dataSource = "testDataSource"; + bool isSqlDwResult; + SqlConnectionStringBuilder testSource = new SqlConnectionStringBuilder + { + DataSource = dataSource, + InitialCatalog = string.Empty + }; + cache.AddOrUpdateCache(testSource, true, CachedServerInfo.CacheVariable.IsSqlDw); + + // Expect the same returned result + Assert.True(cache.TryGetIsSqlDw(testSource, out isSqlDwResult)); + Assert.True(isSqlDwResult); + + // And expect the same for the null string + Assert.True(cache.TryGetIsSqlDw(new SqlConnectionStringBuilder + { + DataSource = dataSource + // Initial Catalog is null. Can't set explicitly as this throws + }, out isSqlDwResult)); + Assert.True(isSqlDwResult); + + // But expect false for a different DB + Assert.False(cache.TryGetIsSqlDw(new SqlConnectionStringBuilder + { + DataSource = dataSource, + InitialCatalog = "OtherDb" + }, out isSqlDwResult)); + } [Theory] - [InlineData(true)] // is SqlDW instance - [InlineData(false)] // is not a SqlDw Instance - public void AddOrUpdateIsSqlDw(bool state) + [InlineData(null, true)] // is SqlDW instance + [InlineData("", true)] // is SqlDW instance + [InlineData("myDb", true)] // is SqlDW instance + [InlineData(null, false)] // is not a SqlDw Instance + [InlineData("", false)] // is not a SqlDw Instance + [InlineData("myDb", false)] // is not SqlDW instance + public void AddOrUpdateIsSqlDw(string dbName, bool state) { // Set sqlDw result into cache bool isSqlDwResult; - CachedServerInfo.AddOrUpdateCache("testDataSource", state, CachedServerInfo.CacheVariable.IsSqlDw); + SqlConnectionStringBuilder testSource = new SqlConnectionStringBuilder + { + DataSource = "testDataSource" + }; + if (dbName != null) + { + testSource.InitialCatalog = dbName; + } + + cache.AddOrUpdateCache(testSource, state, CachedServerInfo.CacheVariable.IsSqlDw); // Expect the same returned result - Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResult)); + Assert.True(cache.TryGetIsSqlDw(testSource, out isSqlDwResult)); Assert.Equal(isSqlDwResult, state); } @@ -36,18 +89,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection { // Set sqlDw result into cache bool isSqlDwResult; - CachedServerInfo.AddOrUpdateCache("testDataSource", state, CachedServerInfo.CacheVariable.IsSqlDw); + SqlConnectionStringBuilder testSource = new SqlConnectionStringBuilder + { + DataSource = "testDataSource" + }; + cache.AddOrUpdateCache(testSource, state, CachedServerInfo.CacheVariable.IsSqlDw); // Expect the same returned result - Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResult)); + Assert.True(cache.TryGetIsSqlDw(testSource, out isSqlDwResult)); Assert.Equal(isSqlDwResult, state); // Toggle isSqlDw cache state bool isSqlDwResultToggle; - CachedServerInfo.AddOrUpdateCache("testDataSource", !state, CachedServerInfo.CacheVariable.IsSqlDw); + cache.AddOrUpdateCache(testSource, !state, CachedServerInfo.CacheVariable.IsSqlDw); // Expect the oppisite returned result - Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResultToggle)); + Assert.True(cache.TryGetIsSqlDw(testSource, out isSqlDwResultToggle)); Assert.Equal(isSqlDwResultToggle, !state); } @@ -56,19 +113,40 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection public void AddOrUpdateIsSqlDwFalseToggle() { bool state = true; + + SqlConnectionStringBuilder testSource = new SqlConnectionStringBuilder + { + DataSource = "testDataSource" + }; + + SqlConnectionStringBuilder sameServerDifferentDb = new SqlConnectionStringBuilder + { + DataSource = "testDataSource", + InitialCatalog = "myDb" + }; + SqlConnectionStringBuilder differentServerSameDb = new SqlConnectionStringBuilder + { + DataSource = "testDataSource2", + InitialCatalog = "" + }; + + cache.AddOrUpdateCache(testSource, state, CachedServerInfo.CacheVariable.IsSqlDw); + cache.AddOrUpdateCache(sameServerDifferentDb, !state, CachedServerInfo.CacheVariable.IsSqlDw); + cache.AddOrUpdateCache(differentServerSameDb, !state, CachedServerInfo.CacheVariable.IsSqlDw); + + // Expect the same returned result // Set sqlDw result into cache bool isSqlDwResult; bool isSqlDwResult2; - CachedServerInfo.AddOrUpdateCache("testDataSource", state, CachedServerInfo.CacheVariable.IsSqlDw); - CachedServerInfo.AddOrUpdateCache("testDataSource2", !state, CachedServerInfo.CacheVariable.IsSqlDw); - - // Expect the same returned result - Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResult)); - Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource2", out isSqlDwResult2)); + bool isSqlDwResult3; + Assert.True(cache.TryGetIsSqlDw(testSource, out isSqlDwResult)); + Assert.True(cache.TryGetIsSqlDw(sameServerDifferentDb, out isSqlDwResult2)); + Assert.True(cache.TryGetIsSqlDw(differentServerSameDb, out isSqlDwResult3)); // Assert cache is set on a per connection basis Assert.Equal(isSqlDwResult, state); Assert.Equal(isSqlDwResult2, !state); + Assert.Equal(isSqlDwResult3, !state); } @@ -76,7 +154,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection public void AskforSqlDwBeforeCached() { bool isSqlDwResult; - Assert.False(CachedServerInfo.TryGetIsSqlDw("testDataSourceWithNoCache", out isSqlDwResult)); + Assert.False(cache.TryGetIsSqlDw(new SqlConnectionStringBuilder + { + DataSource = "testDataSourceUnCached" + }, + out isSqlDwResult)); } } } \ No newline at end of file