diff --git a/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs b/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs index 8a21ac37..f9a75d8c 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs @@ -5,12 +5,14 @@ */ using System; +using System.Collections.Concurrent; using System.Collections.Immutable; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; using OpenIddict.Abstractions; namespace OpenIddict.Core @@ -21,7 +23,8 @@ namespace OpenIddict.Core /// The type of the Application entity. public class OpenIddictApplicationCache : IOpenIddictApplicationCache, IDisposable where TApplication : class { - private readonly IMemoryCache _cache; + private readonly MemoryCache _cache; + private readonly ConcurrentDictionary> _signals; private readonly IOpenIddictApplicationStore _store; public OpenIddictApplicationCache( @@ -33,6 +36,7 @@ namespace OpenIddict.Core SizeLimit = options.CurrentValue.EntityCacheLimit }); + _signals = new ConcurrentDictionary>(StringComparer.Ordinal); _store = resolver.Get(); } @@ -51,14 +55,51 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(application)); } + _cache.Remove(new + { + Method = nameof(FindByClientIdAsync), + Identifier = await _store.GetClientIdAsync(application, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByIdAsync), + Identifier = await _store.GetIdAsync(application, cancellationToken) + }); + + foreach (var address in await _store.GetPostLogoutRedirectUrisAsync(application, cancellationToken)) + { + _cache.Remove(new + { + Method = nameof(FindByPostLogoutRedirectUriAsync), + Address = address + }); + } + + foreach (var address in await _store.GetRedirectUrisAsync(application, cancellationToken)) + { + _cache.Remove(new + { + Method = nameof(FindByRedirectUriAsync), + Address = address + }); + } + + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + using (var entry = _cache.CreateEntry(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(application, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(application); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(application); } using (var entry = _cache.CreateEntry(new @@ -67,15 +108,24 @@ namespace OpenIddict.Core Identifier = await _store.GetClientIdAsync(application, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(application); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(application); } } /// - /// Disposes the cache held by this instance. + /// Disposes the resources held by this instance. /// - public void Dispose() => _cache.Dispose(); + public void Dispose() + { + foreach (var signal in _signals) + { + signal.Value.Value.Dispose(); + } + + _cache.Dispose(); + } /// /// Retrieves an application using its client identifier. @@ -113,6 +163,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (application != null) + { + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(application); } @@ -159,6 +220,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (application != null) + { + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(application); } @@ -206,6 +278,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var application in applications) + { + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(applications.Length); entry.SetValue(applications); } @@ -253,6 +336,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var application in applications) + { + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(applications.Length); entry.SetValue(applications); } @@ -278,35 +372,53 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(application)); } - _cache.Remove(new + var identifier = await _store.GetIdAsync(application, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - Method = nameof(FindByClientIdAsync), - Identifier = await _store.GetClientIdAsync(application, cancellationToken) - }); + throw new InvalidOperationException("The application identifier cannot be extracted."); + } - _cache.Remove(new + if (_signals.TryGetValue(identifier, out Lazy signal)) { - Method = nameof(FindByIdAsync), - Identifier = await _store.GetIdAsync(application, cancellationToken) - }); + signal.Value.Cancel(); - foreach (var address in await _store.GetPostLogoutRedirectUrisAsync(application, cancellationToken)) + _signals.TryRemove(identifier, out signal); + } + } + + /// + /// Creates an expiration signal allowing to invalidate all the + /// cache entries associated with the specified application. + /// + /// The application associated with the expiration signal. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns an expiration signal for the specified application. + /// + protected virtual async Task CreateExpirationSignalAsync( + [NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) { - _cache.Remove(new - { - Method = nameof(FindByPostLogoutRedirectUriAsync), - Address = address - }); + throw new ArgumentNullException(nameof(application)); } - foreach (var address in await _store.GetRedirectUrisAsync(application, cancellationToken)) + var identifier = await _store.GetIdAsync(application, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - _cache.Remove(new - { - Method = nameof(FindByRedirectUriAsync), - Address = address - }); + throw new InvalidOperationException("The application identifier cannot be extracted."); } + + var signal = _signals.GetOrAdd(identifier, delegate + { + // Note: a Lazy is used here to ensure only one CancellationTokenSource + // can be created. Not doing so would result in expiration signals being potentially linked to + // multiple sources, with a single one of them being eventually tracked and thus, cancelable. + return new Lazy(() => new CancellationTokenSource()); + }); + + return new CancellationChangeToken(signal.Value.Token); } } } diff --git a/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs b/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs index f453e50a..53b673cd 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs @@ -5,12 +5,14 @@ */ using System; +using System.Collections.Concurrent; using System.Collections.Immutable; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; using OpenIddict.Abstractions; namespace OpenIddict.Core @@ -21,7 +23,8 @@ namespace OpenIddict.Core /// The type of the Authorization entity. public class OpenIddictAuthorizationCache : IOpenIddictAuthorizationCache, IDisposable where TAuthorization : class { - private readonly IMemoryCache _cache; + private readonly MemoryCache _cache; + private readonly ConcurrentDictionary> _signals; private readonly IOpenIddictAuthorizationStore _store; public OpenIddictAuthorizationCache( @@ -33,6 +36,7 @@ namespace OpenIddict.Core SizeLimit = options.CurrentValue.EntityCacheLimit }); + _signals = new ConcurrentDictionary>(StringComparer.Ordinal); _store = resolver.Get(); } @@ -51,21 +55,78 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(authorization)); } + _cache.Remove(new + { + Method = nameof(FindAsync), + Subject = await _store.GetSubjectAsync(authorization, cancellationToken), + Client = await _store.GetApplicationIdAsync(authorization, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindAsync), + Subject = await _store.GetSubjectAsync(authorization, cancellationToken), + Client = await _store.GetApplicationIdAsync(authorization, cancellationToken), + Status = await _store.GetStatusAsync(authorization, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindAsync), + Subject = await _store.GetSubjectAsync(authorization, cancellationToken), + Client = await _store.GetApplicationIdAsync(authorization, cancellationToken), + Status = await _store.GetStatusAsync(authorization, cancellationToken), + Type = await _store.GetTypeAsync(authorization, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByApplicationIdAsync), + Identifier = await _store.GetApplicationIdAsync(authorization, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByIdAsync), + Identifier = await _store.GetIdAsync(authorization, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindBySubjectAsync), + Subject = await _store.GetSubjectAsync(authorization, cancellationToken) + }); + + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + using (var entry = _cache.CreateEntry(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(authorization, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(authorization); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(authorization); } } /// - /// Disposes the cache held by this instance. + /// Disposes the resources held by this instance. /// - public void Dispose() => _cache.Dispose(); + public void Dispose() + { + foreach (var signal in _signals) + { + signal.Value.Value.Dispose(); + } + + _cache.Dispose(); + } /// /// Retrieves the authorizations corresponding to the specified @@ -112,6 +173,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var authorization in authorizations) + { + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(authorizations.Length); entry.SetValue(authorizations); } @@ -174,6 +246,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var authorization in authorizations) + { + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(authorizations.Length); entry.SetValue(authorizations); } @@ -243,6 +326,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var authorization in authorizations) + { + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(authorizations.Length); entry.SetValue(authorizations); } @@ -345,6 +439,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var authorization in authorizations) + { + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(authorizations.Length); entry.SetValue(authorizations); } @@ -391,6 +496,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (authorization != null) + { + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(authorization); } @@ -438,6 +554,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var authorization in authorizations) + { + var signal = await CreateExpirationTokenAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(authorizations.Length); entry.SetValue(authorizations); } @@ -463,47 +590,53 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(authorization)); } - _cache.Remove(new + var identifier = await _store.GetIdAsync(authorization, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - Method = nameof(FindAsync), - Subject = await _store.GetSubjectAsync(authorization, cancellationToken), - Client = await _store.GetApplicationIdAsync(authorization, cancellationToken) - }); + throw new InvalidOperationException("The application identifier cannot be extracted."); + } - _cache.Remove(new + if (_signals.TryGetValue(identifier, out Lazy signal)) { - Method = nameof(FindAsync), - Subject = await _store.GetSubjectAsync(authorization, cancellationToken), - Client = await _store.GetApplicationIdAsync(authorization, cancellationToken), - Status = await _store.GetStatusAsync(authorization, cancellationToken) - }); + signal.Value.Cancel(); - _cache.Remove(new - { - Method = nameof(FindAsync), - Subject = await _store.GetSubjectAsync(authorization, cancellationToken), - Client = await _store.GetApplicationIdAsync(authorization, cancellationToken), - Status = await _store.GetStatusAsync(authorization, cancellationToken), - Type = await _store.GetTypeAsync(authorization, cancellationToken) - }); + _signals.TryRemove(identifier, out signal); + } + } - _cache.Remove(new + /// + /// Creates an expiration signal allowing to invalidate all the + /// cache entries associated with the specified authorization. + /// + /// The authorization associated with the expiration signal. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns an expiration signal for the specified authorization. + /// + protected virtual async Task CreateExpirationTokenAsync( + [NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) { - Method = nameof(FindByApplicationIdAsync), - Identifier = await _store.GetApplicationIdAsync(authorization, cancellationToken) - }); + throw new ArgumentNullException(nameof(authorization)); + } - _cache.Remove(new + var identifier = await _store.GetIdAsync(authorization, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - Method = nameof(FindByIdAsync), - Identifier = await _store.GetIdAsync(authorization, cancellationToken) - }); + throw new InvalidOperationException("The authorization identifier cannot be extracted."); + } - _cache.Remove(new + var signal = _signals.GetOrAdd(identifier, delegate { - Method = nameof(FindBySubjectAsync), - Subject = await _store.GetSubjectAsync(authorization, cancellationToken) + // Note: a Lazy is used here to ensure only one CancellationTokenSource + // can be created. Not doing so would result in expiration signals being potentially linked to + // multiple sources, with a single one of them being eventually tracked and thus, cancelable. + return new Lazy(() => new CancellationTokenSource()); }); + + return new CancellationChangeToken(signal.Value.Token); } } } diff --git a/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs b/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs index bec4c35e..f30dfdfb 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs @@ -5,6 +5,7 @@ */ using System; +using System.Collections.Concurrent; using System.Collections.Immutable; using System.Linq; using System.Threading; @@ -12,6 +13,7 @@ using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; using OpenIddict.Abstractions; namespace OpenIddict.Core @@ -22,7 +24,8 @@ namespace OpenIddict.Core /// The type of the Scope entity. public class OpenIddictScopeCache : IOpenIddictScopeCache, IDisposable where TScope : class { - private readonly IMemoryCache _cache; + private readonly MemoryCache _cache; + private readonly ConcurrentDictionary> _signals; private readonly IOpenIddictScopeStore _store; public OpenIddictScopeCache( @@ -34,6 +37,7 @@ namespace OpenIddict.Core SizeLimit = options.CurrentValue.EntityCacheLimit }); + _signals = new ConcurrentDictionary>(StringComparer.Ordinal); _store = resolver.Get(); } @@ -52,14 +56,42 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(scope)); } + _cache.Remove(new + { + Method = nameof(FindByIdAsync), + Identifier = await _store.GetIdAsync(scope, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByNameAsync), + Name = await _store.GetNameAsync(scope, cancellationToken) + }); + + foreach (var resource in await _store.GetResourcesAsync(scope, cancellationToken)) + { + _cache.Remove(new + { + Method = nameof(FindByResourceAsync), + Resource = resource + }); + } + + var signal = await CreateExpirationSignalAsync(scope, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration token."); + } + using (var entry = _cache.CreateEntry(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(scope, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(scope); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(scope); } using (var entry = _cache.CreateEntry(new @@ -68,15 +100,24 @@ namespace OpenIddict.Core Name = await _store.GetNameAsync(scope, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(scope); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(scope); } } /// - /// Disposes the cache held by this instance. + /// Disposes the resources held by this instance. /// - public void Dispose() => _cache.Dispose(); + public void Dispose() + { + foreach (var signal in _signals) + { + signal.Value.Value.Dispose(); + } + + _cache.Dispose(); + } /// /// Retrieves a scope using its unique identifier. @@ -114,6 +155,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (scope != null) + { + var signal = await CreateExpirationSignalAsync(scope, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(scope); } @@ -160,6 +212,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (scope != null) + { + var signal = await CreateExpirationSignalAsync(scope, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(scope); } @@ -244,6 +307,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var scope in scopes) + { + var signal = await CreateExpirationSignalAsync(scope, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(scopes.Length); entry.SetValue(scopes); } @@ -269,26 +343,52 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(scope)); } - _cache.Remove(new + var identifier = await _store.GetIdAsync(scope, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - Method = nameof(FindByIdAsync), - Identifier = await _store.GetIdAsync(scope, cancellationToken) - }); + throw new InvalidOperationException("The application identifier cannot be extracted."); + } - _cache.Remove(new + if (_signals.TryGetValue(identifier, out Lazy signal)) { - Method = nameof(FindByNameAsync), - Name = await _store.GetNameAsync(scope, cancellationToken) - }); + signal.Value.Cancel(); - foreach (var resource in await _store.GetResourcesAsync(scope, cancellationToken)) + _signals.TryRemove(identifier, out signal); + } + } + + /// + /// Creates an expiration signal allowing to invalidate all the + /// cache entries associated with the specified scope. + /// + /// The scope associated with the expiration signal. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns an expiration signal for the specified scope. + /// + protected virtual async Task CreateExpirationSignalAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) { - _cache.Remove(new - { - Method = nameof(FindByResourceAsync), - Resource = resource - }); + throw new ArgumentNullException(nameof(scope)); } + + var identifier = await _store.GetIdAsync(scope, cancellationToken); + if (string.IsNullOrEmpty(identifier)) + { + throw new InvalidOperationException("The scope identifier cannot be extracted."); + } + + var signal = _signals.GetOrAdd(identifier, delegate + { + // Note: a Lazy is used here to ensure only one CancellationTokenSource + // can be created. Not doing so would result in expiration signals being potentially linked to + // multiple sources, with a single one of them being eventually tracked and thus, cancelable. + return new Lazy(() => new CancellationTokenSource()); + }); + + return new CancellationChangeToken(signal.Value.Token); } } } diff --git a/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs b/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs index 121fb5b3..d2e8d451 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs @@ -5,12 +5,14 @@ */ using System; +using System.Collections.Concurrent; using System.Collections.Immutable; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; using OpenIddict.Abstractions; namespace OpenIddict.Core @@ -21,7 +23,8 @@ namespace OpenIddict.Core /// The type of the Token entity. public class OpenIddictTokenCache : IOpenIddictTokenCache, IDisposable where TToken : class { - private readonly IMemoryCache _cache; + private readonly MemoryCache _cache; + private readonly ConcurrentDictionary> _signals; private readonly IOpenIddictTokenStore _store; public OpenIddictTokenCache( @@ -33,6 +36,7 @@ namespace OpenIddict.Core SizeLimit = options.CurrentValue.EntityCacheLimit }); + _signals = new ConcurrentDictionary>(StringComparer.Ordinal); _store = resolver.Get(); } @@ -51,14 +55,75 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(token)); } + _cache.Remove(new + { + Method = nameof(FindAsync), + Subject = await _store.GetSubjectAsync(token, cancellationToken), + Client = await _store.GetApplicationIdAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindAsync), + Subject = await _store.GetSubjectAsync(token, cancellationToken), + Client = await _store.GetApplicationIdAsync(token, cancellationToken), + Status = await _store.GetStatusAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindAsync), + Subject = await _store.GetSubjectAsync(token, cancellationToken), + Client = await _store.GetApplicationIdAsync(token, cancellationToken), + Status = await _store.GetStatusAsync(token, cancellationToken), + Type = await _store.GetTypeAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByApplicationIdAsync), + Identifier = await _store.GetApplicationIdAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByAuthorizationIdAsync), + Identifier = await _store.GetAuthorizationIdAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByIdAsync), + Identifier = await _store.GetIdAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindByReferenceIdAsync), + Identifier = await _store.GetReferenceIdAsync(token, cancellationToken) + }); + + _cache.Remove(new + { + Method = nameof(FindBySubjectAsync), + Subject = await _store.GetSubjectAsync(token, cancellationToken) + }); + + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + using (var entry = _cache.CreateEntry(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(token, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(token); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(token); } using (var entry = _cache.CreateEntry(new @@ -67,15 +132,24 @@ namespace OpenIddict.Core Identifier = await _store.GetReferenceIdAsync(token, cancellationToken) })) { - entry.SetSize(1L); - entry.SetValue(token); + entry.AddExpirationToken(signal) + .SetSize(1L) + .SetValue(token); } } /// - /// Disposes the cache held by this instance. + /// Disposes the resources held by this instance. /// - public void Dispose() => _cache.Dispose(); + public void Dispose() + { + foreach (var signal in _signals) + { + signal.Value.Value.Dispose(); + } + + _cache.Dispose(); + } /// /// Retrieves the tokens corresponding to the specified @@ -122,6 +196,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(tokens.Length); entry.SetValue(tokens); } @@ -184,6 +269,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(tokens.Length); entry.SetValue(tokens); } @@ -253,6 +349,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(tokens.Length); entry.SetValue(tokens); } @@ -300,6 +407,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(tokens.Length); entry.SetValue(tokens); } @@ -347,6 +465,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(tokens.Length); entry.SetValue(tokens); } @@ -393,6 +522,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (token != null) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(token); } @@ -440,6 +580,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + if (token != null) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(1L); entry.SetValue(token); } @@ -486,6 +637,17 @@ namespace OpenIddict.Core using (var entry = _cache.CreateEntry(parameters)) { + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + entry.SetSize(tokens.Length); entry.SetValue(tokens); } @@ -511,59 +673,52 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(token)); } - _cache.Remove(new - { - Method = nameof(FindAsync), - Subject = await _store.GetSubjectAsync(token, cancellationToken), - Client = await _store.GetApplicationIdAsync(token, cancellationToken) - }); - - _cache.Remove(new + var identifier = await _store.GetIdAsync(token, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - Method = nameof(FindAsync), - Subject = await _store.GetSubjectAsync(token, cancellationToken), - Client = await _store.GetApplicationIdAsync(token, cancellationToken), - Status = await _store.GetStatusAsync(token, cancellationToken) - }); + throw new InvalidOperationException("The application identifier cannot be extracted."); + } - _cache.Remove(new + if (_signals.TryGetValue(identifier, out Lazy signal)) { - Method = nameof(FindAsync), - Subject = await _store.GetSubjectAsync(token, cancellationToken), - Client = await _store.GetApplicationIdAsync(token, cancellationToken), - Status = await _store.GetStatusAsync(token, cancellationToken), - Type = await _store.GetTypeAsync(token, cancellationToken) - }); + signal.Value.Cancel(); - _cache.Remove(new - { - Method = nameof(FindByApplicationIdAsync), - Identifier = await _store.GetApplicationIdAsync(token, cancellationToken) - }); + _signals.TryRemove(identifier, out signal); + } + } - _cache.Remove(new + /// + /// Creates an expiration signal allowing to invalidate all the + /// cache entries associated with the specified token. + /// + /// The token associated with the expiration signal. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns an expiration signal for the specified token. + /// + protected virtual async Task CreateExpirationSignalAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) { - Method = nameof(FindByAuthorizationIdAsync), - Identifier = await _store.GetAuthorizationIdAsync(token, cancellationToken) - }); + throw new ArgumentNullException(nameof(token)); + } - _cache.Remove(new + var identifier = await _store.GetIdAsync(token, cancellationToken); + if (string.IsNullOrEmpty(identifier)) { - Method = nameof(FindByIdAsync), - Identifier = await _store.GetIdAsync(token, cancellationToken) - }); + throw new InvalidOperationException("The token identifier cannot be extracted."); + } - _cache.Remove(new + var signal = _signals.GetOrAdd(identifier, delegate { - Method = nameof(FindByReferenceIdAsync), - Identifier = await _store.GetReferenceIdAsync(token, cancellationToken) + // Note: a Lazy is used here to ensure only one CancellationTokenSource + // can be created. Not doing so would result in expiration signals being potentially linked to + // multiple sources, with a single one of them being eventually tracked and thus, cancelable. + return new Lazy(() => new CancellationTokenSource()); }); - _cache.Remove(new - { - Method = nameof(FindBySubjectAsync), - Subject = await _store.GetSubjectAsync(token, cancellationToken) - }); + return new CancellationChangeToken(signal.Value.Token); } } } diff --git a/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs b/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs index 48dd73e9..9437635a 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs @@ -157,6 +157,11 @@ namespace OpenIddict.Core } await Store.CreateAsync(application, cancellationToken); + + if (!Options.CurrentValue.DisableEntityCaching) + { + await Cache.AddAsync(application, cancellationToken); + } } /// @@ -899,12 +904,13 @@ namespace OpenIddict.Core throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } + await Store.UpdateAsync(application, cancellationToken); + if (!Options.CurrentValue.DisableEntityCaching) { await Cache.RemoveAsync(application, cancellationToken); + await Cache.AddAsync(application, cancellationToken); } - - await Store.UpdateAsync(application, cancellationToken); } /// diff --git a/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs b/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs index beca6591..2ec888ee 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs @@ -126,6 +126,11 @@ namespace OpenIddict.Core } await Store.CreateAsync(authorization, cancellationToken); + + if (!Options.CurrentValue.DisableEntityCaching) + { + await Cache.AddAsync(authorization, cancellationToken); + } } /// @@ -1115,12 +1120,13 @@ namespace OpenIddict.Core throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } + await Store.UpdateAsync(authorization, cancellationToken); + if (!Options.CurrentValue.DisableEntityCaching) { await Cache.RemoveAsync(authorization, cancellationToken); + await Cache.AddAsync(authorization, cancellationToken); } - - await Store.UpdateAsync(authorization, cancellationToken); } /// diff --git a/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs b/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs index 4ddf0772..cdf115dd 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs @@ -120,6 +120,11 @@ namespace OpenIddict.Core } await Store.CreateAsync(scope, cancellationToken); + + if (!Options.CurrentValue.DisableEntityCaching) + { + await Cache.AddAsync(scope, cancellationToken); + } } /// @@ -679,12 +684,13 @@ namespace OpenIddict.Core throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } + await Store.UpdateAsync(scope, cancellationToken); + if (!Options.CurrentValue.DisableEntityCaching) { await Cache.RemoveAsync(scope, cancellationToken); + await Cache.AddAsync(scope, cancellationToken); } - - await Store.UpdateAsync(scope, cancellationToken); } /// diff --git a/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs b/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs index 96669932..4d016b72 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs @@ -134,6 +134,11 @@ namespace OpenIddict.Core } await Store.CreateAsync(token, cancellationToken); + + if (!Options.CurrentValue.DisableEntityCaching) + { + await Cache.AddAsync(token, cancellationToken); + } } /// @@ -1180,12 +1185,13 @@ namespace OpenIddict.Core throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } + await Store.UpdateAsync(token, cancellationToken); + if (!Options.CurrentValue.DisableEntityCaching) { await Cache.RemoveAsync(token, cancellationToken); + await Cache.AddAsync(token, cancellationToken); } - - await Store.UpdateAsync(token, cancellationToken); } ///