Browse Source

Update the EF 6/EF Core stores to retrieve the entities from the change tracker when available

pull/1975/head
Kévin Chalet 2 years ago
parent
commit
b0371ae00e
  1. 34
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs
  2. 16
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs
  3. 32
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs
  4. 32
      src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkTokenStore.cs
  5. 33
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs
  6. 16
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs
  7. 32
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs
  8. 32
      src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreTokenStore.cs

34
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkApplicationStore.cs

@ -210,31 +210,47 @@ public class OpenIddictEntityFrameworkApplicationStore<TApplication, TAuthorizat
}
/// <inheritdoc/>
public virtual async ValueTask<TApplication?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TApplication?> FindByClientIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier));
}
var key = ConvertIdentifierFromString(identifier);
return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync());
TApplication? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TApplication>()
where string.Equals(entry.Entity.ClientId, identifier, StringComparison.Ordinal)
select entry.Entity).FirstOrDefault();
return await (from application in Applications
where application.Id!.Equals(key)
select application).FirstOrDefaultAsync(cancellationToken);
Task<TApplication?> QueryAsync() =>
(from application in Applications
where application.ClientId == identifier
select application).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>
public virtual async ValueTask<TApplication?> FindByClientIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TApplication?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier));
}
return await (from application in Applications
where application.ClientId == identifier
select application).FirstOrDefaultAsync(cancellationToken);
var key = ConvertIdentifierFromString(identifier);
return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync());
TApplication? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TApplication>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TApplication?> QueryAsync() =>
(from application in Applications
where application.Id!.Equals(key)
select application).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

16
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkAuthorizationStore.cs

@ -336,7 +336,7 @@ public class OpenIddictEntityFrameworkAuthorizationStore<TAuthorization, TApplic
}
/// <inheritdoc/>
public virtual async ValueTask<TAuthorization?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TAuthorization?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -345,9 +345,17 @@ public class OpenIddictEntityFrameworkAuthorizationStore<TAuthorization, TApplic
var key = ConvertIdentifierFromString(identifier);
return await (from authorization in Authorizations.Include(authorization => authorization.Application)
where authorization.Id!.Equals(key)
select authorization).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TAuthorization authorization ? new(authorization) : new(QueryAsync());
TAuthorization? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TAuthorization>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TAuthorization?> QueryAsync() =>
(from authorization in Authorizations.Include(authorization => authorization.Application)
where authorization.Id!.Equals(key)
select authorization).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

32
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkScopeStore.cs

@ -130,7 +130,7 @@ public class OpenIddictEntityFrameworkScopeStore<TScope, TContext, TKey> : IOpen
}
/// <inheritdoc/>
public virtual async ValueTask<TScope?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TScope?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -139,22 +139,38 @@ public class OpenIddictEntityFrameworkScopeStore<TScope, TContext, TKey> : IOpen
var key = ConvertIdentifierFromString(identifier);
return await (from scope in Scopes
where scope.Id!.Equals(key)
select scope).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync());
TScope? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TScope>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TScope?> QueryAsync() =>
(from scope in Scopes
where scope.Id!.Equals(key)
select scope).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>
public virtual async ValueTask<TScope?> FindByNameAsync(string name, CancellationToken cancellationToken)
public virtual ValueTask<TScope?> FindByNameAsync(string name, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(name))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0202), nameof(name));
}
return await (from scope in Scopes
where scope.Name == name
select scope).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync());
TScope? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TScope>()
where string.Equals(entry.Entity.Name, name, StringComparison.Ordinal)
select entry.Entity).FirstOrDefault();
Task<TScope?> QueryAsync() =>
(from scope in Scopes
where scope.Name == name
select scope).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

32
src/OpenIddict.EntityFramework/Stores/OpenIddictEntityFrameworkTokenStore.cs

@ -262,7 +262,7 @@ public class OpenIddictEntityFrameworkTokenStore<TToken, TApplication, TAuthoriz
}
/// <inheritdoc/>
public virtual async ValueTask<TToken?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TToken?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -271,22 +271,38 @@ public class OpenIddictEntityFrameworkTokenStore<TToken, TApplication, TAuthoriz
var key = ConvertIdentifierFromString(identifier);
return await (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization)
where token.Id!.Equals(key)
select token).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync());
TToken? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TToken>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TToken?> QueryAsync() =>
(from token in Tokens.Include(token => token.Application).Include(token => token.Authorization)
where token.Id!.Equals(key)
select token).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>
public virtual async ValueTask<TToken?> FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TToken?> FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier));
}
return await (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization)
where token.ReferenceId == identifier
select token).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync());
TToken? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TToken>()
where string.Equals(entry.Entity.ReferenceId, identifier, StringComparison.Ordinal)
select entry.Entity).FirstOrDefault();
Task<TToken?> QueryAsync() =>
(from token in Tokens.Include(token => token.Application).Include(token => token.Authorization)
where token.ReferenceId == identifier
select token).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

33
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreApplicationStore.cs

@ -9,7 +9,6 @@ using System.ComponentModel;
using System.Data;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Net;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Encodings.Web;
@ -280,20 +279,28 @@ public class OpenIddictEntityFrameworkCoreApplicationStore<TApplication, TAuthor
}
/// <inheritdoc/>
public virtual async ValueTask<TApplication?> FindByClientIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TApplication?> FindByClientIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier));
}
return await (from application in Applications.AsTracking()
where application.ClientId == identifier
select application).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync());
TApplication? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TApplication>()
where string.Equals(entry.Entity.ClientId, identifier, StringComparison.Ordinal)
select entry.Entity).FirstOrDefault();
Task<TApplication?> QueryAsync() =>
(from application in Applications.AsTracking()
where application.ClientId == identifier
select application).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>
public virtual async ValueTask<TApplication?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TApplication?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -302,9 +309,17 @@ public class OpenIddictEntityFrameworkCoreApplicationStore<TApplication, TAuthor
var key = ConvertIdentifierFromString(identifier);
return await (from application in Applications.AsTracking()
where application.Id!.Equals(key)
select application).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TApplication application ? new(application) : new(QueryAsync());
TApplication? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TApplication>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TApplication?> QueryAsync() =>
(from application in Applications.AsTracking()
where application.Id!.Equals(key)
select application).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

16
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreAuthorizationStore.cs

@ -423,7 +423,7 @@ public class OpenIddictEntityFrameworkCoreAuthorizationStore<TAuthorization, TAp
}
/// <inheritdoc/>
public virtual async ValueTask<TAuthorization?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TAuthorization?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -432,9 +432,17 @@ public class OpenIddictEntityFrameworkCoreAuthorizationStore<TAuthorization, TAp
var key = ConvertIdentifierFromString(identifier);
return await (from authorization in Authorizations.Include(authorization => authorization.Application).AsTracking()
where authorization.Id!.Equals(key)
select authorization).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TAuthorization authorization ? new(authorization) : new(QueryAsync());
TAuthorization? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TAuthorization>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TAuthorization?> QueryAsync() =>
(from authorization in Authorizations.Include(authorization => authorization.Application).AsTracking()
where authorization.Id!.Equals(key)
select authorization).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

32
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreScopeStore.cs

@ -146,7 +146,7 @@ public class OpenIddictEntityFrameworkCoreScopeStore<TScope, TContext, TKey> : I
}
/// <inheritdoc/>
public virtual async ValueTask<TScope?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TScope?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -155,22 +155,38 @@ public class OpenIddictEntityFrameworkCoreScopeStore<TScope, TContext, TKey> : I
var key = ConvertIdentifierFromString(identifier);
return await (from scope in Scopes.AsTracking()
where scope.Id!.Equals(key)
select scope).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync());
TScope? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TScope>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TScope?> QueryAsync() =>
(from scope in Scopes.AsTracking()
where scope.Id!.Equals(key)
select scope).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>
public virtual async ValueTask<TScope?> FindByNameAsync(string name, CancellationToken cancellationToken)
public virtual ValueTask<TScope?> FindByNameAsync(string name, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(name))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0202), nameof(name));
}
return await (from scope in Scopes.AsTracking()
where scope.Name == name
select scope).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TScope scope ? new(scope) : new(QueryAsync());
TScope? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TScope>()
where string.Equals(entry.Entity.Name, name, StringComparison.Ordinal)
select entry.Entity).FirstOrDefault();
Task<TScope?> QueryAsync() =>
(from scope in Scopes.AsTracking()
where scope.Name == name
select scope).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

32
src/OpenIddict.EntityFrameworkCore/Stores/OpenIddictEntityFrameworkCoreTokenStore.cs

@ -311,7 +311,7 @@ public class OpenIddictEntityFrameworkCoreTokenStore<TToken, TApplication, TAuth
}
/// <inheritdoc/>
public virtual async ValueTask<TToken?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TToken?> FindByIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
@ -320,22 +320,38 @@ public class OpenIddictEntityFrameworkCoreTokenStore<TToken, TApplication, TAuth
var key = ConvertIdentifierFromString(identifier);
return await (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking()
where token.Id!.Equals(key)
select token).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync());
TToken? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TToken>()
where entry.Entity.Id is TKey identifier && identifier.Equals(key)
select entry.Entity).FirstOrDefault();
Task<TToken?> QueryAsync() =>
(from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking()
where token.Id!.Equals(key)
select token).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>
public virtual async ValueTask<TToken?> FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken)
public virtual ValueTask<TToken?> FindByReferenceIdAsync(string identifier, CancellationToken cancellationToken)
{
if (string.IsNullOrEmpty(identifier))
{
throw new ArgumentException(SR.GetResourceString(SR.ID0195), nameof(identifier));
}
return await (from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking()
where token.ReferenceId == identifier
select token).FirstOrDefaultAsync(cancellationToken);
return GetTrackedEntity() is TToken token ? new(token) : new(QueryAsync());
TToken? GetTrackedEntity() =>
(from entry in Context.ChangeTracker.Entries<TToken>()
where string.Equals(entry.Entity.ReferenceId, identifier, StringComparison.Ordinal)
select entry.Entity).FirstOrDefault();
Task<TToken?> QueryAsync() =>
(from token in Tokens.Include(token => token.Application).Include(token => token.Authorization).AsTracking()
where token.ReferenceId == identifier
select token).FirstOrDefaultAsync(cancellationToken);
}
/// <inheritdoc/>

Loading…
Cancel
Save