diff --git a/samples/Mvc.Server/Helpers/AsyncEnumerableExtensions.cs b/samples/Mvc.Server/Helpers/AsyncEnumerableExtensions.cs new file mode 100644 index 00000000..f4c85f06 --- /dev/null +++ b/samples/Mvc.Server/Helpers/AsyncEnumerableExtensions.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Mvc.Server.Helpers +{ + public static class AsyncEnumerableExtensions + { + public static Task> ToListAsync(this IAsyncEnumerable source) + { + if (source == null) + { + throw new ArgumentNullException(nameof(source)); + } + + return ExecuteAsync(); + + async Task> ExecuteAsync() + { + var list = new List(); + + await foreach (var element in source) + { + list.Add(element); + } + + return list; + } + } + } +} diff --git a/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs b/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs index f633e493..88c17bc2 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs @@ -8,7 +8,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; @@ -85,33 +85,17 @@ namespace OpenIddict.Core }); } - var signal = await CreateExpirationSignalAsync(application, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(application, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(application); - } + }, application, cancellationToken); - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByClientIdAsync), Identifier = await _store.GetClientIdAsync(application, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(application); - } + }, application, cancellationToken); } /// @@ -154,6 +138,8 @@ namespace OpenIddict.Core return new ValueTask(application); } + return new ValueTask(ExecuteAsync()); + async Task ExecuteAsync() { if ((application = await _store.FindByClientIdAsync(identifier, cancellationToken)) != null) @@ -161,27 +147,10 @@ namespace OpenIddict.Core await AddAsync(application, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (application != null) - { - var signal = await CreateExpirationSignalAsync(application, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(application); - } + await CreateEntryAsync(parameters, application, cancellationToken); return application; } - - return new ValueTask(ExecuteAsync()); } /// @@ -211,6 +180,8 @@ namespace OpenIddict.Core return new ValueTask(application); } + return new ValueTask(ExecuteAsync()); + async Task ExecuteAsync() { if ((application = await _store.FindByIdAsync(identifier, cancellationToken)) != null) @@ -218,27 +189,10 @@ namespace OpenIddict.Core await AddAsync(application, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (application != null) - { - var signal = await CreateExpirationSignalAsync(application, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(application); - } + await CreateEntryAsync(parameters, application, cancellationToken); return application; } - - return new ValueTask(ExecuteAsync()); } /// @@ -255,42 +209,30 @@ namespace OpenIddict.Core throw new ArgumentException("The address cannot be null or empty.", nameof(address)); } - var parameters = new - { - Method = nameof(FindByPostLogoutRedirectUriAsync), - Address = address - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray applications)) - { - return applications.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var applications = ImmutableArray.CreateRange(await _store.FindByPostLogoutRedirectUriAsync( - address, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var application in applications) + var parameters = new { - await AddAsync(application, cancellationToken); - } + Method = nameof(FindByPostLogoutRedirectUriAsync), + Address = address + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray applications)) { - foreach (var application in applications) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var application in _store.FindByPostLogoutRedirectUriAsync(address, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(application, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(application); - entry.AddExpirationToken(signal); + await AddAsync(application, cancellationToken); } - entry.SetSize(applications.Length); - entry.SetValue(applications); + applications = builder.ToImmutable(); + + await CreateEntryAsync(parameters, applications, cancellationToken); } foreach (var application in applications) @@ -298,8 +240,6 @@ namespace OpenIddict.Core yield return application; } } - - return ExecuteAsync(); } /// @@ -316,42 +256,30 @@ namespace OpenIddict.Core throw new ArgumentException("The address cannot be null or empty.", nameof(address)); } - var parameters = new - { - Method = nameof(FindByRedirectUriAsync), - Address = address - }; + return ExecuteAsync(cancellationToken); - if (_cache.TryGetValue(parameters, out ImmutableArray applications)) + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - return applications.ToAsyncEnumerable(); - } - - async IAsyncEnumerable ExecuteAsync() - { - var applications = ImmutableArray.CreateRange(await _store.FindByRedirectUriAsync( - address, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var application in applications) + var parameters = new { - await AddAsync(application, cancellationToken); - } + Method = nameof(FindByRedirectUriAsync), + Address = address + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray applications)) { - foreach (var application in applications) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var application in _store.FindByRedirectUriAsync(address, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(application, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(application); - entry.AddExpirationToken(signal); + await AddAsync(application, cancellationToken); } - entry.SetSize(applications.Length); - entry.SetValue(applications); + applications = builder.ToImmutable(); + + await CreateEntryAsync(parameters, applications, cancellationToken); } foreach (var application in applications) @@ -359,8 +287,6 @@ namespace OpenIddict.Core yield return application; } } - - return ExecuteAsync(); } /// @@ -389,6 +315,70 @@ namespace OpenIddict.Core } } + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The application to store in the cache entry, if applicable. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] TApplication application, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + if (application != null) + { + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(1L); + entry.SetValue(application); + } + + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The applications to store in the cache entry. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] ImmutableArray applications, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + foreach (var application in applications) + { + var signal = await CreateExpirationSignalAsync(application, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(applications.Length); + entry.SetValue(applications); + } + /// /// Creates an expiration signal allowing to invalidate all the /// cache entries associated with the specified application. diff --git a/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs b/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs index 5268f22a..2deb8cf5 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs @@ -8,7 +8,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; @@ -99,22 +99,11 @@ namespace OpenIddict.Core Subject = await _store.GetSubjectAsync(authorization, cancellationToken) }); - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(authorization, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(authorization); - } + }, authorization, cancellationToken); } /// @@ -151,43 +140,31 @@ namespace OpenIddict.Core throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); } - var parameters = new - { - Method = nameof(FindAsync), - Subject = subject, - Client = client - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray authorizations)) - { - return authorizations.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var authorizations = ImmutableArray.CreateRange(await _store.FindAsync( - subject, client, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var authorization in authorizations) + var parameters = new { - await AddAsync(authorization, cancellationToken); - } + Method = nameof(FindAsync), + Subject = subject, + Client = client + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray authorizations)) { - foreach (var authorization in authorizations) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var authorization in _store.FindAsync(subject, client, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(authorization); - entry.AddExpirationToken(signal); + await AddAsync(authorization, cancellationToken); } - entry.SetSize(authorizations.Length); - entry.SetValue(authorizations); + authorizations = builder.ToImmutable(); + + await CreateEntryAsync(parameters, authorizations, cancellationToken); } foreach (var authorization in authorizations) @@ -195,8 +172,6 @@ namespace OpenIddict.Core yield return authorization; } } - - return ExecuteAsync(); } /// @@ -226,44 +201,32 @@ namespace OpenIddict.Core throw new ArgumentException("The status cannot be null or empty.", nameof(status)); } - var parameters = new - { - Method = nameof(FindAsync), - Subject = subject, - Client = client, - Status = status - }; + return ExecuteAsync(cancellationToken); - if (_cache.TryGetValue(parameters, out ImmutableArray authorizations)) + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - return authorizations.ToAsyncEnumerable(); - } - - async IAsyncEnumerable ExecuteAsync() - { - var authorizations = ImmutableArray.CreateRange(await _store.FindAsync( - subject, client, status, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var authorization in authorizations) + var parameters = new { - await AddAsync(authorization, cancellationToken); - } + Method = nameof(FindAsync), + Subject = subject, + Client = client, + Status = status + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray authorizations)) { - foreach (var authorization in authorizations) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var authorization in _store.FindAsync(subject, client, status, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(authorization); - entry.AddExpirationToken(signal); + await AddAsync(authorization, cancellationToken); } - entry.SetSize(authorizations.Length); - entry.SetValue(authorizations); + authorizations = builder.ToImmutable(); + + await CreateEntryAsync(parameters, authorizations, cancellationToken); } foreach (var authorization in authorizations) @@ -271,8 +234,6 @@ namespace OpenIddict.Core yield return authorization; } } - - return ExecuteAsync(); } /// @@ -308,45 +269,33 @@ namespace OpenIddict.Core throw new ArgumentException("The type cannot be null or empty.", nameof(type)); } - var parameters = new - { - Method = nameof(FindAsync), - Subject = subject, - Client = client, - Status = status, - Type = type - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray authorizations)) - { - return authorizations.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var authorizations = ImmutableArray.CreateRange(await _store.FindAsync( - subject, client, status, type, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var authorization in authorizations) + var parameters = new { - await AddAsync(authorization, cancellationToken); - } - - using (var entry = _cache.CreateEntry(parameters)) + Method = nameof(FindAsync), + Subject = subject, + Client = client, + Status = status, + Type = type + }; + + if (!_cache.TryGetValue(parameters, out ImmutableArray authorizations)) { - foreach (var authorization in authorizations) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var authorization in _store.FindAsync(subject, client, status, type, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(authorization); - entry.AddExpirationToken(signal); + await AddAsync(authorization, cancellationToken); } - entry.SetSize(authorizations.Length); - entry.SetValue(authorizations); + authorizations = builder.ToImmutable(); + + await CreateEntryAsync(parameters, authorizations, cancellationToken); } foreach (var authorization in authorizations) @@ -354,8 +303,6 @@ namespace OpenIddict.Core yield return authorization; } } - - return ExecuteAsync(); } /// @@ -395,7 +342,9 @@ namespace OpenIddict.Core // Note: this method is only partially cached. - async IAsyncEnumerable ExecuteAsync() + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { await foreach (var authorization in _store.FindAsync(subject, client, status, type, scopes, cancellationToken)) { @@ -404,8 +353,6 @@ namespace OpenIddict.Core yield return authorization; } } - - return ExecuteAsync(); } /// @@ -422,42 +369,30 @@ namespace OpenIddict.Core throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - var parameters = new - { - Method = nameof(FindByApplicationIdAsync), - Identifier = identifier - }; + return ExecuteAsync(cancellationToken); - if (_cache.TryGetValue(parameters, out ImmutableArray authorizations)) + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - return authorizations.ToAsyncEnumerable(); - } - - async IAsyncEnumerable ExecuteAsync() - { - var authorizations = ImmutableArray.CreateRange(await _store.FindByApplicationIdAsync( - identifier, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var authorization in authorizations) + var parameters = new { - await AddAsync(authorization, cancellationToken); - } + Method = nameof(FindByApplicationIdAsync), + Identifier = identifier + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray authorizations)) { - foreach (var authorization in authorizations) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var authorization in _store.FindByApplicationIdAsync(identifier, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(authorization); - entry.AddExpirationToken(signal); + await AddAsync(authorization, cancellationToken); } - entry.SetSize(authorizations.Length); - entry.SetValue(authorizations); + authorizations = builder.ToImmutable(); + + await CreateEntryAsync(parameters, authorizations, cancellationToken); } foreach (var authorization in authorizations) @@ -465,8 +400,6 @@ namespace OpenIddict.Core yield return authorization; } } - - return ExecuteAsync(); } /// @@ -496,6 +429,8 @@ namespace OpenIddict.Core return new ValueTask(authorization); } + return new ValueTask(ExecuteAsync()); + async Task ExecuteAsync() { if ((authorization = await _store.FindByIdAsync(identifier, cancellationToken)) != null) @@ -503,27 +438,10 @@ namespace OpenIddict.Core await AddAsync(authorization, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (authorization != null) - { - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(authorization); - } + await CreateEntryAsync(parameters, authorization, cancellationToken); return authorization; } - - return new ValueTask(ExecuteAsync()); } /// @@ -540,42 +458,30 @@ namespace OpenIddict.Core throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); } - var parameters = new - { - Method = nameof(FindBySubjectAsync), - Subject = subject - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray authorizations)) - { - return authorizations.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var authorizations = ImmutableArray.CreateRange(await _store.FindBySubjectAsync( - subject, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var authorization in authorizations) + var parameters = new { - await AddAsync(authorization, cancellationToken); - } + Method = nameof(FindBySubjectAsync), + Subject = subject + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray authorizations)) { - foreach (var authorization in authorizations) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var authorization in _store.FindBySubjectAsync(subject, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(authorization); - entry.AddExpirationToken(signal); + await AddAsync(authorization, cancellationToken); } - entry.SetSize(authorizations.Length); - entry.SetValue(authorizations); + authorizations = builder.ToImmutable(); + + await CreateEntryAsync(parameters, authorizations, cancellationToken); } foreach (var authorization in authorizations) @@ -583,8 +489,6 @@ namespace OpenIddict.Core yield return authorization; } } - - return ExecuteAsync(); } /// @@ -613,6 +517,70 @@ namespace OpenIddict.Core } } + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The authorization to store in the cache entry, if applicable. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + if (authorization != null) + { + var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(1L); + entry.SetValue(authorization); + } + + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The authorizations to store in the cache entry. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] ImmutableArray authorizations, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + foreach (var authorization in authorizations) + { + var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(authorizations.Length); + entry.SetValue(authorizations); + } + /// /// Creates an expiration signal allowing to invalidate all the /// cache entries associated with the specified authorization. diff --git a/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs b/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs index 3dab926b..ccb121b8 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs @@ -9,6 +9,7 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; @@ -76,33 +77,17 @@ namespace OpenIddict.Core }); } - var signal = await CreateExpirationSignalAsync(scope, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration token."); - } - - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(scope, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(scope); - } + }, scope, cancellationToken); - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByNameAsync), Name = await _store.GetNameAsync(scope, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(scope); - } + }, scope, cancellationToken); } /// @@ -152,22 +137,7 @@ namespace OpenIddict.Core await AddAsync(scope, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (scope != null) - { - var signal = await CreateExpirationSignalAsync(scope, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(scope); - } + await CreateEntryAsync(parameters, scope, cancellationToken); return scope; } @@ -209,22 +179,7 @@ namespace OpenIddict.Core await AddAsync(scope, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (scope != null) - { - var signal = await CreateExpirationSignalAsync(scope, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(scope); - } + await CreateEntryAsync(parameters, scope, cancellationToken); return scope; } @@ -240,11 +195,6 @@ namespace OpenIddict.Core /// The scopes corresponding to the specified names. public IAsyncEnumerable FindByNamesAsync(ImmutableArray names, CancellationToken cancellationToken) { - if (names.IsDefaultOrEmpty) - { - return AsyncEnumerable.Empty(); - } - if (names.Any(name => string.IsNullOrEmpty(name))) { throw new ArgumentException("Scope names cannot be null or empty.", nameof(names)); @@ -252,7 +202,9 @@ namespace OpenIddict.Core // Note: this method is only partially cached. - async IAsyncEnumerable ExecuteAsync() + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { await foreach (var scope in _store.FindByNamesAsync(names, cancellationToken)) { @@ -261,8 +213,6 @@ namespace OpenIddict.Core yield return scope; } } - - return ExecuteAsync(); } /// @@ -278,42 +228,30 @@ namespace OpenIddict.Core throw new ArgumentException("The resource cannot be null or empty.", nameof(resource)); } - var parameters = new - { - Method = nameof(FindByResourceAsync), - Resource = resource - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray scopes)) - { - return scopes.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var scopes = ImmutableArray.CreateRange(await _store.FindByResourceAsync( - resource, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var scope in scopes) + var parameters = new { - await AddAsync(scope, cancellationToken); - } + Method = nameof(FindByResourceAsync), + Resource = resource + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray scopes)) { - foreach (var scope in scopes) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var scope in _store.FindByResourceAsync(resource, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(scope, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(scope); - entry.AddExpirationToken(signal); + await AddAsync(scope, cancellationToken); } - entry.SetSize(scopes.Length); - entry.SetValue(scopes); + scopes = builder.ToImmutable(); + + await CreateEntryAsync(parameters, scopes, cancellationToken); } foreach (var scope in scopes) @@ -321,8 +259,6 @@ namespace OpenIddict.Core yield return scope; } } - - return ExecuteAsync(); } /// @@ -351,6 +287,70 @@ namespace OpenIddict.Core } } + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The scope to store in the cache entry, if applicable. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] TScope scope, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + if (scope != null) + { + var signal = await CreateExpirationSignalAsync(scope, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(1L); + entry.SetValue(scope); + } + + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The scopes to store in the cache entry. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] ImmutableArray scopes, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + foreach (var scope in scopes) + { + var signal = await CreateExpirationSignalAsync(scope, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(scopes.Length); + entry.SetValue(scopes); + } + /// /// Creates an expiration signal allowing to invalidate all the /// cache entries associated with the specified scope. diff --git a/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs b/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs index b8c98056..4317a093 100644 --- a/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs +++ b/src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs @@ -8,7 +8,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; @@ -109,33 +109,17 @@ namespace OpenIddict.Core Subject = await _store.GetSubjectAsync(token, cancellationToken) }); - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByIdAsync), Identifier = await _store.GetIdAsync(token, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(token); - } + }, token, cancellationToken); - using (var entry = _cache.CreateEntry(new + await CreateEntryAsync(new { Method = nameof(FindByReferenceIdAsync), Identifier = await _store.GetReferenceIdAsync(token, cancellationToken) - })) - { - entry.AddExpirationToken(signal) - .SetSize(1L) - .SetValue(token); - } + }, token, cancellationToken); } /// @@ -172,43 +156,31 @@ namespace OpenIddict.Core throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); } - var parameters = new - { - Method = nameof(FindAsync), - Subject = subject, - Client = client - }; + return ExecuteAsync(cancellationToken); - if (_cache.TryGetValue(parameters, out ImmutableArray tokens)) + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - return tokens.ToAsyncEnumerable(); - } - - async IAsyncEnumerable ExecuteAsync() - { - var tokens = ImmutableArray.CreateRange(await _store.FindAsync( - subject, client, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var token in tokens) + var parameters = new { - await AddAsync(token, cancellationToken); - } + Method = nameof(FindAsync), + Subject = subject, + Client = client + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray tokens)) { - foreach (var token in tokens) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var token in _store.FindAsync(subject, client, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(token); - entry.AddExpirationToken(signal); + await AddAsync(token, cancellationToken); } - entry.SetSize(tokens.Length); - entry.SetValue(tokens); + tokens = builder.ToImmutable(); + + await CreateEntryAsync(parameters, tokens, cancellationToken); } foreach (var token in tokens) @@ -216,8 +188,6 @@ namespace OpenIddict.Core yield return token; } } - - return ExecuteAsync(); } /// @@ -247,44 +217,32 @@ namespace OpenIddict.Core throw new ArgumentException("The status cannot be null or empty.", nameof(status)); } - var parameters = new - { - Method = nameof(FindAsync), - Subject = subject, - Client = client, - Status = status - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray tokens)) - { - return tokens.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var tokens = ImmutableArray.CreateRange(await _store.FindAsync( - subject, client, status, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var token in tokens) + var parameters = new { - await AddAsync(token, cancellationToken); - } + Method = nameof(FindAsync), + Subject = subject, + Client = client, + Status = status + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray tokens)) { - foreach (var token in tokens) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var token in _store.FindAsync(subject, client, status, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(token); - entry.AddExpirationToken(signal); + await AddAsync(token, cancellationToken); } - entry.SetSize(tokens.Length); - entry.SetValue(tokens); + tokens = builder.ToImmutable(); + + await CreateEntryAsync(parameters, tokens, cancellationToken); } foreach (var token in tokens) @@ -292,8 +250,6 @@ namespace OpenIddict.Core yield return token; } } - - return ExecuteAsync(); } /// @@ -329,45 +285,33 @@ namespace OpenIddict.Core throw new ArgumentException("The type cannot be null or empty.", nameof(type)); } - var parameters = new - { - Method = nameof(FindAsync), - Subject = subject, - Client = client, - Status = status, - Type = type - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray tokens)) - { - return tokens.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var tokens = ImmutableArray.CreateRange(await _store.FindAsync( - subject, client, status, type, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var token in tokens) + var parameters = new { - await AddAsync(token, cancellationToken); - } - - using (var entry = _cache.CreateEntry(parameters)) + Method = nameof(FindAsync), + Subject = subject, + Client = client, + Status = status, + Type = type + }; + + if (!_cache.TryGetValue(parameters, out ImmutableArray tokens)) { - foreach (var token in tokens) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var token in _store.FindAsync(subject, client, status, type, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(token); - entry.AddExpirationToken(signal); + await AddAsync(token, cancellationToken); } - entry.SetSize(tokens.Length); - entry.SetValue(tokens); + tokens = builder.ToImmutable(); + + await CreateEntryAsync(parameters, tokens, cancellationToken); } foreach (var token in tokens) @@ -375,8 +319,6 @@ namespace OpenIddict.Core yield return token; } } - - return ExecuteAsync(); } /// @@ -393,42 +335,30 @@ namespace OpenIddict.Core throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - var parameters = new - { - Method = nameof(FindByApplicationIdAsync), - Identifier = identifier - }; + return ExecuteAsync(cancellationToken); - if (_cache.TryGetValue(parameters, out ImmutableArray tokens)) + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - return tokens.ToAsyncEnumerable(); - } - - async IAsyncEnumerable ExecuteAsync() - { - var tokens = ImmutableArray.CreateRange(await _store.FindByApplicationIdAsync( - identifier, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var token in tokens) + var parameters = new { - await AddAsync(token, cancellationToken); - } + Method = nameof(FindByApplicationIdAsync), + Identifier = identifier + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray tokens)) { - foreach (var token in tokens) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var token in _store.FindByApplicationIdAsync(identifier, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(token); - entry.AddExpirationToken(signal); + await AddAsync(token, cancellationToken); } - entry.SetSize(tokens.Length); - entry.SetValue(tokens); + tokens = builder.ToImmutable(); + + await CreateEntryAsync(parameters, tokens, cancellationToken); } foreach (var token in tokens) @@ -436,8 +366,6 @@ namespace OpenIddict.Core yield return token; } } - - return ExecuteAsync(); } /// @@ -454,42 +382,30 @@ namespace OpenIddict.Core throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - var parameters = new - { - Method = nameof(FindByAuthorizationIdAsync), - Identifier = identifier - }; + return ExecuteAsync(cancellationToken); - if (_cache.TryGetValue(parameters, out ImmutableArray tokens)) + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - return tokens.ToAsyncEnumerable(); - } - - async IAsyncEnumerable ExecuteAsync() - { - var tokens = ImmutableArray.CreateRange(await _store.FindByAuthorizationIdAsync( - identifier, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var token in tokens) + var parameters = new { - await AddAsync(token, cancellationToken); - } + Method = nameof(FindByAuthorizationIdAsync), + Identifier = identifier + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray tokens)) { - foreach (var token in tokens) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var token in _store.FindByAuthorizationIdAsync(identifier, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(token); - entry.AddExpirationToken(signal); + await AddAsync(token, cancellationToken); } - entry.SetSize(tokens.Length); - entry.SetValue(tokens); + tokens = builder.ToImmutable(); + + await CreateEntryAsync(parameters, tokens, cancellationToken); } foreach (var token in tokens) @@ -497,8 +413,6 @@ namespace OpenIddict.Core yield return token; } } - - return ExecuteAsync(); } /// @@ -528,6 +442,8 @@ namespace OpenIddict.Core return new ValueTask(token); } + return new ValueTask(ExecuteAsync()); + async Task ExecuteAsync() { if ((token = await _store.FindByIdAsync(identifier, cancellationToken)) != null) @@ -535,27 +451,10 @@ namespace OpenIddict.Core await AddAsync(token, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (token != null) - { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(token); - } + await CreateEntryAsync(parameters, token, cancellationToken); return token; } - - return new ValueTask(ExecuteAsync()); } /// @@ -586,6 +485,8 @@ namespace OpenIddict.Core return new ValueTask(token); } + return new ValueTask(ExecuteAsync()); + async Task ExecuteAsync() { if ((token = await _store.FindByReferenceIdAsync(identifier, cancellationToken)) != null) @@ -593,27 +494,10 @@ namespace OpenIddict.Core await AddAsync(token, cancellationToken); } - using (var entry = _cache.CreateEntry(parameters)) - { - if (token != null) - { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } - - entry.AddExpirationToken(signal); - } - - entry.SetSize(1L); - entry.SetValue(token); - } + await CreateEntryAsync(parameters, token, cancellationToken); return token; } - - return new ValueTask(ExecuteAsync()); } /// @@ -629,42 +513,30 @@ namespace OpenIddict.Core throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); } - var parameters = new - { - Method = nameof(FindBySubjectAsync), - Identifier = subject - }; - - if (_cache.TryGetValue(parameters, out ImmutableArray tokens)) - { - return tokens.ToAsyncEnumerable(); - } + return ExecuteAsync(cancellationToken); - async IAsyncEnumerable ExecuteAsync() + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) { - var tokens = ImmutableArray.CreateRange(await _store.FindBySubjectAsync( - subject, cancellationToken).ToListAsync(cancellationToken)); - - foreach (var token in tokens) + var parameters = new { - await AddAsync(token, cancellationToken); - } + Method = nameof(FindBySubjectAsync), + Identifier = subject + }; - using (var entry = _cache.CreateEntry(parameters)) + if (!_cache.TryGetValue(parameters, out ImmutableArray tokens)) { - foreach (var token in tokens) + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var token in _store.FindBySubjectAsync(subject, cancellationToken)) { - var signal = await CreateExpirationSignalAsync(token, cancellationToken); - if (signal == null) - { - throw new InvalidOperationException("An error occurred while creating an expiration signal."); - } + builder.Add(token); - entry.AddExpirationToken(signal); + await AddAsync(token, cancellationToken); } - entry.SetSize(tokens.Length); - entry.SetValue(tokens); + tokens = builder.ToImmutable(); + + await CreateEntryAsync(parameters, tokens, cancellationToken); } foreach (var token in tokens) @@ -672,8 +544,6 @@ namespace OpenIddict.Core yield return token; } } - - return ExecuteAsync(); } /// @@ -702,6 +572,70 @@ namespace OpenIddict.Core } } + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The token to store in the cache entry, if applicable. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] TToken token, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + if (token != null) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(1L); + entry.SetValue(token); + } + + /// + /// Creates a cache entry for the specified key. + /// + /// The cache key. + /// The tokens to store in the cache entry. + /// The that can be used to abort the operation. + /// A that can be used to monitor the asynchronous operation. + protected virtual async ValueTask CreateEntryAsync( + [NotNull] object key, [CanBeNull] ImmutableArray tokens, CancellationToken cancellationToken) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + using var entry = _cache.CreateEntry(key); + + foreach (var token in tokens) + { + var signal = await CreateExpirationSignalAsync(token, cancellationToken); + if (signal == null) + { + throw new InvalidOperationException("An error occurred while creating an expiration signal."); + } + + entry.AddExpirationToken(signal); + } + + entry.SetSize(tokens.Length); + entry.SetValue(tokens); + } + /// /// Creates an expiration signal allowing to invalidate all the /// cache entries associated with the specified token. diff --git a/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs b/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs index a880ba88..375f186a 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs @@ -156,7 +156,7 @@ namespace OpenIddict.Core await Store.SetClientSecretAsync(application, secret, cancellationToken); } - var results = await ValidateAsync(application, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(application, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -168,7 +168,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.CreateAsync(application, cancellationToken); @@ -177,6 +177,19 @@ namespace OpenIddict.Core { await Cache.AddAsync(application, cancellationToken); } + + async Task> GetValidationResultsAsync( + TApplication application, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(application, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -341,12 +354,23 @@ namespace OpenIddict.Core return applications; } + return ExecuteAsync(cancellationToken); + // SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default. // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return applications.WhereAwait(async application => - (await Store.GetPostLogoutRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal)); + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var application in applications) + { + var addresses = await Store.GetPostLogoutRedirectUrisAsync(application, cancellationToken); + if (addresses.Contains(address, StringComparer.Ordinal)) + { + yield return application; + } + } + } } /// @@ -376,8 +400,19 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return applications.WhereAwait(async application => - (await Store.GetRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var application in applications) + { + var addresses = await Store.GetRedirectUrisAsync(application, cancellationToken); + if (addresses.Contains(address, StringComparer.Ordinal)) + { + yield return application; + } + } + } } /// @@ -872,7 +907,7 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(application)); } - var results = await ValidateAsync(application, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(application, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -884,7 +919,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.UpdateAsync(application, cancellationToken); @@ -894,6 +929,19 @@ namespace OpenIddict.Core await Cache.RemoveAsync(application, cancellationToken); await Cache.AddAsync(application, cancellationToken); } + + async Task> GetValidationResultsAsync( + TApplication application, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(application, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -1379,10 +1427,10 @@ namespace OpenIddict.Core => await FindByIdAsync(identifier, cancellationToken); IAsyncEnumerable IOpenIddictApplicationManager.FindByPostLogoutRedirectUriAsync(string address, CancellationToken cancellationToken) - => FindByPostLogoutRedirectUriAsync(address, cancellationToken).OfType(); + => FindByPostLogoutRedirectUriAsync(address, cancellationToken); IAsyncEnumerable IOpenIddictApplicationManager.FindByRedirectUriAsync(string address, CancellationToken cancellationToken) - => FindByRedirectUriAsync(address, cancellationToken).OfType(); + => FindByRedirectUriAsync(address, cancellationToken); ValueTask IOpenIddictApplicationManager.GetAsync(Func, IQueryable> query, CancellationToken cancellationToken) => GetAsync(query, cancellationToken); @@ -1430,7 +1478,7 @@ namespace OpenIddict.Core => HasRequirementAsync((TApplication) application, requirement, cancellationToken); IAsyncEnumerable IOpenIddictApplicationManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) - => ListAsync(count, offset, cancellationToken).OfType(); + => ListAsync(count, offset, cancellationToken); IAsyncEnumerable IOpenIddictApplicationManager.ListAsync(Func, IQueryable> query, CancellationToken cancellationToken) => ListAsync(query, cancellationToken); diff --git a/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs b/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs index c6dc1b42..f8f8bf93 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs @@ -114,7 +114,7 @@ namespace OpenIddict.Core await Store.SetStatusAsync(authorization, Statuses.Valid, cancellationToken); } - var results = await ValidateAsync(authorization, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(authorization, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -126,7 +126,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.CreateAsync(authorization, cancellationToken); @@ -135,6 +135,19 @@ namespace OpenIddict.Core { await Cache.AddAsync(authorization, cancellationToken); } + + async Task> GetValidationResultsAsync( + TAuthorization authorization, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(authorization, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -272,8 +285,22 @@ namespace OpenIddict.Core return authorizations; } - return authorizations.WhereAwait(async authorization => string.Equals( - await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)); + // SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default. + // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation + // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var authorization in authorizations) + { + if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return authorization; + } + } + } } /// @@ -316,8 +343,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return authorizations.WhereAwait(async authorization => string.Equals( - await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var authorization in authorizations) + { + if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return authorization; + } + } + } } /// @@ -362,8 +399,22 @@ namespace OpenIddict.Core return authorizations; } - return authorizations.WhereAwait(async authorization => string.Equals( - await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)); + // SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default. + // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation + // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var authorization in authorizations) + { + if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return authorization; + } + } + } } /// @@ -414,9 +465,25 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return authorizations.WhereAwait(async authorization => string.Equals( - await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal) && - await HasScopesAsync(authorization, scopes, cancellationToken)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var authorization in authorizations) + { + if (!string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)) + { + continue; + } + + if (!await HasScopesAsync(authorization, scopes, cancellationToken)) + { + continue; + } + + yield return authorization; + } + } } /// @@ -446,8 +513,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return authorizations.WhereAwait(async authorization => string.Equals( - await Store.GetApplicationIdAsync(authorization, cancellationToken), identifier, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var authorization in authorizations) + { + if (string.Equals(await Store.GetApplicationIdAsync(authorization, cancellationToken), identifier, StringComparison.Ordinal)) + { + yield return authorization; + } + } + } } /// @@ -514,8 +591,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return authorizations.WhereAwait(async authorization => string.Equals( - await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var authorization in authorizations) + { + if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return authorization; + } + } + } } /// @@ -951,7 +1038,7 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(authorization)); } - var results = await ValidateAsync(authorization, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(authorization, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -963,7 +1050,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.UpdateAsync(authorization, cancellationToken); @@ -973,6 +1060,19 @@ namespace OpenIddict.Core await Cache.RemoveAsync(authorization, cancellationToken); await Cache.AddAsync(authorization, cancellationToken); } + + async Task> GetValidationResultsAsync( + TAuthorization authorization, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(authorization, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -1070,25 +1170,25 @@ namespace OpenIddict.Core => DeleteAsync((TAuthorization) authorization, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.FindAsync(string subject, string client, CancellationToken cancellationToken) - => FindAsync(subject, client, cancellationToken).OfType(); + => FindAsync(subject, client, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, CancellationToken cancellationToken) - => FindAsync(subject, client, status, cancellationToken).OfType(); + => FindAsync(subject, client, status, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, string type, CancellationToken cancellationToken) - => FindAsync(subject, client, status, type, cancellationToken).OfType(); + => FindAsync(subject, client, status, type, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, string type, ImmutableArray scopes, CancellationToken cancellationToken) - => FindAsync(subject, client, status, type, scopes, cancellationToken).OfType(); + => FindAsync(subject, client, status, type, scopes, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.FindByApplicationIdAsync(string identifier, CancellationToken cancellationToken) - => FindByApplicationIdAsync(identifier, cancellationToken).OfType(); + => FindByApplicationIdAsync(identifier, cancellationToken); async ValueTask IOpenIddictAuthorizationManager.FindByIdAsync(string identifier, CancellationToken cancellationToken) => await FindByIdAsync(identifier, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.FindBySubjectAsync(string subject, CancellationToken cancellationToken) - => FindBySubjectAsync(subject, cancellationToken).OfType(); + => FindBySubjectAsync(subject, cancellationToken); ValueTask IOpenIddictAuthorizationManager.GetApplicationIdAsync(object authorization, CancellationToken cancellationToken) => GetApplicationIdAsync((TAuthorization) authorization, cancellationToken); @@ -1124,7 +1224,7 @@ namespace OpenIddict.Core => HasTypeAsync((TAuthorization) authorization, type, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) - => ListAsync(count, offset, cancellationToken).OfType(); + => ListAsync(count, offset, cancellationToken); IAsyncEnumerable IOpenIddictAuthorizationManager.ListAsync(Func, IQueryable> query, CancellationToken cancellationToken) => ListAsync(query, cancellationToken); diff --git a/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs b/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs index f56511a0..dd2f8d11 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs @@ -106,7 +106,7 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(scope)); } - var results = await ValidateAsync(scope, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(scope, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -118,7 +118,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.CreateAsync(scope, cancellationToken); @@ -127,6 +127,19 @@ namespace OpenIddict.Core { await Cache.AddAsync(scope, cancellationToken); } + + async Task> GetValidationResultsAsync( + TScope scope, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(scope, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -264,11 +277,6 @@ namespace OpenIddict.Core public virtual IAsyncEnumerable FindByNamesAsync( ImmutableArray names, CancellationToken cancellationToken = default) { - if (names.IsDefaultOrEmpty) - { - return AsyncEnumerable.Empty(); - } - if (names.Any(name => string.IsNullOrEmpty(name))) { throw new ArgumentException("Scope names cannot be null or empty.", nameof(names)); @@ -287,7 +295,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return scopes.WhereAwait(async scope => names.Contains(await Store.GetNameAsync(scope, cancellationToken), StringComparer.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var scope in scopes) + { + if (names.Contains(await Store.GetNameAsync(scope, cancellationToken), StringComparer.Ordinal)) + { + yield return scope; + } + } + } } /// @@ -317,8 +336,19 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return scopes.WhereAwait(async scope => - (await Store.GetResourcesAsync(scope, cancellationToken)).Contains(resource, StringComparer.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var scope in scopes) + { + var resources = await Store.GetResourcesAsync(scope, cancellationToken); + if (resources.Contains(resource, StringComparer.Ordinal)) + { + yield return scope; + } + } + } } /// @@ -518,29 +548,19 @@ namespace OpenIddict.Core /// The scopes. /// The that can be used to abort the operation. /// All the resources associated with the specified scopes. - public virtual IAsyncEnumerable ListResourcesAsync( - ImmutableArray scopes, CancellationToken cancellationToken = default) + public virtual async IAsyncEnumerable ListResourcesAsync( + ImmutableArray scopes, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - if (scopes.IsDefaultOrEmpty) + var resources = new HashSet(StringComparer.Ordinal); + + await foreach (var scope in FindByNamesAsync(scopes, cancellationToken)) { - return AsyncEnumerable.Empty(); + resources.UnionWith(await GetResourcesAsync(scope, cancellationToken)); } - return ExecuteAsync(cancellationToken); - - async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + foreach (var resource in resources) { - var resources = new HashSet(StringComparer.Ordinal); - - await foreach (var scope in FindByNamesAsync(scopes, cancellationToken)) - { - resources.UnionWith(await GetResourcesAsync(scope, cancellationToken)); - } - - foreach (var resource in resources) - { - yield return resource; - } + yield return resource; } } @@ -617,7 +637,7 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(scope)); } - var results = await ValidateAsync(scope, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(scope, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -629,7 +649,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.UpdateAsync(scope, cancellationToken); @@ -639,6 +659,19 @@ namespace OpenIddict.Core await Cache.RemoveAsync(scope, cancellationToken); await Cache.AddAsync(scope, cancellationToken); } + + async Task> GetValidationResultsAsync( + TScope scope, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(scope, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -732,10 +765,10 @@ namespace OpenIddict.Core => await FindByNameAsync(name, cancellationToken); IAsyncEnumerable IOpenIddictScopeManager.FindByNamesAsync(ImmutableArray names, CancellationToken cancellationToken) - => FindByNamesAsync(names, cancellationToken).OfType(); + => FindByNamesAsync(names, cancellationToken); IAsyncEnumerable IOpenIddictScopeManager.FindByResourceAsync(string resource, CancellationToken cancellationToken) - => FindByResourceAsync(resource, cancellationToken).OfType(); + => FindByResourceAsync(resource, cancellationToken); ValueTask IOpenIddictScopeManager.GetAsync(Func, IQueryable> query, CancellationToken cancellationToken) => GetAsync(query, cancellationToken); @@ -759,7 +792,7 @@ namespace OpenIddict.Core => GetResourcesAsync((TScope) scope, cancellationToken); IAsyncEnumerable IOpenIddictScopeManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) - => ListAsync(count, offset, cancellationToken).OfType(); + => ListAsync(count, offset, cancellationToken); IAsyncEnumerable IOpenIddictScopeManager.ListAsync(Func, IQueryable> query, CancellationToken cancellationToken) => ListAsync(query, cancellationToken); diff --git a/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs b/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs index 0c4f5b42..0000bc8b 100644 --- a/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs +++ b/src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs @@ -122,7 +122,7 @@ namespace OpenIddict.Core await Store.SetReferenceIdAsync(token, identifier, cancellationToken); } - var results = await ValidateAsync(token, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(token, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -134,7 +134,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.CreateAsync(token, cancellationToken); @@ -143,6 +143,19 @@ namespace OpenIddict.Core { await Cache.AddAsync(token, cancellationToken); } + + async Task> GetValidationResultsAsync( + TToken token, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(token, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -230,8 +243,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( - token, cancellationToken), subject, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var token in tokens) + { + if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return token; + } + } + } } /// @@ -274,8 +297,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( - token, cancellationToken), subject, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var token in tokens) + { + if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return token; + } + } + } } /// @@ -324,8 +357,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( - token, cancellationToken), subject, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var token in tokens) + { + if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return token; + } + } + } } /// @@ -355,8 +398,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return tokens.WhereAwait(async token => string.Equals(await Store.GetApplicationIdAsync( - token, cancellationToken), identifier, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var token in tokens) + { + if (string.Equals(await Store.GetApplicationIdAsync(token, cancellationToken), identifier, StringComparison.Ordinal)) + { + yield return token; + } + } + } } /// @@ -386,8 +439,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return tokens.WhereAwait(async token => string.Equals(await Store.GetAuthorizationIdAsync( - token, cancellationToken), identifier, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var token in tokens) + { + if (string.Equals(await Store.GetAuthorizationIdAsync(token, cancellationToken), identifier, StringComparison.Ordinal)) + { + yield return token; + } + } + } } /// @@ -495,8 +558,18 @@ namespace OpenIddict.Core // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. - return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( - token, cancellationToken), subject, StringComparison.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var token in tokens) + { + if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal)) + { + yield return token; + } + } + } } /// @@ -1168,7 +1241,7 @@ namespace OpenIddict.Core throw new ArgumentNullException(nameof(token)); } - var results = await ValidateAsync(token, cancellationToken).ToListAsync(cancellationToken); + var results = await GetValidationResultsAsync(token, cancellationToken); if (results.Any(result => result != ValidationResult.Success)) { var builder = new StringBuilder(); @@ -1180,7 +1253,7 @@ namespace OpenIddict.Core builder.AppendLine(result.ErrorMessage); } - throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); + throw new OpenIddictExceptions.ValidationException(builder.ToString(), results); } await Store.UpdateAsync(token, cancellationToken); @@ -1190,6 +1263,19 @@ namespace OpenIddict.Core await Cache.RemoveAsync(token, cancellationToken); await Cache.AddAsync(token, cancellationToken); } + + async Task> GetValidationResultsAsync( + TToken token, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + + await foreach (var result in ValidateAsync(token, cancellationToken)) + { + builder.Add(result); + } + + return builder.ToImmutable(); + } } /// @@ -1318,19 +1404,19 @@ namespace OpenIddict.Core => DeleteAsync((TToken) token, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.FindAsync(string subject, string client, CancellationToken cancellationToken) - => FindAsync(subject, client, cancellationToken).OfType(); + => FindAsync(subject, client, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.FindAsync(string subject, string client, string status, CancellationToken cancellationToken) - => FindAsync(subject, client, status, cancellationToken).OfType(); + => FindAsync(subject, client, status, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.FindAsync(string subject, string client, string status, string type, CancellationToken cancellationToken) - => FindAsync(subject, client, status, type, cancellationToken).OfType(); + => FindAsync(subject, client, status, type, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.FindByApplicationIdAsync(string identifier, CancellationToken cancellationToken) - => FindByApplicationIdAsync(identifier, cancellationToken).OfType(); + => FindByApplicationIdAsync(identifier, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.FindByAuthorizationIdAsync(string identifier, CancellationToken cancellationToken) - => FindByAuthorizationIdAsync(identifier, cancellationToken).OfType(); + => FindByAuthorizationIdAsync(identifier, cancellationToken); async ValueTask IOpenIddictTokenManager.FindByIdAsync(string identifier, CancellationToken cancellationToken) => await FindByIdAsync(identifier, cancellationToken); @@ -1339,7 +1425,7 @@ namespace OpenIddict.Core => await FindByReferenceIdAsync(identifier, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.FindBySubjectAsync(string subject, CancellationToken cancellationToken) - => FindBySubjectAsync(subject, cancellationToken).OfType(); + => FindBySubjectAsync(subject, cancellationToken); ValueTask IOpenIddictTokenManager.GetApplicationIdAsync(object token, CancellationToken cancellationToken) => GetApplicationIdAsync((TToken) token, cancellationToken); @@ -1384,7 +1470,7 @@ namespace OpenIddict.Core => HasTypeAsync((TToken) token, type, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) - => ListAsync(count, offset, cancellationToken).OfType(); + => ListAsync(count, offset, cancellationToken); IAsyncEnumerable IOpenIddictTokenManager.ListAsync(Func, IQueryable> query, CancellationToken cancellationToken) => ListAsync(query, cancellationToken); diff --git a/src/OpenIddict.Core/OpenIddict.Core.csproj b/src/OpenIddict.Core/OpenIddict.Core.csproj index 275109da..38f41e72 100644 --- a/src/OpenIddict.Core/OpenIddict.Core.csproj +++ b/src/OpenIddict.Core/OpenIddict.Core.csproj @@ -18,7 +18,6 @@ - diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs index eca958f8..b0ade607 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs @@ -12,6 +12,7 @@ using System.Data; using System.Data.Entity; using System.Data.Entity.Infrastructure; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; using System.Text.Json; @@ -303,10 +304,24 @@ namespace OpenIddict.EntityFramework // are retrieved, a second pass is made to ensure only valid elements are returned. // Implementers that use this method in a hot path may want to override this method // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. - return Applications.Where(application => application.PostLogoutRedirectUris.Contains(address)) - .AsAsyncEnumerable(cancellationToken) - .WhereAwait(async application => (await GetPostLogoutRedirectUrisAsync(application, cancellationToken)) - .Contains(address, StringComparer.Ordinal)); + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var applications = (from application in Applications + where application.PostLogoutRedirectUris.Contains(address) + select application).AsAsyncEnumerable(cancellationToken); + + await foreach (var application in applications) + { + var addresses = await GetPostLogoutRedirectUrisAsync(application, cancellationToken); + if (addresses.Contains(address, StringComparer.Ordinal)) + { + yield return application; + } + } + } } /// @@ -328,10 +343,24 @@ namespace OpenIddict.EntityFramework // are retrieved, a second pass is made to ensure only valid elements are returned. // Implementers that use this method in a hot path may want to override this method // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. - return Applications.Where(application => application.RedirectUris.Contains(address)) - .AsAsyncEnumerable(cancellationToken) - .WhereAwait(async application => (await GetRedirectUrisAsync(application, cancellationToken)) - .Contains(address, StringComparer.Ordinal)); + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var applications = (from application in Applications + where application.RedirectUris.Contains(address) + select application).AsAsyncEnumerable(cancellationToken); + + await foreach (var application in applications) + { + var addresses = await GetRedirectUrisAsync(application, cancellationToken); + if (addresses.Contains(address, StringComparer.Ordinal)) + { + yield return application; + } + } + } } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs index b2ef3062..0d5ffe47 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs @@ -12,6 +12,7 @@ using System.Data; using System.Data.Entity; using System.Data.Entity.Infrastructure; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; using System.Text.Json; @@ -342,9 +343,50 @@ namespace OpenIddict.EntityFramework [NotNull] string subject, [NotNull] string client, [NotNull] string status, [NotNull] string type, ImmutableArray scopes, CancellationToken cancellationToken) - => FindAsync(subject, client, status, type, cancellationToken) - .WhereAwait(async authorization => new HashSet( - await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes)); + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); + } + + if (string.IsNullOrEmpty(status)) + { + throw new ArgumentException("The status cannot be null or empty.", nameof(status)); + } + + if (string.IsNullOrEmpty(type)) + { + throw new ArgumentException("The type cannot be null or empty.", nameof(type)); + } + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var key = ConvertIdentifierFromString(client); + + var authorizations = (from authorization in Authorizations.Include(authorization => authorization.Application) + where authorization.Application != null && + authorization.Application.Id.Equals(key) && + authorization.Subject == subject && + authorization.Status == status && + authorization.Type == type + select authorization).AsAsyncEnumerable(cancellationToken); + + await foreach (var authorization in authorizations) + { + if (new HashSet(await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes)) + { + yield return authorization; + } + } + } + } /// /// Retrieves the list of authorizations corresponding to the specified application identifier. diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs index 3aa2aa7b..b1ac4fa4 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs @@ -11,6 +11,7 @@ using System.ComponentModel; using System.Data.Entity; using System.Data.Entity.Infrastructure; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; using System.Text.Json; @@ -247,9 +248,24 @@ namespace OpenIddict.EntityFramework // are retrieved, a second pass is made to ensure only valid elements are returned. // Implementers that use this method in a hot path may want to override this method // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. - return Scopes.Where(scope => scope.Resources.Contains(resource)) - .AsAsyncEnumerable(cancellationToken) - .WhereAwait(async scope => (await GetResourcesAsync(scope, cancellationToken)).Contains(resource, StringComparer.Ordinal)); + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var scopes = (from scope in Scopes + where scope.Resources.Contains(resource) + select scope).AsAsyncEnumerable(cancellationToken); + + await foreach (var scope in scopes) + { + var resources = await GetResourcesAsync(scope, cancellationToken); + if (resources.Contains(resource, StringComparer.Ordinal)) + { + yield return scope; + } + } + } } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs index 14b974eb..72068d10 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs @@ -10,6 +10,7 @@ using System.Collections.Immutable; using System.ComponentModel; using System.Data; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; using System.Text.Json; @@ -347,12 +348,24 @@ namespace OpenIddict.EntityFrameworkCore // are retrieved, a second pass is made to ensure only valid elements are returned. // Implementers that use this method in a hot path may want to override this method // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. - var applications = (from application in Applications.AsTracking() - where application.PostLogoutRedirectUris.Contains(address) - select application).AsAsyncEnumerable(); - return applications.WhereAwait(async application => - (await GetPostLogoutRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var applications = (from application in Applications.AsTracking() + where application.PostLogoutRedirectUris.Contains(address) + select application).AsAsyncEnumerable(); + + await foreach (var application in applications) + { + var addresses = await GetPostLogoutRedirectUrisAsync(application, cancellationToken); + if (addresses.Contains(address, StringComparer.Ordinal)) + { + yield return application; + } + } + } } /// @@ -374,12 +387,24 @@ namespace OpenIddict.EntityFrameworkCore // are retrieved, a second pass is made to ensure only valid elements are returned. // Implementers that use this method in a hot path may want to override this method // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. - var applications = (from application in Applications.AsTracking() - where application.RedirectUris.Contains(address) - select application).AsAsyncEnumerable(); - return applications.WhereAwait(async application => - (await GetRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var applications = (from application in Applications.AsTracking() + where application.RedirectUris.Contains(address) + select application).AsAsyncEnumerable(); + + await foreach (var application in applications) + { + var addresses = await GetRedirectUrisAsync(application, cancellationToken); + if (addresses.Contains(address, StringComparer.Ordinal)) + { + yield return application; + } + } + } } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs index 1dc7043e..c088230c 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs @@ -10,6 +10,7 @@ using System.Collections.Immutable; using System.ComponentModel; using System.Data; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; using System.Text.Json; @@ -393,9 +394,55 @@ namespace OpenIddict.EntityFrameworkCore [NotNull] string subject, [NotNull] string client, [NotNull] string status, [NotNull] string type, ImmutableArray scopes, CancellationToken cancellationToken) - => FindAsync(subject, client, status, type, cancellationToken) - .WhereAwait(async authorization => new HashSet( - await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes)); + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); + } + + if (string.IsNullOrEmpty(status)) + { + throw new ArgumentException("The status cannot be null or empty.", nameof(status)); + } + + if (string.IsNullOrEmpty(type)) + { + throw new ArgumentException("The type cannot be null or empty.", nameof(type)); + } + + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + // Note: due to a bug in Entity Framework Core's query visitor, the authorizations can't be + // filtered using authorization.Application.Id.Equals(key). To work around this issue, + // this method is overriden to use an explicit join before applying the equality check. + // See https://github.com/openiddict/openiddict-core/issues/499 for more information. + + var key = ConvertIdentifierFromString(client); + + var authorizations = (from authorization in Authorizations.Include(authorization => authorization.Application).AsTracking() + where authorization.Subject == subject && + authorization.Status == status && + authorization.Type == type + join application in Applications.AsTracking() on authorization.Application.Id equals application.Id + where application.Id.Equals(key) + select authorization).AsAsyncEnumerable(); + + await foreach (var authorization in authorizations) + { + if (new HashSet(await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes)) + { + yield return authorization; + } + } + } + } /// /// Retrieves the list of authorizations corresponding to the specified application identifier. diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs index ba081f04..b3527d67 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs @@ -9,6 +9,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.ComponentModel; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; using System.Text.Json; @@ -258,12 +259,29 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The resource cannot be null or empty.", nameof(resource)); } - var scopes = (from scope in Scopes.AsTracking() - where scope.Resources.Contains(resource) - select scope).AsAsyncEnumerable(); + // To optimize the efficiency of the query a bit, only scopes whose stringified + // Resources column contains the specified resource are returned. Once the scopes + // are retrieved, a second pass is made to ensure only valid elements are returned. + // Implementers that use this method in a hot path may want to override this method + // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. - return scopes.WhereAwait(async scope => - (await GetResourcesAsync(scope, cancellationToken)).Contains(resource, StringComparer.Ordinal)); + return ExecuteAsync(cancellationToken); + + async IAsyncEnumerable ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + var scopes = (from scope in Scopes.AsTracking() + where scope.Resources.Contains(resource) + select scope).AsAsyncEnumerable(); + + await foreach (var scope in scopes) + { + var resources = await GetResourcesAsync(scope, cancellationToken); + if (resources.Contains(resource, StringComparer.Ordinal)) + { + yield return scope; + } + } + } } /// diff --git a/test/OpenIddict.Server.IntegrationTests/OpenIddict.Server.IntegrationTests.csproj b/test/OpenIddict.Server.IntegrationTests/OpenIddict.Server.IntegrationTests.csproj index 8315858a..edb28bd9 100644 --- a/test/OpenIddict.Server.IntegrationTests/OpenIddict.Server.IntegrationTests.csproj +++ b/test/OpenIddict.Server.IntegrationTests/OpenIddict.Server.IntegrationTests.csproj @@ -18,6 +18,7 @@ +