From 002eff080fa57e950f01d7fd99c379c9b65a79b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Chalet?= Date: Sat, 24 Oct 2015 19:28:40 +0200 Subject: [PATCH] Update OpenIddictStore to support database contexts derived from DbContext --- src/OpenIddict.EF/OpenIddictExtensions.cs | 62 +++++++++++++++-------- src/OpenIddict.EF/OpenIddictStore.cs | 13 +++-- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/OpenIddict.EF/OpenIddictExtensions.cs b/src/OpenIddict.EF/OpenIddictExtensions.cs index 9d970547..e136dd63 100644 --- a/src/OpenIddict.EF/OpenIddictExtensions.cs +++ b/src/OpenIddict.EF/OpenIddictExtensions.cs @@ -9,6 +9,7 @@ using System.Linq; using System.Reflection; using Microsoft.AspNet.Identity; using Microsoft.AspNet.Identity.EntityFramework; +using Microsoft.Data.Entity; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using OpenIddict; @@ -16,35 +17,52 @@ using OpenIddict; namespace Microsoft.AspNet.Builder { public static class OpenIddictExtensions { public static OpenIddictBuilder AddEntityFrameworkStore([NotNull] this OpenIddictBuilder builder) { - // Resolve the key type from the user type definition. - var keyType = ResolveKeyType(builder); - builder.Services.AddScoped( typeof(IOpenIddictStore<,>).MakeGenericType(builder.UserType, builder.ApplicationType), - typeof(OpenIddictStore<,,,>).MakeGenericType(builder.UserType, builder.ApplicationType, builder.RoleType, keyType)); - - var type = typeof(OpenIddictContext<,,,>).MakeGenericType(new[] { - /* TUser: */ builder.UserType, - /* TApplication: */ builder.ApplicationType, - /* TRole: */ builder.RoleType, - /* TKey: */ keyType - }); + typeof(OpenIddictStore<,,,,>).MakeGenericType( + /* TUser: */ builder.UserType, + /* TApplication: */ builder.ApplicationType, + /* TRole: */ builder.RoleType, + /* TContext: */ ResolveContextType(builder), + /* TKey: */ ResolveKeyType(builder))); + + return builder; + } + + private static Type ResolveContextType([NotNull] OpenIddictBuilder builder) { + var service = (from registration in builder.Services + where registration.ServiceType.IsConstructedGenericType + let definition = registration.ServiceType.GetGenericTypeDefinition() + where definition == typeof(IUserStore<>) + select registration.ImplementationType).SingleOrDefault(); - builder.Services.AddScoped(type, provider => { - // Resolve the user store from the parent container and extract the associated context. - dynamic store = provider.GetRequiredService(typeof(IUserStore<>).MakeGenericType(builder.UserType)); + if (service == null) { + throw new InvalidOperationException( + "The type of the database context cannot be automatically inferred. " + + "Make sure 'AddOpenIddict()' is the last chained call when configuring the services."); + } - dynamic context = store?.Context; - if (!type.GetTypeInfo().IsAssignableFrom(context?.GetType())) { - throw new InvalidOperationException( - "Only EntityFramework contexts derived from " + - "OpenIddictContext can be used with OpenIddict."); + TypeInfo type; + for (type = service.GetTypeInfo(); type != null; type = type.BaseType?.GetTypeInfo()) { + if (!type.IsGenericType) { + continue; } - return context; - }); + var definition = type.GetGenericTypeDefinition(); + if (definition == null) { + continue; + } - return builder; + if (definition != typeof(UserStore<,,,>)) { + continue; + } + + return (from argument in type.AsType().GetGenericArguments() + where typeof(DbContext).IsAssignableFrom(argument) + select argument).Single(); + } + + throw new InvalidOperationException("The type of the database context cannot be automatically inferred."); } private static Type ResolveKeyType([NotNull] OpenIddictBuilder builder) { diff --git a/src/OpenIddict.EF/OpenIddictStore.cs b/src/OpenIddict.EF/OpenIddictStore.cs index 7d785bdd..140607a0 100644 --- a/src/OpenIddict.EF/OpenIddictStore.cs +++ b/src/OpenIddict.EF/OpenIddictStore.cs @@ -6,21 +6,26 @@ using Microsoft.Data.Entity; using OpenIddict.Models; namespace OpenIddict { - public class OpenIddictStore : UserStore, TKey>, IOpenIddictStore + public class OpenIddictStore : UserStore, IOpenIddictStore where TUser : IdentityUser where TApplication : Application where TRole : IdentityRole + where TContext : DbContext where TKey : IEquatable { - public OpenIddictStore(OpenIddictContext context) + public OpenIddictStore(TContext context) : base(context) { } + public DbSet Applications { + get { return Context.Set(); } + } + public virtual Task FindApplicationByIdAsync(string identifier, CancellationToken cancellationToken) { - return Context.Applications.SingleOrDefaultAsync(application => application.ApplicationID == identifier, cancellationToken); + return Applications.SingleOrDefaultAsync(application => application.ApplicationID == identifier, cancellationToken); } public virtual Task FindApplicationByLogoutRedirectUri(string url, CancellationToken cancellationToken) { - return Context.Applications.SingleOrDefaultAsync(application => application.LogoutRedirectUri == url, cancellationToken); + return Applications.SingleOrDefaultAsync(application => application.LogoutRedirectUri == url, cancellationToken); } public virtual Task GetApplicationTypeAsync(TApplication application, CancellationToken cancellationToken) {