Browse Source

Remove the System.Linq.Async dependency from OpenIddict.Core

pull/984/head
Kévin Chalet 6 years ago
parent
commit
4a2dedfd8c
  1. 31
      samples/Mvc.Server/Helpers/AsyncEnumerableExtensions.cs
  2. 220
      src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs
  3. 342
      src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs
  4. 178
      src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs
  5. 410
      src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs
  6. 70
      src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs
  7. 148
      src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs
  8. 99
      src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs
  9. 132
      src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs
  10. 1
      src/OpenIddict.Core/OpenIddict.Core.csproj
  11. 45
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs
  12. 48
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs
  13. 22
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs
  14. 45
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs
  15. 53
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs
  16. 28
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs
  17. 1
      test/OpenIddict.Server.IntegrationTests/OpenIddict.Server.IntegrationTests.csproj

31
samples/Mvc.Server/Helpers/AsyncEnumerableExtensions.cs

@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Mvc.Server.Helpers
{
public static class AsyncEnumerableExtensions
{
public static Task<List<T>> ToListAsync<T>(this IAsyncEnumerable<T> source)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}
return ExecuteAsync();
async Task<List<T>> ExecuteAsync()
{
var list = new List<T>();
await foreach (var element in source)
{
list.Add(element);
}
return list;
}
}
}
}

220
src/OpenIddict.Core/Caches/OpenIddictApplicationCache.cs

@ -8,7 +8,7 @@ using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Linq; using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using JetBrains.Annotations; using JetBrains.Annotations;
@ -85,33 +85,17 @@ namespace OpenIddict.Core
}); });
} }
var signal = await CreateExpirationSignalAsync(application, cancellationToken); await CreateEntryAsync(new
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
using (var entry = _cache.CreateEntry(new
{ {
Method = nameof(FindByIdAsync), Method = nameof(FindByIdAsync),
Identifier = await _store.GetIdAsync(application, cancellationToken) Identifier = await _store.GetIdAsync(application, cancellationToken)
})) }, application, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(application);
}
using (var entry = _cache.CreateEntry(new await CreateEntryAsync(new
{ {
Method = nameof(FindByClientIdAsync), Method = nameof(FindByClientIdAsync),
Identifier = await _store.GetClientIdAsync(application, cancellationToken) Identifier = await _store.GetClientIdAsync(application, cancellationToken)
})) }, application, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(application);
}
} }
/// <summary> /// <summary>
@ -154,6 +138,8 @@ namespace OpenIddict.Core
return new ValueTask<TApplication>(application); return new ValueTask<TApplication>(application);
} }
return new ValueTask<TApplication>(ExecuteAsync());
async Task<TApplication> ExecuteAsync() async Task<TApplication> ExecuteAsync()
{ {
if ((application = await _store.FindByClientIdAsync(identifier, cancellationToken)) != null) if ((application = await _store.FindByClientIdAsync(identifier, cancellationToken)) != null)
@ -161,27 +147,10 @@ namespace OpenIddict.Core
await AddAsync(application, cancellationToken); await AddAsync(application, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, application, cancellationToken);
{
if (application != null)
{
var signal = await CreateExpirationSignalAsync(application, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(application);
}
return application; return application;
} }
return new ValueTask<TApplication>(ExecuteAsync());
} }
/// <summary> /// <summary>
@ -211,6 +180,8 @@ namespace OpenIddict.Core
return new ValueTask<TApplication>(application); return new ValueTask<TApplication>(application);
} }
return new ValueTask<TApplication>(ExecuteAsync());
async Task<TApplication> ExecuteAsync() async Task<TApplication> ExecuteAsync()
{ {
if ((application = await _store.FindByIdAsync(identifier, cancellationToken)) != null) if ((application = await _store.FindByIdAsync(identifier, cancellationToken)) != null)
@ -218,27 +189,10 @@ namespace OpenIddict.Core
await AddAsync(application, cancellationToken); await AddAsync(application, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, application, cancellationToken);
{
if (application != null)
{
var signal = await CreateExpirationSignalAsync(application, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(application);
}
return application; return application;
} }
return new ValueTask<TApplication>(ExecuteAsync());
} }
/// <summary> /// <summary>
@ -255,42 +209,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The address cannot be null or empty.", nameof(address)); throw new ArgumentException("The address cannot be null or empty.", nameof(address));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindByPostLogoutRedirectUriAsync),
Address = address
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TApplication> applications))
{
return applications.ToAsyncEnumerable();
}
async IAsyncEnumerable<TApplication> ExecuteAsync() async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var applications = ImmutableArray.CreateRange(await _store.FindByPostLogoutRedirectUriAsync( var parameters = new
address, cancellationToken).ToListAsync(cancellationToken));
foreach (var application in applications)
{ {
await AddAsync(application, cancellationToken); Method = nameof(FindByPostLogoutRedirectUriAsync),
} Address = address
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TApplication> applications))
{ {
foreach (var application in applications) var builder = ImmutableArray.CreateBuilder<TApplication>();
await foreach (var application in _store.FindByPostLogoutRedirectUriAsync(address, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(application, cancellationToken); builder.Add(application);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(application, cancellationToken);
} }
entry.SetSize(applications.Length); applications = builder.ToImmutable();
entry.SetValue(applications);
await CreateEntryAsync(parameters, applications, cancellationToken);
} }
foreach (var application in applications) foreach (var application in applications)
@ -298,8 +240,6 @@ namespace OpenIddict.Core
yield return application; yield return application;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -316,42 +256,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The address cannot be null or empty.", nameof(address)); throw new ArgumentException("The address cannot be null or empty.", nameof(address));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindByRedirectUriAsync),
Address = address
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TApplication> applications)) async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return applications.ToAsyncEnumerable(); var parameters = new
}
async IAsyncEnumerable<TApplication> ExecuteAsync()
{
var applications = ImmutableArray.CreateRange(await _store.FindByRedirectUriAsync(
address, cancellationToken).ToListAsync(cancellationToken));
foreach (var application in applications)
{ {
await AddAsync(application, cancellationToken); Method = nameof(FindByRedirectUriAsync),
} Address = address
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TApplication> applications))
{ {
foreach (var application in applications) var builder = ImmutableArray.CreateBuilder<TApplication>();
await foreach (var application in _store.FindByRedirectUriAsync(address, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(application, cancellationToken); builder.Add(application);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(application, cancellationToken);
} }
entry.SetSize(applications.Length); applications = builder.ToImmutable();
entry.SetValue(applications);
await CreateEntryAsync(parameters, applications, cancellationToken);
} }
foreach (var application in applications) foreach (var application in applications)
@ -359,8 +287,6 @@ namespace OpenIddict.Core
yield return application; yield return application;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -389,6 +315,70 @@ namespace OpenIddict.Core
} }
} }
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="application">The application to store in the cache entry, if applicable.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] TApplication application, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
if (application != null)
{
var signal = await CreateExpirationSignalAsync(application, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(application);
}
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="applications">The applications to store in the cache entry.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] ImmutableArray<TApplication> applications, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
foreach (var application in applications)
{
var signal = await CreateExpirationSignalAsync(application, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(applications.Length);
entry.SetValue(applications);
}
/// <summary> /// <summary>
/// Creates an expiration signal allowing to invalidate all the /// Creates an expiration signal allowing to invalidate all the
/// cache entries associated with the specified application. /// cache entries associated with the specified application.

342
src/OpenIddict.Core/Caches/OpenIddictAuthorizationCache.cs

@ -8,7 +8,7 @@ using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Linq; using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using JetBrains.Annotations; using JetBrains.Annotations;
@ -99,22 +99,11 @@ namespace OpenIddict.Core
Subject = await _store.GetSubjectAsync(authorization, cancellationToken) Subject = await _store.GetSubjectAsync(authorization, cancellationToken)
}); });
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); await CreateEntryAsync(new
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
using (var entry = _cache.CreateEntry(new
{ {
Method = nameof(FindByIdAsync), Method = nameof(FindByIdAsync),
Identifier = await _store.GetIdAsync(authorization, cancellationToken) Identifier = await _store.GetIdAsync(authorization, cancellationToken)
})) }, authorization, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(authorization);
}
} }
/// <summary> /// <summary>
@ -151,43 +140,31 @@ namespace OpenIddict.Core
throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindAsync),
Subject = subject,
Client = client
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{
return authorizations.ToAsyncEnumerable();
}
async IAsyncEnumerable<TAuthorization> ExecuteAsync() async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var authorizations = ImmutableArray.CreateRange(await _store.FindAsync( var parameters = new
subject, client, cancellationToken).ToListAsync(cancellationToken));
foreach (var authorization in authorizations)
{ {
await AddAsync(authorization, cancellationToken); Method = nameof(FindAsync),
} Subject = subject,
Client = client
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{ {
foreach (var authorization in authorizations) var builder = ImmutableArray.CreateBuilder<TAuthorization>();
await foreach (var authorization in _store.FindAsync(subject, client, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); builder.Add(authorization);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(authorization, cancellationToken);
} }
entry.SetSize(authorizations.Length); authorizations = builder.ToImmutable();
entry.SetValue(authorizations);
await CreateEntryAsync(parameters, authorizations, cancellationToken);
} }
foreach (var authorization in authorizations) foreach (var authorization in authorizations)
@ -195,8 +172,6 @@ namespace OpenIddict.Core
yield return authorization; yield return authorization;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -226,44 +201,32 @@ namespace OpenIddict.Core
throw new ArgumentException("The status cannot be null or empty.", nameof(status)); throw new ArgumentException("The status cannot be null or empty.", nameof(status));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindAsync),
Subject = subject,
Client = client,
Status = status
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations)) async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return authorizations.ToAsyncEnumerable(); var parameters = new
}
async IAsyncEnumerable<TAuthorization> ExecuteAsync()
{
var authorizations = ImmutableArray.CreateRange(await _store.FindAsync(
subject, client, status, cancellationToken).ToListAsync(cancellationToken));
foreach (var authorization in authorizations)
{ {
await AddAsync(authorization, cancellationToken); Method = nameof(FindAsync),
} Subject = subject,
Client = client,
Status = status
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{ {
foreach (var authorization in authorizations) var builder = ImmutableArray.CreateBuilder<TAuthorization>();
await foreach (var authorization in _store.FindAsync(subject, client, status, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); builder.Add(authorization);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(authorization, cancellationToken);
} }
entry.SetSize(authorizations.Length); authorizations = builder.ToImmutable();
entry.SetValue(authorizations);
await CreateEntryAsync(parameters, authorizations, cancellationToken);
} }
foreach (var authorization in authorizations) foreach (var authorization in authorizations)
@ -271,8 +234,6 @@ namespace OpenIddict.Core
yield return authorization; yield return authorization;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -308,45 +269,33 @@ namespace OpenIddict.Core
throw new ArgumentException("The type cannot be null or empty.", nameof(type)); throw new ArgumentException("The type cannot be null or empty.", nameof(type));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindAsync),
Subject = subject,
Client = client,
Status = status,
Type = type
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{
return authorizations.ToAsyncEnumerable();
}
async IAsyncEnumerable<TAuthorization> ExecuteAsync() async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var authorizations = ImmutableArray.CreateRange(await _store.FindAsync( var parameters = new
subject, client, status, type, cancellationToken).ToListAsync(cancellationToken));
foreach (var authorization in authorizations)
{ {
await AddAsync(authorization, cancellationToken); Method = nameof(FindAsync),
} Subject = subject,
Client = client,
using (var entry = _cache.CreateEntry(parameters)) Status = status,
Type = type
};
if (!_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{ {
foreach (var authorization in authorizations) var builder = ImmutableArray.CreateBuilder<TAuthorization>();
await foreach (var authorization in _store.FindAsync(subject, client, status, type, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); builder.Add(authorization);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(authorization, cancellationToken);
} }
entry.SetSize(authorizations.Length); authorizations = builder.ToImmutable();
entry.SetValue(authorizations);
await CreateEntryAsync(parameters, authorizations, cancellationToken);
} }
foreach (var authorization in authorizations) foreach (var authorization in authorizations)
@ -354,8 +303,6 @@ namespace OpenIddict.Core
yield return authorization; yield return authorization;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -395,7 +342,9 @@ namespace OpenIddict.Core
// Note: this method is only partially cached. // Note: this method is only partially cached.
async IAsyncEnumerable<TAuthorization> ExecuteAsync() return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
await foreach (var authorization in _store.FindAsync(subject, client, status, type, scopes, cancellationToken)) await foreach (var authorization in _store.FindAsync(subject, client, status, type, scopes, cancellationToken))
{ {
@ -404,8 +353,6 @@ namespace OpenIddict.Core
yield return authorization; yield return authorization;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -422,42 +369,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindByApplicationIdAsync),
Identifier = identifier
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations)) async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return authorizations.ToAsyncEnumerable(); var parameters = new
}
async IAsyncEnumerable<TAuthorization> ExecuteAsync()
{
var authorizations = ImmutableArray.CreateRange(await _store.FindByApplicationIdAsync(
identifier, cancellationToken).ToListAsync(cancellationToken));
foreach (var authorization in authorizations)
{ {
await AddAsync(authorization, cancellationToken); Method = nameof(FindByApplicationIdAsync),
} Identifier = identifier
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{ {
foreach (var authorization in authorizations) var builder = ImmutableArray.CreateBuilder<TAuthorization>();
await foreach (var authorization in _store.FindByApplicationIdAsync(identifier, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); builder.Add(authorization);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(authorization, cancellationToken);
} }
entry.SetSize(authorizations.Length); authorizations = builder.ToImmutable();
entry.SetValue(authorizations);
await CreateEntryAsync(parameters, authorizations, cancellationToken);
} }
foreach (var authorization in authorizations) foreach (var authorization in authorizations)
@ -465,8 +400,6 @@ namespace OpenIddict.Core
yield return authorization; yield return authorization;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -496,6 +429,8 @@ namespace OpenIddict.Core
return new ValueTask<TAuthorization>(authorization); return new ValueTask<TAuthorization>(authorization);
} }
return new ValueTask<TAuthorization>(ExecuteAsync());
async Task<TAuthorization> ExecuteAsync() async Task<TAuthorization> ExecuteAsync()
{ {
if ((authorization = await _store.FindByIdAsync(identifier, cancellationToken)) != null) if ((authorization = await _store.FindByIdAsync(identifier, cancellationToken)) != null)
@ -503,27 +438,10 @@ namespace OpenIddict.Core
await AddAsync(authorization, cancellationToken); await AddAsync(authorization, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, authorization, cancellationToken);
{
if (authorization != null)
{
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(authorization);
}
return authorization; return authorization;
} }
return new ValueTask<TAuthorization>(ExecuteAsync());
} }
/// <summary> /// <summary>
@ -540,42 +458,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); throw new ArgumentException("The subject cannot be null or empty.", nameof(subject));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindBySubjectAsync),
Subject = subject
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{
return authorizations.ToAsyncEnumerable();
}
async IAsyncEnumerable<TAuthorization> ExecuteAsync() async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var authorizations = ImmutableArray.CreateRange(await _store.FindBySubjectAsync( var parameters = new
subject, cancellationToken).ToListAsync(cancellationToken));
foreach (var authorization in authorizations)
{ {
await AddAsync(authorization, cancellationToken); Method = nameof(FindBySubjectAsync),
} Subject = subject
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TAuthorization> authorizations))
{ {
foreach (var authorization in authorizations) var builder = ImmutableArray.CreateBuilder<TAuthorization>();
await foreach (var authorization in _store.FindBySubjectAsync(subject, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken); builder.Add(authorization);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(authorization, cancellationToken);
} }
entry.SetSize(authorizations.Length); authorizations = builder.ToImmutable();
entry.SetValue(authorizations);
await CreateEntryAsync(parameters, authorizations, cancellationToken);
} }
foreach (var authorization in authorizations) foreach (var authorization in authorizations)
@ -583,8 +489,6 @@ namespace OpenIddict.Core
yield return authorization; yield return authorization;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -613,6 +517,70 @@ namespace OpenIddict.Core
} }
} }
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="authorization">The authorization to store in the cache entry, if applicable.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] TAuthorization authorization, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
if (authorization != null)
{
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(authorization);
}
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="authorizations">The authorizations to store in the cache entry.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] ImmutableArray<TAuthorization> authorizations, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
foreach (var authorization in authorizations)
{
var signal = await CreateExpirationSignalAsync(authorization, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(authorizations.Length);
entry.SetValue(authorizations);
}
/// <summary> /// <summary>
/// Creates an expiration signal allowing to invalidate all the /// Creates an expiration signal allowing to invalidate all the
/// cache entries associated with the specified authorization. /// cache entries associated with the specified authorization.

178
src/OpenIddict.Core/Caches/OpenIddictScopeCache.cs

@ -9,6 +9,7 @@ using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using JetBrains.Annotations; using JetBrains.Annotations;
@ -76,33 +77,17 @@ namespace OpenIddict.Core
}); });
} }
var signal = await CreateExpirationSignalAsync(scope, cancellationToken); await CreateEntryAsync(new
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration token.");
}
using (var entry = _cache.CreateEntry(new
{ {
Method = nameof(FindByIdAsync), Method = nameof(FindByIdAsync),
Identifier = await _store.GetIdAsync(scope, cancellationToken) Identifier = await _store.GetIdAsync(scope, cancellationToken)
})) }, scope, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(scope);
}
using (var entry = _cache.CreateEntry(new await CreateEntryAsync(new
{ {
Method = nameof(FindByNameAsync), Method = nameof(FindByNameAsync),
Name = await _store.GetNameAsync(scope, cancellationToken) Name = await _store.GetNameAsync(scope, cancellationToken)
})) }, scope, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(scope);
}
} }
/// <summary> /// <summary>
@ -152,22 +137,7 @@ namespace OpenIddict.Core
await AddAsync(scope, cancellationToken); await AddAsync(scope, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, scope, cancellationToken);
{
if (scope != null)
{
var signal = await CreateExpirationSignalAsync(scope, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(scope);
}
return scope; return scope;
} }
@ -209,22 +179,7 @@ namespace OpenIddict.Core
await AddAsync(scope, cancellationToken); await AddAsync(scope, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, scope, cancellationToken);
{
if (scope != null)
{
var signal = await CreateExpirationSignalAsync(scope, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(scope);
}
return scope; return scope;
} }
@ -240,11 +195,6 @@ namespace OpenIddict.Core
/// <returns>The scopes corresponding to the specified names.</returns> /// <returns>The scopes corresponding to the specified names.</returns>
public IAsyncEnumerable<TScope> FindByNamesAsync(ImmutableArray<string> names, CancellationToken cancellationToken) public IAsyncEnumerable<TScope> FindByNamesAsync(ImmutableArray<string> names, CancellationToken cancellationToken)
{ {
if (names.IsDefaultOrEmpty)
{
return AsyncEnumerable.Empty<TScope>();
}
if (names.Any(name => string.IsNullOrEmpty(name))) if (names.Any(name => string.IsNullOrEmpty(name)))
{ {
throw new ArgumentException("Scope names cannot be null or empty.", nameof(names)); throw new ArgumentException("Scope names cannot be null or empty.", nameof(names));
@ -252,7 +202,9 @@ namespace OpenIddict.Core
// Note: this method is only partially cached. // Note: this method is only partially cached.
async IAsyncEnumerable<TScope> ExecuteAsync() return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TScope> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
await foreach (var scope in _store.FindByNamesAsync(names, cancellationToken)) await foreach (var scope in _store.FindByNamesAsync(names, cancellationToken))
{ {
@ -261,8 +213,6 @@ namespace OpenIddict.Core
yield return scope; yield return scope;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -278,42 +228,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The resource cannot be null or empty.", nameof(resource)); throw new ArgumentException("The resource cannot be null or empty.", nameof(resource));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindByResourceAsync),
Resource = resource
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TScope> scopes))
{
return scopes.ToAsyncEnumerable();
}
async IAsyncEnumerable<TScope> ExecuteAsync() async IAsyncEnumerable<TScope> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var scopes = ImmutableArray.CreateRange(await _store.FindByResourceAsync( var parameters = new
resource, cancellationToken).ToListAsync(cancellationToken));
foreach (var scope in scopes)
{ {
await AddAsync(scope, cancellationToken); Method = nameof(FindByResourceAsync),
} Resource = resource
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TScope> scopes))
{ {
foreach (var scope in scopes) var builder = ImmutableArray.CreateBuilder<TScope>();
await foreach (var scope in _store.FindByResourceAsync(resource, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(scope, cancellationToken); builder.Add(scope);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(scope, cancellationToken);
} }
entry.SetSize(scopes.Length); scopes = builder.ToImmutable();
entry.SetValue(scopes);
await CreateEntryAsync(parameters, scopes, cancellationToken);
} }
foreach (var scope in scopes) foreach (var scope in scopes)
@ -321,8 +259,6 @@ namespace OpenIddict.Core
yield return scope; yield return scope;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -351,6 +287,70 @@ namespace OpenIddict.Core
} }
} }
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="scope">The scope to store in the cache entry, if applicable.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] TScope scope, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
if (scope != null)
{
var signal = await CreateExpirationSignalAsync(scope, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(scope);
}
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="scopes">The scopes to store in the cache entry.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] ImmutableArray<TScope> scopes, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
foreach (var scope in scopes)
{
var signal = await CreateExpirationSignalAsync(scope, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(scopes.Length);
entry.SetValue(scopes);
}
/// <summary> /// <summary>
/// Creates an expiration signal allowing to invalidate all the /// Creates an expiration signal allowing to invalidate all the
/// cache entries associated with the specified scope. /// cache entries associated with the specified scope.

410
src/OpenIddict.Core/Caches/OpenIddictTokenCache.cs

@ -8,7 +8,7 @@ using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Linq; using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using JetBrains.Annotations; using JetBrains.Annotations;
@ -109,33 +109,17 @@ namespace OpenIddict.Core
Subject = await _store.GetSubjectAsync(token, cancellationToken) Subject = await _store.GetSubjectAsync(token, cancellationToken)
}); });
var signal = await CreateExpirationSignalAsync(token, cancellationToken); await CreateEntryAsync(new
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
using (var entry = _cache.CreateEntry(new
{ {
Method = nameof(FindByIdAsync), Method = nameof(FindByIdAsync),
Identifier = await _store.GetIdAsync(token, cancellationToken) Identifier = await _store.GetIdAsync(token, cancellationToken)
})) }, token, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(token);
}
using (var entry = _cache.CreateEntry(new await CreateEntryAsync(new
{ {
Method = nameof(FindByReferenceIdAsync), Method = nameof(FindByReferenceIdAsync),
Identifier = await _store.GetReferenceIdAsync(token, cancellationToken) Identifier = await _store.GetReferenceIdAsync(token, cancellationToken)
})) }, token, cancellationToken);
{
entry.AddExpirationToken(signal)
.SetSize(1L)
.SetValue(token);
}
} }
/// <summary> /// <summary>
@ -172,43 +156,31 @@ namespace OpenIddict.Core
throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindAsync),
Subject = subject,
Client = client
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens)) async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return tokens.ToAsyncEnumerable(); var parameters = new
}
async IAsyncEnumerable<TToken> ExecuteAsync()
{
var tokens = ImmutableArray.CreateRange(await _store.FindAsync(
subject, client, cancellationToken).ToListAsync(cancellationToken));
foreach (var token in tokens)
{ {
await AddAsync(token, cancellationToken); Method = nameof(FindAsync),
} Subject = subject,
Client = client
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{ {
foreach (var token in tokens) var builder = ImmutableArray.CreateBuilder<TToken>();
await foreach (var token in _store.FindAsync(subject, client, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(token, cancellationToken); builder.Add(token);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(token, cancellationToken);
} }
entry.SetSize(tokens.Length); tokens = builder.ToImmutable();
entry.SetValue(tokens);
await CreateEntryAsync(parameters, tokens, cancellationToken);
} }
foreach (var token in tokens) foreach (var token in tokens)
@ -216,8 +188,6 @@ namespace OpenIddict.Core
yield return token; yield return token;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -247,44 +217,32 @@ namespace OpenIddict.Core
throw new ArgumentException("The status cannot be null or empty.", nameof(status)); throw new ArgumentException("The status cannot be null or empty.", nameof(status));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindAsync),
Subject = subject,
Client = client,
Status = status
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{
return tokens.ToAsyncEnumerable();
}
async IAsyncEnumerable<TToken> ExecuteAsync() async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var tokens = ImmutableArray.CreateRange(await _store.FindAsync( var parameters = new
subject, client, status, cancellationToken).ToListAsync(cancellationToken));
foreach (var token in tokens)
{ {
await AddAsync(token, cancellationToken); Method = nameof(FindAsync),
} Subject = subject,
Client = client,
Status = status
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{ {
foreach (var token in tokens) var builder = ImmutableArray.CreateBuilder<TToken>();
await foreach (var token in _store.FindAsync(subject, client, status, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(token, cancellationToken); builder.Add(token);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(token, cancellationToken);
} }
entry.SetSize(tokens.Length); tokens = builder.ToImmutable();
entry.SetValue(tokens);
await CreateEntryAsync(parameters, tokens, cancellationToken);
} }
foreach (var token in tokens) foreach (var token in tokens)
@ -292,8 +250,6 @@ namespace OpenIddict.Core
yield return token; yield return token;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -329,45 +285,33 @@ namespace OpenIddict.Core
throw new ArgumentException("The type cannot be null or empty.", nameof(type)); throw new ArgumentException("The type cannot be null or empty.", nameof(type));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindAsync),
Subject = subject,
Client = client,
Status = status,
Type = type
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{
return tokens.ToAsyncEnumerable();
}
async IAsyncEnumerable<TToken> ExecuteAsync() async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var tokens = ImmutableArray.CreateRange(await _store.FindAsync( var parameters = new
subject, client, status, type, cancellationToken).ToListAsync(cancellationToken));
foreach (var token in tokens)
{ {
await AddAsync(token, cancellationToken); Method = nameof(FindAsync),
} Subject = subject,
Client = client,
using (var entry = _cache.CreateEntry(parameters)) Status = status,
Type = type
};
if (!_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{ {
foreach (var token in tokens) var builder = ImmutableArray.CreateBuilder<TToken>();
await foreach (var token in _store.FindAsync(subject, client, status, type, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(token, cancellationToken); builder.Add(token);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(token, cancellationToken);
} }
entry.SetSize(tokens.Length); tokens = builder.ToImmutable();
entry.SetValue(tokens);
await CreateEntryAsync(parameters, tokens, cancellationToken);
} }
foreach (var token in tokens) foreach (var token in tokens)
@ -375,8 +319,6 @@ namespace OpenIddict.Core
yield return token; yield return token;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -393,42 +335,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindByApplicationIdAsync),
Identifier = identifier
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens)) async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return tokens.ToAsyncEnumerable(); var parameters = new
}
async IAsyncEnumerable<TToken> ExecuteAsync()
{
var tokens = ImmutableArray.CreateRange(await _store.FindByApplicationIdAsync(
identifier, cancellationToken).ToListAsync(cancellationToken));
foreach (var token in tokens)
{ {
await AddAsync(token, cancellationToken); Method = nameof(FindByApplicationIdAsync),
} Identifier = identifier
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{ {
foreach (var token in tokens) var builder = ImmutableArray.CreateBuilder<TToken>();
await foreach (var token in _store.FindByApplicationIdAsync(identifier, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(token, cancellationToken); builder.Add(token);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(token, cancellationToken);
} }
entry.SetSize(tokens.Length); tokens = builder.ToImmutable();
entry.SetValue(tokens);
await CreateEntryAsync(parameters, tokens, cancellationToken);
} }
foreach (var token in tokens) foreach (var token in tokens)
@ -436,8 +366,6 @@ namespace OpenIddict.Core
yield return token; yield return token;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -454,42 +382,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindByAuthorizationIdAsync),
Identifier = identifier
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens)) async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
return tokens.ToAsyncEnumerable(); var parameters = new
}
async IAsyncEnumerable<TToken> ExecuteAsync()
{
var tokens = ImmutableArray.CreateRange(await _store.FindByAuthorizationIdAsync(
identifier, cancellationToken).ToListAsync(cancellationToken));
foreach (var token in tokens)
{ {
await AddAsync(token, cancellationToken); Method = nameof(FindByAuthorizationIdAsync),
} Identifier = identifier
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{ {
foreach (var token in tokens) var builder = ImmutableArray.CreateBuilder<TToken>();
await foreach (var token in _store.FindByAuthorizationIdAsync(identifier, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(token, cancellationToken); builder.Add(token);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(token, cancellationToken);
} }
entry.SetSize(tokens.Length); tokens = builder.ToImmutable();
entry.SetValue(tokens);
await CreateEntryAsync(parameters, tokens, cancellationToken);
} }
foreach (var token in tokens) foreach (var token in tokens)
@ -497,8 +413,6 @@ namespace OpenIddict.Core
yield return token; yield return token;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -528,6 +442,8 @@ namespace OpenIddict.Core
return new ValueTask<TToken>(token); return new ValueTask<TToken>(token);
} }
return new ValueTask<TToken>(ExecuteAsync());
async Task<TToken> ExecuteAsync() async Task<TToken> ExecuteAsync()
{ {
if ((token = await _store.FindByIdAsync(identifier, cancellationToken)) != null) if ((token = await _store.FindByIdAsync(identifier, cancellationToken)) != null)
@ -535,27 +451,10 @@ namespace OpenIddict.Core
await AddAsync(token, cancellationToken); await AddAsync(token, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, token, cancellationToken);
{
if (token != null)
{
var signal = await CreateExpirationSignalAsync(token, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(token);
}
return token; return token;
} }
return new ValueTask<TToken>(ExecuteAsync());
} }
/// <summary> /// <summary>
@ -586,6 +485,8 @@ namespace OpenIddict.Core
return new ValueTask<TToken>(token); return new ValueTask<TToken>(token);
} }
return new ValueTask<TToken>(ExecuteAsync());
async Task<TToken> ExecuteAsync() async Task<TToken> ExecuteAsync()
{ {
if ((token = await _store.FindByReferenceIdAsync(identifier, cancellationToken)) != null) if ((token = await _store.FindByReferenceIdAsync(identifier, cancellationToken)) != null)
@ -593,27 +494,10 @@ namespace OpenIddict.Core
await AddAsync(token, cancellationToken); await AddAsync(token, cancellationToken);
} }
using (var entry = _cache.CreateEntry(parameters)) await CreateEntryAsync(parameters, token, cancellationToken);
{
if (token != null)
{
var signal = await CreateExpirationSignalAsync(token, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(token);
}
return token; return token;
} }
return new ValueTask<TToken>(ExecuteAsync());
} }
/// <summary> /// <summary>
@ -629,42 +513,30 @@ namespace OpenIddict.Core
throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); throw new ArgumentException("The subject cannot be null or empty.", nameof(subject));
} }
var parameters = new return ExecuteAsync(cancellationToken);
{
Method = nameof(FindBySubjectAsync),
Identifier = subject
};
if (_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{
return tokens.ToAsyncEnumerable();
}
async IAsyncEnumerable<TToken> ExecuteAsync() async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var tokens = ImmutableArray.CreateRange(await _store.FindBySubjectAsync( var parameters = new
subject, cancellationToken).ToListAsync(cancellationToken));
foreach (var token in tokens)
{ {
await AddAsync(token, cancellationToken); Method = nameof(FindBySubjectAsync),
} Identifier = subject
};
using (var entry = _cache.CreateEntry(parameters)) if (!_cache.TryGetValue(parameters, out ImmutableArray<TToken> tokens))
{ {
foreach (var token in tokens) var builder = ImmutableArray.CreateBuilder<TToken>();
await foreach (var token in _store.FindBySubjectAsync(subject, cancellationToken))
{ {
var signal = await CreateExpirationSignalAsync(token, cancellationToken); builder.Add(token);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal); await AddAsync(token, cancellationToken);
} }
entry.SetSize(tokens.Length); tokens = builder.ToImmutable();
entry.SetValue(tokens);
await CreateEntryAsync(parameters, tokens, cancellationToken);
} }
foreach (var token in tokens) foreach (var token in tokens)
@ -672,8 +544,6 @@ namespace OpenIddict.Core
yield return token; yield return token;
} }
} }
return ExecuteAsync();
} }
/// <summary> /// <summary>
@ -702,6 +572,70 @@ namespace OpenIddict.Core
} }
} }
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="token">The token to store in the cache entry, if applicable.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] TToken token, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
if (token != null)
{
var signal = await CreateExpirationSignalAsync(token, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(1L);
entry.SetValue(token);
}
/// <summary>
/// Creates a cache entry for the specified key.
/// </summary>
/// <param name="key">The cache key.</param>
/// <param name="tokens">The tokens to store in the cache entry.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>A <see cref="ValueTask"/> that can be used to monitor the asynchronous operation.</returns>
protected virtual async ValueTask CreateEntryAsync(
[NotNull] object key, [CanBeNull] ImmutableArray<TToken> tokens, CancellationToken cancellationToken)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
using var entry = _cache.CreateEntry(key);
foreach (var token in tokens)
{
var signal = await CreateExpirationSignalAsync(token, cancellationToken);
if (signal == null)
{
throw new InvalidOperationException("An error occurred while creating an expiration signal.");
}
entry.AddExpirationToken(signal);
}
entry.SetSize(tokens.Length);
entry.SetValue(tokens);
}
/// <summary> /// <summary>
/// Creates an expiration signal allowing to invalidate all the /// Creates an expiration signal allowing to invalidate all the
/// cache entries associated with the specified token. /// cache entries associated with the specified token.

70
src/OpenIddict.Core/Managers/OpenIddictApplicationManager.cs

@ -156,7 +156,7 @@ namespace OpenIddict.Core
await Store.SetClientSecretAsync(application, secret, cancellationToken); await Store.SetClientSecretAsync(application, secret, cancellationToken);
} }
var results = await ValidateAsync(application, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(application, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -168,7 +168,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.CreateAsync(application, cancellationToken); await Store.CreateAsync(application, cancellationToken);
@ -177,6 +177,19 @@ namespace OpenIddict.Core
{ {
await Cache.AddAsync(application, cancellationToken); await Cache.AddAsync(application, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TApplication application, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(application, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -341,12 +354,23 @@ namespace OpenIddict.Core
return applications; return applications;
} }
return ExecuteAsync(cancellationToken);
// SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default. // SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default.
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return applications.WhereAwait(async application => async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
(await Store.GetPostLogoutRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal)); {
await foreach (var application in applications)
{
var addresses = await Store.GetPostLogoutRedirectUrisAsync(application, cancellationToken);
if (addresses.Contains(address, StringComparer.Ordinal))
{
yield return application;
}
}
}
} }
/// <summary> /// <summary>
@ -376,8 +400,19 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return applications.WhereAwait(async application => return ExecuteAsync(cancellationToken);
(await Store.GetRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal));
async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var application in applications)
{
var addresses = await Store.GetRedirectUrisAsync(application, cancellationToken);
if (addresses.Contains(address, StringComparer.Ordinal))
{
yield return application;
}
}
}
} }
/// <summary> /// <summary>
@ -872,7 +907,7 @@ namespace OpenIddict.Core
throw new ArgumentNullException(nameof(application)); throw new ArgumentNullException(nameof(application));
} }
var results = await ValidateAsync(application, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(application, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -884,7 +919,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.UpdateAsync(application, cancellationToken); await Store.UpdateAsync(application, cancellationToken);
@ -894,6 +929,19 @@ namespace OpenIddict.Core
await Cache.RemoveAsync(application, cancellationToken); await Cache.RemoveAsync(application, cancellationToken);
await Cache.AddAsync(application, cancellationToken); await Cache.AddAsync(application, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TApplication application, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(application, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -1379,10 +1427,10 @@ namespace OpenIddict.Core
=> await FindByIdAsync(identifier, cancellationToken); => await FindByIdAsync(identifier, cancellationToken);
IAsyncEnumerable<object> IOpenIddictApplicationManager.FindByPostLogoutRedirectUriAsync(string address, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictApplicationManager.FindByPostLogoutRedirectUriAsync(string address, CancellationToken cancellationToken)
=> FindByPostLogoutRedirectUriAsync(address, cancellationToken).OfType<object>(); => FindByPostLogoutRedirectUriAsync(address, cancellationToken);
IAsyncEnumerable<object> IOpenIddictApplicationManager.FindByRedirectUriAsync(string address, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictApplicationManager.FindByRedirectUriAsync(string address, CancellationToken cancellationToken)
=> FindByRedirectUriAsync(address, cancellationToken).OfType<object>(); => FindByRedirectUriAsync(address, cancellationToken);
ValueTask<TResult> IOpenIddictApplicationManager.GetAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken) ValueTask<TResult> IOpenIddictApplicationManager.GetAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken)
=> GetAsync(query, cancellationToken); => GetAsync(query, cancellationToken);
@ -1430,7 +1478,7 @@ namespace OpenIddict.Core
=> HasRequirementAsync((TApplication) application, requirement, cancellationToken); => HasRequirementAsync((TApplication) application, requirement, cancellationToken);
IAsyncEnumerable<object> IOpenIddictApplicationManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictApplicationManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken)
=> ListAsync(count, offset, cancellationToken).OfType<object>(); => ListAsync(count, offset, cancellationToken);
IAsyncEnumerable<TResult> IOpenIddictApplicationManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken) IAsyncEnumerable<TResult> IOpenIddictApplicationManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken)
=> ListAsync(query, cancellationToken); => ListAsync(query, cancellationToken);

148
src/OpenIddict.Core/Managers/OpenIddictAuthorizationManager.cs

@ -114,7 +114,7 @@ namespace OpenIddict.Core
await Store.SetStatusAsync(authorization, Statuses.Valid, cancellationToken); await Store.SetStatusAsync(authorization, Statuses.Valid, cancellationToken);
} }
var results = await ValidateAsync(authorization, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(authorization, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -126,7 +126,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.CreateAsync(authorization, cancellationToken); await Store.CreateAsync(authorization, cancellationToken);
@ -135,6 +135,19 @@ namespace OpenIddict.Core
{ {
await Cache.AddAsync(authorization, cancellationToken); await Cache.AddAsync(authorization, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TAuthorization authorization, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(authorization, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -272,8 +285,22 @@ namespace OpenIddict.Core
return authorizations; return authorizations;
} }
return authorizations.WhereAwait(async authorization => string.Equals( // SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default.
await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)); // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var authorization in authorizations)
{
if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal))
{
yield return authorization;
}
}
}
} }
/// <summary> /// <summary>
@ -316,8 +343,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return authorizations.WhereAwait(async authorization => string.Equals( return ExecuteAsync(cancellationToken);
await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal));
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var authorization in authorizations)
{
if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal))
{
yield return authorization;
}
}
}
} }
/// <summary> /// <summary>
@ -362,8 +399,22 @@ namespace OpenIddict.Core
return authorizations; return authorizations;
} }
return authorizations.WhereAwait(async authorization => string.Equals( // SQL engines like Microsoft SQL Server or MySQL are known to use case-insensitive lookups by default.
await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal)); // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var authorization in authorizations)
{
if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal))
{
yield return authorization;
}
}
}
} }
/// <summary> /// <summary>
@ -414,9 +465,25 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return authorizations.WhereAwait(async authorization => string.Equals( return ExecuteAsync(cancellationToken);
await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal) &&
await HasScopesAsync(authorization, scopes, cancellationToken)); async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var authorization in authorizations)
{
if (!string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal))
{
continue;
}
if (!await HasScopesAsync(authorization, scopes, cancellationToken))
{
continue;
}
yield return authorization;
}
}
} }
/// <summary> /// <summary>
@ -446,8 +513,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return authorizations.WhereAwait(async authorization => string.Equals( return ExecuteAsync(cancellationToken);
await Store.GetApplicationIdAsync(authorization, cancellationToken), identifier, StringComparison.Ordinal));
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var authorization in authorizations)
{
if (string.Equals(await Store.GetApplicationIdAsync(authorization, cancellationToken), identifier, StringComparison.Ordinal))
{
yield return authorization;
}
}
}
} }
/// <summary> /// <summary>
@ -514,8 +591,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return authorizations.WhereAwait(async authorization => string.Equals( return ExecuteAsync(cancellationToken);
await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal));
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var authorization in authorizations)
{
if (string.Equals(await Store.GetSubjectAsync(authorization, cancellationToken), subject, StringComparison.Ordinal))
{
yield return authorization;
}
}
}
} }
/// <summary> /// <summary>
@ -951,7 +1038,7 @@ namespace OpenIddict.Core
throw new ArgumentNullException(nameof(authorization)); throw new ArgumentNullException(nameof(authorization));
} }
var results = await ValidateAsync(authorization, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(authorization, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -963,7 +1050,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.UpdateAsync(authorization, cancellationToken); await Store.UpdateAsync(authorization, cancellationToken);
@ -973,6 +1060,19 @@ namespace OpenIddict.Core
await Cache.RemoveAsync(authorization, cancellationToken); await Cache.RemoveAsync(authorization, cancellationToken);
await Cache.AddAsync(authorization, cancellationToken); await Cache.AddAsync(authorization, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TAuthorization authorization, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(authorization, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -1070,25 +1170,25 @@ namespace OpenIddict.Core
=> DeleteAsync((TAuthorization) authorization, cancellationToken); => DeleteAsync((TAuthorization) authorization, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, CancellationToken cancellationToken)
=> FindAsync(subject, client, cancellationToken).OfType<object>(); => FindAsync(subject, client, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, cancellationToken).OfType<object>(); => FindAsync(subject, client, status, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, string type, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, string type, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, type, cancellationToken).OfType<object>(); => FindAsync(subject, client, status, type, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, string type, ImmutableArray<string> scopes, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindAsync(string subject, string client, string status, string type, ImmutableArray<string> scopes, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, type, scopes, cancellationToken).OfType<object>(); => FindAsync(subject, client, status, type, scopes, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindByApplicationIdAsync(string identifier, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindByApplicationIdAsync(string identifier, CancellationToken cancellationToken)
=> FindByApplicationIdAsync(identifier, cancellationToken).OfType<object>(); => FindByApplicationIdAsync(identifier, cancellationToken);
async ValueTask<object> IOpenIddictAuthorizationManager.FindByIdAsync(string identifier, CancellationToken cancellationToken) async ValueTask<object> IOpenIddictAuthorizationManager.FindByIdAsync(string identifier, CancellationToken cancellationToken)
=> await FindByIdAsync(identifier, cancellationToken); => await FindByIdAsync(identifier, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindBySubjectAsync(string subject, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.FindBySubjectAsync(string subject, CancellationToken cancellationToken)
=> FindBySubjectAsync(subject, cancellationToken).OfType<object>(); => FindBySubjectAsync(subject, cancellationToken);
ValueTask<string> IOpenIddictAuthorizationManager.GetApplicationIdAsync(object authorization, CancellationToken cancellationToken) ValueTask<string> IOpenIddictAuthorizationManager.GetApplicationIdAsync(object authorization, CancellationToken cancellationToken)
=> GetApplicationIdAsync((TAuthorization) authorization, cancellationToken); => GetApplicationIdAsync((TAuthorization) authorization, cancellationToken);
@ -1124,7 +1224,7 @@ namespace OpenIddict.Core
=> HasTypeAsync((TAuthorization) authorization, type, cancellationToken); => HasTypeAsync((TAuthorization) authorization, type, cancellationToken);
IAsyncEnumerable<object> IOpenIddictAuthorizationManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictAuthorizationManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken)
=> ListAsync(count, offset, cancellationToken).OfType<object>(); => ListAsync(count, offset, cancellationToken);
IAsyncEnumerable<TResult> IOpenIddictAuthorizationManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken) IAsyncEnumerable<TResult> IOpenIddictAuthorizationManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken)
=> ListAsync(query, cancellationToken); => ListAsync(query, cancellationToken);

99
src/OpenIddict.Core/Managers/OpenIddictScopeManager.cs

@ -106,7 +106,7 @@ namespace OpenIddict.Core
throw new ArgumentNullException(nameof(scope)); throw new ArgumentNullException(nameof(scope));
} }
var results = await ValidateAsync(scope, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(scope, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -118,7 +118,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.CreateAsync(scope, cancellationToken); await Store.CreateAsync(scope, cancellationToken);
@ -127,6 +127,19 @@ namespace OpenIddict.Core
{ {
await Cache.AddAsync(scope, cancellationToken); await Cache.AddAsync(scope, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TScope scope, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(scope, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -264,11 +277,6 @@ namespace OpenIddict.Core
public virtual IAsyncEnumerable<TScope> FindByNamesAsync( public virtual IAsyncEnumerable<TScope> FindByNamesAsync(
ImmutableArray<string> names, CancellationToken cancellationToken = default) ImmutableArray<string> names, CancellationToken cancellationToken = default)
{ {
if (names.IsDefaultOrEmpty)
{
return AsyncEnumerable.Empty<TScope>();
}
if (names.Any(name => string.IsNullOrEmpty(name))) if (names.Any(name => string.IsNullOrEmpty(name)))
{ {
throw new ArgumentException("Scope names cannot be null or empty.", nameof(names)); throw new ArgumentException("Scope names cannot be null or empty.", nameof(names));
@ -287,7 +295,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return scopes.WhereAwait(async scope => names.Contains(await Store.GetNameAsync(scope, cancellationToken), StringComparer.Ordinal)); return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TScope> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var scope in scopes)
{
if (names.Contains(await Store.GetNameAsync(scope, cancellationToken), StringComparer.Ordinal))
{
yield return scope;
}
}
}
} }
/// <summary> /// <summary>
@ -317,8 +336,19 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return scopes.WhereAwait(async scope => return ExecuteAsync(cancellationToken);
(await Store.GetResourcesAsync(scope, cancellationToken)).Contains(resource, StringComparer.Ordinal));
async IAsyncEnumerable<TScope> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var scope in scopes)
{
var resources = await Store.GetResourcesAsync(scope, cancellationToken);
if (resources.Contains(resource, StringComparer.Ordinal))
{
yield return scope;
}
}
}
} }
/// <summary> /// <summary>
@ -518,29 +548,19 @@ namespace OpenIddict.Core
/// <param name="scopes">The scopes.</param> /// <param name="scopes">The scopes.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param> /// <param name="cancellationToken">The <see cref="CancellationToken"/> that can be used to abort the operation.</param>
/// <returns>All the resources associated with the specified scopes.</returns> /// <returns>All the resources associated with the specified scopes.</returns>
public virtual IAsyncEnumerable<string> ListResourcesAsync( public virtual async IAsyncEnumerable<string> ListResourcesAsync(
ImmutableArray<string> scopes, CancellationToken cancellationToken = default) ImmutableArray<string> scopes, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{ {
if (scopes.IsDefaultOrEmpty) var resources = new HashSet<string>(StringComparer.Ordinal);
await foreach (var scope in FindByNamesAsync(scopes, cancellationToken))
{ {
return AsyncEnumerable.Empty<string>(); resources.UnionWith(await GetResourcesAsync(scope, cancellationToken));
} }
return ExecuteAsync(cancellationToken); foreach (var resource in resources)
async IAsyncEnumerable<string> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{ {
var resources = new HashSet<string>(StringComparer.Ordinal); yield return resource;
await foreach (var scope in FindByNamesAsync(scopes, cancellationToken))
{
resources.UnionWith(await GetResourcesAsync(scope, cancellationToken));
}
foreach (var resource in resources)
{
yield return resource;
}
} }
} }
@ -617,7 +637,7 @@ namespace OpenIddict.Core
throw new ArgumentNullException(nameof(scope)); throw new ArgumentNullException(nameof(scope));
} }
var results = await ValidateAsync(scope, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(scope, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -629,7 +649,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.UpdateAsync(scope, cancellationToken); await Store.UpdateAsync(scope, cancellationToken);
@ -639,6 +659,19 @@ namespace OpenIddict.Core
await Cache.RemoveAsync(scope, cancellationToken); await Cache.RemoveAsync(scope, cancellationToken);
await Cache.AddAsync(scope, cancellationToken); await Cache.AddAsync(scope, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TScope scope, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(scope, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -732,10 +765,10 @@ namespace OpenIddict.Core
=> await FindByNameAsync(name, cancellationToken); => await FindByNameAsync(name, cancellationToken);
IAsyncEnumerable<object> IOpenIddictScopeManager.FindByNamesAsync(ImmutableArray<string> names, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictScopeManager.FindByNamesAsync(ImmutableArray<string> names, CancellationToken cancellationToken)
=> FindByNamesAsync(names, cancellationToken).OfType<object>(); => FindByNamesAsync(names, cancellationToken);
IAsyncEnumerable<object> IOpenIddictScopeManager.FindByResourceAsync(string resource, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictScopeManager.FindByResourceAsync(string resource, CancellationToken cancellationToken)
=> FindByResourceAsync(resource, cancellationToken).OfType<object>(); => FindByResourceAsync(resource, cancellationToken);
ValueTask<TResult> IOpenIddictScopeManager.GetAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken) ValueTask<TResult> IOpenIddictScopeManager.GetAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken)
=> GetAsync(query, cancellationToken); => GetAsync(query, cancellationToken);
@ -759,7 +792,7 @@ namespace OpenIddict.Core
=> GetResourcesAsync((TScope) scope, cancellationToken); => GetResourcesAsync((TScope) scope, cancellationToken);
IAsyncEnumerable<object> IOpenIddictScopeManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictScopeManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken)
=> ListAsync(count, offset, cancellationToken).OfType<object>(); => ListAsync(count, offset, cancellationToken);
IAsyncEnumerable<TResult> IOpenIddictScopeManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken) IAsyncEnumerable<TResult> IOpenIddictScopeManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken)
=> ListAsync(query, cancellationToken); => ListAsync(query, cancellationToken);

132
src/OpenIddict.Core/Managers/OpenIddictTokenManager.cs

@ -122,7 +122,7 @@ namespace OpenIddict.Core
await Store.SetReferenceIdAsync(token, identifier, cancellationToken); await Store.SetReferenceIdAsync(token, identifier, cancellationToken);
} }
var results = await ValidateAsync(token, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(token, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -134,7 +134,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.CreateAsync(token, cancellationToken); await Store.CreateAsync(token, cancellationToken);
@ -143,6 +143,19 @@ namespace OpenIddict.Core
{ {
await Cache.AddAsync(token, cancellationToken); await Cache.AddAsync(token, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TToken token, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(token, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -230,8 +243,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( return ExecuteAsync(cancellationToken);
token, cancellationToken), subject, StringComparison.Ordinal));
async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var token in tokens)
{
if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal))
{
yield return token;
}
}
}
} }
/// <summary> /// <summary>
@ -274,8 +297,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( return ExecuteAsync(cancellationToken);
token, cancellationToken), subject, StringComparison.Ordinal));
async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var token in tokens)
{
if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal))
{
yield return token;
}
}
}
} }
/// <summary> /// <summary>
@ -324,8 +357,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( return ExecuteAsync(cancellationToken);
token, cancellationToken), subject, StringComparison.Ordinal));
async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var token in tokens)
{
if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal))
{
yield return token;
}
}
}
} }
/// <summary> /// <summary>
@ -355,8 +398,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return tokens.WhereAwait(async token => string.Equals(await Store.GetApplicationIdAsync( return ExecuteAsync(cancellationToken);
token, cancellationToken), identifier, StringComparison.Ordinal));
async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var token in tokens)
{
if (string.Equals(await Store.GetApplicationIdAsync(token, cancellationToken), identifier, StringComparison.Ordinal))
{
yield return token;
}
}
}
} }
/// <summary> /// <summary>
@ -386,8 +439,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return tokens.WhereAwait(async token => string.Equals(await Store.GetAuthorizationIdAsync( return ExecuteAsync(cancellationToken);
token, cancellationToken), identifier, StringComparison.Ordinal));
async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var token in tokens)
{
if (string.Equals(await Store.GetAuthorizationIdAsync(token, cancellationToken), identifier, StringComparison.Ordinal))
{
yield return token;
}
}
}
} }
/// <summary> /// <summary>
@ -495,8 +558,18 @@ namespace OpenIddict.Core
// To ensure a case-sensitive comparison is enforced independently of the database/table/query collation // To ensure a case-sensitive comparison is enforced independently of the database/table/query collation
// used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here. // used by the store, a second pass using string.Equals(StringComparison.Ordinal) is manually made here.
return tokens.WhereAwait(async token => string.Equals(await Store.GetSubjectAsync( return ExecuteAsync(cancellationToken);
token, cancellationToken), subject, StringComparison.Ordinal));
async IAsyncEnumerable<TToken> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
await foreach (var token in tokens)
{
if (string.Equals(await Store.GetSubjectAsync(token, cancellationToken), subject, StringComparison.Ordinal))
{
yield return token;
}
}
}
} }
/// <summary> /// <summary>
@ -1168,7 +1241,7 @@ namespace OpenIddict.Core
throw new ArgumentNullException(nameof(token)); throw new ArgumentNullException(nameof(token));
} }
var results = await ValidateAsync(token, cancellationToken).ToListAsync(cancellationToken); var results = await GetValidationResultsAsync(token, cancellationToken);
if (results.Any(result => result != ValidationResult.Success)) if (results.Any(result => result != ValidationResult.Success))
{ {
var builder = new StringBuilder(); var builder = new StringBuilder();
@ -1180,7 +1253,7 @@ namespace OpenIddict.Core
builder.AppendLine(result.ErrorMessage); builder.AppendLine(result.ErrorMessage);
} }
throw new OpenIddictExceptions.ValidationException(builder.ToString(), results.ToImmutableArray()); throw new OpenIddictExceptions.ValidationException(builder.ToString(), results);
} }
await Store.UpdateAsync(token, cancellationToken); await Store.UpdateAsync(token, cancellationToken);
@ -1190,6 +1263,19 @@ namespace OpenIddict.Core
await Cache.RemoveAsync(token, cancellationToken); await Cache.RemoveAsync(token, cancellationToken);
await Cache.AddAsync(token, cancellationToken); await Cache.AddAsync(token, cancellationToken);
} }
async Task<ImmutableArray<ValidationResult>> GetValidationResultsAsync(
TToken token, CancellationToken cancellationToken)
{
var builder = ImmutableArray.CreateBuilder<ValidationResult>();
await foreach (var result in ValidateAsync(token, cancellationToken))
{
builder.Add(result);
}
return builder.ToImmutable();
}
} }
/// <summary> /// <summary>
@ -1318,19 +1404,19 @@ namespace OpenIddict.Core
=> DeleteAsync((TToken) token, cancellationToken); => DeleteAsync((TToken) token, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.FindAsync(string subject, string client, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.FindAsync(string subject, string client, CancellationToken cancellationToken)
=> FindAsync(subject, client, cancellationToken).OfType<object>(); => FindAsync(subject, client, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.FindAsync(string subject, string client, string status, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.FindAsync(string subject, string client, string status, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, cancellationToken).OfType<object>(); => FindAsync(subject, client, status, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.FindAsync(string subject, string client, string status, string type, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.FindAsync(string subject, string client, string status, string type, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, type, cancellationToken).OfType<object>(); => FindAsync(subject, client, status, type, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.FindByApplicationIdAsync(string identifier, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.FindByApplicationIdAsync(string identifier, CancellationToken cancellationToken)
=> FindByApplicationIdAsync(identifier, cancellationToken).OfType<object>(); => FindByApplicationIdAsync(identifier, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.FindByAuthorizationIdAsync(string identifier, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.FindByAuthorizationIdAsync(string identifier, CancellationToken cancellationToken)
=> FindByAuthorizationIdAsync(identifier, cancellationToken).OfType<object>(); => FindByAuthorizationIdAsync(identifier, cancellationToken);
async ValueTask<object> IOpenIddictTokenManager.FindByIdAsync(string identifier, CancellationToken cancellationToken) async ValueTask<object> IOpenIddictTokenManager.FindByIdAsync(string identifier, CancellationToken cancellationToken)
=> await FindByIdAsync(identifier, cancellationToken); => await FindByIdAsync(identifier, cancellationToken);
@ -1339,7 +1425,7 @@ namespace OpenIddict.Core
=> await FindByReferenceIdAsync(identifier, cancellationToken); => await FindByReferenceIdAsync(identifier, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.FindBySubjectAsync(string subject, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.FindBySubjectAsync(string subject, CancellationToken cancellationToken)
=> FindBySubjectAsync(subject, cancellationToken).OfType<object>(); => FindBySubjectAsync(subject, cancellationToken);
ValueTask<string> IOpenIddictTokenManager.GetApplicationIdAsync(object token, CancellationToken cancellationToken) ValueTask<string> IOpenIddictTokenManager.GetApplicationIdAsync(object token, CancellationToken cancellationToken)
=> GetApplicationIdAsync((TToken) token, cancellationToken); => GetApplicationIdAsync((TToken) token, cancellationToken);
@ -1384,7 +1470,7 @@ namespace OpenIddict.Core
=> HasTypeAsync((TToken) token, type, cancellationToken); => HasTypeAsync((TToken) token, type, cancellationToken);
IAsyncEnumerable<object> IOpenIddictTokenManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken) IAsyncEnumerable<object> IOpenIddictTokenManager.ListAsync(int? count, int? offset, CancellationToken cancellationToken)
=> ListAsync(count, offset, cancellationToken).OfType<object>(); => ListAsync(count, offset, cancellationToken);
IAsyncEnumerable<TResult> IOpenIddictTokenManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken) IAsyncEnumerable<TResult> IOpenIddictTokenManager.ListAsync<TResult>(Func<IQueryable<object>, IQueryable<TResult>> query, CancellationToken cancellationToken)
=> ListAsync(query, cancellationToken); => ListAsync(query, cancellationToken);

1
src/OpenIddict.Core/OpenIddict.Core.csproj

@ -18,7 +18,6 @@
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(ExtensionsVersion)" /> <PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(ExtensionsVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="$(ExtensionsVersion)" /> <PackageReference Include="Microsoft.Extensions.Logging" Version="$(ExtensionsVersion)" />
<PackageReference Include="Microsoft.Extensions.Options" Version="$(ExtensionsVersion)" /> <PackageReference Include="Microsoft.Extensions.Options" Version="$(ExtensionsVersion)" />
<PackageReference Include="System.Linq.Async" Version="$(LinqAsyncVersion)" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'net472' Or '$(TargetFramework)' == 'netstandard2.0' "> <ItemGroup Condition=" '$(TargetFramework)' == 'net472' Or '$(TargetFramework)' == 'netstandard2.0' ">

45
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs

@ -12,6 +12,7 @@ using System.Data;
using System.Data.Entity; using System.Data.Entity;
using System.Data.Entity.Infrastructure; using System.Data.Entity.Infrastructure;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
using System.Text.Json; using System.Text.Json;
@ -303,10 +304,24 @@ namespace OpenIddict.EntityFramework
// are retrieved, a second pass is made to ensure only valid elements are returned. // are retrieved, a second pass is made to ensure only valid elements are returned.
// Implementers that use this method in a hot path may want to override this method // Implementers that use this method in a hot path may want to override this method
// to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient.
return Applications.Where(application => application.PostLogoutRedirectUris.Contains(address))
.AsAsyncEnumerable(cancellationToken) return ExecuteAsync(cancellationToken);
.WhereAwait(async application => (await GetPostLogoutRedirectUrisAsync(application, cancellationToken))
.Contains(address, StringComparer.Ordinal)); async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var applications = (from application in Applications
where application.PostLogoutRedirectUris.Contains(address)
select application).AsAsyncEnumerable(cancellationToken);
await foreach (var application in applications)
{
var addresses = await GetPostLogoutRedirectUrisAsync(application, cancellationToken);
if (addresses.Contains(address, StringComparer.Ordinal))
{
yield return application;
}
}
}
} }
/// <summary> /// <summary>
@ -328,10 +343,24 @@ namespace OpenIddict.EntityFramework
// are retrieved, a second pass is made to ensure only valid elements are returned. // are retrieved, a second pass is made to ensure only valid elements are returned.
// Implementers that use this method in a hot path may want to override this method // Implementers that use this method in a hot path may want to override this method
// to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient.
return Applications.Where(application => application.RedirectUris.Contains(address))
.AsAsyncEnumerable(cancellationToken) return ExecuteAsync(cancellationToken);
.WhereAwait(async application => (await GetRedirectUrisAsync(application, cancellationToken))
.Contains(address, StringComparer.Ordinal)); async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var applications = (from application in Applications
where application.RedirectUris.Contains(address)
select application).AsAsyncEnumerable(cancellationToken);
await foreach (var application in applications)
{
var addresses = await GetRedirectUrisAsync(application, cancellationToken);
if (addresses.Contains(address, StringComparer.Ordinal))
{
yield return application;
}
}
}
} }
/// <summary> /// <summary>

48
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs

@ -12,6 +12,7 @@ using System.Data;
using System.Data.Entity; using System.Data.Entity;
using System.Data.Entity.Infrastructure; using System.Data.Entity.Infrastructure;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
using System.Text.Json; using System.Text.Json;
@ -342,9 +343,50 @@ namespace OpenIddict.EntityFramework
[NotNull] string subject, [NotNull] string client, [NotNull] string subject, [NotNull] string client,
[NotNull] string status, [NotNull] string type, [NotNull] string status, [NotNull] string type,
ImmutableArray<string> scopes, CancellationToken cancellationToken) ImmutableArray<string> scopes, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, type, cancellationToken) {
.WhereAwait(async authorization => new HashSet<string>( if (string.IsNullOrEmpty(subject))
await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes)); {
throw new ArgumentException("The subject cannot be null or empty.", nameof(subject));
}
if (string.IsNullOrEmpty(client))
{
throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client));
}
if (string.IsNullOrEmpty(status))
{
throw new ArgumentException("The status cannot be null or empty.", nameof(status));
}
if (string.IsNullOrEmpty(type))
{
throw new ArgumentException("The type cannot be null or empty.", nameof(type));
}
return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var key = ConvertIdentifierFromString(client);
var authorizations = (from authorization in Authorizations.Include(authorization => authorization.Application)
where authorization.Application != null &&
authorization.Application.Id.Equals(key) &&
authorization.Subject == subject &&
authorization.Status == status &&
authorization.Type == type
select authorization).AsAsyncEnumerable(cancellationToken);
await foreach (var authorization in authorizations)
{
if (new HashSet<string>(await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes))
{
yield return authorization;
}
}
}
}
/// <summary> /// <summary>
/// Retrieves the list of authorizations corresponding to the specified application identifier. /// Retrieves the list of authorizations corresponding to the specified application identifier.

22
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs

@ -11,6 +11,7 @@ using System.ComponentModel;
using System.Data.Entity; using System.Data.Entity;
using System.Data.Entity.Infrastructure; using System.Data.Entity.Infrastructure;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
using System.Text.Json; using System.Text.Json;
@ -247,9 +248,24 @@ namespace OpenIddict.EntityFramework
// are retrieved, a second pass is made to ensure only valid elements are returned. // are retrieved, a second pass is made to ensure only valid elements are returned.
// Implementers that use this method in a hot path may want to override this method // Implementers that use this method in a hot path may want to override this method
// to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient.
return Scopes.Where(scope => scope.Resources.Contains(resource))
.AsAsyncEnumerable(cancellationToken) return ExecuteAsync(cancellationToken);
.WhereAwait(async scope => (await GetResourcesAsync(scope, cancellationToken)).Contains(resource, StringComparer.Ordinal));
async IAsyncEnumerable<TScope> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var scopes = (from scope in Scopes
where scope.Resources.Contains(resource)
select scope).AsAsyncEnumerable(cancellationToken);
await foreach (var scope in scopes)
{
var resources = await GetResourcesAsync(scope, cancellationToken);
if (resources.Contains(resource, StringComparer.Ordinal))
{
yield return scope;
}
}
}
} }
/// <summary> /// <summary>

45
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs

@ -10,6 +10,7 @@ using System.Collections.Immutable;
using System.ComponentModel; using System.ComponentModel;
using System.Data; using System.Data;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
using System.Text.Json; using System.Text.Json;
@ -347,12 +348,24 @@ namespace OpenIddict.EntityFrameworkCore
// are retrieved, a second pass is made to ensure only valid elements are returned. // are retrieved, a second pass is made to ensure only valid elements are returned.
// Implementers that use this method in a hot path may want to override this method // Implementers that use this method in a hot path may want to override this method
// to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient.
var applications = (from application in Applications.AsTracking()
where application.PostLogoutRedirectUris.Contains(address)
select application).AsAsyncEnumerable();
return applications.WhereAwait(async application => return ExecuteAsync(cancellationToken);
(await GetPostLogoutRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal));
async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var applications = (from application in Applications.AsTracking()
where application.PostLogoutRedirectUris.Contains(address)
select application).AsAsyncEnumerable();
await foreach (var application in applications)
{
var addresses = await GetPostLogoutRedirectUrisAsync(application, cancellationToken);
if (addresses.Contains(address, StringComparer.Ordinal))
{
yield return application;
}
}
}
} }
/// <summary> /// <summary>
@ -374,12 +387,24 @@ namespace OpenIddict.EntityFrameworkCore
// are retrieved, a second pass is made to ensure only valid elements are returned. // are retrieved, a second pass is made to ensure only valid elements are returned.
// Implementers that use this method in a hot path may want to override this method // Implementers that use this method in a hot path may want to override this method
// to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient.
var applications = (from application in Applications.AsTracking()
where application.RedirectUris.Contains(address)
select application).AsAsyncEnumerable();
return applications.WhereAwait(async application => return ExecuteAsync(cancellationToken);
(await GetRedirectUrisAsync(application, cancellationToken)).Contains(address, StringComparer.Ordinal));
async IAsyncEnumerable<TApplication> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var applications = (from application in Applications.AsTracking()
where application.RedirectUris.Contains(address)
select application).AsAsyncEnumerable();
await foreach (var application in applications)
{
var addresses = await GetRedirectUrisAsync(application, cancellationToken);
if (addresses.Contains(address, StringComparer.Ordinal))
{
yield return application;
}
}
}
} }
/// <summary> /// <summary>

53
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs

@ -10,6 +10,7 @@ using System.Collections.Immutable;
using System.ComponentModel; using System.ComponentModel;
using System.Data; using System.Data;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
using System.Text.Json; using System.Text.Json;
@ -393,9 +394,55 @@ namespace OpenIddict.EntityFrameworkCore
[NotNull] string subject, [NotNull] string client, [NotNull] string subject, [NotNull] string client,
[NotNull] string status, [NotNull] string type, [NotNull] string status, [NotNull] string type,
ImmutableArray<string> scopes, CancellationToken cancellationToken) ImmutableArray<string> scopes, CancellationToken cancellationToken)
=> FindAsync(subject, client, status, type, cancellationToken) {
.WhereAwait(async authorization => new HashSet<string>( if (string.IsNullOrEmpty(subject))
await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes)); {
throw new ArgumentException("The subject cannot be null or empty.", nameof(subject));
}
if (string.IsNullOrEmpty(client))
{
throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client));
}
if (string.IsNullOrEmpty(status))
{
throw new ArgumentException("The status cannot be null or empty.", nameof(status));
}
if (string.IsNullOrEmpty(type))
{
throw new ArgumentException("The type cannot be null or empty.", nameof(type));
}
return ExecuteAsync(cancellationToken);
async IAsyncEnumerable<TAuthorization> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
// Note: due to a bug in Entity Framework Core's query visitor, the authorizations can't be
// filtered using authorization.Application.Id.Equals(key). To work around this issue,
// this method is overriden to use an explicit join before applying the equality check.
// See https://github.com/openiddict/openiddict-core/issues/499 for more information.
var key = ConvertIdentifierFromString(client);
var authorizations = (from authorization in Authorizations.Include(authorization => authorization.Application).AsTracking()
where authorization.Subject == subject &&
authorization.Status == status &&
authorization.Type == type
join application in Applications.AsTracking() on authorization.Application.Id equals application.Id
where application.Id.Equals(key)
select authorization).AsAsyncEnumerable();
await foreach (var authorization in authorizations)
{
if (new HashSet<string>(await GetScopesAsync(authorization, cancellationToken), StringComparer.Ordinal).IsSupersetOf(scopes))
{
yield return authorization;
}
}
}
}
/// <summary> /// <summary>
/// Retrieves the list of authorizations corresponding to the specified application identifier. /// Retrieves the list of authorizations corresponding to the specified application identifier.

28
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs

@ -9,6 +9,7 @@ using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.ComponentModel; using System.ComponentModel;
using System.Linq; using System.Linq;
using System.Runtime.CompilerServices;
using System.Text; using System.Text;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
using System.Text.Json; using System.Text.Json;
@ -258,12 +259,29 @@ namespace OpenIddict.EntityFrameworkCore
throw new ArgumentException("The resource cannot be null or empty.", nameof(resource)); throw new ArgumentException("The resource cannot be null or empty.", nameof(resource));
} }
var scopes = (from scope in Scopes.AsTracking() // To optimize the efficiency of the query a bit, only scopes whose stringified
where scope.Resources.Contains(resource) // Resources column contains the specified resource are returned. Once the scopes
select scope).AsAsyncEnumerable(); // are retrieved, a second pass is made to ensure only valid elements are returned.
// Implementers that use this method in a hot path may want to override this method
// to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient.
return scopes.WhereAwait(async scope => return ExecuteAsync(cancellationToken);
(await GetResourcesAsync(scope, cancellationToken)).Contains(resource, StringComparer.Ordinal));
async IAsyncEnumerable<TScope> ExecuteAsync([EnumeratorCancellation] CancellationToken cancellationToken)
{
var scopes = (from scope in Scopes.AsTracking()
where scope.Resources.Contains(resource)
select scope).AsAsyncEnumerable();
await foreach (var scope in scopes)
{
var resources = await GetResourcesAsync(scope, cancellationToken);
if (resources.Contains(resource, StringComparer.Ordinal))
{
yield return scope;
}
}
}
} }
/// <summary> /// <summary>

1
test/OpenIddict.Server.IntegrationTests/OpenIddict.Server.IntegrationTests.csproj

@ -18,6 +18,7 @@
<PackageReference Include="AngleSharp" Version="$(AngleSharpVersion)" /> <PackageReference Include="AngleSharp" Version="$(AngleSharpVersion)" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="$(ExtensionsVersion)" /> <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="$(ExtensionsVersion)" />
<PackageReference Include="Moq" Version="$(MoqVersion)" /> <PackageReference Include="Moq" Version="$(MoqVersion)" />
<PackageReference Include="System.Linq.Async" Version="$(LinqAsyncVersion)" />
</ItemGroup> </ItemGroup>
<ItemGroup Condition=" '$(TargetFramework)' == 'net472' "> <ItemGroup Condition=" '$(TargetFramework)' == 'net472' ">

Loading…
Cancel
Save