From b0371ae00e003b46b73d420e2f2d0ed96c394632 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Chalet?= Date: Sat, 3 Feb 2024 18:19:04 +0100 Subject: [PATCH] Update the EF 6/EF Core stores to retrieve the entities from the change tracker when available --- ...enIddictEntityFrameworkApplicationStore.cs | 34 ++++++++++++++----- ...IddictEntityFrameworkAuthorizationStore.cs | 16 ++++++--- .../OpenIddictEntityFrameworkScopeStore.cs | 32 ++++++++++++----- .../OpenIddictEntityFrameworkTokenStore.cs | 32 ++++++++++++----- ...dictEntityFrameworkCoreApplicationStore.cs | 33 +++++++++++++----- ...ctEntityFrameworkCoreAuthorizationStore.cs | 16 ++++++--- ...OpenIddictEntityFrameworkCoreScopeStore.cs | 32 ++++++++++++----- ...OpenIddictEntityFrameworkCoreTokenStore.cs | 32 ++++++++++++----- 8 files changed, 169 insertions(+), 58 deletions(-) diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs index 801394c3..f33a741c 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs @@ -210,31 +210,47 @@ public class OpenIddictEntityFrameworkApplicationStore - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByClientIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier)); } - var key = ConvertIdentifierFromString(identifier); + return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync()); + + TApplication? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where string.Equals(entry.Entity.ClientId, identifier, StringComparison.Ordinal) + select entry.Entity).FirstOrDefault(); - return await (from application in Applications - where application.Id!.Equals(key) - select application).FirstOrDefaultAsync(cancellationToken); + Task QueryAsync() => + (from application in Applications + where application.ClientId == identifier + select application).FirstOrDefaultAsync(cancellationToken); } /// - public virtual async ValueTask FindByClientIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier)); } - return await (from application in Applications - where application.ClientId == identifier - select application).FirstOrDefaultAsync(cancellationToken); + var key = ConvertIdentifierFromString(identifier); + + return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync()); + + TApplication? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from application in Applications + where application.Id!.Equals(key) + select application).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs index 771c1ae9..b92b1597 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs @@ -336,7 +336,7 @@ public class OpenIddictEntityFrameworkAuthorizationStore - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -345,9 +345,17 @@ public class OpenIddictEntityFrameworkAuthorizationStore authorization.Application) - where authorization.Id!.Equals(key) - select authorization).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TAuthorization authorization ? new(authorization) : new(QueryAsync()); + + TAuthorization? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from authorization in Authorizations.Include(authorization => authorization.Application) + where authorization.Id!.Equals(key) + select authorization).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs index b468f092..dc1bcbc2 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs @@ -130,7 +130,7 @@ public class OpenIddictEntityFrameworkScopeStore : IOpen } /// - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -139,22 +139,38 @@ public class OpenIddictEntityFrameworkScopeStore : IOpen var key = ConvertIdentifierFromString(identifier); - return await (from scope in Scopes - where scope.Id!.Equals(key) - select scope).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync()); + + TScope? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from scope in Scopes + where scope.Id!.Equals(key) + select scope).FirstOrDefaultAsync(cancellationToken); } /// - public virtual async ValueTask FindByNameAsync(string name, CancellationToken cancellationToken) + public virtual ValueTask FindByNameAsync(string name, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(name)) { throw new ArgumentException(SR.GetResourceString(SR.ID0202), nameof(name)); } - return await (from scope in Scopes - where scope.Name == name - select scope).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync()); + + TScope? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where string.Equals(entry.Entity.Name, name, StringComparison.Ordinal) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from scope in Scopes + where scope.Name == name + select scope).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkTokenStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkTokenStore.cs index d48e6ff8..4eb211b1 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkTokenStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkTokenStore.cs @@ -262,7 +262,7 @@ public class OpenIddictEntityFrameworkTokenStore - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -271,22 +271,38 @@ public class OpenIddictEntityFrameworkTokenStore token.Application).Include(token => token.Authorization) - where token.Id!.Equals(key) - select token).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync()); + + TToken? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization) + where token.Id!.Equals(key) + select token).FirstOrDefaultAsync(cancellationToken); } /// - public virtual async ValueTask FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier)); } - return await (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization) - where token.ReferenceId == identifier - select token).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync()); + + TToken? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where string.Equals(entry.Entity.ReferenceId, identifier, StringComparison.Ordinal) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization) + where token.ReferenceId == identifier + select token).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs index 02fa0de1..7ed453eb 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs @@ -9,7 +9,6 @@ using System.ComponentModel; using System.Data; using System.Diagnostics.CodeAnalysis; using System.Globalization; -using System.Net; using System.Runtime.CompilerServices; using System.Text; using System.Text.Encodings.Web; @@ -280,20 +279,28 @@ public class OpenIddictEntityFrameworkCoreApplicationStore - public virtual async ValueTask FindByClientIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByClientIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier)); } - return await (from application in Applications.AsTracking() - where application.ClientId == identifier - select application).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync()); + + TApplication? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where string.Equals(entry.Entity.ClientId, identifier, StringComparison.Ordinal) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from application in Applications.AsTracking() + where application.ClientId == identifier + select application).FirstOrDefaultAsync(cancellationToken); } /// - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -302,9 +309,17 @@ public class OpenIddictEntityFrameworkCoreApplicationStore + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from application in Applications.AsTracking() + where application.Id!.Equals(key) + select application).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs index bf09aa67..5a854423 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs @@ -423,7 +423,7 @@ public class OpenIddictEntityFrameworkCoreAuthorizationStore - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -432,9 +432,17 @@ public class OpenIddictEntityFrameworkCoreAuthorizationStore authorization.Application).AsTracking() - where authorization.Id!.Equals(key) - select authorization).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TAuthorization authorization ? new(authorization) : new(QueryAsync()); + + TAuthorization? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from authorization in Authorizations.Include(authorization => authorization.Application).AsTracking() + where authorization.Id!.Equals(key) + select authorization).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs index 326f59b5..6436df02 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs @@ -146,7 +146,7 @@ public class OpenIddictEntityFrameworkCoreScopeStore : I } /// - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -155,22 +155,38 @@ public class OpenIddictEntityFrameworkCoreScopeStore : I var key = ConvertIdentifierFromString(identifier); - return await (from scope in Scopes.AsTracking() - where scope.Id!.Equals(key) - select scope).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync()); + + TScope? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from scope in Scopes.AsTracking() + where scope.Id!.Equals(key) + select scope).FirstOrDefaultAsync(cancellationToken); } /// - public virtual async ValueTask FindByNameAsync(string name, CancellationToken cancellationToken) + public virtual ValueTask FindByNameAsync(string name, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(name)) { throw new ArgumentException(SR.GetResourceString(SR.ID0202), nameof(name)); } - return await (from scope in Scopes.AsTracking() - where scope.Name == name - select scope).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync()); + + TScope? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where string.Equals(entry.Entity.Name, name, StringComparison.Ordinal) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from scope in Scopes.AsTracking() + where scope.Name == name + select scope).FirstOrDefaultAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreTokenStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreTokenStore.cs index 4700ba07..defa48ff 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreTokenStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreTokenStore.cs @@ -311,7 +311,7 @@ public class OpenIddictEntityFrameworkCoreTokenStore - public virtual async ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { @@ -320,22 +320,38 @@ public class OpenIddictEntityFrameworkCoreTokenStore token.Application).Include(token => token.Authorization).AsTracking() - where token.Id!.Equals(key) - select token).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync()); + + TToken? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where entry.Entity.Id is TKey identifier && identifier.Equals(key) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() + where token.Id!.Equals(key) + select token).FirstOrDefaultAsync(cancellationToken); } /// - public virtual async ValueTask FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken) + public virtual ValueTask FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken) { if (string.IsNullOrEmpty(identifier)) { throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier)); } - return await (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() - where token.ReferenceId == identifier - select token).FirstOrDefaultAsync(cancellationToken); + return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync()); + + TToken? GetTrackedEntity() => + (from entry in Context.ChangeTracker.Entries() + where string.Equals(entry.Entity.ReferenceId, identifier, StringComparison.Ordinal) + select entry.Entity).FirstOrDefault(); + + Task QueryAsync() => + (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() + where token.ReferenceId == identifier + select token).FirstOrDefaultAsync(cancellationToken); } ///