diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs index 7646262c..0610fff4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/CachedServerInfo.cs @@ -17,17 +17,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// internal static class CachedServerInfo { + public enum CacheVariable { + IsSqlDw, + IsAzure + } + private struct CachedInfo { public bool IsAzure; public DateTime LastUpdate; + public bool IsSqlDw; } private static ConcurrentDictionary _cache; private static object _cacheLock; private const int _maxCacheSize = 1024; private const int _deleteBatchSize = 512; - private const int MinimalQueryTimeoutSecondsForAzure = 300; static CachedServerInfo() @@ -70,19 +75,29 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public static void AddOrUpdateIsAzure(IDbConnection connection, bool isAzure) { - Validate.IsNotNull(nameof(connection), connection); - - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString); - AddOrUpdateIsAzure(builder.DataSource, isAzure); + AddOrUpdateCache(connection, isAzure, CacheVariable.IsAzure); } - public static void AddOrUpdateIsAzure(string dataSource, bool isAzure) + public static void AddOrUpdateIsSqlDw(IDbConnection connection, bool isSqlDw) + { + AddOrUpdateCache(connection, isSqlDw, CacheVariable.IsSqlDw); + } + + private static void AddOrUpdateCache(IDbConnection connection, bool newState, CacheVariable cacheVar) + { + Validate.IsNotNull(nameof(connection), connection); + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString); + AddOrUpdateCache(builder.DataSource, newState, cacheVar); + } + + internal static void AddOrUpdateCache(string dataSource, bool newState, CacheVariable cacheVar) { Validate.IsNotNullOrWhitespaceString(nameof(dataSource), dataSource); CachedInfo info; bool hasFound = _cache.TryGetValue(dataSource, out info); - if (hasFound && info.IsAzure == isAzure) + if ((cacheVar == CacheVariable.IsSqlDw && hasFound && info.IsSqlDw == newState) || + (cacheVar == CacheVariable.IsAzure && hasFound && info.IsAzure == newState)) { return; } @@ -108,13 +123,44 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection } } - info.IsAzure = isAzure; + if (cacheVar == CacheVariable.IsSqlDw) + { + info.IsSqlDw = newState; + } + else if (cacheVar == CacheVariable.IsAzure) + { + info.IsAzure = newState; + } info.LastUpdate = DateTime.UtcNow; _cache.AddOrUpdate(dataSource, info, (key, oldValue) => info); } } } + public static bool TryGetIsSqlDw(IDbConnection connection, out bool isSqlDw) + { + Validate.IsNotNull(nameof(connection), connection); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connection.ConnectionString); + return TryGetIsSqlDw(builder.DataSource, out isSqlDw); + } + + public static bool TryGetIsSqlDw(string dataSource, out bool isSqlDw) + { + Validate.IsNotNullOrWhitespaceString(nameof(dataSource), dataSource); + CachedInfo info; + bool hasFound = _cache.TryGetValue(dataSource, out info); + + if(hasFound) + { + isSqlDw = info.IsSqlDw; + return true; + } + + isSqlDw = false; + return false; + } + private static string SafeGetDataSourceFromConnection(IDbConnection connection) { if (connection == null) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs index 3ce8c9e8..2cad15d4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs @@ -12,6 +12,7 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Security; using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlServer.Management.Common; namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection { @@ -26,6 +27,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection private const string ApplicationIntent = "ApplicationIntent"; private const string MultiSubnetFailover = "MultiSubnetFailover"; private const string DacFxApplicationName = "DacFx"; + + private const int SqlDwEngineEditionId = (int)DatabaseEngineEdition.SqlDataWarehouse; // See MSDN documentation for "SERVERPROPERTY (SQL Azure Database)" for "EngineEdition" property: // http://msdn.microsoft.com/en-us/library/ee336261.aspx @@ -459,6 +462,56 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return engineEditionId == SqlAzureEngineEditionId; } + /// + /// Determines if the type of database that a connection is being made to is SQL data warehouse. + /// + /// + /// True if the database is a SQL data warehouse + public static bool IsSqlDwDatabase(IDbConnection connection) + { + Validate.IsNotNull(nameof(connection), connection); + + Func executeCommand = commandText => + { + bool result = false; + ExecuteReader(connection, + commandText, + readResult: (reader) => + { + reader.Read(); + int engineEditionId = int.Parse(reader[0].ToString(), CultureInfo.InvariantCulture); + + result = IsSqlDwEngineId(engineEditionId); + } + ); + return result; + }; + + bool isSqlDw = false; + try + { + isSqlDw = executeCommand(SqlConnectionHelperScripts.EngineEdition); + } + catch (SqlException) + { + // The default query contains a WITH (NOLOCK). This doesn't work for Azure DW, so when things don't work out, + // we'll fall back to a version without NOLOCK and try again. + isSqlDw = executeCommand(SqlConnectionHelperScripts.EngineEditionWithLock); + } + + return isSqlDw; + } + + /// + /// Compares the engine edition id of a given database with that of SQL data warehouse. + /// + /// + /// True if the engine edition id is that of SQL data warehouse + private static bool IsSqlDwEngineId(int engineEditionId) + { + return engineEditionId == SqlDwEngineEditionId; + } + /// /// Handles the exceptions typically thrown when a SQLConnection is being opened /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs index 579e86f6..8b66ade8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs @@ -47,6 +47,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection private readonly RetryPolicy _connectionRetryPolicy; private RetryPolicy _commandRetryPolicy; private Guid _azureSessionId; + private bool _isSqlDwDatabase; /// /// Initializes a new instance of the ReliableSqlConnection class with a given connection string @@ -96,6 +97,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection } } + /// + /// Determines if a connection is being made to a SQL DW database. + /// + /// A connection object. + private bool IsSqlDwConnection(IDbConnection conn) + { + //Set the connection only if it has not been set earlier. + //This is assuming that it is highly unlikely for a connection to change between instances. + //Hence any subsequent calls to this method will just return the cached value and not + //verify again if this is a SQL DW database connection or not. + if (!CachedServerInfo.TryGetIsSqlDw(conn, out _isSqlDwDatabase)) + { + _isSqlDwDatabase = ReliableConnectionHelper.IsSqlDwDatabase(conn); + CachedServerInfo.AddOrUpdateIsSqlDw(conn, _isSqlDwDatabase);; + } + + return _isSqlDwDatabase; + } + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Security", "CA2100:Review SQL queries for security vulnerabilities")] internal static void SetLockAndCommandTimeout(IDbConnection conn) { @@ -120,7 +140,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection } } - internal static void SetDefaultAnsiSettings(IDbConnection conn) + internal static void SetDefaultAnsiSettings(IDbConnection conn, bool isSqlDw) { Validate.IsNotNull(nameof(conn), conn); @@ -136,8 +156,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection using (IDbCommand cmd = conn.CreateCommand()) { cmd.CommandTimeout = CachedServerInfo.GetQueryTimeoutSeconds(conn); - cmd.CommandText = @"SET ANSI_NULLS, ANSI_PADDING, ANSI_WARNINGS, ARITHABORT, CONCAT_NULL_YIELDS_NULL, QUOTED_IDENTIFIER ON; + if (!isSqlDw) + { + cmd.CommandText = @"SET ANSI_NULLS, ANSI_PADDING, ANSI_WARNINGS, ARITHABORT, CONCAT_NULL_YIELDS_NULL, QUOTED_IDENTIFIER ON; SET NUMERIC_ROUNDABORT OFF;"; + } + else + { + cmd.CommandText = @"SET ANSI_NULLS ON; SET ANSI_PADDING ON; SET ANSI_WARNINGS ON; SET ARITHABORT ON; SET CONCAT_NULL_YIELDS_NULL ON; SET QUOTED_IDENTIFIER ON;"; //SQL DW does not support NUMERIC_ROUNDABORT + } cmd.ExecuteNonQuery(); } } @@ -343,7 +370,7 @@ SET NUMERIC_ROUNDABORT OFF;"; _underlyingConnection.Open(); } SetLockAndCommandTimeout(_underlyingConnection); - SetDefaultAnsiSettings(_underlyingConnection); + SetDefaultAnsiSettings(_underlyingConnection, IsSqlDwConnection(_underlyingConnection)); }); return _underlyingConnection; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs index 3fcbe225..796cfeb1 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/Resources.cs @@ -145,5 +145,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return "Unable to retrieve Azure session-id."; } } + + internal static string ServerInfoCacheMiss + { + get + { + return "Server Info does not have the requested property in the cache"; + } + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs new file mode 100644 index 00000000..74276634 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/CachedServerInfoTests.cs @@ -0,0 +1,82 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Xunit; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Connection +{ + /// + /// Tests for Sever Information Caching Class + /// + public class CachedServerInfoTests + { + + [Theory] + [InlineData(true)] // is SqlDW instance + [InlineData(false)] // is not a SqlDw Instance + public void AddOrUpdateIsSqlDw(bool state) + { + // Set sqlDw result into cache + bool isSqlDwResult; + CachedServerInfo.AddOrUpdateCache("testDataSource", state, CachedServerInfo.CacheVariable.IsSqlDw); + + // Expect the same returned result + Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResult)); + Assert.Equal(isSqlDwResult, state); + } + + [Theory] + [InlineData(true)] // is SqlDW instance + [InlineData(false)] // is not a SqlDw Instance + public void AddOrUpdateIsSqlDwFalseToggle(bool state) + { + // Set sqlDw result into cache + bool isSqlDwResult; + CachedServerInfo.AddOrUpdateCache("testDataSource", state, CachedServerInfo.CacheVariable.IsSqlDw); + + // Expect the same returned result + Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResult)); + Assert.Equal(isSqlDwResult, state); + + // Toggle isSqlDw cache state + bool isSqlDwResultToggle; + CachedServerInfo.AddOrUpdateCache("testDataSource", !state, CachedServerInfo.CacheVariable.IsSqlDw); + + // Expect the oppisite returned result + Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResultToggle)); + Assert.Equal(isSqlDwResultToggle, !state); + + } + + [Fact] + public void AddOrUpdateIsSqlDwFalseToggle() + { + bool state = true; + // Set sqlDw result into cache + bool isSqlDwResult; + bool isSqlDwResult2; + CachedServerInfo.AddOrUpdateCache("testDataSource", state, CachedServerInfo.CacheVariable.IsSqlDw); + CachedServerInfo.AddOrUpdateCache("testDataSource2", !state, CachedServerInfo.CacheVariable.IsSqlDw); + + // Expect the same returned result + Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource", out isSqlDwResult)); + Assert.True(CachedServerInfo.TryGetIsSqlDw("testDataSource2", out isSqlDwResult2)); + + // Assert cache is set on a per connection basis + Assert.Equal(isSqlDwResult, state); + Assert.Equal(isSqlDwResult2, !state); + + } + + [Fact] + public void AskforSqlDwBeforeCached() + { + bool isSqlDwResult; + Assert.False(CachedServerInfo.TryGetIsSqlDw("testDataSourceWithNoCache", out isSqlDwResult)); + } + } +} \ No newline at end of file