From f2eb590d977d24f5287b27052108a0127d06554a Mon Sep 17 00:00:00 2001 From: Justin M <63619224+JustinMDotNet@users.noreply.github.com> Date: Tue, 6 Oct 2020 17:20:13 -0700 Subject: [PATCH] 3490 Kusto Connection Refresh Fix (#1085) * 3490 Injected OwnerUri into KustClient to store for token refreshing. Removed UpdateAzureToken from IDataSource, DataSourceBase, and KustoDataSource. Removed logic for retrying queries related to Unauthorized datasource in Batch and Query. Changed ScriptingService, ScriptingScriptOperation, and ScriptAsScriptingOperation to take DataSource in the constructor instead of datasourcefactory. Changed ScriptingService to inject ConnectionService through InitializeService function. * 3490 Removed Catch block for DataSourceUnauthorizedException in ExecuteControlCommandAsync * 3490 Removed OwnerUri from KustoClient and used azureAccountToken to refresh token in ConnectionService. * 3490 Reverted unneeded changes. * 3490 Split ExecuteQuery in KustoClient to execute first query then remaining queries after * 3490 Passed OwnerUri down into KustoClient to refresh token. * 3490 Removed DataSourceUnauthorizedException. Refactored ExecuteQuery to catch aggregate exception. Added RefreshAzureToken logic to ExecuteControlCommand * 3490 Added logic to update ReliableDataSourceConnection azure token within ConnectionInfo. * 3490 Add retry logic to ExecuteQuery and ExecuteControlCommand in KustoClient --- .../Connection/ConnectionInfo.cs | 5 ++ .../Connection/ConnectionService.cs | 57 ++++------------- .../Connection/DataSourceConnectionFactory.cs | 4 +- .../IDataSourceConnectionFactory.cs | 6 +- .../DataSource/DataSourceBase.cs | 2 - .../DataSource/DataSourceFactory.cs | 4 +- .../DataSourceUnauthorizedException.cs | 11 ---- .../DataSource/IDataSource.cs | 6 -- .../DataSource/IDataSourceFactory.cs | 2 +- .../DataSource/IKustoClient.cs | 7 +- .../DataSource/KustoClient.cs | 64 +++++++++++-------- .../DataSource/KustoDataSource.cs | 4 -- .../ReliableDataSourceConnection.cs | 16 +++-- .../HostLoader.cs | 3 +- .../LanguageServices/ConnectedBindingQueue.cs | 2 +- .../QueryExecution/Batch.cs | 12 +--- .../QueryExecution/Query.cs | 6 -- .../Scripting/ScriptAsScriptingOperation.cs | 32 +++------- .../Scripting/ScriptingScriptOperation.cs | 9 +-- .../Scripting/ScriptingService.cs | 48 ++++---------- .../Scripting/SmoScriptingOperation.cs | 20 ++---- .../Connection/ConnectionInfoTests.cs | 6 +- .../DataSourceConnectionFactoryTests.cs | 2 +- .../DataSource/DataSourceFactoryTests.cs | 2 +- .../ConnectedBindingQueueTests.cs | 2 +- 25 files changed, 120 insertions(+), 212 deletions(-) delete mode 100644 src/Microsoft.Kusto.ServiceLayer/DataSource/Exceptions/DataSourceUnauthorizedException.cs diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionInfo.cs index 78fe4154..b9e26081 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionInfo.cs @@ -161,6 +161,11 @@ namespace Microsoft.Kusto.ServiceLayer.Connection public void UpdateAzureToken(string token) { ConnectionDetails.AzureAccountToken = token; + + foreach (var connection in _connectionTypeToConnectionMap.Values) + { + connection.UpdateAzureToken(token); + } } } } diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs index 22877353..42044a69 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/ConnectionService.cs @@ -6,7 +6,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Globalization; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -268,23 +267,22 @@ namespace Microsoft.Kusto.ServiceLayer.Connection return completeParams; } - internal void RefreshAzureToken(string ownerUri) + internal string RefreshAzureToken(string ownerUri) { - ConnectionInfo existingConnection = OwnerToConnectionMap[ownerUri]; - + TryFindConnection(ownerUri, out ConnectionInfo connection); + var requestMessage = new RequestSecurityTokenParams { - AccountId = existingConnection.ConnectionDetails.GetOptionValue("azureAccount", string.Empty), - Authority = existingConnection.ConnectionDetails.GetOptionValue("azureTenantId", string.Empty), + AccountId = connection.ConnectionDetails.GetOptionValue("azureAccount", string.Empty), + Authority = connection.ConnectionDetails.GetOptionValue("azureTenantId", string.Empty), Provider = "Azure", Resource = "SQL" }; var response = Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, requestMessage, true).Result; - existingConnection.UpdateAzureToken(response.Token); - - existingConnection.TryGetConnection(ConnectionType.Query, out var reliableDataSourceConnection); - reliableDataSourceConnection.GetUnderlyingConnection().UpdateAzureToken(response.Token); + connection.UpdateAzureToken(response.Token); + + return response.Token; } private void TryCloseConnectionTemporaryConnection(ConnectParams connectionParams, ConnectionInfo connectionInfo) @@ -452,7 +450,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); // create a sql connection instance - connection = connectionInfo.Factory.CreateDataSourceConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken); + connection = connectionInfo.Factory.CreateDataSourceConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken, connectionInfo.OwnerUri); connectionInfo.AddConnection(connectionParams.Type, connection); // Add a cancellation token source so that the connection OpenAsync() can be cancelled @@ -786,9 +784,10 @@ namespace Microsoft.Kusto.ServiceLayer.Connection { throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner)); } - ConnectionDetails connectionDetails = info.ConnectionDetails.Clone(); - IDataSource dataSource = OpenDataSourceConnection(info); + info.TryGetConnection(ConnectionType.Default, out ReliableDataSourceConnection connection); + IDataSource dataSource = connection.GetUnderlyingConnection(); + DataSourceObjectMetadata objectMetadata = MetadataFactory.CreateClusterMetadata(info.ConnectionDetails.ServerName); ListDatabasesResponse response = new ListDatabasesResponse(); @@ -813,7 +812,6 @@ namespace Microsoft.Kusto.ServiceLayer.Connection ServiceHost = serviceHost; _dataSourceConnectionFactory = dataSourceConnectionFactory; _dataSourceFactory = dataSourceFactory; - connectedQueues.AddOrUpdate("Default", connectedBindingQueue, (key, old) => connectedBindingQueue); LockedDatabaseManager.ConnectionService = this; @@ -1262,7 +1260,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection string connectionString = BuildConnectionString(info.ConnectionDetails); // create a sql connection instance - ReliableDataSourceConnection connection = info.Factory.CreateDataSourceConnection(connectionString, info.ConnectionDetails.AzureAccountToken); + ReliableDataSourceConnection connection = info.Factory.CreateDataSourceConnection(connectionString, info.ConnectionDetails.AzureAccountToken, ownerUri); connection.Open(); info.AddConnection(key, connection); } @@ -1358,35 +1356,6 @@ namespace Microsoft.Kusto.ServiceLayer.Connection } } - /// - /// Create and open a new SqlConnection from a ConnectionInfo object - /// Note: we need to audit all uses of this method to determine why we're - /// bypassing normal ConnectionService connection management - /// - /// The connection info to connect with - /// A plaintext string that will be included in the application name for the connection - /// A SqlConnection created with the given connection info - private IDataSource OpenDataSourceConnection(ConnectionInfo connInfo, string featureName = null) - { - try - { - // generate connection string - string connectionString = BuildConnectionString(connInfo.ConnectionDetails); - - // TODOKusto: Pass in type of DataSource needed to make this generic. Hard coded to Kusto right now. - return _dataSourceFactory.Create(DataSourceType.Kusto, connectionString, connInfo.ConnectionDetails.AzureAccountToken); - } - catch (Exception ex) - { - string error = string.Format(CultureInfo.InvariantCulture, - "Failed opening a DataSource of type {0}: error:{1} inner:{2} stacktrace:{3}", - DataSourceType.Kusto, ex.Message, ex.InnerException != null ? ex.InnerException.Message : string.Empty, ex.StackTrace); - Logger.Write(TraceEventType.Error, error); - } - - return null; - } - public static void EnsureConnectionIsOpen(ReliableDataSourceConnection conn, bool forceReopen = false) { // verify that the connection is open diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/DataSourceConnectionFactory.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/DataSourceConnectionFactory.cs index 567a70ce..b18e7228 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/DataSourceConnectionFactory.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/DataSourceConnectionFactory.cs @@ -28,11 +28,11 @@ namespace Microsoft.Kusto.ServiceLayer.Connection /// /// Creates a new SqlConnection object /// - public ReliableDataSourceConnection CreateDataSourceConnection(string connectionString, string azureAccountToken) + public ReliableDataSourceConnection CreateDataSourceConnection(string connectionString, string azureAccountToken, string ownerUri) { RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); - return new ReliableDataSourceConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken, _dataSourceFactory); + return new ReliableDataSourceConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken, _dataSourceFactory, ownerUri); } } } diff --git a/src/Microsoft.Kusto.ServiceLayer/Connection/IDataSourceConnectionFactory.cs b/src/Microsoft.Kusto.ServiceLayer/Connection/IDataSourceConnectionFactory.cs index fbfd45b1..102779fa 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Connection/IDataSourceConnectionFactory.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Connection/IDataSourceConnectionFactory.cs @@ -3,9 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System.Data.Common; - -namespace Microsoft.Kusto.ServiceLayer.Connection + namespace Microsoft.Kusto.ServiceLayer.Connection { /// /// Interface for the SQL Connection factory @@ -15,6 +13,6 @@ namespace Microsoft.Kusto.ServiceLayer.Connection /// /// Create a new SQL Connection object /// - ReliableDataSourceConnection CreateDataSourceConnection(string connectionString, string azureAccountToken); + ReliableDataSourceConnection CreateDataSourceConnection(string connectionString, string azureAccountToken, string ownerUri); } } diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceBase.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceBase.cs index 6bd35bf7..adb9bcf6 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceBase.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceBase.cs @@ -89,8 +89,6 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource public abstract string GenerateExecuteFunctionScript(string functionName); - public abstract void UpdateAzureToken(string azureToken); - /// public DataSourceType DataSourceType { get; protected set; } diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs index 6ca18479..f6064b78 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/DataSourceFactory.cs @@ -13,7 +13,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource [Export(typeof(IDataSourceFactory))] public class DataSourceFactory : IDataSourceFactory { - public IDataSource Create(DataSourceType dataSourceType, string connectionString, string azureAccountToken) + public IDataSource Create(DataSourceType dataSourceType, string connectionString, string azureAccountToken, string ownerUri) { ValidationUtils.IsArgumentNotNullOrWhiteSpace(connectionString, nameof(connectionString)); ValidationUtils.IsArgumentNotNullOrWhiteSpace(azureAccountToken, nameof(azureAccountToken)); @@ -22,7 +22,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource { case DataSourceType.Kusto: { - var kustoClient = new KustoClient(connectionString, azureAccountToken); + var kustoClient = new KustoClient(connectionString, azureAccountToken, ownerUri); return new KustoDataSource(kustoClient); } diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/Exceptions/DataSourceUnauthorizedException.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/Exceptions/DataSourceUnauthorizedException.cs deleted file mode 100644 index 51ef3fa2..00000000 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/Exceptions/DataSourceUnauthorizedException.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; - -namespace Microsoft.Kusto.ServiceLayer.DataSource.Exceptions -{ - public class DataSourceUnauthorizedException : Exception - { - public DataSourceUnauthorizedException(Exception ex) : base (ex.Message, ex) - { - } - } -} \ No newline at end of file diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSource.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSource.cs index 1979f430..7429016f 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSource.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSource.cs @@ -111,11 +111,5 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource /// /// string GenerateExecuteFunctionScript(string functionName); - - /// - /// Updates Azure Token - /// - /// - void UpdateAzureToken(string azureToken); } } \ No newline at end of file diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSourceFactory.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSourceFactory.cs index 0b5b70a1..670239ac 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSourceFactory.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/IDataSourceFactory.cs @@ -2,6 +2,6 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource { public interface IDataSourceFactory { - IDataSource Create(DataSourceType dataSourceType, string connectionString, string azureAccountToken); + IDataSource Create(DataSourceType dataSourceType, string connectionString, string azureAccountToken, string ownerUri); } } \ No newline at end of file diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/IKustoClient.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/IKustoClient.cs index 861f6cf2..6f3fc22e 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/IKustoClient.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/IKustoClient.cs @@ -17,9 +17,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource string DatabaseName { get; } - void UpdateAzureToken(string azureAccountToken); - - IDataReader ExecuteQuery(string query, CancellationToken cancellationToken, string databaseName = null); + IDataReader ExecuteQuery(string query, CancellationToken cancellationToken, string databaseName = null, int retryCount = 1); /// /// Executes a query or command against a kusto cluster and returns a sequence of result row instances. @@ -37,7 +35,8 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource /// Executes a Kusto control command. /// /// The command. - void ExecuteControlCommand(string command); + /// + void ExecuteControlCommand(string command, int retryCount = 1); void UpdateDatabase(string databaseName); diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoClient.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoClient.cs index 9fa4b246..b72158fa 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoClient.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoClient.cs @@ -14,14 +14,16 @@ using Kusto.Data.Net.Client; using Kusto.Language; using Kusto.Language.Editor; using Microsoft.Data.SqlClient; +using Microsoft.Kusto.ServiceLayer.Connection; using Microsoft.Kusto.ServiceLayer.DataSource.DataSourceIntellisense; -using Microsoft.Kusto.ServiceLayer.DataSource.Exceptions; using Microsoft.Kusto.ServiceLayer.Utility; namespace Microsoft.Kusto.ServiceLayer.DataSource { public class KustoClient : IKustoClient { + private readonly string _ownerUri; + [DebuggerBrowsable(DebuggerBrowsableState.Never)] private ICslAdminProvider _kustoAdminProvider; @@ -36,8 +38,9 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource public string ClusterName { get; } public string DatabaseName { get; private set; } - public KustoClient(string connectionString, string azureAccountToken) + public KustoClient(string connectionString, string azureAccountToken, string ownerUri) { + _ownerUri = ownerUri; ClusterName = GetClusterName(connectionString); var databaseName = new SqlConnectionStringBuilder(connectionString).InitialCatalog; Initialize(ClusterName, databaseName, azureAccountToken); @@ -75,8 +78,9 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource _kustoAdminProvider = KustoClientFactory.CreateCslAdminProvider(stringBuilder); } - public void UpdateAzureToken(string azureAccountToken) + private void RefreshAzureToken() { + string azureAccountToken = ConnectionService.Instance.RefreshAzureToken(_ownerUri); _kustoQueryProvider.Dispose(); _kustoAdminProvider.Dispose(); Initialize(ClusterName, DatabaseName, azureAccountToken); @@ -162,7 +166,7 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource } } - public IDataReader ExecuteQuery(string query, CancellationToken cancellationToken, string databaseName = null) + public IDataReader ExecuteQuery(string query, CancellationToken cancellationToken, string databaseName = null, int retryCount = 1) { ValidationUtils.IsArgumentNotNullOrWhiteSpace(query, nameof(query)); @@ -175,29 +179,33 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource var script = CodeScript.From(query, GlobalState.Default); IDataReader[] origReaders = new IDataReader[script.Blocks.Count]; - - Parallel.ForEach(script.Blocks, (codeBlock, state, index) => + try { - var minimalQuery = codeBlock.Service.GetMinimalText(MinimalTextKind.RemoveLeadingWhitespaceAndComments); - - try + Parallel.ForEach(script.Blocks, (codeBlock, state, index) => { + var minimalQuery = + codeBlock.Service.GetMinimalText(MinimalTextKind.RemoveLeadingWhitespaceAndComments); IDataReader origReader = _kustoQueryProvider.ExecuteQuery( KustoQueryUtils.IsClusterLevelQuery(minimalQuery) ? "" : databaseName, minimalQuery, clientRequestProperties); - + origReaders[index] = origReader; - } - catch (KustoRequestException exception) when (exception.FailureCode == 401) // Unauthorized - { - throw new DataSourceUnauthorizedException(exception); - } - }); - - return new KustoResultsReader(origReaders); + }); + + return new KustoResultsReader(origReaders); + } + catch (AggregateException exception) + when (retryCount > 0 && + exception.InnerException is KustoRequestException innerException + && innerException.FailureCode == 401) // Unauthorized + { + RefreshAzureToken(); + retryCount--; + return ExecuteQuery(query, cancellationToken, databaseName, retryCount); + } } - + /// /// Executes a query or command against a kusto cluster and returns a sequence of result row instances. /// @@ -211,10 +219,6 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource var tableReader = results[WellKnownDataSet.PrimaryResult].Single().TableData.CreateDataReader(); return new ObjectReader(tableReader); } - catch (DataSourceUnauthorizedException) - { - throw; - } catch (Exception) when (!throwOnError) { return null; @@ -243,12 +247,22 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource /// Executes a Kusto control command. /// /// The command. - public void ExecuteControlCommand(string command) + /// + public void ExecuteControlCommand(string command, int retryCount = 1) { ValidationUtils.IsArgumentNotNullOrWhiteSpace(command, nameof(command)); - using (var adminOutput = _kustoAdminProvider.ExecuteControlCommand(command, null)) + try { + using (var adminOutput = _kustoAdminProvider.ExecuteControlCommand(command, null)) + { + } + } + catch (KustoRequestException exception) when (retryCount > 0 && exception.FailureCode == 401) // Unauthorized + { + RefreshAzureToken(); + retryCount--; + ExecuteControlCommand(command, retryCount); } } diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoDataSource.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoDataSource.cs index a958121a..6366b3f8 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoDataSource.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/KustoDataSource.cs @@ -796,10 +796,6 @@ namespace Microsoft.Kusto.ServiceLayer.DataSource return string.IsNullOrWhiteSpace(objectName) ? databaseName : $"{databaseName}.{objectName}"; } - public override void UpdateAzureToken(string azureToken) - { - _kustoClient.UpdateAzureToken(azureToken); - } #endregion } } diff --git a/src/Microsoft.Kusto.ServiceLayer/DataSource/ReliableDataSourceConnection.cs b/src/Microsoft.Kusto.ServiceLayer/DataSource/ReliableDataSourceConnection.cs index b0c30c0c..a59955b1 100644 --- a/src/Microsoft.Kusto.ServiceLayer/DataSource/ReliableDataSourceConnection.cs +++ b/src/Microsoft.Kusto.ServiceLayer/DataSource/ReliableDataSourceConnection.cs @@ -42,8 +42,9 @@ namespace Microsoft.Kusto.ServiceLayer.Connection private readonly Guid _azureSessionId = Guid.NewGuid(); private readonly string _connectionString; - private readonly string _azureAccountToken; + private string _azureAccountToken; private readonly IDataSourceFactory _dataSourceFactory; + private readonly string _ownerUri; /// /// Initializes a new instance of the ReliableKustoClient class with a given connection string @@ -55,13 +56,15 @@ namespace Microsoft.Kusto.ServiceLayer.Connection /// The retry policy defining whether to retry a request if a command fails to be executed. /// /// + /// public ReliableDataSourceConnection(string connectionString, RetryPolicy connectionRetryPolicy, - RetryPolicy commandRetryPolicy, string azureAccountToken, IDataSourceFactory dataSourceFactory) + RetryPolicy commandRetryPolicy, string azureAccountToken, IDataSourceFactory dataSourceFactory, string ownerUri) { _connectionString = connectionString; _azureAccountToken = azureAccountToken; _dataSourceFactory = dataSourceFactory; - _dataSource = dataSourceFactory.Create(DataSourceType.Kusto, connectionString, azureAccountToken); + _ownerUri = ownerUri; + _dataSource = dataSourceFactory.Create(DataSourceType.Kusto, connectionString, azureAccountToken, ownerUri); _connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy(); _commandRetryPolicy = commandRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy(); @@ -190,7 +193,7 @@ namespace Microsoft.Kusto.ServiceLayer.Connection { _connectionRetryPolicy.ExecuteAction(() => { - _dataSource = _dataSourceFactory.Create(DataSourceType.Kusto, _connectionString, _azureAccountToken); + _dataSource = _dataSourceFactory.Create(DataSourceType.Kusto, _connectionString, _azureAccountToken, _ownerUri); }); } } @@ -247,6 +250,11 @@ namespace Microsoft.Kusto.ServiceLayer.Connection { get { return _dataSource.DatabaseName; } } + + public void UpdateAzureToken(string token) + { + _azureAccountToken = token; + } } } diff --git a/src/Microsoft.Kusto.ServiceLayer/HostLoader.cs b/src/Microsoft.Kusto.ServiceLayer/HostLoader.cs index b9221b0c..165c8a16 100644 --- a/src/Microsoft.Kusto.ServiceLayer/HostLoader.cs +++ b/src/Microsoft.Kusto.ServiceLayer/HostLoader.cs @@ -12,7 +12,6 @@ using Microsoft.Kusto.ServiceLayer.Admin; using Microsoft.Kusto.ServiceLayer.Metadata; using Microsoft.Kusto.ServiceLayer.Connection; using Microsoft.Kusto.ServiceLayer.DataSource; -using Microsoft.Kusto.ServiceLayer.DataSource.Metadata; using Microsoft.Kusto.ServiceLayer.LanguageServices; using Microsoft.Kusto.ServiceLayer.QueryExecution; using Microsoft.Kusto.ServiceLayer.Scripting; @@ -91,7 +90,7 @@ namespace Microsoft.Kusto.ServiceLayer QueryExecutionService.Instance.InitializeService(serviceHost); serviceProvider.RegisterSingleService(QueryExecutionService.Instance); - ScriptingService.Instance.InitializeService(serviceHost, scripter, dataSourceFactory); + ScriptingService.Instance.InitializeService(serviceHost, scripter, ConnectionService.Instance); serviceProvider.RegisterSingleService(ScriptingService.Instance); AdminService.Instance.InitializeService(serviceHost, ConnectionService.Instance); diff --git a/src/Microsoft.Kusto.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs b/src/Microsoft.Kusto.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs index b0d0cfbc..5ea86e18 100644 --- a/src/Microsoft.Kusto.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs +++ b/src/Microsoft.Kusto.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs @@ -98,7 +98,7 @@ namespace Microsoft.Kusto.ServiceLayer.LanguageServices bindingContext.BindingLock.Reset(); string connectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); - bindingContext.DataSource = _dataSourceFactory.Create(DataSourceType.Kusto, connectionString, connInfo.ConnectionDetails.AzureAccountToken); + bindingContext.DataSource = _dataSourceFactory.Create(DataSourceType.Kusto, connectionString, connInfo.ConnectionDetails.AzureAccountToken, connInfo.OwnerUri); bindingContext.BindingTimeout = DefaultBindingTimeout; bindingContext.IsConnected = true; } diff --git a/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Batch.cs index f3557cf8..9950a28d 100644 --- a/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Batch.cs @@ -15,7 +15,6 @@ using Microsoft.Kusto.ServiceLayer.QueryExecution.Contracts; using Microsoft.Kusto.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.Utility; using System.Globalization; -using Microsoft.Kusto.ServiceLayer.DataSource.Exceptions; namespace Microsoft.Kusto.ServiceLayer.QueryExecution { @@ -66,8 +65,6 @@ namespace Microsoft.Kusto.ServiceLayer.QueryExecution /// private readonly bool getFullColumnSchema; - private int _retryCount; - #endregion internal Batch(string batchText, SelectionData selection, int ordinalId, @@ -88,7 +85,6 @@ namespace Microsoft.Kusto.ServiceLayer.QueryExecution this.outputFileFactory = outputFileFactory; specialAction = new SpecialAction(); BatchExecutionCount = executionCount > 0 ? executionCount : 1; - _retryCount = 1; this.getFullColumnSchema = getFullColumnSchema; } @@ -252,7 +248,7 @@ namespace Microsoft.Kusto.ServiceLayer.QueryExecution public async Task Execute(ReliableDataSourceConnection conn, CancellationToken cancellationToken) { // Sanity check to make sure we haven't already run this batch - if (HasExecuted && _retryCount < 0) + if (HasExecuted) { throw new InvalidOperationException("Batch has already executed."); } @@ -267,12 +263,6 @@ namespace Microsoft.Kusto.ServiceLayer.QueryExecution { await DoExecute(conn, cancellationToken); } - catch (DataSourceUnauthorizedException) - { - // Rerun the query once if unauthorized - _retryCount--; - throw; - } catch (TaskCanceledException) { // Cancellation isn't considered an error condition diff --git a/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Query.cs index aaa2eadc..7ef881d0 100644 --- a/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.Kusto.ServiceLayer/QueryExecution/Query.cs @@ -16,7 +16,6 @@ using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode; using System.Collections.Generic; using System.Diagnostics; -using Microsoft.Kusto.ServiceLayer.DataSource.Exceptions; using Microsoft.Kusto.ServiceLayer.Utility; namespace Microsoft.Kusto.ServiceLayer.QueryExecution @@ -391,11 +390,6 @@ namespace Microsoft.Kusto.ServiceLayer.QueryExecution await QueryCompleted(this); } } - catch (DataSourceUnauthorizedException) - { - ConnectionService.Instance.RefreshAzureToken(editorConnection.OwnerUri); - await ExecuteInternal(); - } catch (Exception e) { HasErrored = true; diff --git a/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs b/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs index 3b54e1a1..d05689ba 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs @@ -24,20 +24,15 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting { private readonly IScripter _scripter; private static readonly Dictionary scriptCompatibilityMap = LoadScriptCompatibilityMap(); + private string _serverName; + private string _databaseName; - public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken, IScripter scripter, IDataSourceFactory dataSourceFactory) : base(parameters, dataSourceFactory) + public ScriptAsScriptingOperation(ScriptingParams parameters, IScripter scripter, IDataSource datasource) : + base(parameters, datasource) { - DataSource = _dataSourceFactory.Create(DataSourceType.Kusto, this.Parameters.ConnectionString, - azureAccountToken); _scripter = scripter; } - internal IDataSource DataSource { get; set; } - - private string serverName; - private string databaseName; - private bool disconnectAtDispose = false; - public override void Execute() { try @@ -49,7 +44,7 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting this.CancellationToken.ThrowIfCancellationRequested(); string resultScript = string.Empty; - UrnCollection urns = CreateUrns(DataSource); + UrnCollection urns = CreateUrns(_dataSource); ScriptingOptions options = new ScriptingOptions(); SetScriptBehavior(options); ScriptAsOptions scriptAsOptions = new ScriptAsOptions(this.Parameters.ScriptOptions); @@ -65,12 +60,12 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting switch (this.Parameters.Operation) { case ScriptingOperationType.Select: - resultScript = GenerateScriptSelect(DataSource, urns); + resultScript = GenerateScriptSelect(_dataSource, urns); break; case ScriptingOperationType.Alter: case ScriptingOperationType.Execute: - resultScript = GenerateScriptForFunction(DataSource); + resultScript = GenerateScriptForFunction(_dataSource); break; } @@ -118,13 +113,6 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting }); } } - finally - { - if (disconnectAtDispose && DataSource != null) - { - DataSource.Dispose(); - } - } } private string GenerateScriptSelect(IDataSource dataSource, UrnCollection urns) @@ -168,8 +156,8 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting { IEnumerable selectedObjects = new List(this.Parameters.ScriptingObjects); - serverName = dataSource.ClusterName; - databaseName = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog; + _serverName = dataSource.ClusterName; + _databaseName = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog; UrnCollection urnCollection = new UrnCollection(); foreach (var scriptingObject in selectedObjects) { @@ -178,7 +166,7 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting // TODO: get the default schema scriptingObject.Schema = "dbo"; } - urnCollection.Add(scriptingObject.ToUrn(serverName, databaseName)); + urnCollection.Add(scriptingObject.ToUrn(_serverName, _databaseName)); } return urnCollection; } diff --git a/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingScriptOperation.cs b/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingScriptOperation.cs index 8b8b3aa3..78816351 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingScriptOperation.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingScriptOperation.cs @@ -20,18 +20,15 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting /// public sealed class ScriptingScriptOperation : SmoScriptingOperation { - private int scriptedObjectCount = 0; private int totalScriptedObjectCount = 0; private int eventSequenceNumber = 1; - private string azureAccessToken; - - public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken, IDataSourceFactory dataSourceFactory) : base(parameters, dataSourceFactory) + public ScriptingScriptOperation(ScriptingParams parameters, IDataSource dataSource) : base(parameters, dataSource) { - this.azureAccessToken = azureAccessToken; + } public override void Execute() @@ -204,7 +201,7 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting selectedObjects.Count(), string.Join(", ", selectedObjects))); - string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken); + string server = GetServerNameFromLiveInstance(); string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog; foreach (ScriptingObject scriptingObject in selectedObjects) diff --git a/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingService.cs b/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingService.cs index deffcc2e..2f73e4fc 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingService.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Scripting/ScriptingService.cs @@ -10,7 +10,6 @@ using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.Kusto.ServiceLayer.Connection; -using Microsoft.Kusto.ServiceLayer.DataSource; using Microsoft.Kusto.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.Utility; using Microsoft.Kusto.ServiceLayer.Utility; @@ -22,13 +21,11 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting /// public sealed class ScriptingService : IDisposable { - private const int ScriptingOperationTimeout = 60000; - private static readonly Lazy LazyInstance = new Lazy(() => new ScriptingService()); public static ScriptingService Instance => LazyInstance.Value; - private static ConnectionService connectionService; + private static ConnectionService _connectionService; private readonly Lazy> operations = new Lazy>(() => new ConcurrentDictionary()); @@ -36,26 +33,6 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting private bool disposed; private IScripter _scripter; - private IDataSourceFactory _dataSourceFactory; - - /// - /// Internal for testing purposes only - /// - internal static ConnectionService ConnectionServiceInstance - { - get - { - if (connectionService == null) - { - connectionService = ConnectionService.Instance; - } - return connectionService; - } - set - { - connectionService = value; - } - } /// /// The collection of active operations @@ -66,11 +43,13 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting /// Initializes the Scripting Service instance /// /// - /// - public void InitializeService(ServiceHost serviceHost, IScripter scripter, IDataSourceFactory dataSourceFactory) + /// + /// + public void InitializeService(ServiceHost serviceHost, IScripter scripter, ConnectionService connectionService) { _scripter = scripter; - _dataSourceFactory = dataSourceFactory; + _connectionService = connectionService; + serviceHost.SetRequestHandler(ScriptingRequest.Type, this.HandleScriptExecuteRequest); serviceHost.SetRequestHandler(ScriptingCancelRequest.Type, this.HandleScriptCancelRequest); serviceHost.SetRequestHandler(ScriptingListObjectsRequest.Type, this.HandleListObjectsRequest); @@ -108,22 +87,18 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting /// public async Task HandleScriptExecuteRequest(ScriptingParams parameters, RequestContext requestContext) { - SmoScriptingOperation operation = null; - try { // if a connection string wasn't provided as a parameter then // use the owner uri property to lookup its associated ConnectionInfo // and then build a connection string out of that - ConnectionInfo connInfo = null; - string accessToken = null; if (parameters.ConnectionString == null) { - ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo); + ConnectionInfo connInfo; + _connectionService.TryFindConnection(parameters.OwnerUri, out connInfo); if (connInfo != null) { parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); - accessToken = connInfo.ConnectionDetails.AzureAccountToken; } else { @@ -131,13 +106,16 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting } } + SmoScriptingOperation operation; + var datasource = _connectionService.GetOrOpenConnection(parameters.OwnerUri, ConnectionType.Default) + .Result.GetUnderlyingConnection(); if (!ShouldCreateScriptAsOperation(parameters)) { - operation = new ScriptingScriptOperation(parameters, accessToken, _dataSourceFactory); + operation = new ScriptingScriptOperation(parameters, datasource); } else { - operation = new ScriptAsScriptingOperation(parameters, accessToken, _scripter, _dataSourceFactory); + operation = new ScriptAsScriptingOperation(parameters, _scripter, datasource); } operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait(); diff --git a/src/Microsoft.Kusto.ServiceLayer/Scripting/SmoScriptingOperation.cs b/src/Microsoft.Kusto.ServiceLayer/Scripting/SmoScriptingOperation.cs index 33108281..bbaf99f8 100644 --- a/src/Microsoft.Kusto.ServiceLayer/Scripting/SmoScriptingOperation.cs +++ b/src/Microsoft.Kusto.ServiceLayer/Scripting/SmoScriptingOperation.cs @@ -20,14 +20,13 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting /// public abstract class SmoScriptingOperation : ScriptingOperation { - protected readonly IDataSourceFactory _dataSourceFactory; + protected readonly IDataSource _dataSource; private bool _disposed; - protected SmoScriptingOperation(ScriptingParams parameters, IDataSourceFactory dataSourceFactory) + protected SmoScriptingOperation(ScriptingParams parameters, IDataSource datasource) { - _dataSourceFactory = dataSourceFactory; + _dataSource = datasource; Validate.IsNotNull("parameters", parameters); - this.Parameters = parameters; } @@ -73,17 +72,10 @@ namespace Microsoft.Kusto.ServiceLayer.Scripting parameters.OperationId = this.OperationId; } - protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken) + protected string GetServerNameFromLiveInstance() { - string serverName = string.Empty; - - using(var dataSource = _dataSourceFactory.Create(DataSourceType.Kusto, connectionString, azureAccessToken)) - { - serverName = dataSource.ClusterName; - } - - Logger.Write(TraceEventType.Verbose, string.Format("Resolved server name '{0}'", serverName)); - return serverName; + Logger.Write(TraceEventType.Verbose, string.Format("Resolved server name '{0}'", _dataSource.ClusterName)); + return _dataSource.ClusterName; } protected void ValidateScriptDatabaseParams() diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionInfoTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionInfoTests.cs index a1405c49..259ff89e 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionInfoTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/ConnectionInfoTests.cs @@ -36,7 +36,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.Connection var dataSourceFactoryMock = new Mock(); var reliableDataSource = new ReliableDataSourceConnection("", RetryPolicyFactory.NoRetryPolicy, - RetryPolicyFactory.NoRetryPolicy, "", dataSourceFactoryMock.Object); + RetryPolicyFactory.NoRetryPolicy, "", dataSourceFactoryMock.Object, ""); connectionInfo.AddConnection("ConnectionType", reliableDataSource); connectionInfo.TryGetConnection("ConnectionType", out var connection); @@ -60,7 +60,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.Connection var dataSourceFactoryMock = new Mock(); var reliableDataSource = new ReliableDataSourceConnection("", RetryPolicyFactory.NoRetryPolicy, - RetryPolicyFactory.NoRetryPolicy, "", dataSourceFactoryMock.Object); + RetryPolicyFactory.NoRetryPolicy, "", dataSourceFactoryMock.Object, ""); connectionInfo.AddConnection("ConnectionType", reliableDataSource); connectionInfo.RemoveConnection("ConnectionType"); @@ -77,7 +77,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.Connection var dataSourceFactoryMock = new Mock(); var reliableDataSource = new ReliableDataSourceConnection("", RetryPolicyFactory.NoRetryPolicy, - RetryPolicyFactory.NoRetryPolicy, "", dataSourceFactoryMock.Object); + RetryPolicyFactory.NoRetryPolicy, "", dataSourceFactoryMock.Object, ""); connectionInfo.AddConnection("ConnectionType", reliableDataSource); connectionInfo.RemoveAllConnections(); diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/DataSourceConnectionFactoryTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/DataSourceConnectionFactoryTests.cs index 64ebab23..ec8db3e4 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/DataSourceConnectionFactoryTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/Connection/DataSourceConnectionFactoryTests.cs @@ -12,7 +12,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.Connection { var dataSourceFactoryMock = new Mock(); var connectionFactory = new DataSourceConnectionFactory(dataSourceFactoryMock.Object); - var connection = connectionFactory.CreateDataSourceConnection("", ""); + var connection = connectionFactory.CreateDataSourceConnection("", "", ""); Assert.IsNotNull(connection); } diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs index 91db92bb..a488b3de 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/DataSource/DataSourceFactoryTests.cs @@ -19,7 +19,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.DataSource { var dataSourceFactory = new DataSourceFactory(); Assert.Throws(exceptionType, - () => dataSourceFactory.Create(DataSourceType.None, connectionString, azureAccountToken)); + () => dataSourceFactory.Create(DataSourceType.None, connectionString, azureAccountToken, "")); } [Test] diff --git a/test/Microsoft.Kusto.ServiceLayer.UnitTests/LanguageServices/ConnectedBindingQueueTests.cs b/test/Microsoft.Kusto.ServiceLayer.UnitTests/LanguageServices/ConnectedBindingQueueTests.cs index c484699e..f32485b3 100644 --- a/test/Microsoft.Kusto.ServiceLayer.UnitTests/LanguageServices/ConnectedBindingQueueTests.cs +++ b/test/Microsoft.Kusto.ServiceLayer.UnitTests/LanguageServices/ConnectedBindingQueueTests.cs @@ -97,7 +97,7 @@ namespace Microsoft.Kusto.ServiceLayer.UnitTests.LanguageServices var dataSourceFactory = new Mock(); var dataSourceMock = new Mock(); dataSourceFactory - .Setup(x => x.Create(It.IsAny(), It.IsAny(), It.IsAny())) + .Setup(x => x.Create(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Returns(dataSourceMock.Object); var connectedBindingQueue =