diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
index 9e8334dc..82542721 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
@@ -173,5 +173,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ConnectionTypeToConnectionMap.TryRemove(type, out connection);
}
}
+
+ ///
+ /// Updates the Auth Token and Expires On fields
+ ///
+ public void UpdateAuthToken(string token, int expiresOn)
+ {
+ ConnectionDetails.AzureAccountToken = token;
+ ConnectionDetails.ExpiresOn = expiresOn;
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index 2768c1f8..48d0a3dc 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -33,6 +33,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public const string AdminConnectionPrefix = "ADMIN:";
internal const string PasswordPlaceholder = "******";
private const string SqlAzureEdition = "SQL Azure";
+ public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens
///
/// Singleton service instance
@@ -59,6 +60,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap =
new ConcurrentDictionary();
+ ///
+ /// A map containing the uris of connections with expired tokens, these editors should have intellisense
+ /// disabled until the new refresh token is returned, upon which they will be removed from the map
+ ///
+ public readonly ConcurrentDictionary TokenUpdateUris = new ConcurrentDictionary();
private readonly object cancellationTokenSourceLock = new object();
private ConcurrentDictionary connectedQueues = new ConcurrentDictionary();
@@ -228,11 +234,84 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) => this.OwnerToConnectionMap.TryGetValue(ownerUri, out connectionInfo);
///
- /// Validates the given ConnectParams object.
+ /// Refreshes the auth token of a given connection, if needed
///
- /// The params to validate
- /// A ConnectionCompleteParams object upon validation error,
- /// null upon validation success
+ /// The URI of the connection
+ /// True if a refreshed was needed and requested, false otherwise
+
+ internal async Task TryRequestRefreshAuthToken(string ownerUri)
+ {
+ ConnectionInfo connInfo;
+ if (this.TryFindConnection(ownerUri, out connInfo))
+ {
+ // If not an azure connection, no need to refresh token
+ if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
+ {
+ return false;
+ }
+ else
+ {
+ // Check if token is expired or about to expire
+ if (connInfo.ConnectionDetails.ExpiresOn - DateTimeOffset.Now.ToUnixTimeSeconds() < MaxTolerance)
+ {
+
+ var requestMessage = new RefreshTokenParams
+ {
+ AccountId = connInfo.ConnectionDetails.GetOptionValue("azureAccount", string.Empty),
+ TenantId = connInfo.ConnectionDetails.GetOptionValue("azureTenantId", string.Empty),
+ Provider = "Azure",
+ Resource = "SQL",
+ Uri = ownerUri
+ };
+ if (string.IsNullOrEmpty(requestMessage.TenantId))
+ {
+ Logger.Error("No tenant in connection details when refreshing token for connection {ownerUri}");
+ return false;
+ }
+ if (string.IsNullOrEmpty(requestMessage.AccountId))
+ {
+ Logger.Error("No accountId in connection details when refreshing token for connection {ownerUri}");
+ return false;
+ }
+ // Check if the token is updating already, in which case there is no need to request a new one,
+ // but still return true so that autocompletion is disabled until the token is refreshed
+ if (!this.TokenUpdateUris.TryAdd(ownerUri, true))
+ {
+ return true;
+ }
+ await this.ServiceHost.SendEvent(RefreshTokenNotification.Type, requestMessage);
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+ else
+ {
+ Logger.Error("Failed to find connection when refreshing token");
+ return false;
+ }
+ }
+
+ ///
+ /// Requests an update of the azure auth token
+ ///
+ /// The token to update
+ /// true upon successful update, false if it failed to find
+ /// the connection
+ internal void UpdateAuthToken(TokenRefreshedParams tokenRefreshedParams)
+ {
+ if (!this.TryFindConnection(tokenRefreshedParams.Uri, out ConnectionInfo connection))
+ {
+ Logger.Error($"Failed to find connection when updating refreshed token for URI {tokenRefreshedParams.Uri}");
+ return;
+ }
+ this.TokenUpdateUris.Remove(tokenRefreshedParams.Uri, out var result);
+ connection.UpdateAuthToken(tokenRefreshedParams.Token, tokenRefreshedParams.ExpiresOn);
+ }
+
public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
{
string paramValidationErrorMessage;
@@ -807,6 +886,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return false;
}
+ // This clears the uri of the connection from the tokenUpdateUris map, which is used to track
+ // open editors that have requested a refreshed AAD token.
+ this.TokenUpdateUris.Remove(disconnectParams.OwnerUri, out bool result);
+
// Call Close() on the connections we want to disconnect
// If no connections were located, return false
if (!CloseConnections(info, disconnectParams.Type))
@@ -1318,7 +1401,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ConnectionInfo info;
SqlConnectionStringBuilder connStringBuilder;
try
- {
+ {
// set connection string using connection uri if connection details are undefined
if (connStringParams.ConnectionDetails == null)
{
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
index d300d7cd..fac37cef 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
@@ -543,7 +543,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
SetOptionValue("databaseDisplayName", value);
}
}
-
+
public string AzureAccountToken
{
get
@@ -556,6 +556,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
}
}
+ public int? ExpiresOn
+ {
+ get
+ {
+ return GetOptionValue("expiresOn");
+ }
+ set
+ {
+ SetOptionValue("expiresOn", value);
+ }
+ }
+
public bool IsComparableTo(ConnectionDetails other)
{
if (other == null)
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/RefreshTokenNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/RefreshTokenNotification.cs
new file mode 100644
index 00000000..9cf02e81
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/RefreshTokenNotification.cs
@@ -0,0 +1,73 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
+{
+ class RefreshTokenParams
+ {
+ ///
+ /// ID of the tenant
+ ///
+ public string TenantId { get; set; }
+
+ ///
+ /// Gets or sets the provider that indicates the type of linked account to query.
+ ///
+ public string Provider { get; set; }
+
+ ///
+ /// Gets or sets the identifier of the target resource of the requested token.
+ ///
+ public string Resource { get; set; }
+
+ ///
+ /// Gets or sets the account ID
+ ///
+ public string AccountId { get; set; }
+
+ ///
+ /// Gets or sets the URI
+ ///
+ public string Uri { get; set; }
+
+ }
+
+ ///
+ /// Refresh token request mapping entry
+ ///
+ class RefreshTokenNotification
+ {
+ public static readonly
+ EventType Type =
+ EventType.Create("account/refreshToken");
+ }
+
+ class TokenRefreshedParams
+ {
+ ///
+ /// Gets or sets the refresh token.
+ ///
+ public string Token { get; set; }
+
+ ///
+ /// Gets or sets the token expiration, a Unix epoch
+ ///
+ public int ExpiresOn { get; set; }
+
+ ///
+ /// Connection URI
+ ///
+ public string Uri { get; set; }
+ }
+
+ class TokenRefreshedNotification
+ {
+ public static readonly
+ EventType Type =
+ EventType.Create("account/tokenRefreshed");
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
index 9b2dbc21..f7d42ed3 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/LanguageService.cs
@@ -267,6 +267,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
serviceHost.SetRequestHandler(CompletionExtLoadRequest.Type, HandleCompletionExtLoadRequest);
serviceHost.SetEventHandler(RebuildIntelliSenseNotification.Type, HandleRebuildIntelliSenseNotification);
serviceHost.SetEventHandler(LanguageFlavorChangeNotification.Type, HandleDidChangeLanguageFlavorNotification);
+ serviceHost.SetEventHandler(TokenRefreshedNotification.Type, HandleTokenRefreshedNotification);
// Register a no-op shutdown task for validation of the shutdown logic
serviceHost.RegisterShutdownTask(async (shutdownParams, shutdownRequestContext) =>
@@ -452,11 +453,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
}
else
{
- // get the current list of completion items and return to client
- ConnectionServiceInstance.TryFindConnection(
- scriptFile.ClientUri,
- out ConnectionInfo connInfo);
-
+ ConnectionInfo connInfo = null;
+ // Check if we need to refresh the auth token, and if we do then don't pass in the
+ // connection so that we only show the default options until the refreshed token is returned
+ if (!await connectionService.TryRequestRefreshAuthToken(scriptFile.ClientUri))
+ {
+ ConnectionServiceInstance.TryFindConnection(
+ scriptFile.ClientUri,
+ out connInfo);
+ }
var completionItems = await GetCompletionItems(
textDocumentPosition, scriptFile, connInfo);
@@ -719,6 +724,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
{
try
{
+ // This clears the uri of the connection from the tokenUpdateUris map, which is used to track
+ // open editors that have requested a refreshed AAD token.
+ connectionService.TokenUpdateUris.Remove(uri, out var result);
// if not in the preview window and diagnostics are enabled then clear diagnostics
if (!IsPreviewWindow(scriptFile)
&& CurrentWorkspaceSettings.IsDiagnosticsEnabled)
@@ -906,6 +914,15 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
}
}
+ internal Task HandleTokenRefreshedNotification(
+ TokenRefreshedParams tokenRefreshedParams,
+ EventContext eventContext
+ )
+ {
+ connectionService.UpdateAuthToken(tokenRefreshedParams);
+ return Task.CompletedTask;
+ }
+
#endregion
@@ -1061,7 +1078,6 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
Monitor.Exit(scriptInfo.BuildingMetadataLock);
}
}
-
PrepopulateCommonMetadata(info, scriptInfo, this.BindingQueue);
// Send a notification to signal that autocomplete is ready