diff --git a/src/OpenIddict.Core/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.Core/Stores/OpenIddictApplicationStore.cs index 9b6a88a1..dca463fc 100644 --- a/src/OpenIddict.Core/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.Core/Stores/OpenIddictApplicationStore.cs @@ -389,15 +389,17 @@ namespace OpenIddict.Core // To mitigate that, the resulting array is stored in the memory cache. var key = string.Concat(nameof(GetPermissionsAsync), "\x1e", application.Permissions); - var permissions = Cache.Get(key) as ImmutableArray?; - if (permissions == null) + var permissions = Cache.GetOrCreate(key, entry => { - permissions = Cache.Set(key, JArray.Parse(application.Permissions) + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(application.Permissions) .Select(element => (string) element) - .ToImmutableArray()); - } + .ToImmutableArray(); + }); - return new ValueTask>(permissions.GetValueOrDefault()); + return new ValueTask>(permissions); } /// @@ -425,15 +427,17 @@ namespace OpenIddict.Core // To mitigate that, the resulting array is stored in the memory cache. var key = string.Concat(nameof(GetPostLogoutRedirectUrisAsync), "\x1e", application.PostLogoutRedirectUris); - var addresses = Cache.Get(key) as ImmutableArray?; - if (addresses == null) + var addresses = Cache.GetOrCreate(key, entry => { - addresses = Cache.Set(key, JArray.Parse(application.PostLogoutRedirectUris) + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(application.PostLogoutRedirectUris) .Select(element => (string) element) - .ToImmutableArray()); - } + .ToImmutableArray(); + }); - return new ValueTask>(addresses.GetValueOrDefault()); + return new ValueTask>(addresses); } /// @@ -485,15 +489,17 @@ namespace OpenIddict.Core // To mitigate that, the resulting array is stored in the memory cache. var key = string.Concat(nameof(GetRedirectUrisAsync), "\x1e", application.RedirectUris); - var addresses = Cache.Get(key) as ImmutableArray?; - if (addresses == null) + var addresses = Cache.GetOrCreate(key, entry => { - addresses = Cache.Set(key, JArray.Parse(application.RedirectUris) + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(application.RedirectUris) .Select(element => (string) element) - .ToImmutableArray()); - } + .ToImmutableArray(); + }); - return new ValueTask>(addresses.GetValueOrDefault()); + return new ValueTask>(addresses); } /// diff --git a/src/OpenIddict.Core/Stores/OpenIddictScopeStore.cs b/src/OpenIddict.Core/Stores/OpenIddictScopeStore.cs index b7686484..728b698d 100644 --- a/src/OpenIddict.Core/Stores/OpenIddictScopeStore.cs +++ b/src/OpenIddict.Core/Stores/OpenIddictScopeStore.cs @@ -302,15 +302,17 @@ namespace OpenIddict.Core // To mitigate that, the resulting array is stored in the memory cache. var key = string.Concat(nameof(GetResourcesAsync), "\x1e", scope.Resources); - var resources = Cache.Get(key) as ImmutableArray?; - if (resources == null) + var resources = Cache.GetOrCreate(key, entry => { - resources = Cache.Set(key, JArray.Parse(scope.Resources) + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(scope.Resources) .Select(element => (string) element) - .ToImmutableArray()); - } + .ToImmutableArray(); + }); - return new ValueTask>(resources.GetValueOrDefault()); + return new ValueTask>(resources); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs index 46ef278d..e5d1717b 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs @@ -29,9 +29,7 @@ namespace OpenIddict.EntityFramework OpenIddictToken, TContext, string> where TContext : DbContext { - public OpenIddictApplicationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictApplicationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -49,9 +47,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictApplicationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictApplicationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -74,9 +70,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictApplicationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictApplicationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -216,25 +210,6 @@ namespace OpenIddict.EntityFramework } } - /// - /// Retrieves an application using its unique identifier. - /// - /// The unique identifier associated with the application. - /// The that can be used to abort the operation. - /// - /// A that can be used to monitor the asynchronous operation, - /// whose result returns the client application corresponding to the identifier. - /// - public override Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) - { - if (string.IsNullOrEmpty(identifier)) - { - throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); - } - - return Applications.FindAsync(cancellationToken, ConvertIdentifierFromString(identifier)); - } - /// /// Executes the specified query and returns the first element. /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs index 0fca5218..e646c9ea 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs @@ -29,9 +29,7 @@ namespace OpenIddict.EntityFramework OpenIddictToken, TContext, string> where TContext : DbContext { - public OpenIddictAuthorizationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictAuthorizationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -49,9 +47,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictAuthorizationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictAuthorizationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -74,9 +70,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictAuthorizationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictAuthorizationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -198,35 +192,6 @@ namespace OpenIddict.EntityFramework } } - /// - /// Retrieves an authorization using its unique identifier. - /// - /// The unique identifier associated with the authorization. - /// The that can be used to abort the operation. - /// - /// A that can be used to monitor the asynchronous operation, - /// whose result returns the authorization corresponding to the identifier. - /// - public override Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) - { - if (string.IsNullOrEmpty(identifier)) - { - throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); - } - - var authorization = (from entry in Context.ChangeTracker.Entries() - where entry.Entity != null - where entry.Entity.Id.Equals(ConvertIdentifierFromString(identifier)) - select entry.Entity).FirstOrDefault(); - - if (authorization != null) - { - return Task.FromResult(authorization); - } - - return base.FindByIdAsync(identifier, cancellationToken); - } - /// /// Retrieves the optional application identifier associated with an authorization. /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictScopeStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictScopeStore.cs index f7459fbd..20aaab33 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictScopeStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictScopeStore.cs @@ -43,9 +43,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictScopeStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictScopeStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -63,9 +61,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictScopeStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictScopeStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -146,25 +142,6 @@ namespace OpenIddict.EntityFramework return Context.SaveChangesAsync(cancellationToken); } - /// - /// Retrieves a scope using its unique identifier. - /// - /// The unique identifier associated with the scope. - /// The that can be used to abort the operation. - /// - /// A that can be used to monitor the asynchronous operation, - /// whose result returns the scope corresponding to the identifier. - /// - public override Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) - { - if (string.IsNullOrEmpty(identifier)) - { - throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); - } - - return Scopes.FindAsync(cancellationToken, ConvertIdentifierFromString(identifier)); - } - /// /// Executes the specified query and returns the first element. /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs index 31fcb1c9..eb457884 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs @@ -29,9 +29,7 @@ namespace OpenIddict.EntityFramework OpenIddictAuthorization, TContext, string> where TContext : DbContext { - public OpenIddictTokenStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictTokenStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -49,9 +47,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictTokenStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictTokenStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -74,9 +70,7 @@ namespace OpenIddict.EntityFramework where TContext : DbContext where TKey : IEquatable { - public OpenIddictTokenStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictTokenStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -165,35 +159,6 @@ namespace OpenIddict.EntityFramework return Context.SaveChangesAsync(cancellationToken); } - /// - /// Retrieves a token using its unique identifier. - /// - /// The unique identifier associated with the token. - /// The that can be used to abort the operation. - /// - /// A that can be used to monitor the asynchronous operation, - /// whose result returns the token corresponding to the unique identifier. - /// - public override Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) - { - if (string.IsNullOrEmpty(identifier)) - { - throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); - } - - var token = (from entry in Context.ChangeTracker.Entries() - where entry.Entity != null - where entry.Entity.Id.Equals(ConvertIdentifierFromString(identifier)) - select entry.Entity).FirstOrDefault(); - - if (token != null) - { - return Task.FromResult(token); - } - - return base.FindByIdAsync(identifier, cancellationToken); - } - /// /// Retrieves the optional application identifier associated with a token. /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs index 8cc655e4..ee794471 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs @@ -9,11 +9,13 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Data; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Caching.Memory; using OpenIddict.Core; @@ -31,9 +33,7 @@ namespace OpenIddict.EntityFrameworkCore OpenIddictToken, TContext, string> where TContext : DbContext { - public OpenIddictApplicationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictApplicationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -51,9 +51,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictApplicationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictApplicationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -76,9 +74,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictApplicationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictApplicationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -240,6 +236,37 @@ namespace OpenIddict.EntityFrameworkCore } } + /// + /// Retrieves an application using its client identifier. + /// + /// The client identifier associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the client application corresponding to the identifier. + /// + public override Task FindByClientIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + const string key = nameof(FindByClientIdAsync) + "\x1e" + nameof(identifier); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string id) => + (from application in context.Set().AsTracking() + where application.ClientId == id + select application).FirstOrDefault()); + }); + + return query(Context, identifier); + } + /// /// Retrieves an application using its unique identifier. /// @@ -256,7 +283,125 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - return Applications.FindAsync(new object[] { ConvertIdentifierFromString(identifier) }, cancellationToken); + const string key = nameof(FindByIdAsync) + "\x1e" + nameof(identifier); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id) => + (from application in context.Set().AsTracking() + where application.Id.Equals(id) + select application).FirstOrDefault()); + }); + + return query(Context, ConvertIdentifierFromString(identifier)); + } + + /// + /// Retrieves all the applications associated with the specified post_logout_redirect_uri. + /// + /// The post_logout_redirect_uri associated with the applications. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, whose result + /// returns the client applications corresponding to the specified post_logout_redirect_uri. + /// + public override async Task> FindByPostLogoutRedirectUriAsync([NotNull] string address, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(address)) + { + throw new ArgumentException("The address cannot be null or empty.", nameof(address)); + } + + const string key = nameof(FindByPostLogoutRedirectUriAsync) + "\x1e" + nameof(address); + + // To optimize the efficiency of the query a bit, only applications whose stringified + // PostLogoutRedirectUris contains the specified URL are returned. Once the applications + // are retrieved, a second pass is made to ensure only valid elements are returned. + // Implementers that use this method in a hot path may want to override this method + // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string uri) => + from application in context.Set().AsTracking() + where application.PostLogoutRedirectUris.Contains(uri) + select application); + }); + + var builder = ImmutableArray.CreateBuilder(); + + foreach (var application in await query(Context, address).ToListAsync(cancellationToken)) + { + foreach (var uri in await GetPostLogoutRedirectUrisAsync(application, cancellationToken)) + { + // Note: the post_logout_redirect_uri must be compared + // using case-sensitive "Simple String Comparison". + if (string.Equals(uri, address, StringComparison.Ordinal)) + { + builder.Add(application); + + break; + } + } + } + + return builder.ToImmutable(); + } + + /// + /// Retrieves all the applications associated with the specified redirect_uri. + /// + /// The redirect_uri associated with the applications. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, whose result + /// returns the client applications corresponding to the specified redirect_uri. + /// + public override async Task> FindByRedirectUriAsync([NotNull] string address, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(address)) + { + throw new ArgumentException("The address cannot be null or empty.", nameof(address)); + } + + const string key = nameof(FindByRedirectUriAsync) + "\x1e" + nameof(address); + + // To optimize the efficiency of the query a bit, only applications whose stringified + // RedirectUris property contains the specified URL are returned. Once the applications + // are retrieved, a second pass is made to ensure only valid elements are returned. + // Implementers that use this method in a hot path may want to override this method + // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string uri) => + from application in context.Set().AsTracking() + where application.RedirectUris.Contains(uri) + select application); + }); + + var builder = ImmutableArray.CreateBuilder(); + + foreach (var application in await query(Context, address).ToListAsync(cancellationToken)) + { + foreach (var uri in await GetRedirectUrisAsync(application, cancellationToken)) + { + // Note: the redirect_uri must be compared using case-sensitive "Simple String Comparison". + // See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest for more information. + if (string.Equals(uri, address, StringComparison.Ordinal)) + { + builder.Add(application); + + break; + } + } + } + + return builder.ToImmutable(); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs index 5b0ed0ed..7bb3e628 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs @@ -9,11 +9,13 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Data; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Caching.Memory; using OpenIddict.Core; @@ -31,9 +33,7 @@ namespace OpenIddict.EntityFrameworkCore OpenIddictToken, TContext, string> where TContext : DbContext { - public OpenIddictAuthorizationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictAuthorizationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -51,9 +51,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictAuthorizationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictAuthorizationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -76,9 +74,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictAuthorizationStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictAuthorizationStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -240,21 +236,28 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The client cannot be null or empty.", nameof(client)); } + const string key = nameof(FindAsync) + "\x1e" + nameof(subject) + "\x1e" + nameof(client); + // Note: due to a bug in Entity Framework Core's query visitor, the authorizations can't be // filtered using authorization.Application.Id.Equals(key). To work around this issue, // this method is overriden to use an explicit join before applying the equality check. // See https://github.com/openiddict/openiddict-core/issues/499 for more information. - - IQueryable Query(IQueryable authorizations, - IQueryable applications, TKey key, string principal) - => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() - where authorization.Subject == principal - join application in applications.AsTracking() on authorization.Application.Id equals application.Id - where application.Id.Equals(key) - select authorization; - - return ImmutableArray.CreateRange(await Query( - Authorizations, Applications, ConvertIdentifierFromString(client), subject).ToListAsync(cancellationToken)); + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id, string principal) => + from authorization in context.Set() + .Include(authorization => authorization.Application) + .AsTracking() + where authorization.Subject == principal + join application in context.Set().AsTracking() on authorization.Application.Id equals application.Id + where application.Id.Equals(id) + select authorization); + }); + + return ImmutableArray.CreateRange(await query(Context, + ConvertIdentifierFromString(client), subject).ToListAsync(cancellationToken)); } /// @@ -287,21 +290,28 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The status cannot be null or empty.", nameof(status)); } + const string key = nameof(FindAsync) + "\x1e" + nameof(subject) + "\x1e" + nameof(client) + "\x1e" + nameof(status); + // Note: due to a bug in Entity Framework Core's query visitor, the authorizations can't be // filtered using authorization.Application.Id.Equals(key). To work around this issue, // this method is overriden to use an explicit join before applying the equality check. // See https://github.com/openiddict/openiddict-core/issues/499 for more information. - - IQueryable Query(IQueryable authorizations, - IQueryable applications, TKey key, string principal, string state) - => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() - where authorization.Subject == principal && authorization.Status == state - join application in applications.AsTracking() on authorization.Application.Id equals application.Id - where application.Id.Equals(key) - select authorization; - - return ImmutableArray.CreateRange(await Query( - Authorizations, Applications, ConvertIdentifierFromString(client), subject, status).ToListAsync(cancellationToken)); + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id, string principal, string state) => + from authorization in context.Set() + .Include(authorization => authorization.Application) + .AsTracking() + where authorization.Subject == principal && authorization.Status == state + join application in context.Set().AsTracking() on authorization.Application.Id equals application.Id + where application.Id.Equals(id) + select authorization); + }); + + return ImmutableArray.CreateRange(await query(Context, + ConvertIdentifierFromString(client), subject, status).ToListAsync(cancellationToken)); } /// @@ -340,23 +350,31 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The type cannot be null or empty.", nameof(type)); } + const string key = nameof(FindAsync) + "\x1e" + nameof(subject) + "\x1e" + + nameof(client) + "\x1e" + nameof(status) + "\x1e" + nameof(type); + // Note: due to a bug in Entity Framework Core's query visitor, the authorizations can't be // filtered using authorization.Application.Id.Equals(key). To work around this issue, // this method is overriden to use an explicit join before applying the equality check. // See https://github.com/openiddict/openiddict-core/issues/499 for more information. - - IQueryable Query(IQueryable authorizations, - IQueryable applications, TKey key, string principal, string state, string kind) - => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() - where authorization.Subject == principal && - authorization.Status == state && - authorization.Type == kind - join application in applications.AsTracking() on authorization.Application.Id equals application.Id - where application.Id.Equals(key) - select authorization; - - return ImmutableArray.CreateRange(await Query( - Authorizations, Applications, ConvertIdentifierFromString(client), subject, status, type).ToListAsync(cancellationToken)); + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id, string principal, string state, string kind) => + from authorization in context.Set() + .Include(authorization => authorization.Application) + .AsTracking() + where authorization.Subject == principal && + authorization.Status == state && + authorization.Type == kind + join application in context.Set().AsTracking() on authorization.Application.Id equals application.Id + where application.Id.Equals(id) + select authorization); + }); + + return ImmutableArray.CreateRange(await query(Context, + ConvertIdentifierFromString(client), subject, status, type).ToListAsync(cancellationToken)); } /// @@ -375,17 +393,55 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - var authorization = (from entry in Context.ChangeTracker.Entries() - where entry.Entity != null && - entry.Entity.Id.Equals(ConvertIdentifierFromString(identifier)) - select entry.Entity).FirstOrDefault(); + const string key = nameof(FindByIdAsync) + "\x1e" + nameof(identifier); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id) => + (from authorization in context.Set() + .Include(authorization => authorization.Application) + .AsTracking() + where authorization.Id.Equals(id) + select authorization).FirstOrDefault()); + }); + + return query(Context, ConvertIdentifierFromString(identifier)); + } - if (authorization != null) + /// + /// Retrieves all the authorizations corresponding to the specified subject. + /// + /// The subject associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the specified subject. + /// + public override async Task> FindBySubjectAsync( + [NotNull] string subject, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) { - return Task.FromResult(authorization); + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); } - return base.FindByIdAsync(identifier, cancellationToken); + const string key = nameof(FindBySubjectAsync) + "\x1e" + nameof(subject); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string principal) => + from authorization in context.Set() + .Include(authorization => authorization.Application) + .AsTracking() + where authorization.Subject == principal + select authorization); + }); + + return ImmutableArray.CreateRange(await query(Context, subject).ToListAsync(cancellationToken)); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs index 26750a04..c9155347 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs @@ -11,8 +11,8 @@ using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Query; using Microsoft.Extensions.Caching.Memory; -using OpenIddict.Core; using OpenIddict.Models; namespace OpenIddict.EntityFrameworkCore @@ -25,9 +25,7 @@ namespace OpenIddict.EntityFrameworkCore public class OpenIddictScopeStore : OpenIddictScopeStore where TContext : DbContext { - public OpenIddictScopeStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictScopeStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -43,9 +41,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictScopeStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictScopeStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -63,9 +59,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictScopeStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictScopeStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -162,7 +156,82 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - return Scopes.FindAsync(new object[] { ConvertIdentifierFromString(identifier) }, cancellationToken); + const string key = nameof(FindByIdAsync) + "\x1e" + nameof(identifier); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id) => + (from scope in context.Set().AsTracking() + where scope.Id.Equals(id) + select scope).FirstOrDefault()); + }); + + return query(Context, ConvertIdentifierFromString(identifier)); + } + + /// + /// Retrieves a scope using its name. + /// + /// The name associated with the scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scope corresponding to the specified name. + /// + public override Task FindByNameAsync([NotNull] string name, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentException("The scope name cannot be null or empty.", nameof(name)); + } + + const string key = nameof(FindByNameAsync) + "\x1e" + nameof(name); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string id) => + (from scope in context.Set().AsTracking() + where scope.Name == id + select scope).FirstOrDefault()); + }); + + return query(Context, name); + } + + /// + /// Retrieves a list of scopes using their name. + /// + /// The names associated with the scopes. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scopes corresponding to the specified names. + /// + public override async Task> FindByNamesAsync( + ImmutableArray names, CancellationToken cancellationToken) + { + if (names.Any(name => string.IsNullOrEmpty(name))) + { + throw new ArgumentException("Scope names cannot be null or empty.", nameof(names)); + } + + const string key = nameof(FindByNamesAsync) + "\x1e" + nameof(names); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, ImmutableArray ids) => + from scope in context.Set().AsTracking() + where ids.Contains(scope.Name) + select scope); + }); + + return ImmutableArray.CreateRange(await query(Context, names).ToListAsync(cancellationToken)); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs index 46f106e9..34bcc3e9 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs @@ -9,11 +9,13 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Data; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Caching.Memory; using OpenIddict.Core; @@ -31,9 +33,7 @@ namespace OpenIddict.EntityFrameworkCore OpenIddictAuthorization, TContext, string> where TContext : DbContext { - public OpenIddictTokenStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictTokenStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -51,9 +51,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictTokenStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictTokenStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(context, cache) { } @@ -76,9 +74,7 @@ namespace OpenIddict.EntityFrameworkCore where TContext : DbContext where TKey : IEquatable { - public OpenIddictTokenStore( - [NotNull] TContext context, - [NotNull] IMemoryCache cache) + public OpenIddictTokenStore([NotNull] TContext context, [NotNull] IMemoryCache cache) : base(cache) { if (context == null) @@ -183,19 +179,28 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } + const string key = nameof(FindByApplicationIdAsync) + "\x1e" + nameof(identifier); + // Note: due to a bug in Entity Framework Core's query visitor, the tokens can't be // filtered using token.Application.Id.Equals(key). To work around this issue, // this method is overriden to use an explicit join before applying the equality check. // See https://github.com/openiddict/openiddict-core/issues/499 for more information. - - IQueryable Query(IQueryable applications, IQueryable tokens, TKey key) - => from token in tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() - join application in applications.AsTracking() on token.Application.Id equals application.Id - where application.Id.Equals(key) - select token; - - return ImmutableArray.CreateRange(await Query( - Applications, Tokens, ConvertIdentifierFromString(identifier)).ToListAsync(cancellationToken)); + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id) => + from token in context.Set() + .Include(token => token.Application) + .Include(token => token.Authorization) + .AsTracking() + join application in context.Set().AsTracking() on token.Application.Id equals application.Id + where application.Id.Equals(id) + select token); + }); + + return ImmutableArray.CreateRange(await query(Context, + ConvertIdentifierFromString(identifier)).ToListAsync(cancellationToken)); } /// @@ -214,19 +219,28 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } + const string key = nameof(FindByAuthorizationIdAsync) + "\x1e" + nameof(identifier); + // Note: due to a bug in Entity Framework Core's query visitor, the tokens can't be // filtered using token.Authorization.Id.Equals(key). To work around this issue, // this method is overriden to use an explicit join before applying the equality check. // See https://github.com/openiddict/openiddict-core/issues/499 for more information. - - IQueryable Query(IQueryable authorizations, IQueryable tokens, TKey key) - => from token in tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() - join authorization in authorizations.AsTracking() on token.Authorization.Id equals authorization.Id - where authorization.Id.Equals(key) - select token; - - return ImmutableArray.CreateRange(await Query( - Authorizations, Tokens, ConvertIdentifierFromString(identifier)).ToListAsync(cancellationToken)); + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id) => + from token in context.Set() + .Include(token => token.Application) + .Include(token => token.Authorization) + .AsTracking() + join authorization in context.Set().AsTracking() on token.Authorization.Id equals authorization.Id + where authorization.Id.Equals(id) + select token); + }); + + return ImmutableArray.CreateRange(await query(Context, + ConvertIdentifierFromString(identifier)).ToListAsync(cancellationToken)); } /// @@ -245,17 +259,91 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); } - var token = (from entry in Context.ChangeTracker.Entries() - where entry.Entity != null && - entry.Entity.Id.Equals(ConvertIdentifierFromString(identifier)) - select entry.Entity).FirstOrDefault(); + const string key = nameof(FindByIdAsync) + "\x1e" + nameof(identifier); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, TKey id) => + (from token in context.Set() + .Include(token => token.Application) + .Include(token => token.Authorization) + .AsTracking() + where token.Id.Equals(id) + select token).FirstOrDefault()); + }); + + return query(Context, ConvertIdentifierFromString(identifier)); + } + + /// + /// Retrieves the list of tokens corresponding to the specified reference identifier. + /// Note: the reference identifier may be hashed or encrypted for security reasons. + /// + /// The reference identifier associated with the tokens. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the specified reference identifier. + /// + public override Task FindByReferenceIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + const string key = nameof(FindByReferenceIdAsync) + "\x1e" + nameof(identifier); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string id) => + (from token in context.Set() + .Include(token => token.Application) + .Include(token => token.Authorization) + .AsTracking() + where token.ReferenceId == id + select token).FirstOrDefault()); + }); + + return query(Context, identifier); + } - if (token != null) + /// + /// Retrieves the list of tokens corresponding to the specified subject. + /// + /// The subject associated with the tokens. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the specified subject. + /// + public override async Task> FindBySubjectAsync([NotNull] string subject, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) { - return Task.FromResult(token); + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); } - return base.FindByIdAsync(identifier, cancellationToken); + const string key = nameof(FindBySubjectAsync) + "\x1e" + nameof(subject); + + var query = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.NeverRemove); + + return EF.CompileAsyncQuery((TContext context, string principal) => + from token in context.Set() + .Include(token => token.Application) + .Include(token => token.Authorization) + .AsTracking() + where token.Subject == principal + select token); + }); + + return ImmutableArray.CreateRange(await query(Context, subject).ToListAsync(cancellationToken)); } ///