From 866c61f399acb01f5b1df370a9a94d1f736b8947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Chalet?= Date: Fri, 23 Feb 2018 23:52:38 +0100 Subject: [PATCH] Update the Entity Framework Core stores to be compatible with QueryTrackingBehavior.NoTracking --- .../Stores/OpenIddictAuthorizationStore.cs | 23 ++++++---- .../Stores/OpenIddictTokenStore.cs | 42 ++++++++++++------- .../Stores/OpenIddictApplicationStore.cs | 1 + .../Stores/OpenIddictAuthorizationStore.cs | 5 ++- .../Stores/OpenIddictTokenStore.cs | 6 ++- .../Stores/OpenIddictApplicationStore.cs | 13 +++--- .../Stores/OpenIddictAuthorizationStore.cs | 33 ++++++++------- .../Stores/OpenIddictScopeStore.cs | 4 +- .../Stores/OpenIddictTokenStore.cs | 22 ++++++---- 9 files changed, 88 insertions(+), 61 deletions(-) diff --git a/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs index 97dcc6cd..9b9a92f5 100644 --- a/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs @@ -281,7 +281,7 @@ namespace OpenIddict.Core /// A that can be used to monitor the asynchronous operation, /// whose result returns the application identifier associated with the authorization. /// - public virtual async ValueTask GetApplicationIdAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + public virtual ValueTask GetApplicationIdAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) { if (authorization == null) { @@ -290,17 +290,22 @@ namespace OpenIddict.Core if (authorization.Application != null) { - return ConvertIdentifierToString(authorization.Application.Id); + return new ValueTask(ConvertIdentifierToString(authorization.Application.Id)); } - IQueryable Query(IQueryable authorizations, TKey key) - => from element in authorizations - where element.Id.Equals(key) && - element.Application != null - select element.Application.Id; + async Task RetrieveApplicationIdAsync() + { + IQueryable Query(IQueryable authorizations, TKey key) + => from element in authorizations + where element.Id.Equals(key) && + element.Application != null + select element.Application.Id; + + return ConvertIdentifierToString(await GetAsync( + (authorizations, key) => Query(authorizations, key), authorization.Id, cancellationToken)); + } - return ConvertIdentifierToString(await GetAsync( - (authorizations, key) => Query(authorizations, key), authorization.Id, cancellationToken)); + return new ValueTask(RetrieveApplicationIdAsync()); } /// diff --git a/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs index 9ef06ac0..3334fb16 100644 --- a/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs @@ -239,7 +239,7 @@ namespace OpenIddict.Core /// A that can be used to monitor the asynchronous operation, /// whose result returns the application identifier associated with the token. /// - public virtual async ValueTask GetApplicationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + public virtual ValueTask GetApplicationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) { if (token == null) { @@ -248,16 +248,21 @@ namespace OpenIddict.Core if (token.Application != null) { - return ConvertIdentifierToString(token.Application.Id); + return new ValueTask(ConvertIdentifierToString(token.Application.Id)); } - IQueryable Query(IQueryable tokens, TKey key) - => from element in tokens - where element.Id.Equals(key) && - element.Application != null - select element.Application.Id; + async Task RetrieveApplicationIdAsync() + { + IQueryable Query(IQueryable tokens, TKey key) + => from element in tokens + where element.Id.Equals(key) && + element.Application != null + select element.Application.Id; + + return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + } - return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + return new ValueTask(RetrieveApplicationIdAsync()); } /// @@ -269,7 +274,7 @@ namespace OpenIddict.Core /// A that can be used to monitor the asynchronous operation, /// whose result returns the authorization identifier associated with the token. /// - public virtual async ValueTask GetAuthorizationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + public virtual ValueTask GetAuthorizationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) { if (token == null) { @@ -278,16 +283,21 @@ namespace OpenIddict.Core if (token.Authorization != null) { - return ConvertIdentifierToString(token.Authorization.Id); + return new ValueTask(ConvertIdentifierToString(token.Authorization.Id)); } - IQueryable Query(IQueryable tokens, TKey key) - => from element in tokens - where element.Id.Equals(key) && - element.Authorization != null - select element.Authorization.Id; + async Task RetrieveAuthorizationIdAsync() + { + IQueryable Query(IQueryable tokens, TKey key) + => from element in tokens + where element.Id.Equals(key) && + element.Authorization != null + select element.Authorization.Id; + + return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + } - return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + return new ValueTask(RetrieveAuthorizationIdAsync()); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs index 6a1dc4f3..f2cbeb4a 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs @@ -168,6 +168,7 @@ namespace OpenIddict.EntityFramework Task> ListTokensAsync() => (from token in Tokens + where token.Authorization == null where token.Application.Id.Equals(application.Id) select token).ToListAsync(cancellationToken); diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs index 00939ab9..9f614ba7 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs @@ -167,7 +167,7 @@ namespace OpenIddict.EntityFramework where token.Authorization.Id.Equals(authorization.Id) select token).ToListAsync(cancellationToken); - // Remove all the tokens associated with the application. + // Remove all the tokens associated with the authorization. foreach (var token in await ListTokensAsync()) { Tokens.Remove(token); @@ -383,7 +383,8 @@ namespace OpenIddict.EntityFramework /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (authorization == null) { diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs index c401c05b..31fcb1c9 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs @@ -407,7 +407,8 @@ namespace OpenIddict.EntityFramework /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { @@ -452,7 +453,8 @@ namespace OpenIddict.EntityFramework /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetAuthorizationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetAuthorizationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs index fafc3ff1..7aa4db13 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs @@ -167,8 +167,8 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. Task> ListAuthorizationsAsync() - => (from authorization in Authorizations.Include(authorization => authorization.Tokens) - join element in Applications on authorization.Application.Id equals element.Id + => (from authorization in Authorizations.Include(authorization => authorization.Tokens).AsTracking() + join element in Applications.AsTracking() on authorization.Application.Id equals element.Id where element.Id.Equals(application.Id) select authorization).ToListAsync(cancellationToken); @@ -178,8 +178,9 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. Task> ListTokensAsync() - => (from token in Tokens - join element in Applications on token.Application.Id equals element.Id + => (from token in Tokens.AsTracking() + where token.Authorization == null + join element in Applications.AsTracking() on token.Application.Id equals element.Id where element.Id.Equals(application.Id) select token).ToListAsync(cancellationToken); @@ -246,7 +247,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return query(Applications, state).FirstOrDefaultAsync(cancellationToken); + return query(Applications.AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -270,7 +271,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return ImmutableArray.CreateRange(await query(Applications, state).ToListAsync(cancellationToken)); + return ImmutableArray.CreateRange(await query(Applications.AsTracking(), state).ToListAsync(cancellationToken)); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs index 654eef1f..850555e1 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs @@ -170,12 +170,12 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. Task> ListTokensAsync() - => (from token in Tokens - join element in Authorizations on token.Authorization.Id equals element.Id + => (from token in Tokens.AsTracking() + join element in Authorizations.AsTracking() on token.Authorization.Id equals element.Id where element.Id.Equals(authorization.Id) select token).ToListAsync(cancellationToken); - // Remove all the tokens associated with the application. + // Remove all the tokens associated with the authorization. foreach (var token in await ListTokensAsync()) { Context.Remove(token); @@ -217,9 +217,9 @@ namespace OpenIddict.EntityFrameworkCore IQueryable Query(IQueryable authorizations, IQueryable applications, TKey key, string principal) - => from authorization in authorizations.Include(authorization => authorization.Application) + => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() where authorization.Subject == principal - join application in applications on authorization.Application.Id equals application.Id + join application in applications.AsTracking() on authorization.Application.Id equals application.Id where application.Id.Equals(key) select authorization; @@ -264,10 +264,9 @@ namespace OpenIddict.EntityFrameworkCore IQueryable Query(IQueryable authorizations, IQueryable applications, TKey key, string principal, string state) - => from authorization in authorizations.Include(authorization => authorization.Application) - where authorization.Subject == principal && - authorization.Status == state - join application in applications on authorization.Application.Id equals application.Id + => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() + where authorization.Subject == principal && authorization.Status == state + join application in applications.AsTracking() on authorization.Application.Id equals application.Id where application.Id.Equals(key) select authorization; @@ -318,11 +317,11 @@ namespace OpenIddict.EntityFrameworkCore IQueryable Query(IQueryable authorizations, IQueryable applications, TKey key, string principal, string state, string kind) - => from authorization in authorizations.Include(authorization => authorization.Application) + => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() where authorization.Subject == principal && authorization.Status == state && authorization.Type == kind - join application in applications on authorization.Application.Id equals application.Id + join application in applications.AsTracking() on authorization.Application.Id equals application.Id where application.Id.Equals(key) select authorization; @@ -416,7 +415,9 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return query(Authorizations.Include(authorization => authorization.Application), state).FirstOrDefaultAsync(cancellationToken); + return query( + Authorizations.Include(authorization => authorization.Application) + .AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -441,7 +442,8 @@ namespace OpenIddict.EntityFrameworkCore } return ImmutableArray.CreateRange(await query( - Authorizations.Include(authorization => authorization.Application), state).ToListAsync(cancellationToken)); + Authorizations.Include(authorization => authorization.Application) + .AsTracking(), state).ToListAsync(cancellationToken)); } /// @@ -460,7 +462,7 @@ namespace OpenIddict.EntityFrameworkCore IList exceptions = null; IQueryable Query(IQueryable authorizations, int offset) - => (from authorization in authorizations.Include(authorization => authorization.Tokens) + => (from authorization in authorizations.Include(authorization => authorization.Tokens).AsTracking() where authorization.Status != OpenIddictConstants.Statuses.Valid || (authorization.Type == OpenIddictConstants.AuthorizationTypes.AdHoc && !authorization.Tokens.Any(token => token.Status == OpenIddictConstants.Statuses.Valid)) @@ -545,7 +547,8 @@ namespace OpenIddict.EntityFrameworkCore /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (authorization == null) { diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs index 018f1b34..26750a04 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs @@ -186,7 +186,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return query(Scopes, state).FirstOrDefaultAsync(cancellationToken); + return query(Scopes.AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -210,7 +210,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return ImmutableArray.CreateRange(await query(Scopes, state).ToListAsync(cancellationToken)); + return ImmutableArray.CreateRange(await query(Scopes.AsTracking(), state).ToListAsync(cancellationToken)); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs index ff047cc4..46f106e9 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs @@ -189,8 +189,8 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. IQueryable Query(IQueryable applications, IQueryable tokens, TKey key) - => from token in tokens.Include(token => token.Application).Include(token => token.Authorization) - join application in applications on token.Application.Id equals application.Id + => from token in tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() + join application in applications.AsTracking() on token.Application.Id equals application.Id where application.Id.Equals(key) select token; @@ -220,8 +220,8 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. IQueryable Query(IQueryable authorizations, IQueryable tokens, TKey key) - => from token in tokens.Include(token => token.Application).Include(token => token.Authorization) - join authorization in authorizations on token.Authorization.Id equals authorization.Id + => from token in tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() + join authorization in authorizations.AsTracking() on token.Authorization.Id equals authorization.Id where authorization.Id.Equals(key) select token; @@ -317,7 +317,8 @@ namespace OpenIddict.EntityFrameworkCore return query( Tokens.Include(token => token.Application) - .Include(token => token.Authorization), state).FirstOrDefaultAsync(cancellationToken); + .Include(token => token.Authorization) + .AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -379,7 +380,8 @@ namespace OpenIddict.EntityFrameworkCore return ImmutableArray.CreateRange(await query( Tokens.Include(token => token.Application) - .Include(token => token.Authorization), state).ToListAsync(cancellationToken)); + .Include(token => token.Authorization) + .AsTracking(), state).ToListAsync(cancellationToken)); } /// @@ -398,7 +400,7 @@ namespace OpenIddict.EntityFrameworkCore IList exceptions = null; IQueryable Query(IQueryable tokens, int offset) - => (from token in tokens + => (from token in tokens.AsTracking() where token.ExpirationDate < DateTimeOffset.UtcNow || token.Status != OpenIddictConstants.Statuses.Valid orderby token.Id @@ -481,7 +483,8 @@ namespace OpenIddict.EntityFrameworkCore /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { @@ -526,7 +529,8 @@ namespace OpenIddict.EntityFrameworkCore /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetAuthorizationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetAuthorizationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) {