Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos
{
using System;
using System.Globalization;
using System.Net;
using System.Threading.Tasks;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Core.Trace;
Expand All @@ -18,7 +19,7 @@ internal sealed class AuthorizationTokenProviderTokenCredential : AuthorizationT
private const string InferenceTokenPrefix = "Bearer ";
internal readonly TokenCredentialCache tokenCredentialCache;
private bool isDisposed = false;

internal readonly TokenCredential tokenCredential;

public AuthorizationTokenProviderTokenCredential(
Expand Down Expand Up @@ -116,5 +117,78 @@ public override void Dispose()
this.tokenCredentialCache.Dispose();
}
}

/// <summary>
/// Attempts to handle AAD token revocation by checking for claims challenge.
/// Extracts claims from WWW-Authenticate header and resets cache for retry with fresh token.
/// </summary>
/// <param name="statusCode">HTTP status code from the response</param>
/// <param name="headers">Response headers containing WWW-Authenticate</param>
/// <returns>True if claims challenge detected and request should be retried; false otherwise</returns>
internal bool TryHandleTokenRevocation(
HttpStatusCode statusCode,
INameValueCollection headers)
{
if (statusCode != HttpStatusCode.Unauthorized || headers == null)
{
return false;
}

string wwwAuth = headers[HttpConstants.HttpHeaders.WwwAuthenticate];
if (string.IsNullOrEmpty(wwwAuth))
{
return false;
}

// Check for claims challenge indicators
bool hasClaimsChallenge = wwwAuth.IndexOf("insufficient_claims", StringComparison.OrdinalIgnoreCase) >= 0
|| wwwAuth.IndexOf("claims=", StringComparison.OrdinalIgnoreCase) >= 0;

if (!hasClaimsChallenge)
{
return false;
}

string claimsChallenge = AuthorizationTokenProviderTokenCredential.ExtractClaimsFromWwwAuthenticate(wwwAuth);

// Reset cache with claims challenge for next token request
this.tokenCredentialCache.ResetCachedToken(claimsChallenge);

DefaultTrace.TraceInformation(
"AAD token revocation detected (claims challenge present). Token cache reset. " +
"Request will be retried with fresh token including claims. HasClaims={0}",
claimsChallenge != null);

return true;
}

/// <summary>
/// Extracts the claims challenge from the WWW-Authenticate header value.
/// </summary>
/// <param name="wwwAuthenticateHeader">WWW-Authenticate header value</param>
/// <returns>Base64-encoded claims string, or null if not present</returns>
private static string ExtractClaimsFromWwwAuthenticate(string wwwAuthenticateHeader)
{
if (string.IsNullOrEmpty(wwwAuthenticateHeader))
{
return null;
}

const string claimsPrefix = "claims=\"";
int claimsIndex = wwwAuthenticateHeader.IndexOf(claimsPrefix, StringComparison.OrdinalIgnoreCase);
if (claimsIndex < 0)
{
return null;
}

int startIndex = claimsIndex + claimsPrefix.Length;
int endIndex = wwwAuthenticateHeader.IndexOf("\"", startIndex, StringComparison.Ordinal);
if (endIndex < 0)
{
return null;
}

return wwwAuthenticateHeader.Substring(startIndex, endIndex - startIndex);
}
}
}
221 changes: 176 additions & 45 deletions Microsoft.Azure.Cosmos/src/Authorization/TokenCredentialCache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace Microsoft.Azure.Cosmos
using System.Threading;
using System.Threading.Tasks;
using global::Azure;
using global::Azure.Core;
using global::Azure.Core;
using Microsoft.Azure.Cosmos.Authorization;
using Microsoft.Azure.Cosmos.Core.Trace;
using Microsoft.Azure.Cosmos.Resource.CosmosExceptions;
Expand All @@ -36,8 +36,8 @@ internal sealed class TokenCredentialCache : IDisposable
// The token refresh retries half the time. Given default of 1hr it will retry at 30m, 15, 7.5, 3.75, 1.875
// If the background refresh fails with less than a minute then just allow the request to hit the exception.
public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1);
private readonly IScopeProvider scopeProvider;

private readonly IScopeProvider scopeProvider;
private readonly TokenCredential tokenCredential;
private readonly CancellationTokenSource cancellationTokenSource;
private readonly CancellationToken cancellationToken;
Expand All @@ -50,7 +50,8 @@ internal sealed class TokenCredentialCache : IDisposable
private Task<AccessToken>? currentRefreshOperation = null;
private AccessToken? cachedAccessToken = null;
private bool isBackgroundTaskRunning = false;
private bool isDisposed = false;
private bool isDisposed = false;
private string? cachedClaimsChallenge = null;

internal TokenCredentialCache(
TokenCredential tokenCredential,
Expand All @@ -62,10 +63,10 @@ internal TokenCredentialCache(
if (accountEndpoint == null)
{
throw new ArgumentNullException(nameof(accountEndpoint));
}
this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint);
}

this.scopeProvider = new Microsoft.Azure.Cosmos.Authorization.CosmosScopeProvider(accountEndpoint);

if (backgroundTokenCredentialRefreshInterval.HasValue)
{
if (backgroundTokenCredentialRefreshInterval.Value <= TimeSpan.Zero)
Expand Down Expand Up @@ -121,10 +122,34 @@ public void Dispose()
}

this.cancellationTokenSource.Cancel();
this.cancellationTokenSource.Dispose();
this.cancellationTokenSource.Dispose();
this.isDisposed = true;
}

/// <summary>
/// Resets the cached token and stores claims challenge for AAD token revocation.
/// The stored claims will be merged with client capabilities (cp1) in the next token request.
/// </summary>
/// <param name="claimsChallenge">Optional claims challenge (base64-encoded) from WWW-Authenticate header to merge with client capabilities</param>
internal void ResetCachedToken(string? claimsChallenge = null)
{
if (this.isDisposed)
{
return;
}

lock (this.backgroundRefreshLock)
{
this.cachedAccessToken = null;
this.currentRefreshOperation = null;
this.isBackgroundTaskRunning = false;
this.cachedClaimsChallenge = claimsChallenge;
}

DefaultTrace.TraceInformation(
$"TokenCredentialCache: Token cache reset due to AAD revocation signal. HasClaims={claimsChallenge != null}");
}

private async Task<AccessToken> GetNewTokenAsync(
ITrace trace)
{
Expand Down Expand Up @@ -161,13 +186,91 @@ private async Task<AccessToken> GetNewTokenAsync(
return await currentTask;
}

/// <summary>
/// Merges claims with client capabilities for token requests.
/// For Token Revocation: Returns cp1 + claims challenge
/// For Normal requests: Returns only cp1
/// </summary>
/// <param name="claimsChallenge">The base64-encoded claims challenge from WWW-Authenticate header (null for normal requests)</param>
/// <returns>JSON string with client capabilities and optional claims (NOT base64-encoded)</returns>
internal static string MergeClaimsWithClientCapabilities(string? claimsChallenge)
{
const string clientCapabilitiesJson = "{\"access_token\":{\"xms_cc\":{\"values\":[\"cp1\"]}}}";

// Return only cp1 capability
if (string.IsNullOrEmpty(claimsChallenge))
{
return clientCapabilitiesJson;
}

// Token Revocation: Merge claims challenge with cp1
try
{
byte[] claimsBytes = Convert.FromBase64String(claimsChallenge);
string claimsJson = System.Text.Encoding.UTF8.GetString(claimsBytes);

int accessTokenIndex = claimsJson.IndexOf("\"access_token\"", StringComparison.Ordinal);
if (accessTokenIndex < 0)
{
DefaultTrace.TraceWarning("TokenCredentialCache: CAE claims challenge missing 'access_token' key, using client capabilities only");
return clientCapabilitiesJson;
}

int openBraceIndex = claimsJson.IndexOf('{', accessTokenIndex);
if (openBraceIndex < 0)
{
DefaultTrace.TraceWarning("TokenCredentialCache: Malformed CAE claims challenge, using client capabilities only");
return clientCapabilitiesJson;
}

// Find the matching closing brace
int braceCount = 1;
int currentIndex = openBraceIndex + 1;
int closeBraceIndex = -1;

while (currentIndex < claimsJson.Length && braceCount > 0)
{
if (claimsJson[currentIndex] == '{')
{
braceCount++;
}
else if (claimsJson[currentIndex] == '}')
{
braceCount--;
if (braceCount == 0)
{
closeBraceIndex = currentIndex;
break;
}
}
currentIndex++;
}

if (closeBraceIndex < 0)
{
return clientCapabilitiesJson;
}

string mergedJson = claimsJson.Substring(0, closeBraceIndex) +
",\"xms_cc\":{\"values\":[\"cp1\"]}" +
claimsJson.Substring(closeBraceIndex);

return mergedJson;
}
catch (Exception ex)
{
DefaultTrace.TraceWarning($"TokenCredentialCache: Failed to merge claims challenge: {ex.Message}. Using client capabilities only.");
return clientCapabilitiesJson;
}
}

private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
ITrace trace)
{
Exception? lastException = null;
const int totalRetryCount = 2;
TokenRequestContext tokenRequestContext = default;
{
Exception? lastException = null;
const int totalRetryCount = 2;
TokenRequestContext tokenRequestContext = default;

try
{
for (int retry = 0; retry < totalRetryCount; retry++)
Expand All @@ -184,10 +287,30 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
name: nameof(this.RefreshCachedTokenWithRetryHelperAsync),
component: TraceComponent.Authorization,
level: Tracing.TraceLevel.Info))
{
{
try
{
tokenRequestContext = this.scopeProvider.GetTokenRequestContext();
{
tokenRequestContext = this.scopeProvider.GetTokenRequestContext();

string mergedClaims = MergeClaimsWithClientCapabilities(this.cachedClaimsChallenge);

if (string.IsNullOrEmpty(this.cachedClaimsChallenge))
{
DefaultTrace.TraceInformation(
$"Requesting AAD token with CAE client capabilities (cp1). Retry={retry}");
}
else
{
DefaultTrace.TraceInformation(
$"Requesting AAD token for revocation with claims challenge and client capabilities (cp1). Retry={retry}");
}

tokenRequestContext = new TokenRequestContext(
scopes: tokenRequestContext.Scopes,
parentRequestId: tokenRequestContext.ParentRequestId,
claims: mergedClaims,
tenantId: tokenRequestContext.TenantId,
isCaeEnabled: tokenRequestContext.IsCaeEnabled);

this.cachedAccessToken = await this.tokenCredential.GetTokenAsync(
requestContext: tokenRequestContext,
Expand All @@ -203,6 +326,9 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
throw new ArgumentOutOfRangeException($"TokenCredential.GetTokenAsync returned a token that is already expired. Current Time:{DateTime.UtcNow:O}; Token expire time:{this.cachedAccessToken.Value.ExpiresOn:O}");
}

// Clear claims challenge after successful token acquisition
this.cachedClaimsChallenge = null;

if (!this.userDefinedBackgroundTokenCredentialRefreshInterval.HasValue)
{
double refreshIntervalInSeconds = (this.cachedAccessToken.Value.ExpiresOn - DateTimeOffset.UtcNow).TotalSeconds * DefaultBackgroundTokenCredentialRefreshIntervalPercentage;
Expand All @@ -220,10 +346,10 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
lastException = operationCancelled;
getTokenTrace.AddDatum(
$"OperationCanceledException at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
operationCancelled.Message);
DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");
operationCancelled.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

throw CosmosExceptionFactory.CreateRequestTimeoutException(
message: ClientResources.FailedToGetAadToken,
Expand All @@ -234,29 +360,34 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
innerException: lastException,
trace: getTokenTrace);
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError($"TokenCredential.GetToken() failed with RequestFailedException. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}");

// Don't retry on auth failures
if (exception is RequestFailedException requestFailedException &&
(requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden))
{
this.cachedAccessToken = default;
throw;
}
bool didFallback = this.scopeProvider.TryFallback(exception);

if (didFallback)
{
DefaultTrace.TraceInformation($"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}");
}
catch (Exception exception)
{
lastException = exception;
getTokenTrace.AddDatum(
$"Exception at {DateTime.UtcNow.ToString(CultureInfo.InvariantCulture)}",
exception.Message);

DefaultTrace.TraceError(
$"TokenCredential.GetTokenAsync() failed. " +
$"scope = {string.Join(";", tokenRequestContext.Scopes)}, " +
$"hasClaimsChallenge = {this.cachedClaimsChallenge != null}, " +
$"retry = {retry}, " +
$"Exception = {lastException.Message}");
// Don't retry on auth failures
if (exception is RequestFailedException requestFailedException &&
(requestFailedException.Status == (int)HttpStatusCode.Unauthorized ||
requestFailedException.Status == (int)HttpStatusCode.Forbidden))
{
this.cachedAccessToken = default;
this.cachedClaimsChallenge = null;
throw;
}
bool didFallback = this.scopeProvider.TryFallback(exception);

if (didFallback)
{
DefaultTrace.TraceInformation($"TokenCredential.GetTokenAsync() failed. scope = {string.Join(";", tokenRequestContext.Scopes)}, retry = {retry}, Exception = {lastException.Message}. Fallback attempted: {didFallback}");
}
}
}
}
Expand All @@ -265,7 +396,7 @@ private async ValueTask<AccessToken> RefreshCachedTokenWithRetryHelperAsync(
{
throw new ArgumentException("Last exception is null.");
}

this.cachedClaimsChallenge = null;
// The retries have been exhausted. Throw the last exception.
throw lastException;
}
Expand Down
Loading
Loading