From 866c61f399acb01f5b1df370a9a94d1f736b8947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Chalet?= Date: Fri, 23 Feb 2018 23:52:38 +0100 Subject: [PATCH 1/2] Update the Entity Framework Core stores to be compatible with QueryTrackingBehavior.NoTracking --- .../Stores/OpenIddictAuthorizationStore.cs | 23 ++++++---- .../Stores/OpenIddictTokenStore.cs | 42 ++++++++++++------- .../Stores/OpenIddictApplicationStore.cs | 1 + .../Stores/OpenIddictAuthorizationStore.cs | 5 ++- .../Stores/OpenIddictTokenStore.cs | 6 ++- .../Stores/OpenIddictApplicationStore.cs | 13 +++--- .../Stores/OpenIddictAuthorizationStore.cs | 33 ++++++++------- .../Stores/OpenIddictScopeStore.cs | 4 +- .../Stores/OpenIddictTokenStore.cs | 22 ++++++---- 9 files changed, 88 insertions(+), 61 deletions(-) diff --git a/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs index 97dcc6cd..9b9a92f5 100644 --- a/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.Core/Stores/OpenIddictAuthorizationStore.cs @@ -281,7 +281,7 @@ namespace OpenIddict.Core /// A that can be used to monitor the asynchronous operation, /// whose result returns the application identifier associated with the authorization. /// - public virtual async ValueTask GetApplicationIdAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + public virtual ValueTask GetApplicationIdAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) { if (authorization == null) { @@ -290,17 +290,22 @@ namespace OpenIddict.Core if (authorization.Application != null) { - return ConvertIdentifierToString(authorization.Application.Id); + return new ValueTask(ConvertIdentifierToString(authorization.Application.Id)); } - IQueryable Query(IQueryable authorizations, TKey key) - => from element in authorizations - where element.Id.Equals(key) && - element.Application != null - select element.Application.Id; + async Task RetrieveApplicationIdAsync() + { + IQueryable Query(IQueryable authorizations, TKey key) + => from element in authorizations + where element.Id.Equals(key) && + element.Application != null + select element.Application.Id; + + return ConvertIdentifierToString(await GetAsync( + (authorizations, key) => Query(authorizations, key), authorization.Id, cancellationToken)); + } - return ConvertIdentifierToString(await GetAsync( - (authorizations, key) => Query(authorizations, key), authorization.Id, cancellationToken)); + return new ValueTask(RetrieveApplicationIdAsync()); } /// diff --git a/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs index 9ef06ac0..3334fb16 100644 --- a/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.Core/Stores/OpenIddictTokenStore.cs @@ -239,7 +239,7 @@ namespace OpenIddict.Core /// A that can be used to monitor the asynchronous operation, /// whose result returns the application identifier associated with the token. /// - public virtual async ValueTask GetApplicationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + public virtual ValueTask GetApplicationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) { if (token == null) { @@ -248,16 +248,21 @@ namespace OpenIddict.Core if (token.Application != null) { - return ConvertIdentifierToString(token.Application.Id); + return new ValueTask(ConvertIdentifierToString(token.Application.Id)); } - IQueryable Query(IQueryable tokens, TKey key) - => from element in tokens - where element.Id.Equals(key) && - element.Application != null - select element.Application.Id; + async Task RetrieveApplicationIdAsync() + { + IQueryable Query(IQueryable tokens, TKey key) + => from element in tokens + where element.Id.Equals(key) && + element.Application != null + select element.Application.Id; + + return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + } - return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + return new ValueTask(RetrieveApplicationIdAsync()); } /// @@ -269,7 +274,7 @@ namespace OpenIddict.Core /// A that can be used to monitor the asynchronous operation, /// whose result returns the authorization identifier associated with the token. /// - public virtual async ValueTask GetAuthorizationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + public virtual ValueTask GetAuthorizationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) { if (token == null) { @@ -278,16 +283,21 @@ namespace OpenIddict.Core if (token.Authorization != null) { - return ConvertIdentifierToString(token.Authorization.Id); + return new ValueTask(ConvertIdentifierToString(token.Authorization.Id)); } - IQueryable Query(IQueryable tokens, TKey key) - => from element in tokens - where element.Id.Equals(key) && - element.Authorization != null - select element.Authorization.Id; + async Task RetrieveAuthorizationIdAsync() + { + IQueryable Query(IQueryable tokens, TKey key) + => from element in tokens + where element.Id.Equals(key) && + element.Authorization != null + select element.Authorization.Id; + + return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + } - return ConvertIdentifierToString(await GetAsync((tokens, key) => Query(tokens, key), token.Id, cancellationToken)); + return new ValueTask(RetrieveAuthorizationIdAsync()); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs index 6a1dc4f3..f2cbeb4a 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs @@ -168,6 +168,7 @@ namespace OpenIddict.EntityFramework Task> ListTokensAsync() => (from token in Tokens + where token.Authorization == null where token.Application.Id.Equals(application.Id) select token).ToListAsync(cancellationToken); diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs index 00939ab9..9f614ba7 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs @@ -167,7 +167,7 @@ namespace OpenIddict.EntityFramework where token.Authorization.Id.Equals(authorization.Id) select token).ToListAsync(cancellationToken); - // Remove all the tokens associated with the application. + // Remove all the tokens associated with the authorization. foreach (var token in await ListTokensAsync()) { Tokens.Remove(token); @@ -383,7 +383,8 @@ namespace OpenIddict.EntityFramework /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (authorization == null) { diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs index c401c05b..31fcb1c9 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictTokenStore.cs @@ -407,7 +407,8 @@ namespace OpenIddict.EntityFramework /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { @@ -452,7 +453,8 @@ namespace OpenIddict.EntityFramework /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetAuthorizationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetAuthorizationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs index fafc3ff1..7aa4db13 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs @@ -167,8 +167,8 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. Task> ListAuthorizationsAsync() - => (from authorization in Authorizations.Include(authorization => authorization.Tokens) - join element in Applications on authorization.Application.Id equals element.Id + => (from authorization in Authorizations.Include(authorization => authorization.Tokens).AsTracking() + join element in Applications.AsTracking() on authorization.Application.Id equals element.Id where element.Id.Equals(application.Id) select authorization).ToListAsync(cancellationToken); @@ -178,8 +178,9 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. Task> ListTokensAsync() - => (from token in Tokens - join element in Applications on token.Application.Id equals element.Id + => (from token in Tokens.AsTracking() + where token.Authorization == null + join element in Applications.AsTracking() on token.Application.Id equals element.Id where element.Id.Equals(application.Id) select token).ToListAsync(cancellationToken); @@ -246,7 +247,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return query(Applications, state).FirstOrDefaultAsync(cancellationToken); + return query(Applications.AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -270,7 +271,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return ImmutableArray.CreateRange(await query(Applications, state).ToListAsync(cancellationToken)); + return ImmutableArray.CreateRange(await query(Applications.AsTracking(), state).ToListAsync(cancellationToken)); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs index 654eef1f..850555e1 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs @@ -170,12 +170,12 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. Task> ListTokensAsync() - => (from token in Tokens - join element in Authorizations on token.Authorization.Id equals element.Id + => (from token in Tokens.AsTracking() + join element in Authorizations.AsTracking() on token.Authorization.Id equals element.Id where element.Id.Equals(authorization.Id) select token).ToListAsync(cancellationToken); - // Remove all the tokens associated with the application. + // Remove all the tokens associated with the authorization. foreach (var token in await ListTokensAsync()) { Context.Remove(token); @@ -217,9 +217,9 @@ namespace OpenIddict.EntityFrameworkCore IQueryable Query(IQueryable authorizations, IQueryable applications, TKey key, string principal) - => from authorization in authorizations.Include(authorization => authorization.Application) + => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() where authorization.Subject == principal - join application in applications on authorization.Application.Id equals application.Id + join application in applications.AsTracking() on authorization.Application.Id equals application.Id where application.Id.Equals(key) select authorization; @@ -264,10 +264,9 @@ namespace OpenIddict.EntityFrameworkCore IQueryable Query(IQueryable authorizations, IQueryable applications, TKey key, string principal, string state) - => from authorization in authorizations.Include(authorization => authorization.Application) - where authorization.Subject == principal && - authorization.Status == state - join application in applications on authorization.Application.Id equals application.Id + => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() + where authorization.Subject == principal && authorization.Status == state + join application in applications.AsTracking() on authorization.Application.Id equals application.Id where application.Id.Equals(key) select authorization; @@ -318,11 +317,11 @@ namespace OpenIddict.EntityFrameworkCore IQueryable Query(IQueryable authorizations, IQueryable applications, TKey key, string principal, string state, string kind) - => from authorization in authorizations.Include(authorization => authorization.Application) + => from authorization in authorizations.Include(authorization => authorization.Application).AsTracking() where authorization.Subject == principal && authorization.Status == state && authorization.Type == kind - join application in applications on authorization.Application.Id equals application.Id + join application in applications.AsTracking() on authorization.Application.Id equals application.Id where application.Id.Equals(key) select authorization; @@ -416,7 +415,9 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return query(Authorizations.Include(authorization => authorization.Application), state).FirstOrDefaultAsync(cancellationToken); + return query( + Authorizations.Include(authorization => authorization.Application) + .AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -441,7 +442,8 @@ namespace OpenIddict.EntityFrameworkCore } return ImmutableArray.CreateRange(await query( - Authorizations.Include(authorization => authorization.Application), state).ToListAsync(cancellationToken)); + Authorizations.Include(authorization => authorization.Application) + .AsTracking(), state).ToListAsync(cancellationToken)); } /// @@ -460,7 +462,7 @@ namespace OpenIddict.EntityFrameworkCore IList exceptions = null; IQueryable Query(IQueryable authorizations, int offset) - => (from authorization in authorizations.Include(authorization => authorization.Tokens) + => (from authorization in authorizations.Include(authorization => authorization.Tokens).AsTracking() where authorization.Status != OpenIddictConstants.Statuses.Valid || (authorization.Type == OpenIddictConstants.AuthorizationTypes.AdHoc && !authorization.Tokens.Any(token => token.Status == OpenIddictConstants.Statuses.Valid)) @@ -545,7 +547,8 @@ namespace OpenIddict.EntityFrameworkCore /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (authorization == null) { diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs index 018f1b34..26750a04 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictScopeStore.cs @@ -186,7 +186,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return query(Scopes, state).FirstOrDefaultAsync(cancellationToken); + return query(Scopes.AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -210,7 +210,7 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(query)); } - return ImmutableArray.CreateRange(await query(Scopes, state).ToListAsync(cancellationToken)); + return ImmutableArray.CreateRange(await query(Scopes.AsTracking(), state).ToListAsync(cancellationToken)); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs index ff047cc4..46f106e9 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictTokenStore.cs @@ -189,8 +189,8 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. IQueryable Query(IQueryable applications, IQueryable tokens, TKey key) - => from token in tokens.Include(token => token.Application).Include(token => token.Authorization) - join application in applications on token.Application.Id equals application.Id + => from token in tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() + join application in applications.AsTracking() on token.Application.Id equals application.Id where application.Id.Equals(key) select token; @@ -220,8 +220,8 @@ namespace OpenIddict.EntityFrameworkCore // See https://github.com/openiddict/openiddict-core/issues/499 for more information. IQueryable Query(IQueryable authorizations, IQueryable tokens, TKey key) - => from token in tokens.Include(token => token.Application).Include(token => token.Authorization) - join authorization in authorizations on token.Authorization.Id equals authorization.Id + => from token in tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking() + join authorization in authorizations.AsTracking() on token.Authorization.Id equals authorization.Id where authorization.Id.Equals(key) select token; @@ -317,7 +317,8 @@ namespace OpenIddict.EntityFrameworkCore return query( Tokens.Include(token => token.Application) - .Include(token => token.Authorization), state).FirstOrDefaultAsync(cancellationToken); + .Include(token => token.Authorization) + .AsTracking(), state).FirstOrDefaultAsync(cancellationToken); } /// @@ -379,7 +380,8 @@ namespace OpenIddict.EntityFrameworkCore return ImmutableArray.CreateRange(await query( Tokens.Include(token => token.Application) - .Include(token => token.Authorization), state).ToListAsync(cancellationToken)); + .Include(token => token.Authorization) + .AsTracking(), state).ToListAsync(cancellationToken)); } /// @@ -398,7 +400,7 @@ namespace OpenIddict.EntityFrameworkCore IList exceptions = null; IQueryable Query(IQueryable tokens, int offset) - => (from token in tokens + => (from token in tokens.AsTracking() where token.ExpirationDate < DateTimeOffset.UtcNow || token.Status != OpenIddictConstants.Statuses.Valid orderby token.Id @@ -481,7 +483,8 @@ namespace OpenIddict.EntityFrameworkCore /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetApplicationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetApplicationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { @@ -526,7 +529,8 @@ namespace OpenIddict.EntityFrameworkCore /// /// A that can be used to monitor the asynchronous operation. /// - public override async Task SetAuthorizationIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + public override async Task SetAuthorizationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) { if (token == null) { From cfcba5f79f6e2b5eff4d309a7ac86f81d7f82639 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Chalet?= Date: Sat, 17 Mar 2018 18:40:03 +0100 Subject: [PATCH 2/2] Update OpenIddictApplicationStore/OpenIddictAuthorizationStore.DeleteAsync() to use serializable transactions --- .../OpenIddictExtensions.cs | 3 +- .../Stores/OpenIddictApplicationStore.cs | 49 ++++++++++----- .../Stores/OpenIddictAuthorizationStore.cs | 36 +++++++++-- .../OpenIddictExtensions.cs | 3 +- .../Stores/OpenIddictApplicationStore.cs | 61 ++++++++++++++----- .../Stores/OpenIddictAuthorizationStore.cs | 46 ++++++++++++-- 6 files changed, 156 insertions(+), 42 deletions(-) diff --git a/src/OpenIddict.EntityFramework/OpenIddictExtensions.cs b/src/OpenIddict.EntityFramework/OpenIddictExtensions.cs index 431e5c7c..5e2cc25d 100644 --- a/src/OpenIddict.EntityFramework/OpenIddictExtensions.cs +++ b/src/OpenIddict.EntityFramework/OpenIddictExtensions.cs @@ -236,7 +236,8 @@ namespace Microsoft.Extensions.DependencyInjection builder.Entity() .HasMany(application => application.Tokens) .WithOptional(token => token.Authorization) - .Map(association => association.MapKey("AuthorizationId")); + .Map(association => association.MapKey("AuthorizationId")) + .WillCascadeOnDelete(); builder.Entity() .ToTable("OpenIddictAuthorizations"); diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs index f2cbeb4a..46ef278d 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictApplicationStore.cs @@ -7,6 +7,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Data; using System.Data.Entity; using System.Linq; using System.Threading; @@ -161,6 +162,19 @@ namespace OpenIddict.EntityFramework throw new ArgumentNullException(nameof(application)); } + DbContextTransaction CreateTransaction() + { + try + { + return Context.Database.BeginTransaction(IsolationLevel.Serializable); + } + + catch + { + return null; + } + } + Task> ListAuthorizationsAsync() => (from authorization in Authorizations.Include(authorization => authorization.Tokens) where authorization.Application.Id.Equals(application.Id) @@ -172,27 +186,34 @@ namespace OpenIddict.EntityFramework where token.Application.Id.Equals(application.Id) select token).ToListAsync(cancellationToken); - // Remove all the authorizations associated with the application and - // the tokens attached to these implicit or explicit authorizations. - foreach (var authorization in await ListAuthorizationsAsync()) + // To prevent an SQL exception from being thrown if a new associated entity is + // created after the existing entries have been listed, the following logic is + // executed in a serializable transaction, that will lock the affected tables. + using (var transaction = CreateTransaction()) { - foreach (var token in authorization.Tokens) + // Remove all the authorizations associated with the application and + // the tokens attached to these implicit or explicit authorizations. + foreach (var authorization in await ListAuthorizationsAsync()) + { + foreach (var token in authorization.Tokens) + { + Tokens.Remove(token); + } + + Authorizations.Remove(authorization); + } + + // Remove all the tokens associated with the application. + foreach (var token in await ListTokensAsync()) { Tokens.Remove(token); } - Authorizations.Remove(authorization); - } + Applications.Remove(application); - // Remove all the tokens associated with the application. - foreach (var token in await ListTokensAsync()) - { - Tokens.Remove(token); + await Context.SaveChangesAsync(cancellationToken); + transaction?.Commit(); } - - Applications.Remove(application); - - await Context.SaveChangesAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs index 9f614ba7..0fca5218 100644 --- a/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFramework/Stores/OpenIddictAuthorizationStore.cs @@ -162,20 +162,40 @@ namespace OpenIddict.EntityFramework throw new ArgumentNullException(nameof(authorization)); } + DbContextTransaction CreateTransaction() + { + try + { + return Context.Database.BeginTransaction(IsolationLevel.Serializable); + } + + catch + { + return null; + } + } + Task> ListTokensAsync() => (from token in Tokens where token.Authorization.Id.Equals(authorization.Id) select token).ToListAsync(cancellationToken); - // Remove all the tokens associated with the authorization. - foreach (var token in await ListTokensAsync()) + // To prevent an SQL exception from being thrown if a new associated entity is + // created after the existing entries have been listed, the following logic is + // executed in a serializable transaction, that will lock the affected tables. + using (var transaction = CreateTransaction()) { - Tokens.Remove(token); - } + // Remove all the tokens associated with the authorization. + foreach (var token in await ListTokensAsync()) + { + Tokens.Remove(token); + } - Authorizations.Remove(authorization); + Authorizations.Remove(authorization); - await Context.SaveChangesAsync(cancellationToken); + await Context.SaveChangesAsync(cancellationToken); + transaction?.Commit(); + } } /// @@ -347,6 +367,10 @@ namespace OpenIddict.EntityFramework break; } + // Note: new tokens may be attached after the authorizations were retrieved + // from the database since the transaction level is deliberately limited to + // repeatable read instead of serializable for performance reasons). In this + // case, the operation will fail, which is considered an acceptable risk. Authorizations.RemoveRange(authorizations); Tokens.RemoveRange(authorizations.SelectMany(authorization => authorization.Tokens)); diff --git a/src/OpenIddict.EntityFrameworkCore/OpenIddictExtensions.cs b/src/OpenIddict.EntityFrameworkCore/OpenIddictExtensions.cs index a18074a7..0cca4c1c 100644 --- a/src/OpenIddict.EntityFrameworkCore/OpenIddictExtensions.cs +++ b/src/OpenIddict.EntityFrameworkCore/OpenIddictExtensions.cs @@ -264,7 +264,8 @@ namespace Microsoft.Extensions.DependencyInjection entity.HasMany(authorization => authorization.Tokens) .WithOne(token => token.Authorization) .HasForeignKey("AuthorizationId") - .IsRequired(required: false); + .IsRequired(required: false) + .OnDelete(DeleteBehavior.Cascade); entity.ToTable("OpenIddictAuthorizations"); }); diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs index 7aa4db13..8cc655e4 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictApplicationStore.cs @@ -7,11 +7,14 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Data; using System.Linq; using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Caching.Memory; using OpenIddict.Core; using OpenIddict.Models; @@ -161,6 +164,29 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(application)); } + async Task CreateTransactionAsync() + { + // Note: transactions that specify an explicit isolation level are only supported by + // relational providers and trying to use them with a different provider results in + // an invalid operation exception being thrown at runtime. To prevent that, a manual + // check is made to ensure the underlying transaction manager is relational. + var manager = Context.Database.GetService(); + if (manager is IRelationalTransactionManager) + { + try + { + return await Context.Database.BeginTransactionAsync(IsolationLevel.Serializable, cancellationToken); + } + + catch + { + return null; + } + } + + return null; + } + // Note: due to a bug in Entity Framework Core's query visitor, the authorizations can't be // filtered using authorization.Application.Id.Equals(key). To work around this issue, // this local method uses an explicit join before applying the equality check. @@ -184,27 +210,34 @@ namespace OpenIddict.EntityFrameworkCore where element.Id.Equals(application.Id) select token).ToListAsync(cancellationToken); - // Remove all the authorizations associated with the application and - // the tokens attached to these implicit or explicit authorizations. - foreach (var authorization in await ListAuthorizationsAsync()) + // To prevent an SQL exception from being thrown if a new associated entity is + // created after the existing entries have been listed, the following logic is + // executed in a serializable transaction, that will lock the affected tables. + using (var transaction = await CreateTransactionAsync()) { - foreach (var token in authorization.Tokens) + // Remove all the authorizations associated with the application and + // the tokens attached to these implicit or explicit authorizations. + foreach (var authorization in await ListAuthorizationsAsync()) + { + foreach (var token in authorization.Tokens) + { + Context.Remove(token); + } + + Context.Remove(authorization); + } + + // Remove all the tokens associated with the application. + foreach (var token in await ListTokensAsync()) { Context.Remove(token); } - Context.Remove(authorization); - } + Context.Remove(application); - // Remove all the tokens associated with the application. - foreach (var token in await ListTokensAsync()) - { - Context.Remove(token); + await Context.SaveChangesAsync(cancellationToken); + transaction?.Commit(); } - - Context.Remove(application); - - await Context.SaveChangesAsync(cancellationToken); } /// diff --git a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs index 850555e1..5b0ed0ed 100644 --- a/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs +++ b/src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictAuthorizationStore.cs @@ -164,6 +164,29 @@ namespace OpenIddict.EntityFrameworkCore throw new ArgumentNullException(nameof(authorization)); } + async Task CreateTransactionAsync() + { + // Note: transactions that specify an explicit isolation level are only supported by + // relational providers and trying to use them with a different provider results in + // an invalid operation exception being thrown at runtime. To prevent that, a manual + // check is made to ensure the underlying transaction manager is relational. + var manager = Context.Database.GetService(); + if (manager is IRelationalTransactionManager) + { + try + { + return await Context.Database.BeginTransactionAsync(IsolationLevel.Serializable, cancellationToken); + } + + catch + { + return null; + } + } + + return null; + } + // Note: due to a bug in Entity Framework Core's query visitor, the tokens can't be // filtered using token.Application.Id.Equals(key). To work around this issue, // this local method uses an explicit join before applying the equality check. @@ -175,15 +198,22 @@ namespace OpenIddict.EntityFrameworkCore where element.Id.Equals(authorization.Id) select token).ToListAsync(cancellationToken); - // Remove all the tokens associated with the authorization. - foreach (var token in await ListTokensAsync()) + // To prevent an SQL exception from being thrown if a new associated entity is + // created after the existing entries have been listed, the following logic is + // executed in a serializable transaction, that will lock the affected tables. + using (var transaction = await CreateTransactionAsync()) { - Context.Remove(token); - } + // Remove all the tokens associated with the authorization. + foreach (var token in await ListTokensAsync()) + { + Context.Remove(token); + } - Context.Remove(authorization); + Context.Remove(authorization); - await Context.SaveChangesAsync(cancellationToken); + await Context.SaveChangesAsync(cancellationToken); + transaction?.Commit(); + } } /// @@ -511,6 +541,10 @@ namespace OpenIddict.EntityFrameworkCore break; } + // Note: new tokens may be attached after the authorizations were retrieved + // from the database since the transaction level is deliberately limited to + // repeatable read instead of serializable for performance reasons). In this + // case, the operation will fail, which is considered an acceptable risk. Context.RemoveRange(authorizations); Context.RemoveRange(authorizations.SelectMany(authorization => authorization.Tokens));