Refreshes token for intellisense (#1476)

* check if token refresh needed

* add more checks

* simplify logic

* add summary and change to false

* wip

* wip

* add ExpiresOn field to check when token needs to be refreshed

* expired token check

* wip

* wip

* wip

* update expiresOn check

* wip

* wip

* working refresh token

* add closing tag

* fix summary

* pr comments

* add max tolerance

* refactoring

* refactoring and updating descriptions

* remove comment

* pr updates

* more pr comments

* pr comments

* wip

* pr comments - add state tracker

* update comment

* fix type

* pr comments

* fix race condition

* wip

* pr comments

* add comment

* pr comments

* nullable int

* pr comments

* remove uri from map upon disconnect

* pr comments

* remove uri from map upon editor close

* pr comments
This commit is contained in:
Christopher Suh
2022-05-13 11:47:37 -07:00
committed by GitHub
parent 01fe402adf
commit 106b6baeda
5 changed files with 205 additions and 12 deletions

View File

@@ -33,6 +33,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public const string AdminConnectionPrefix = "ADMIN:";
internal const string PasswordPlaceholder = "******";
private const string SqlAzureEdition = "SQL Azure";
public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens
/// <summary>
/// Singleton service instance
@@ -59,6 +60,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
private readonly ConcurrentDictionary<CancelTokenKey, CancellationTokenSource> cancelTupleToCancellationTokenSourceMap =
new ConcurrentDictionary<CancelTokenKey, CancellationTokenSource>();
/// <summary>
/// A map containing the uris of connections with expired tokens, these editors should have intellisense
/// disabled until the new refresh token is returned, upon which they will be removed from the map
/// </summary>
public readonly ConcurrentDictionary<string, Boolean> TokenUpdateUris = new ConcurrentDictionary<string, Boolean>();
private readonly object cancellationTokenSourceLock = new object();
private ConcurrentDictionary<string, IConnectedBindingQueue> connectedQueues = new ConcurrentDictionary<string, IConnectedBindingQueue>();
@@ -228,11 +234,84 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) => this.OwnerToConnectionMap.TryGetValue(ownerUri, out connectionInfo);
/// <summary>
/// Validates the given ConnectParams object.
/// Refreshes the auth token of a given connection, if needed
/// </summary>
/// <param name="connectionParams">The params to validate</param>
/// <returns>A ConnectionCompleteParams object upon validation error,
/// null upon validation success</returns>
/// <param name="ownerUri">The URI of the connection</param>
/// <returns> True if a refreshed was needed and requested, false otherwise </returns>
internal async Task<bool> TryRequestRefreshAuthToken(string ownerUri)
{
ConnectionInfo connInfo;
if (this.TryFindConnection(ownerUri, out connInfo))
{
// If not an azure connection, no need to refresh token
if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
{
return false;
}
else
{
// Check if token is expired or about to expire
if (connInfo.ConnectionDetails.ExpiresOn - DateTimeOffset.Now.ToUnixTimeSeconds() < MaxTolerance)
{
var requestMessage = new RefreshTokenParams
{
AccountId = connInfo.ConnectionDetails.GetOptionValue("azureAccount", string.Empty),
TenantId = connInfo.ConnectionDetails.GetOptionValue("azureTenantId", string.Empty),
Provider = "Azure",
Resource = "SQL",
Uri = ownerUri
};
if (string.IsNullOrEmpty(requestMessage.TenantId))
{
Logger.Error("No tenant in connection details when refreshing token for connection {ownerUri}");
return false;
}
if (string.IsNullOrEmpty(requestMessage.AccountId))
{
Logger.Error("No accountId in connection details when refreshing token for connection {ownerUri}");
return false;
}
// Check if the token is updating already, in which case there is no need to request a new one,
// but still return true so that autocompletion is disabled until the token is refreshed
if (!this.TokenUpdateUris.TryAdd(ownerUri, true))
{
return true;
}
await this.ServiceHost.SendEvent(RefreshTokenNotification.Type, requestMessage);
return true;
}
else
{
return false;
}
}
}
else
{
Logger.Error("Failed to find connection when refreshing token");
return false;
}
}
/// <summary>
/// Requests an update of the azure auth token
/// </summary>
/// <param name="refreshToken">The token to update</param>
/// <returns>true upon successful update, false if it failed to find
/// the connection</returns>
internal void UpdateAuthToken(TokenRefreshedParams tokenRefreshedParams)
{
if (!this.TryFindConnection(tokenRefreshedParams.Uri, out ConnectionInfo connection))
{
Logger.Error($"Failed to find connection when updating refreshed token for URI {tokenRefreshedParams.Uri}");
return;
}
this.TokenUpdateUris.Remove(tokenRefreshedParams.Uri, out var result);
connection.UpdateAuthToken(tokenRefreshedParams.Token, tokenRefreshedParams.ExpiresOn);
}
public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
{
string paramValidationErrorMessage;
@@ -807,6 +886,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return false;
}
// This clears the uri of the connection from the tokenUpdateUris map, which is used to track
// open editors that have requested a refreshed AAD token.
this.TokenUpdateUris.Remove(disconnectParams.OwnerUri, out bool result);
// Call Close() on the connections we want to disconnect
// If no connections were located, return false
if (!CloseConnections(info, disconnectParams.Type))
@@ -1318,7 +1401,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ConnectionInfo info;
SqlConnectionStringBuilder connStringBuilder;
try
{
{
// set connection string using connection uri if connection details are undefined
if (connStringParams.ConnectionDetails == null)
{