diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs index 9e8334dc..82542721 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -173,5 +173,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ConnectionTypeToConnectionMap.TryRemove(type, out connection); } } + + /// + /// Updates the Auth Token and Expires On fields + /// + public void UpdateAuthToken(string token, int expiresOn) + { + ConnectionDetails.AzureAccountToken = token; + ConnectionDetails.ExpiresOn = expiresOn; + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 2768c1f8..48d0a3dc 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -33,6 +33,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public const string AdminConnectionPrefix = "ADMIN:"; internal const string PasswordPlaceholder = "******"; private const string SqlAzureEdition = "SQL Azure"; + public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens /// /// Singleton service instance @@ -59,6 +60,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap = new ConcurrentDictionary(); + /// + /// A map containing the uris of connections with expired tokens, these editors should have intellisense + /// disabled until the new refresh token is returned, upon which they will be removed from the map + /// + public readonly ConcurrentDictionary TokenUpdateUris = new ConcurrentDictionary(); private readonly object cancellationTokenSourceLock = new object(); private ConcurrentDictionary connectedQueues = new ConcurrentDictionary(); @@ -228,11 +234,84 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) => this.OwnerToConnectionMap.TryGetValue(ownerUri, out connectionInfo); /// - /// Validates the given ConnectParams object. + /// Refreshes the auth token of a given connection, if needed /// - /// The params to validate - /// A ConnectionCompleteParams object upon validation error, - /// null upon validation success + /// The URI of the connection + /// True if a refreshed was needed and requested, false otherwise + + internal async Task TryRequestRefreshAuthToken(string ownerUri) + { + ConnectionInfo connInfo; + if (this.TryFindConnection(ownerUri, out connInfo)) + { + // If not an azure connection, no need to refresh token + if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA") + { + return false; + } + else + { + // Check if token is expired or about to expire + if (connInfo.ConnectionDetails.ExpiresOn - DateTimeOffset.Now.ToUnixTimeSeconds() < MaxTolerance) + { + + var requestMessage = new RefreshTokenParams + { + AccountId = connInfo.ConnectionDetails.GetOptionValue("azureAccount", string.Empty), + TenantId = connInfo.ConnectionDetails.GetOptionValue("azureTenantId", string.Empty), + Provider = "Azure", + Resource = "SQL", + Uri = ownerUri + }; + if (string.IsNullOrEmpty(requestMessage.TenantId)) + { + Logger.Error("No tenant in connection details when refreshing token for connection {ownerUri}"); + return false; + } + if (string.IsNullOrEmpty(requestMessage.AccountId)) + { + Logger.Error("No accountId in connection details when refreshing token for connection {ownerUri}"); + return false; + } + // Check if the token is updating already, in which case there is no need to request a new one, + // but still return true so that autocompletion is disabled until the token is refreshed + if (!this.TokenUpdateUris.TryAdd(ownerUri, true)) + { + return true; + } + await this.ServiceHost.SendEvent(RefreshTokenNotification.Type, requestMessage); + return true; + } + else + { + return false; + } + } + } + else + { + Logger.Error("Failed to find connection when refreshing token"); + return false; + } + } + + /// + /// Requests an update of the azure auth token + /// + /// The token to update + /// true upon successful update, false if it failed to find + /// the connection + internal void UpdateAuthToken(TokenRefreshedParams tokenRefreshedParams) + { + if (!this.TryFindConnection(tokenRefreshedParams.Uri, out ConnectionInfo connection)) + { + Logger.Error($"Failed to find connection when updating refreshed token for URI {tokenRefreshedParams.Uri}"); + return; + } + this.TokenUpdateUris.Remove(tokenRefreshedParams.Uri, out var result); + connection.UpdateAuthToken(tokenRefreshedParams.Token, tokenRefreshedParams.ExpiresOn); + } + public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams) { string paramValidationErrorMessage; @@ -807,6 +886,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return false; } + // This clears the uri of the connection from the tokenUpdateUris map, which is used to track + // open editors that have requested a refreshed AAD token. + this.TokenUpdateUris.Remove(disconnectParams.OwnerUri, out bool result); + // Call Close() on the connections we want to disconnect // If no connections were located, return false if (!CloseConnections(info, disconnectParams.Type)) @@ -1318,7 +1401,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ConnectionInfo info; SqlConnectionStringBuilder connStringBuilder; try - { + { // set connection string using connection uri if connection details are undefined if (connStringParams.ConnectionDetails == null) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index d300d7cd..fac37cef 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -543,7 +543,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts SetOptionValue("databaseDisplayName", value); } } - + public string AzureAccountToken { get @@ -556,6 +556,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts } } + public int? ExpiresOn + { + get + { + return GetOptionValue("expiresOn"); + } + set + { + SetOptionValue("expiresOn", value); + } + } + public bool IsComparableTo(ConnectionDetails other) { if (other == null) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/RefreshTokenNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/RefreshTokenNotification.cs new file mode 100644 index 00000000..9cf02e81 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/RefreshTokenNotification.cs @@ -0,0 +1,73 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + class RefreshTokenParams + { + /// + /// ID of the tenant + /// + public string TenantId { get; set; } + + /// + /// Gets or sets the provider that indicates the type of linked account to query. + /// + public string Provider { get; set; } + + /// + /// Gets or sets the identifier of the target resource of the requested token. + /// + public string Resource { get; set; } + + /// + /// Gets or sets the account ID + /// + public string AccountId { get; set; } + + /// + /// Gets or sets the URI + /// + public string Uri { get; set; } + + } + + /// + /// Refresh token request mapping entry + /// + class RefreshTokenNotification + { + public static readonly + EventType Type = + EventType.Create("account/refreshToken"); + } + + class TokenRefreshedParams + { + /// + /// Gets or sets the refresh token. + /// + public string Token { get; set; } + + /// + /// Gets or sets the token expiration, a Unix epoch + /// + public int ExpiresOn { get; set; } + + /// + /// Connection URI + /// + public string Uri { get; set; } + } + + class TokenRefreshedNotification + { + public static readonly + EventType Type = + EventType.Create("account/tokenRefreshed"); + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs index 9b2dbc21..f7d42ed3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs @@ -267,6 +267,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices serviceHost.SetRequestHandler(CompletionExtLoadRequest.Type, HandleCompletionExtLoadRequest); serviceHost.SetEventHandler(RebuildIntelliSenseNotification.Type, HandleRebuildIntelliSenseNotification); serviceHost.SetEventHandler(LanguageFlavorChangeNotification.Type, HandleDidChangeLanguageFlavorNotification); + serviceHost.SetEventHandler(TokenRefreshedNotification.Type, HandleTokenRefreshedNotification); // Register a no-op shutdown task for validation of the shutdown logic serviceHost.RegisterShutdownTask(async (shutdownParams, shutdownRequestContext) => @@ -452,11 +453,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } else { - // get the current list of completion items and return to client - ConnectionServiceInstance.TryFindConnection( - scriptFile.ClientUri, - out ConnectionInfo connInfo); - + ConnectionInfo connInfo = null; + // Check if we need to refresh the auth token, and if we do then don't pass in the + // connection so that we only show the default options until the refreshed token is returned + if (!await connectionService.TryRequestRefreshAuthToken(scriptFile.ClientUri)) + { + ConnectionServiceInstance.TryFindConnection( + scriptFile.ClientUri, + out connInfo); + } var completionItems = await GetCompletionItems( textDocumentPosition, scriptFile, connInfo); @@ -719,6 +724,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices { try { + // This clears the uri of the connection from the tokenUpdateUris map, which is used to track + // open editors that have requested a refreshed AAD token. + connectionService.TokenUpdateUris.Remove(uri, out var result); // if not in the preview window and diagnostics are enabled then clear diagnostics if (!IsPreviewWindow(scriptFile) && CurrentWorkspaceSettings.IsDiagnosticsEnabled) @@ -906,6 +914,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices } } + internal Task HandleTokenRefreshedNotification( + TokenRefreshedParams tokenRefreshedParams, + EventContext eventContext + ) + { + connectionService.UpdateAuthToken(tokenRefreshedParams); + return Task.CompletedTask; + } + #endregion @@ -1061,7 +1078,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices Monitor.Exit(scriptInfo.BuildingMetadataLock); } } - PrepopulateCommonMetadata(info, scriptInfo, this.BindingQueue); // Send a notification to signal that autocomplete is ready