mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-31 01:25:42 -05:00
Refreshes token for intellisense (#1476)
* check if token refresh needed * add more checks * simplify logic * add summary and change to false * wip * wip * add ExpiresOn field to check when token needs to be refreshed * expired token check * wip * wip * wip * update expiresOn check * wip * wip * working refresh token * add closing tag * fix summary * pr comments * add max tolerance * refactoring * refactoring and updating descriptions * remove comment * pr updates * more pr comments * pr comments * wip * pr comments - add state tracker * update comment * fix type * pr comments * fix race condition * wip * pr comments * add comment * pr comments * nullable int * pr comments * remove uri from map upon disconnect * pr comments * remove uri from map upon editor close * pr comments
This commit is contained in:
@@ -173,5 +173,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
ConnectionTypeToConnectionMap.TryRemove(type, out connection);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Updates the Auth Token and Expires On fields
|
||||
/// </summary>
|
||||
public void UpdateAuthToken(string token, int expiresOn)
|
||||
{
|
||||
ConnectionDetails.AzureAccountToken = token;
|
||||
ConnectionDetails.ExpiresOn = expiresOn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
public const string AdminConnectionPrefix = "ADMIN:";
|
||||
internal const string PasswordPlaceholder = "******";
|
||||
private const string SqlAzureEdition = "SQL Azure";
|
||||
public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens
|
||||
|
||||
/// <summary>
|
||||
/// Singleton service instance
|
||||
@@ -59,6 +60,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
private readonly ConcurrentDictionary<CancelTokenKey, CancellationTokenSource> cancelTupleToCancellationTokenSourceMap =
|
||||
new ConcurrentDictionary<CancelTokenKey, CancellationTokenSource>();
|
||||
|
||||
/// <summary>
|
||||
/// A map containing the uris of connections with expired tokens, these editors should have intellisense
|
||||
/// disabled until the new refresh token is returned, upon which they will be removed from the map
|
||||
/// </summary>
|
||||
public readonly ConcurrentDictionary<string, Boolean> TokenUpdateUris = new ConcurrentDictionary<string, Boolean>();
|
||||
private readonly object cancellationTokenSourceLock = new object();
|
||||
|
||||
private ConcurrentDictionary<string, IConnectedBindingQueue> connectedQueues = new ConcurrentDictionary<string, IConnectedBindingQueue>();
|
||||
@@ -228,11 +234,84 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) => this.OwnerToConnectionMap.TryGetValue(ownerUri, out connectionInfo);
|
||||
|
||||
/// <summary>
|
||||
/// Validates the given ConnectParams object.
|
||||
/// Refreshes the auth token of a given connection, if needed
|
||||
/// </summary>
|
||||
/// <param name="connectionParams">The params to validate</param>
|
||||
/// <returns>A ConnectionCompleteParams object upon validation error,
|
||||
/// null upon validation success</returns>
|
||||
/// <param name="ownerUri">The URI of the connection</param>
|
||||
/// <returns> True if a refreshed was needed and requested, false otherwise </returns>
|
||||
|
||||
internal async Task<bool> TryRequestRefreshAuthToken(string ownerUri)
|
||||
{
|
||||
ConnectionInfo connInfo;
|
||||
if (this.TryFindConnection(ownerUri, out connInfo))
|
||||
{
|
||||
// If not an azure connection, no need to refresh token
|
||||
if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Check if token is expired or about to expire
|
||||
if (connInfo.ConnectionDetails.ExpiresOn - DateTimeOffset.Now.ToUnixTimeSeconds() < MaxTolerance)
|
||||
{
|
||||
|
||||
var requestMessage = new RefreshTokenParams
|
||||
{
|
||||
AccountId = connInfo.ConnectionDetails.GetOptionValue("azureAccount", string.Empty),
|
||||
TenantId = connInfo.ConnectionDetails.GetOptionValue("azureTenantId", string.Empty),
|
||||
Provider = "Azure",
|
||||
Resource = "SQL",
|
||||
Uri = ownerUri
|
||||
};
|
||||
if (string.IsNullOrEmpty(requestMessage.TenantId))
|
||||
{
|
||||
Logger.Error("No tenant in connection details when refreshing token for connection {ownerUri}");
|
||||
return false;
|
||||
}
|
||||
if (string.IsNullOrEmpty(requestMessage.AccountId))
|
||||
{
|
||||
Logger.Error("No accountId in connection details when refreshing token for connection {ownerUri}");
|
||||
return false;
|
||||
}
|
||||
// Check if the token is updating already, in which case there is no need to request a new one,
|
||||
// but still return true so that autocompletion is disabled until the token is refreshed
|
||||
if (!this.TokenUpdateUris.TryAdd(ownerUri, true))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
await this.ServiceHost.SendEvent(RefreshTokenNotification.Type, requestMessage);
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Logger.Error("Failed to find connection when refreshing token");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Requests an update of the azure auth token
|
||||
/// </summary>
|
||||
/// <param name="refreshToken">The token to update</param>
|
||||
/// <returns>true upon successful update, false if it failed to find
|
||||
/// the connection</returns>
|
||||
internal void UpdateAuthToken(TokenRefreshedParams tokenRefreshedParams)
|
||||
{
|
||||
if (!this.TryFindConnection(tokenRefreshedParams.Uri, out ConnectionInfo connection))
|
||||
{
|
||||
Logger.Error($"Failed to find connection when updating refreshed token for URI {tokenRefreshedParams.Uri}");
|
||||
return;
|
||||
}
|
||||
this.TokenUpdateUris.Remove(tokenRefreshedParams.Uri, out var result);
|
||||
connection.UpdateAuthToken(tokenRefreshedParams.Token, tokenRefreshedParams.ExpiresOn);
|
||||
}
|
||||
|
||||
public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
|
||||
{
|
||||
string paramValidationErrorMessage;
|
||||
@@ -807,6 +886,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
return false;
|
||||
}
|
||||
|
||||
// This clears the uri of the connection from the tokenUpdateUris map, which is used to track
|
||||
// open editors that have requested a refreshed AAD token.
|
||||
this.TokenUpdateUris.Remove(disconnectParams.OwnerUri, out bool result);
|
||||
|
||||
// Call Close() on the connections we want to disconnect
|
||||
// If no connections were located, return false
|
||||
if (!CloseConnections(info, disconnectParams.Type))
|
||||
@@ -1318,7 +1401,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
ConnectionInfo info;
|
||||
SqlConnectionStringBuilder connStringBuilder;
|
||||
try
|
||||
{
|
||||
{
|
||||
// set connection string using connection uri if connection details are undefined
|
||||
if (connStringParams.ConnectionDetails == null)
|
||||
{
|
||||
|
||||
@@ -543,7 +543,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
||||
SetOptionValue("databaseDisplayName", value);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public string AzureAccountToken
|
||||
{
|
||||
get
|
||||
@@ -556,6 +556,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
||||
}
|
||||
}
|
||||
|
||||
public int? ExpiresOn
|
||||
{
|
||||
get
|
||||
{
|
||||
return GetOptionValue<int?>("expiresOn");
|
||||
}
|
||||
set
|
||||
{
|
||||
SetOptionValue("expiresOn", value);
|
||||
}
|
||||
}
|
||||
|
||||
public bool IsComparableTo(ConnectionDetails other)
|
||||
{
|
||||
if (other == null)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
|
||||
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
||||
{
|
||||
class RefreshTokenParams
|
||||
{
|
||||
/// <summary>
|
||||
/// ID of the tenant
|
||||
/// </summary>
|
||||
public string TenantId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the provider that indicates the type of linked account to query.
|
||||
/// </summary>
|
||||
public string Provider { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the identifier of the target resource of the requested token.
|
||||
/// </summary>
|
||||
public string Resource { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the account ID
|
||||
/// </summary>
|
||||
public string AccountId { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the URI
|
||||
/// </summary>
|
||||
public string Uri { get; set; }
|
||||
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Refresh token request mapping entry
|
||||
/// </summary>
|
||||
class RefreshTokenNotification
|
||||
{
|
||||
public static readonly
|
||||
EventType<RefreshTokenParams> Type =
|
||||
EventType<RefreshTokenParams>.Create("account/refreshToken");
|
||||
}
|
||||
|
||||
class TokenRefreshedParams
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets or sets the refresh token.
|
||||
/// </summary>
|
||||
public string Token { get; set; }
|
||||
|
||||
/// <summmary>
|
||||
/// Gets or sets the token expiration, a Unix epoch
|
||||
/// </summary>
|
||||
public int ExpiresOn { get; set; }
|
||||
|
||||
/// <summmary>
|
||||
/// Connection URI
|
||||
/// </summary>
|
||||
public string Uri { get; set; }
|
||||
}
|
||||
|
||||
class TokenRefreshedNotification
|
||||
{
|
||||
public static readonly
|
||||
EventType<TokenRefreshedParams> Type =
|
||||
EventType<TokenRefreshedParams>.Create("account/tokenRefreshed");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user