diff --git a/OpenIddict.sln b/OpenIddict.sln index 093df77b..17758e90 100644 --- a/OpenIddict.sln +++ b/OpenIddict.sln @@ -67,6 +67,12 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "shared", "shared", "{D8075F EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenIddict.Extensions", "shared\OpenIddict.Extensions\OpenIddict.Extensions.csproj", "{B90761B9-7582-44CB-AB0D-3C4058693227}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenIddict.NHibernate", "src\OpenIddict.NHibernate\OpenIddict.NHibernate.csproj", "{17BFF448-F11F-40D6-B658-BD81B306D2CA}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenIddict.NHibernate.Models", "src\OpenIddict.NHibernate.Models\OpenIddict.NHibernate.Models.csproj", "{22882DA6-6A5F-4E48-8BDC-7248B1DE5D14}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenIddict.NHibernate.Tests", "test\OpenIddict.NHibernate.Tests\OpenIddict.NHibernate.Tests.csproj", "{B99BCBEC-9771-4C68-96E2-1A54E9BC432D}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -165,6 +171,18 @@ Global {B90761B9-7582-44CB-AB0D-3C4058693227}.Debug|Any CPU.Build.0 = Debug|Any CPU {B90761B9-7582-44CB-AB0D-3C4058693227}.Release|Any CPU.ActiveCfg = Release|Any CPU {B90761B9-7582-44CB-AB0D-3C4058693227}.Release|Any CPU.Build.0 = Release|Any CPU + {17BFF448-F11F-40D6-B658-BD81B306D2CA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {17BFF448-F11F-40D6-B658-BD81B306D2CA}.Debug|Any CPU.Build.0 = Debug|Any CPU + {17BFF448-F11F-40D6-B658-BD81B306D2CA}.Release|Any CPU.ActiveCfg = Release|Any CPU + {17BFF448-F11F-40D6-B658-BD81B306D2CA}.Release|Any CPU.Build.0 = Release|Any CPU + {22882DA6-6A5F-4E48-8BDC-7248B1DE5D14}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {22882DA6-6A5F-4E48-8BDC-7248B1DE5D14}.Debug|Any CPU.Build.0 = Debug|Any CPU + {22882DA6-6A5F-4E48-8BDC-7248B1DE5D14}.Release|Any CPU.ActiveCfg = Release|Any CPU + {22882DA6-6A5F-4E48-8BDC-7248B1DE5D14}.Release|Any CPU.Build.0 = Release|Any CPU + {B99BCBEC-9771-4C68-96E2-1A54E9BC432D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B99BCBEC-9771-4C68-96E2-1A54E9BC432D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B99BCBEC-9771-4C68-96E2-1A54E9BC432D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B99BCBEC-9771-4C68-96E2-1A54E9BC432D}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -193,6 +211,9 @@ Global {8FACE85E-EF8F-4AB1-85DD-4010D5E2165D} = {5FC71D6A-A994-4F62-977F-88A7D25379D7} {27F603EF-D335-445B-9443-6B5A6CA3C110} = {5FC71D6A-A994-4F62-977F-88A7D25379D7} {B90761B9-7582-44CB-AB0D-3C4058693227} = {D8075F1F-6257-463B-B481-BDC7C5ABA292} + {17BFF448-F11F-40D6-B658-BD81B306D2CA} = {D544447C-D701-46BB-9A5B-C76C612A596B} + {22882DA6-6A5F-4E48-8BDC-7248B1DE5D14} = {D544447C-D701-46BB-9A5B-C76C612A596B} + {B99BCBEC-9771-4C68-96E2-1A54E9BC432D} = {5FC71D6A-A994-4F62-977F-88A7D25379D7} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {A710059F-0466-4D48-9B3A-0EF4F840B616} diff --git a/build/dependencies.props b/build/dependencies.props index 3bcaf1bd..94e21b11 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -16,6 +16,7 @@ 4.7.63 2.0.0 2.0.0 + 5.2.2 2.0.0 4.4.0 15.3.0 diff --git a/src/OpenIddict.NHibernate.Models/OpenIddict.NHibernate.Models.csproj b/src/OpenIddict.NHibernate.Models/OpenIddict.NHibernate.Models.csproj new file mode 100644 index 00000000..c27210f7 --- /dev/null +++ b/src/OpenIddict.NHibernate.Models/OpenIddict.NHibernate.Models.csproj @@ -0,0 +1,15 @@ + + + + + + netstandard2.0 + + + + Relational entities for the NHibernate 5.x stores. + Kévin Chalet + aspnetcore;authentication;jwt;openidconnect;openiddict;security + + + diff --git a/src/OpenIddict.NHibernate.Models/OpenIddictApplication.cs b/src/OpenIddict.NHibernate.Models/OpenIddictApplication.cs new file mode 100644 index 00000000..f0b1b614 --- /dev/null +++ b/src/OpenIddict.NHibernate.Models/OpenIddictApplication.cs @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace OpenIddict.NHibernate.Models +{ + /// + /// Represents an OpenIddict application. + /// + public class OpenIddictApplication : OpenIddictApplication + { + public OpenIddictApplication() + { + // Generate a new string identifier. + Id = Guid.NewGuid().ToString(); + } + } + + /// + /// Represents an OpenIddict application. + /// + public class OpenIddictApplication : OpenIddictApplication, OpenIddictToken> + where TKey : IEquatable + { } + + /// + /// Represents an OpenIddict application. + /// + [DebuggerDisplay("Id = {Id.ToString(),nq} ; ClientId = {ClientId,nq} ; Type = {Type,nq}")] + public class OpenIddictApplication where TKey : IEquatable + { + /// + /// Gets or sets the list of the authorizations associated with this application. + /// + public virtual IList Authorizations { get; set; } = new List(); + + /// + /// Gets or sets the client identifier + /// associated with the current application. + /// + public virtual string ClientId { get; set; } + + /// + /// Gets or sets the client secret associated with the current application. + /// Note: depending on the application manager used to create this instance, + /// this property may be hashed or encrypted for security reasons. + /// + public virtual string ClientSecret { get; set; } + + /// + /// Gets or sets the consent type + /// associated with the current application. + /// + public virtual string ConsentType { get; set; } + + /// + /// Gets or sets the display name + /// associated with the current application. + /// + public virtual string DisplayName { get; set; } + + /// + /// Gets or sets the unique identifier + /// associated with the current application. + /// + public virtual TKey Id { get; set; } + + /// + /// Gets or sets the permissions associated with the + /// current application, serialized as a JSON array. + /// + public virtual string Permissions { get; set; } + + /// + /// Gets or sets the logout callback URLs associated with + /// the current application, serialized as a JSON array. + /// + public virtual string PostLogoutRedirectUris { get; set; } + + /// + /// Gets or sets the additional properties serialized as a JSON object, + /// or null if no bag was associated with the current application. + /// + public virtual string Properties { get; set; } + + /// + /// Gets or sets the callback URLs associated with the + /// current application, serialized as a JSON array. + /// + public virtual string RedirectUris { get; set; } + + /// + /// Gets or sets the list of the tokens associated with this application. + /// + public virtual IList Tokens { get; set; } = new List(); + + /// + /// Gets or sets the application type + /// associated with the current application. + /// + public virtual string Type { get; set; } + + /// + /// Gets or sets the entity version, used as a concurrency token. + /// + public virtual int Version { get; set; } + } +} \ No newline at end of file diff --git a/src/OpenIddict.NHibernate.Models/OpenIddictAuthorization.cs b/src/OpenIddict.NHibernate.Models/OpenIddictAuthorization.cs new file mode 100644 index 00000000..e6d11825 --- /dev/null +++ b/src/OpenIddict.NHibernate.Models/OpenIddictAuthorization.cs @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace OpenIddict.NHibernate.Models +{ + /// + /// Represents an OpenIddict authorization. + /// + public class OpenIddictAuthorization : OpenIddictAuthorization + { + public OpenIddictAuthorization() + { + // Generate a new string identifier. + Id = Guid.NewGuid().ToString(); + } + } + + /// + /// Represents an OpenIddict authorization. + /// + public class OpenIddictAuthorization : OpenIddictAuthorization, OpenIddictToken> + where TKey : IEquatable + { } + + /// + /// Represents an OpenIddict authorization. + /// + [DebuggerDisplay("Id = {Id.ToString(),nq} ; Subject = {Subject,nq} ; Type = {Type,nq} ; Status = {Status,nq}")] + public class OpenIddictAuthorization where TKey : IEquatable + { + /// + /// Gets or sets the application associated with the current authorization. + /// + public virtual TApplication Application { get; set; } + + /// + /// Gets or sets the unique identifier + /// associated with the current authorization. + /// + public virtual TKey Id { get; set; } + + /// + /// Gets or sets the additional properties serialized as a JSON object, + /// or null if no bag was associated with the current authorization. + /// + public virtual string Properties { get; set; } + + /// + /// Gets or sets the scopes associated with the current + /// authorization, serialized as a JSON array. + /// + public virtual string Scopes { get; set; } + + /// + /// Gets or sets the status of the current authorization. + /// + public virtual string Status { get; set; } + + /// + /// Gets or sets the subject associated with the current authorization. + /// + public virtual string Subject { get; set; } + + /// + /// Gets or sets the list of tokens + /// associated with the current authorization. + /// + public virtual IList Tokens { get; set; } = new List(); + + /// + /// Gets or sets the type of the current authorization. + /// + public virtual string Type { get; set; } + + /// + /// Gets or sets the entity version, used as a concurrency token. + /// + public virtual int Version { get; set; } + } +} diff --git a/src/OpenIddict.NHibernate.Models/OpenIddictScope.cs b/src/OpenIddict.NHibernate.Models/OpenIddictScope.cs new file mode 100644 index 00000000..606d3a27 --- /dev/null +++ b/src/OpenIddict.NHibernate.Models/OpenIddictScope.cs @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Diagnostics; + +namespace OpenIddict.NHibernate.Models +{ + /// + /// Represents an OpenIddict scope. + /// + public class OpenIddictScope : OpenIddictScope + { + public OpenIddictScope() + { + // Generate a new string identifier. + Id = Guid.NewGuid().ToString(); + } + } + + /// + /// Represents an OpenIddict scope. + /// + [DebuggerDisplay("Id = {Id.ToString(),nq} ; Name = {Name,nq}")] + public class OpenIddictScope where TKey : IEquatable + { + /// + /// Gets or sets the public description + /// associated with the current scope. + /// + public virtual string Description { get; set; } + + /// + /// Gets or sets the display name + /// associated with the current scope. + /// + public virtual string DisplayName { get; set; } + + /// + /// Gets or sets the unique identifier + /// associated with the current scope. + /// + public virtual TKey Id { get; set; } + + /// + /// Gets or sets the unique name + /// associated with the current scope. + /// + public virtual string Name { get; set; } + + /// + /// Gets or sets the additional properties serialized as a JSON object, + /// or null if no bag was associated with the current scope. + /// + public virtual string Properties { get; set; } + + /// + /// Gets or sets the resources associated with the + /// current scope, serialized as a JSON array. + /// + public virtual string Resources { get; set; } + + /// + /// Gets or sets the entity version, used as a concurrency token. + /// + public virtual int Version { get; set; } + } +} diff --git a/src/OpenIddict.NHibernate.Models/OpenIddictToken.cs b/src/OpenIddict.NHibernate.Models/OpenIddictToken.cs new file mode 100644 index 00000000..1b248f23 --- /dev/null +++ b/src/OpenIddict.NHibernate.Models/OpenIddictToken.cs @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Diagnostics; + +namespace OpenIddict.NHibernate.Models +{ + /// + /// Represents an OpenIddict token. + /// + public class OpenIddictToken : OpenIddictToken + { + public OpenIddictToken() + { + // Generate a new string identifier. + Id = Guid.NewGuid().ToString(); + } + } + + /// + /// Represents an OpenIddict token. + /// + public class OpenIddictToken : OpenIddictToken, OpenIddictAuthorization> + where TKey : IEquatable + { + } + + /// + /// Represents an OpenIddict token. + /// + [DebuggerDisplay("Id = {Id.ToString(),nq} ; Subject = {Subject,nq} ; Type = {Type,nq} ; Status = {Status,nq}")] + public class OpenIddictToken where TKey : IEquatable + { + /// + /// Gets or sets the application associated with the current token. + /// + public virtual TApplication Application { get; set; } + + /// + /// Gets or sets the authorization associated with the current token. + /// + public virtual TAuthorization Authorization { get; set; } + + /// + /// Gets or sets the date on which the token + /// will start to be considered valid. + /// + public virtual DateTimeOffset? CreationDate { get; set; } + + /// + /// Gets or sets the date on which the token + /// will no longer be considered valid. + /// + public virtual DateTimeOffset? ExpirationDate { get; set; } + + /// + /// Gets or sets the unique identifier + /// associated with the current token. + /// + public virtual TKey Id { get; set; } + + /// + /// Gets or sets the payload of the current token, if applicable. + /// Note: this property is only used for reference tokens + /// and may be encrypted for security reasons. + /// + public virtual string Payload { get; set; } + + /// + /// Gets or sets the additional properties serialized as a JSON object, + /// or null if no bag was associated with the current token. + /// + public virtual string Properties { get; set; } + + /// + /// Gets or sets the reference identifier associated + /// with the current token, if applicable. + /// Note: this property is only used for reference tokens + /// and may be hashed or encrypted for security reasons. + /// + public virtual string ReferenceId { get; set; } + + /// + /// Gets or sets the status of the current token. + /// + public virtual string Status { get; set; } + + /// + /// Gets or sets the subject associated with the current token. + /// + public virtual string Subject { get; set; } + + /// + /// Gets or sets the type of the current token. + /// + public virtual string Type { get; set; } + + /// + /// Gets or sets the entity version, used as a concurrency token. + /// + public virtual int Version { get; set; } + } +} diff --git a/src/OpenIddict.NHibernate/IOpenIddictNHibernateContext.cs b/src/OpenIddict.NHibernate/IOpenIddictNHibernateContext.cs new file mode 100644 index 00000000..930daffa --- /dev/null +++ b/src/OpenIddict.NHibernate/IOpenIddictNHibernateContext.cs @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System.Threading; +using System.Threading.Tasks; +using NHibernate; + +namespace OpenIddict.NHibernate +{ + /// + /// Exposes the NHibernate session used by the OpenIddict stores. + /// + public interface IOpenIddictNHibernateContext + { + /// + /// Gets the . + /// + /// + /// A that can be used to monitor the + /// asynchronous operation, whose result returns the NHibernate session. + /// + ValueTask GetSessionAsync(CancellationToken cancellationToken); + } +} diff --git a/src/OpenIddict.NHibernate/Mappings/OpenIddictApplicationMapping.cs b/src/OpenIddict.NHibernate/Mappings/OpenIddictApplicationMapping.cs new file mode 100644 index 00000000..4e89af34 --- /dev/null +++ b/src/OpenIddict.NHibernate/Mappings/OpenIddictApplicationMapping.cs @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.ComponentModel; +using NHibernate.Mapping.ByCode; +using NHibernate.Mapping.ByCode.Conformist; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Defines a relational mapping for the Application entity. + /// + /// The type of the Application entity. + /// The type of the Authorization entity. + /// The type of the Token entity. + /// The type of the Key entity. + [EditorBrowsable(EditorBrowsableState.Never)] + public class OpenIddictApplicationMapping : ClassMapping + where TApplication : OpenIddictApplication + where TAuthorization : OpenIddictAuthorization + where TToken : OpenIddictToken + where TKey : IEquatable + { + public OpenIddictApplicationMapping() + { + Id(application => application.Id, map => + { + map.Generator(Generators.Identity); + }); + + Version(application => application.Version, map => + { + map.Insert(true); + }); + + Property(application => application.ClientId, map => + { + map.NotNullable(true); + map.Unique(true); + }); + + Property(application => application.ClientSecret); + + Property(application => application.ConsentType); + + Property(application => application.DisplayName); + + Property(application => application.Permissions, map => + { + map.Length(10000); + }); + + Property(application => application.PostLogoutRedirectUris, map => + { + map.Length(10000); + }); + + Property(application => application.Properties, map => + { + map.Length(10000); + }); + + Property(application => application.RedirectUris, map => + { + map.Length(10000); + }); + + Property(application => application.Type, map => + { + map.NotNullable(true); + }); + + Bag(application => application.Authorizations, + map => + { + map.Key(key => key.Column("ApplicationId")); + }, + map => + { + map.OneToMany(); + }); + + Bag(application => application.Tokens, + map => + { + map.Key(key => key.Column("ApplicationId")); + }, + map => + { + map.OneToMany(); + }); + + Table("OpenIddictApplications"); + } + } +} diff --git a/src/OpenIddict.NHibernate/Mappings/OpenIddictAuthorizationMapping.cs b/src/OpenIddict.NHibernate/Mappings/OpenIddictAuthorizationMapping.cs new file mode 100644 index 00000000..ddb5dcf6 --- /dev/null +++ b/src/OpenIddict.NHibernate/Mappings/OpenIddictAuthorizationMapping.cs @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.ComponentModel; +using NHibernate.Mapping.ByCode; +using NHibernate.Mapping.ByCode.Conformist; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Defines a relational mapping for the Authorization entity. + /// + /// The type of the Authorization entity. + /// The type of the Application entity. + /// The type of the Token entity. + /// The type of the Key entity. + [EditorBrowsable(EditorBrowsableState.Never)] + public class OpenIddictAuthorizationMapping : ClassMapping + where TAuthorization : OpenIddictAuthorization + where TApplication : OpenIddictApplication + where TToken : OpenIddictToken + where TKey : IEquatable + { + public OpenIddictAuthorizationMapping() + { + Id(authorization => authorization.Id, map => + { + map.Generator(Generators.Identity); + }); + + Version(authorization => authorization.Version, map => + { + map.Insert(true); + }); + + Property(authorization => authorization.Properties, map => + { + map.Length(10000); + }); + + Property(authorization => authorization.Scopes, map => + { + map.Length(10000); + }); + + Property(authorization => authorization.Status, map => + { + map.NotNullable(true); + }); + + Property(authorization => authorization.Subject, map => + { + map.NotNullable(true); + }); + + Property(authorization => authorization.Type, map => + { + map.NotNullable(true); + }); + + ManyToOne(authorization => authorization.Application, map => + { + map.ForeignKey("ApplicationId"); + }); + + Bag(authorization => authorization.Tokens, + map => + { + map.Key(key => key.Column("AuthorizationId")); + }, + map => + { + map.OneToMany(); + }); + + Table("OpenIddictAuthorizations"); + } + } +} diff --git a/src/OpenIddict.NHibernate/Mappings/OpenIddictScopeMapping.cs b/src/OpenIddict.NHibernate/Mappings/OpenIddictScopeMapping.cs new file mode 100644 index 00000000..04dd769e --- /dev/null +++ b/src/OpenIddict.NHibernate/Mappings/OpenIddictScopeMapping.cs @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.ComponentModel; +using NHibernate.Mapping.ByCode; +using NHibernate.Mapping.ByCode.Conformist; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Defines a relational mapping for the Scope entity. + /// + /// The type of the Scope entity. + /// The type of the Key entity. + [EditorBrowsable(EditorBrowsableState.Never)] + public class OpenIddictScopeMapping : ClassMapping + where TScope : OpenIddictScope + where TKey : IEquatable + { + public OpenIddictScopeMapping() + { + Id(scope => scope.Id, map => + { + map.Generator(Generators.Identity); + }); + + Version(scope => scope.Version, map => + { + map.Insert(true); + }); + + Property(scope => scope.Description, map => + { + map.Length(10000); + }); + + Property(scope => scope.DisplayName); + + Property(scope => scope.Name, map => + { + map.NotNullable(true); + map.Unique(true); + }); + + Property(scope => scope.Properties, map => + { + map.Length(10000); + }); + + Property(scope => scope.Resources, map => + { + map.Length(10000); + }); + + Table("OpenIddictScopes"); + } + } +} diff --git a/src/OpenIddict.NHibernate/Mappings/OpenIddictTokenMapping.cs b/src/OpenIddict.NHibernate/Mappings/OpenIddictTokenMapping.cs new file mode 100644 index 00000000..6bc9c798 --- /dev/null +++ b/src/OpenIddict.NHibernate/Mappings/OpenIddictTokenMapping.cs @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.ComponentModel; +using NHibernate.Mapping.ByCode; +using NHibernate.Mapping.ByCode.Conformist; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Defines a relational mapping for the Token entity. + /// + /// The type of the Token entity. + /// The type of the Application entity. + /// The type of the Authorization entity. + /// The type of the Key entity. + [EditorBrowsable(EditorBrowsableState.Never)] + public class OpenIddictTokenMapping : ClassMapping + where TToken : OpenIddictToken + where TApplication : OpenIddictApplication + where TAuthorization : OpenIddictAuthorization + where TKey : IEquatable + { + public OpenIddictTokenMapping() + { + Id(token => token.Id, map => + { + map.Generator(Generators.Identity); + }); + + Version(token => token.Version, map => + { + map.Insert(true); + }); + + Property(token => token.CreationDate); + + Property(token => token.ExpirationDate); + + Property(token => token.Payload, map => + { + map.Length(10000); + }); + + Property(token => token.Properties, map => + { + map.Length(10000); + }); + + Property(token => token.ReferenceId); + + Property(token => token.Status, map => + { + map.NotNullable(true); + }); + + Property(token => token.Subject, map => + { + map.NotNullable(true); + }); + + Property(token => token.Type, map => + { + map.NotNullable(true); + }); + + ManyToOne(token => token.Application, map => + { + map.Column("ApplicationId"); + }); + + ManyToOne(token => token.Authorization, map => + { + map.Column("AuthorizationId"); + }); + + Table("OpenIddictTokens"); + } + } +} diff --git a/src/OpenIddict.NHibernate/OpenIddict.NHibernate.csproj b/src/OpenIddict.NHibernate/OpenIddict.NHibernate.csproj new file mode 100644 index 00000000..cb2c7c6b --- /dev/null +++ b/src/OpenIddict.NHibernate/OpenIddict.NHibernate.csproj @@ -0,0 +1,29 @@ + + + + + + netstandard2.0 + + + + NHibernate 5.x stores for OpenIddict. + Kévin Chalet + aspnetcore;authentication;jwt;openidconnect;openiddict;security + + + + + + + + + + + + + + + + + diff --git a/src/OpenIddict.NHibernate/OpenIddictNHibernateBuilder.cs b/src/OpenIddict.NHibernate/OpenIddictNHibernateBuilder.cs new file mode 100644 index 00000000..5bb79dfe --- /dev/null +++ b/src/OpenIddict.NHibernate/OpenIddictNHibernateBuilder.cs @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.ComponentModel; +using JetBrains.Annotations; +using NHibernate; +using OpenIddict.Core; +using OpenIddict.NHibernate; +using OpenIddict.NHibernate.Models; + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Exposes the necessary methods required to configure the OpenIddict NHibernate services. + /// + public class OpenIddictNHibernateBuilder + { + /// + /// Initializes a new instance of . + /// + /// The services collection. + public OpenIddictNHibernateBuilder([NotNull] IServiceCollection services) + { + if (services == null) + { + throw new ArgumentNullException(nameof(services)); + } + + Services = services; + } + + /// + /// Gets the services collection. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public IServiceCollection Services { get; } + + /// + /// Amends the default OpenIddict NHibernate configuration. + /// + /// The delegate used to configure the OpenIddict options. + /// This extension can be safely called multiple times. + /// The . + public OpenIddictNHibernateBuilder Configure([NotNull] Action configuration) + { + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + Services.Configure(configuration); + + return this; + } + + /// + /// Configures the NHibernate stores to use the specified session factory + /// instead of retrieving it from the dependency injection container. + /// + /// The . + /// The . + public OpenIddictNHibernateBuilder UseSessionFactory([NotNull] ISessionFactory factory) + { + if (factory == null) + { + throw new ArgumentNullException(nameof(factory)); + } + + return Configure(options => options.SessionFactory = factory); + } + + /// + /// Configures OpenIddict to use the default OpenIddict Entity Framework entities, with the specified key type. + /// + /// The . + public OpenIddictNHibernateBuilder ReplaceDefaultEntities() + where TKey : IEquatable + => ReplaceDefaultEntities, + OpenIddictAuthorization, + OpenIddictScope, + OpenIddictToken, TKey>(); + + /// + /// Configures OpenIddict to use the specified entities, derived from the default OpenIddict Entity Framework entities. + /// + /// The . + public OpenIddictNHibernateBuilder ReplaceDefaultEntities() + where TApplication : OpenIddictApplication + where TAuthorization : OpenIddictAuthorization + where TScope : OpenIddictScope + where TToken : OpenIddictToken + where TKey : IEquatable + { + Services.Configure(options => + { + options.DefaultApplicationType = typeof(TApplication); + options.DefaultAuthorizationType = typeof(TAuthorization); + options.DefaultScopeType = typeof(TScope); + options.DefaultTokenType = typeof(TToken); + }); + + return this; + } + } +} diff --git a/src/OpenIddict.NHibernate/OpenIddictNHibernateContext.cs b/src/OpenIddict.NHibernate/OpenIddictNHibernateContext.cs new file mode 100644 index 00000000..92c84041 --- /dev/null +++ b/src/OpenIddict.NHibernate/OpenIddictNHibernateContext.cs @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using NHibernate; + +namespace OpenIddict.NHibernate +{ + /// + /// Exposes the NHibernate session used by the OpenIddict stores. + /// + public class OpenIddictNHibernateContext : IOpenIddictNHibernateContext, IDisposable + { + private readonly IOptionsMonitor _options; + private readonly IServiceProvider _provider; + private ISession _session; + + public OpenIddictNHibernateContext( + [NotNull] IOptionsMonitor options, + [NotNull] IServiceProvider provider) + { + _options = options; + _provider = provider; + } + + /// + /// Disposes the session held by this instance, if applicable. + /// + public void Dispose() => _session?.Dispose(); + + /// + /// Gets the . + /// + /// + /// A that can be used to monitor the + /// asynchronous operation, whose result returns the NHibernate session. + /// + /// + /// If a session factory was explicitly set in the OpenIddict NHibernate options, + /// a new session, specific to the OpenIddict stores is automatically opened + /// and disposed when the ambient scope is collected. If no session factory + /// was set, the session is retrieved from the dependency injection container + /// and a derived instance disabling automatic flush is managed by the context. + /// + public ValueTask GetSessionAsync(CancellationToken cancellationToken) + { + if (_session != null) + { + return new ValueTask(_session); + } + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + var options = _options.CurrentValue; + if (options == null) + { + throw new InvalidOperationException("The OpenIddict NHibernate options cannot be retrieved."); + } + + // Note: by default, NHibernate is natively configured to perform automatic flushes + // on queries when it determines stale data may be returned during their execution. + // Combined with implicit entity updates, this feature is inconvenient for OpenIddict + // as it may result in updated entities being persisted before they are explicitly + // validated by the core managers and marked as updated by the NHibernate stores. + // To ensure this doesn't interfere with OpenIddict, automatic flush is disabled. + + var factory = options.SessionFactory; + if (factory == null) + { + var session = _provider.GetService(); + if (session != null) + { + // If the flush mode is already set to manual, avoid creating a sub-session. + // If the session must be derived, all the parameters are inherited from + // the original session (except the flush mode, explicitly set to manual). + if (session.FlushMode != FlushMode.Manual) + { + session = _session = session.SessionWithOptions() + .AutoClose() + .AutoJoinTransaction() + .Connection() + .ConnectionReleaseMode() + .FlushMode(FlushMode.Manual) + .Interceptor() + .OpenSession(); + } + + return new ValueTask(session); + } + + factory = _provider.GetService(); + } + + if (factory == null) + { + throw new InvalidOperationException(new StringBuilder() + .AppendLine("No suitable NHibernate session or session factory can be found.") + .Append("To configure the OpenIddict NHibernate stores to use a specific factory, use ") + .Append("'services.AddOpenIddict().AddCore().UseNHibernate().UseSessionFactory()' or register an ") + .Append("'ISession'/'ISessionFactory' in the dependency injection container in 'ConfigureServices()'.") + .ToString()); + } + + else + { + var session = factory.OpenSession(); + session.FlushMode = FlushMode.Manual; + + return new ValueTask(_session = session); + } + } + } +} diff --git a/src/OpenIddict.NHibernate/OpenIddictNHibernateExtensions.cs b/src/OpenIddict.NHibernate/OpenIddictNHibernateExtensions.cs new file mode 100644 index 00000000..99ea9061 --- /dev/null +++ b/src/OpenIddict.NHibernate/OpenIddictNHibernateExtensions.cs @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using JetBrains.Annotations; +using Microsoft.Extensions.DependencyInjection.Extensions; +using OpenIddict.NHibernate; +using OpenIddict.NHibernate.Models; + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Exposes extensions allowing to register the OpenIddict NHibernate services. + /// + public static class OpenIddictNHibernateExtensions + { + /// + /// Registers the NHibernate stores services in the DI container and + /// configures OpenIddict to use the NHibernate entities by default. + /// + /// The services builder used by OpenIddict to register new services. + /// This extension can be safely called multiple times. + /// The . + public static OpenIddictNHibernateBuilder UseNHibernate([NotNull] this OpenIddictCoreBuilder builder) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + // Since NHibernate may be used with databases performing case-insensitive or + // culture-sensitive comparisons, ensure the additional filtering logic is enforced + // in case case-sensitive stores were registered before this extension was called. + builder.Configure(options => options.DisableAdditionalFiltering = false); + + builder.SetDefaultApplicationEntity() + .SetDefaultAuthorizationEntity() + .SetDefaultScopeEntity() + .SetDefaultTokenEntity(); + + builder.ReplaceApplicationStoreResolver() + .ReplaceAuthorizationStoreResolver() + .ReplaceScopeStoreResolver() + .ReplaceTokenStoreResolver(); + + builder.Services.TryAddScoped(typeof(OpenIddictApplicationStore<,,,>)); + builder.Services.TryAddScoped(typeof(OpenIddictAuthorizationStore<,,,>)); + builder.Services.TryAddScoped(typeof(OpenIddictScopeStore<,>)); + builder.Services.TryAddScoped(typeof(OpenIddictTokenStore<,,,>)); + + builder.Services.TryAddScoped(); + + return new OpenIddictNHibernateBuilder(builder.Services); + } + + /// + /// Registers the NHibernate stores services in the DI container and + /// configures OpenIddict to use the NHibernate entities by default. + /// + /// The services builder used by OpenIddict to register new services. + /// The configuration delegate used to configure the NHibernate services. + /// This extension can be safely called multiple times. + /// The . + public static OpenIddictCoreBuilder UseNHibernate( + [NotNull] this OpenIddictCoreBuilder builder, + [NotNull] Action configuration) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + configuration(builder.UseNHibernate()); + + return builder; + } + } +} diff --git a/src/OpenIddict.NHibernate/OpenIddictNHibernateHelpers.cs b/src/OpenIddict.NHibernate/OpenIddictNHibernateHelpers.cs new file mode 100644 index 00000000..3d6d39e9 --- /dev/null +++ b/src/OpenIddict.NHibernate/OpenIddictNHibernateHelpers.cs @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using JetBrains.Annotations; +using NHibernate.Mapping.ByCode; +using OpenIddict.NHibernate; +using OpenIddict.NHibernate.Models; + +namespace NHibernate.Cfg +{ + /// + /// Exposes extensions allowing to register the OpenIddict NHibernate mappings. + /// + public static class OpenIddictNHibernateHelpers + { + /// + /// Registers the OpenIddict entity mappings in the NHibernate + /// configuration using the default entities and the default key type. + /// + /// The NHibernate configuration builder. + /// The . + public static Configuration UseOpenIddict([NotNull] this Configuration configuration) + => configuration.UseOpenIddict(); + + /// + /// Registers the OpenIddict entity mappings in the NHibernate + /// configuration using the default entities and the specified key type. + /// + /// The NHibernate configuration builder. + /// The . + public static Configuration UseOpenIddict([NotNull] this Configuration configuration) + where TKey : IEquatable + => configuration.UseOpenIddict, + OpenIddictAuthorization, + OpenIddictScope, + OpenIddictToken, TKey>(); + + /// + /// Registers the OpenIddict entity mappings in the NHibernate + /// configuration using the specified entities and the specified key type. + /// + /// The NHibernate configuration builder. + /// The . + public static Configuration UseOpenIddict([NotNull] this Configuration configuration) + where TApplication : OpenIddictApplication + where TAuthorization : OpenIddictAuthorization + where TScope : OpenIddictScope + where TToken : OpenIddictToken + where TKey : IEquatable + { + if (configuration == null) + { + throw new ArgumentNullException(nameof(configuration)); + } + + var mapper = new ModelMapper(); + mapper.AddMapping>(); + mapper.AddMapping>(); + mapper.AddMapping>(); + mapper.AddMapping>(); + + configuration.AddMapping(mapper.CompileMappingForAllExplicitlyAddedEntities()); + + return configuration; + } + } +} diff --git a/src/OpenIddict.NHibernate/OpenIddictNHibernateOptions.cs b/src/OpenIddict.NHibernate/OpenIddictNHibernateOptions.cs new file mode 100644 index 00000000..5586c194 --- /dev/null +++ b/src/OpenIddict.NHibernate/OpenIddictNHibernateOptions.cs @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using NHibernate; + +namespace OpenIddict.NHibernate +{ + /// + /// Provides various settings needed to configure the OpenIddict NHibernate integration. + /// + public class OpenIddictNHibernateOptions + { + /// + /// Gets or sets the session factory used by the OpenIddict NHibernate stores. + /// If none is explicitly set, the session factory is resolved from the DI container. + /// + public ISessionFactory SessionFactory { get; set; } + } +} diff --git a/src/OpenIddict.NHibernate/Resolvers/OpenIddictApplicationStoreResolver.cs b/src/OpenIddict.NHibernate/Resolvers/OpenIddictApplicationStoreResolver.cs new file mode 100644 index 00000000..3d91cdad --- /dev/null +++ b/src/OpenIddict.NHibernate/Resolvers/OpenIddictApplicationStoreResolver.cs @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Concurrent; +using System.Text; +using JetBrains.Annotations; +using Microsoft.Extensions.DependencyInjection; +using OpenIddict.Abstractions; +using OpenIddict.Core; +using OpenIddict.Extensions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Exposes a method allowing to resolve an application store. + /// + public class OpenIddictApplicationStoreResolver : IOpenIddictApplicationStoreResolver + { + private static readonly ConcurrentDictionary _cache = new ConcurrentDictionary(); + private readonly IServiceProvider _provider; + + public OpenIddictApplicationStoreResolver([NotNull] IServiceProvider provider) + => _provider = provider; + + /// + /// Returns an application store compatible with the specified application type or throws an + /// if no store can be built using the specified type. + /// + /// The type of the Application entity. + /// An . + public IOpenIddictApplicationStore Get() where TApplication : class + { + var store = _provider.GetService>(); + if (store != null) + { + return store; + } + + var type = _cache.GetOrAdd(typeof(TApplication), key => + { + var root = OpenIddictHelpers.FindGenericBaseType(key, typeof(OpenIddictApplication<,,>)); + if (root == null) + { + throw new InvalidOperationException(new StringBuilder() + .AppendLine("The specified application type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictApplication' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictApplication' entity.") + .ToString()); + } + + return typeof(OpenIddictApplicationStore<,,,>).MakeGenericType( + /* TApplication: */ key, + /* TAuthorization: */ root.GenericTypeArguments[1], + /* TToken: */ root.GenericTypeArguments[2], + /* TKey: */ root.GenericTypeArguments[0]); + }); + + return (IOpenIddictApplicationStore) _provider.GetRequiredService(type); + } + } +} diff --git a/src/OpenIddict.NHibernate/Resolvers/OpenIddictAuthorizationStoreResolver.cs b/src/OpenIddict.NHibernate/Resolvers/OpenIddictAuthorizationStoreResolver.cs new file mode 100644 index 00000000..4960e224 --- /dev/null +++ b/src/OpenIddict.NHibernate/Resolvers/OpenIddictAuthorizationStoreResolver.cs @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Concurrent; +using System.Text; +using JetBrains.Annotations; +using Microsoft.Extensions.DependencyInjection; +using OpenIddict.Abstractions; +using OpenIddict.Core; +using OpenIddict.Extensions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Exposes a method allowing to resolve an authorization store. + /// + public class OpenIddictAuthorizationStoreResolver : IOpenIddictAuthorizationStoreResolver + { + private static readonly ConcurrentDictionary _cache = new ConcurrentDictionary(); + private readonly IServiceProvider _provider; + + public OpenIddictAuthorizationStoreResolver([NotNull] IServiceProvider provider) + => _provider = provider; + + /// + /// Returns an authorization store compatible with the specified authorization type or throws an + /// if no store can be built using the specified type. + /// + /// The type of the Authorization entity. + /// An . + public IOpenIddictAuthorizationStore Get() where TAuthorization : class + { + var store = _provider.GetService>(); + if (store != null) + { + return store; + } + + var type = _cache.GetOrAdd(typeof(TAuthorization), key => + { + var root = OpenIddictHelpers.FindGenericBaseType(key, typeof(OpenIddictAuthorization<,,>)); + if (root == null) + { + throw new InvalidOperationException(new StringBuilder() + .AppendLine("The specified authorization type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictAuthorization' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictAuthorization' entity.") + .ToString()); + } + + return typeof(OpenIddictAuthorizationStore<,,,>).MakeGenericType( + /* TAuthorization: */ key, + /* TApplication: */ root.GenericTypeArguments[1], + /* TToken: */ root.GenericTypeArguments[2], + /* TKey: */ root.GenericTypeArguments[0]); + }); + + return (IOpenIddictAuthorizationStore) _provider.GetRequiredService(type); + } + } +} diff --git a/src/OpenIddict.NHibernate/Resolvers/OpenIddictScopeStoreResolver.cs b/src/OpenIddict.NHibernate/Resolvers/OpenIddictScopeStoreResolver.cs new file mode 100644 index 00000000..cdf94b9c --- /dev/null +++ b/src/OpenIddict.NHibernate/Resolvers/OpenIddictScopeStoreResolver.cs @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Concurrent; +using System.Text; +using JetBrains.Annotations; +using Microsoft.Extensions.DependencyInjection; +using OpenIddict.Abstractions; +using OpenIddict.Core; +using OpenIddict.Extensions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Exposes a method allowing to resolve a scope store. + /// + public class OpenIddictScopeStoreResolver : IOpenIddictScopeStoreResolver + { + private static readonly ConcurrentDictionary _cache = new ConcurrentDictionary(); + private readonly IServiceProvider _provider; + + public OpenIddictScopeStoreResolver([NotNull] IServiceProvider provider) + => _provider = provider; + + /// + /// Returns a scope store compatible with the specified scope type or throws an + /// if no store can be built using the specified type. + /// + /// The type of the Scope entity. + /// An . + public IOpenIddictScopeStore Get() where TScope : class + { + var store = _provider.GetService>(); + if (store != null) + { + return store; + } + + var type = _cache.GetOrAdd(typeof(TScope), key => + { + var root = OpenIddictHelpers.FindGenericBaseType(key, typeof(OpenIddictScope<>)); + if (root == null) + { + throw new InvalidOperationException(new StringBuilder() + .AppendLine("The specified scope type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictScope' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictScope' entity.") + .ToString()); + } + + return typeof(OpenIddictScopeStore<,>).MakeGenericType( + /* TScope: */ key, + /* TKey: */ root.GenericTypeArguments[0]); + }); + + return (IOpenIddictScopeStore) _provider.GetRequiredService(type); + } + } +} diff --git a/src/OpenIddict.NHibernate/Resolvers/OpenIddictTokenStoreResolver.cs b/src/OpenIddict.NHibernate/Resolvers/OpenIddictTokenStoreResolver.cs new file mode 100644 index 00000000..aeb52f4e --- /dev/null +++ b/src/OpenIddict.NHibernate/Resolvers/OpenIddictTokenStoreResolver.cs @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Concurrent; +using System.Text; +using JetBrains.Annotations; +using Microsoft.Extensions.DependencyInjection; +using OpenIddict.Abstractions; +using OpenIddict.Core; +using OpenIddict.Extensions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Exposes a method allowing to resolve a token store. + /// + public class OpenIddictTokenStoreResolver : IOpenIddictTokenStoreResolver + { + private static readonly ConcurrentDictionary _cache = new ConcurrentDictionary(); + private readonly IServiceProvider _provider; + + public OpenIddictTokenStoreResolver([NotNull] IServiceProvider provider) + => _provider = provider; + + /// + /// Returns a token store compatible with the specified token type or throws an + /// if no store can be built using the specified type. + /// + /// The type of the Token entity. + /// An . + public IOpenIddictTokenStore Get() where TToken : class + { + var store = _provider.GetService>(); + if (store != null) + { + return store; + } + + var type = _cache.GetOrAdd(typeof(TToken), key => + { + var root = OpenIddictHelpers.FindGenericBaseType(key, typeof(OpenIddictToken<,,>)); + if (root == null) + { + throw new InvalidOperationException(new StringBuilder() + .AppendLine("The specified token type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictToken' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictToken' entity.") + .ToString()); + } + + return typeof(OpenIddictTokenStore<,,,>).MakeGenericType( + /* TToken: */ key, + /* TApplication: */ root.GenericTypeArguments[1], + /* TAuthorization: */ root.GenericTypeArguments[2], + /* TKey: */ root.GenericTypeArguments[0]); + }); + + return (IOpenIddictTokenStore) _provider.GetRequiredService(type); + } + } +} diff --git a/src/OpenIddict.NHibernate/Stores/OpenIddictApplicationStore.cs b/src/OpenIddict.NHibernate/Stores/OpenIddictApplicationStore.cs new file mode 100644 index 00000000..72579b40 --- /dev/null +++ b/src/OpenIddict.NHibernate/Stores/OpenIddictApplicationStore.cs @@ -0,0 +1,982 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Immutable; +using System.ComponentModel; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using NHibernate; +using NHibernate.Linq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Provides methods allowing to manage the applications stored in a database. + /// + public class OpenIddictApplicationStore : OpenIddictApplicationStore + { + public OpenIddictApplicationStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the applications stored in a database. + /// + /// The type of the entity primary keys. + public class OpenIddictApplicationStore : OpenIddictApplicationStore, + OpenIddictAuthorization, + OpenIddictToken, TKey> + where TKey : IEquatable + { + public OpenIddictApplicationStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the applications stored in a database. + /// + /// The type of the Application entity. + /// The type of the Authorization entity. + /// The type of the Token entity. + /// The type of the entity primary keys. + public class OpenIddictApplicationStore : IOpenIddictApplicationStore + where TApplication : OpenIddictApplication + where TAuthorization : OpenIddictAuthorization + where TToken : OpenIddictToken + where TKey : IEquatable + { + public OpenIddictApplicationStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + { + Cache = cache; + Context = context; + Options = options; + } + + /// + /// Gets the memory cache associated with the current store. + /// + protected IMemoryCache Cache { get; } + + /// + /// Gets the database context associated with the current store. + /// + protected IOpenIddictNHibernateContext Context { get; } + + /// + /// Gets the options associated with the current store. + /// + protected IOptionsMonitor Options { get; } + + /// + /// Determines the number of applications that exist in the database. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of applications in the database. + /// + public virtual async Task CountAsync(CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + return await session.Query().LongCountAsync(cancellationToken); + } + + /// + /// Determines the number of applications that match the specified query. + /// + /// The result type. + /// The query to execute. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of applications that match the specified query. + /// + public virtual async Task CountAsync([NotNull] Func, IQueryable> query, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await query(session.Query()).LongCountAsync(cancellationToken); + } + + /// + /// Creates a new application. + /// + /// The application to create. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task CreateAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + await session.PersistAsync(application, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + /// + /// Removes an existing application. + /// + /// The application to delete. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task DeleteAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + // Delete all the tokens associated with the application. + await (from authorization in session.Query() + where authorization.Application.Id.Equals(application.Id) + select authorization).DeleteAsync(cancellationToken); + + // Delete all the tokens associated with the application. + await (from token in session.Query() + where token.Application.Id.Equals(application.Id) + select token).DeleteAsync(cancellationToken); + + await session.DeleteAsync(application, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The application was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the application from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Retrieves an application using its client identifier. + /// + /// The client identifier associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the client application corresponding to the identifier. + /// + public virtual async Task FindByClientIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return await (from application in session.Query() + where application.ClientId == identifier + select application).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves an application using its unique identifier. + /// + /// The unique identifier associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the client application corresponding to the identifier. + /// + public virtual async Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await session.GetAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + /// + /// Retrieves all the applications associated with the specified post_logout_redirect_uri. + /// + /// The post_logout_redirect_uri associated with the applications. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, whose result + /// returns the client applications corresponding to the specified post_logout_redirect_uri. + /// + public virtual async Task> FindByPostLogoutRedirectUriAsync([NotNull] string address, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(address)) + { + throw new ArgumentException("The address cannot be null or empty.", nameof(address)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + // To optimize the efficiency of the query a bit, only applications whose stringified + // PostLogoutRedirectUris contains the specified URL are returned. Once the applications + // are retrieved, a second pass is made to ensure only valid elements are returned. + // Implementers that use this method in a hot path may want to override this method + // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. + var applications = await (from application in session.Query() + where application.PostLogoutRedirectUris.Contains(address) + select application).ToListAsync(cancellationToken); + + var builder = ImmutableArray.CreateBuilder(); + + foreach (var application in applications) + { + foreach (var uri in await GetPostLogoutRedirectUrisAsync(application, cancellationToken)) + { + // Note: the post_logout_redirect_uri must be compared + // using case-sensitive "Simple String Comparison". + if (string.Equals(uri, address, StringComparison.Ordinal)) + { + builder.Add(application); + + break; + } + } + } + + return builder.Count == builder.Capacity ? + builder.MoveToImmutable() : + builder.ToImmutable(); + } + + /// + /// Retrieves all the applications associated with the specified redirect_uri. + /// + /// The redirect_uri associated with the applications. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, whose result + /// returns the client applications corresponding to the specified redirect_uri. + /// + public virtual async Task> FindByRedirectUriAsync([NotNull] string address, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(address)) + { + throw new ArgumentException("The address cannot be null or empty.", nameof(address)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + // To optimize the efficiency of the query a bit, only applications whose stringified + // RedirectUris property contains the specified URL are returned. Once the applications + // are retrieved, a second pass is made to ensure only valid elements are returned. + // Implementers that use this method in a hot path may want to override this method + // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. + var applications = await (from application in session.Query() + where application.RedirectUris.Contains(address) + select application).ToListAsync(cancellationToken); + + var builder = ImmutableArray.CreateBuilder(); + + foreach (var application in applications) + { + foreach (var uri in await GetRedirectUrisAsync(application, cancellationToken)) + { + // Note: the redirect_uri must be compared using case-sensitive "Simple String Comparison". + // See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest for more information. + if (string.Equals(uri, address, StringComparison.Ordinal)) + { + builder.Add(application); + + break; + } + } + } + + return builder.Count == builder.Capacity ? + builder.MoveToImmutable() : + builder.ToImmutable(); + } + + /// + /// Executes the specified query and returns the first element. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the first element returned when executing the query. + /// + public virtual async Task GetAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await query(session.Query(), state).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves the client identifier associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the client identifier associated with the application. + /// + public virtual ValueTask GetClientIdAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + return new ValueTask(application.ClientId); + } + + /// + /// Retrieves the client secret associated with an application. + /// Note: depending on the manager used to create the application, + /// the client secret may be hashed for security reasons. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the client secret associated with the application. + /// + public virtual ValueTask GetClientSecretAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + return new ValueTask(application.ClientSecret); + } + + /// + /// Retrieves the client type associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the client type of the application (by default, "public"). + /// + public virtual ValueTask GetClientTypeAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + return new ValueTask(application.Type); + } + + /// + /// Retrieves the consent type associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the consent type of the application (by default, "explicit"). + /// + public virtual ValueTask GetConsentTypeAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + return new ValueTask(application.ConsentType); + } + + /// + /// Retrieves the display name associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the display name associated with the application. + /// + public virtual ValueTask GetDisplayNameAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + return new ValueTask(application.DisplayName); + } + + /// + /// Retrieves the unique identifier associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the unique identifier associated with the application. + /// + public virtual ValueTask GetIdAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + return new ValueTask(ConvertIdentifierToString(application.Id)); + } + + /// + /// Retrieves the permissions associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the permissions associated with the application. + /// + public virtual ValueTask> GetPermissionsAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (string.IsNullOrEmpty(application.Permissions)) + { + return new ValueTask>(ImmutableArray.Create()); + } + + // Note: parsing the stringified permissions is an expensive operation. + // To mitigate that, the resulting array is stored in the memory cache. + var key = string.Concat("0347e0aa-3a26-410a-97e8-a83bdeb21a1f", "\x1e", application.Permissions); + var permissions = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(application.Permissions) + .Select(element => (string) element) + .ToImmutableArray(); + }); + + return new ValueTask>(permissions); + } + + /// + /// Retrieves the logout callback addresses associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the post_logout_redirect_uri associated with the application. + /// + public virtual ValueTask> GetPostLogoutRedirectUrisAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (string.IsNullOrEmpty(application.PostLogoutRedirectUris)) + { + return new ValueTask>(ImmutableArray.Create()); + } + + // Note: parsing the stringified addresses is an expensive operation. + // To mitigate that, the resulting array is stored in the memory cache. + var key = string.Concat("fb14dfb9-9216-4b77-bfa9-7e85f8201ff4", "\x1e", application.PostLogoutRedirectUris); + var addresses = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(application.PostLogoutRedirectUris) + .Select(element => (string) element) + .ToImmutableArray(); + }); + + return new ValueTask>(addresses); + } + + /// + /// Retrieves the additional properties associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the additional properties associated with the application. + /// + public virtual ValueTask GetPropertiesAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (string.IsNullOrEmpty(application.Properties)) + { + return new ValueTask(new JObject()); + } + + return new ValueTask(JObject.Parse(application.Properties)); + } + + /// + /// Retrieves the callback addresses associated with an application. + /// + /// The application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the redirect_uri associated with the application. + /// + public virtual ValueTask> GetRedirectUrisAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (string.IsNullOrEmpty(application.RedirectUris)) + { + return new ValueTask>(ImmutableArray.Create()); + } + + // Note: parsing the stringified addresses is an expensive operation. + // To mitigate that, the resulting array is stored in the memory cache. + var key = string.Concat("851d6f08-2ee0-4452-bbe5-ab864611ecaa", "\x1e", application.RedirectUris); + var addresses = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(application.RedirectUris) + .Select(element => (string) element) + .ToImmutableArray(); + }); + + return new ValueTask>(addresses); + } + + /// + /// Instantiates a new application. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the instantiated application, that can be persisted in the database. + /// + public virtual ValueTask InstantiateAsync(CancellationToken cancellationToken) + { + try + { + return new ValueTask(Activator.CreateInstance()); + } + + catch (MemberAccessException exception) + { + return new ValueTask(Task.FromException( + new InvalidOperationException(new StringBuilder() + .AppendLine("An error occurred while trying to create a new application instance.") + .Append("Make sure that the application entity is not abstract and has a public parameterless constructor ") + .Append("or create a custom application store that overrides 'InstantiateAsync()' to use a custom factory.") + .ToString(), exception))); + } + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The number of results to return. + /// The number of results to skip. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [CanBeNull] int? count, [CanBeNull] int? offset, CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + var query = session.Query() + .OrderBy(application => application.Id) + .AsQueryable(); + + if (offset.HasValue) + { + query = query.Skip(offset.Value); + } + + if (count.HasValue) + { + query = query.Take(count.Value); + } + + return ImmutableArray.CreateRange(await query.ToListAsync(cancellationToken)); + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return ImmutableArray.CreateRange(await query(session.Query(), state).ToListAsync(cancellationToken)); + } + + /// + /// Sets the client identifier associated with an application. + /// + /// The application. + /// The client identifier associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetClientIdAsync([NotNull] TApplication application, + [CanBeNull] string identifier, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + application.ClientId = identifier; + + return Task.CompletedTask; + } + + /// + /// Sets the client secret associated with an application. + /// Note: depending on the manager used to create the application, + /// the client secret may be hashed for security reasons. + /// + /// The application. + /// The client secret associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetClientSecretAsync([NotNull] TApplication application, + [CanBeNull] string secret, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + application.ClientSecret = secret; + + return Task.CompletedTask; + } + + /// + /// Sets the client type associated with an application. + /// + /// The application. + /// The client type associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetClientTypeAsync([NotNull] TApplication application, + [CanBeNull] string type, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + application.Type = type; + + return Task.CompletedTask; + } + + /// + /// Sets the consent type associated with an application. + /// + /// The application. + /// The consent type associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetConsentTypeAsync([NotNull] TApplication application, + [CanBeNull] string type, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + application.ConsentType = type; + + return Task.CompletedTask; + } + + /// + /// Sets the display name associated with an application. + /// + /// The application. + /// The display name associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetDisplayNameAsync([NotNull] TApplication application, + [CanBeNull] string name, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + application.DisplayName = name; + + return Task.CompletedTask; + } + + /// + /// Sets the permissions associated with an application. + /// + /// The application. + /// The permissions associated with the application + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPermissionsAsync([NotNull] TApplication application, ImmutableArray permissions, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (permissions.IsDefaultOrEmpty) + { + application.Permissions = null; + + return Task.CompletedTask; + } + + application.Permissions = new JArray(permissions.ToArray()).ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the logout callback addresses associated with an application. + /// + /// The application. + /// The logout callback addresses associated with the application + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPostLogoutRedirectUrisAsync([NotNull] TApplication application, + ImmutableArray addresses, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (addresses.IsDefaultOrEmpty) + { + application.PostLogoutRedirectUris = null; + + return Task.CompletedTask; + } + + application.PostLogoutRedirectUris = new JArray(addresses.ToArray()).ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the additional properties associated with an application. + /// + /// The application. + /// The additional properties associated with the application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPropertiesAsync([NotNull] TApplication application, [CanBeNull] JObject properties, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (properties == null) + { + application.Properties = null; + + return Task.CompletedTask; + } + + application.Properties = properties.ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the callback addresses associated with an application. + /// + /// The application. + /// The callback addresses associated with the application + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetRedirectUrisAsync([NotNull] TApplication application, + ImmutableArray addresses, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + if (addresses.IsDefaultOrEmpty) + { + application.RedirectUris = null; + + return Task.CompletedTask; + } + + application.RedirectUris = new JArray(addresses.ToArray()).ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Updates an existing application. + /// + /// The application to update. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task UpdateAsync([NotNull] TApplication application, CancellationToken cancellationToken) + { + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + await session.UpdateAsync(application, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The application was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the application from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Converts the provided identifier to a strongly typed key object. + /// + /// The identifier to convert. + /// An instance of representing the provided identifier. + public virtual TKey ConvertIdentifierFromString([CanBeNull] string identifier) + { + if (string.IsNullOrEmpty(identifier)) + { + return default; + } + + return (TKey) TypeDescriptor.GetConverter(typeof(TKey)).ConvertFromInvariantString(identifier); + } + + /// + /// Converts the provided identifier to its string representation. + /// + /// The identifier to convert. + /// A representation of the provided identifier. + public virtual string ConvertIdentifierToString([CanBeNull] TKey identifier) + { + if (Equals(identifier, default(TKey))) + { + return null; + } + + return TypeDescriptor.GetConverter(typeof(TKey)).ConvertToInvariantString(identifier); + } + } +} \ No newline at end of file diff --git a/src/OpenIddict.NHibernate/Stores/OpenIddictAuthorizationStore.cs b/src/OpenIddict.NHibernate/Stores/OpenIddictAuthorizationStore.cs new file mode 100644 index 00000000..f96a3ecb --- /dev/null +++ b/src/OpenIddict.NHibernate/Stores/OpenIddictAuthorizationStore.cs @@ -0,0 +1,935 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Immutable; +using System.ComponentModel; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using NHibernate; +using NHibernate.Linq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Provides methods allowing to manage the authorizations stored in a database. + /// + public class OpenIddictAuthorizationStore : OpenIddictAuthorizationStore + { + public OpenIddictAuthorizationStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the authorizations stored in a database. + /// + /// The type of the entity primary keys. + public class OpenIddictAuthorizationStore : OpenIddictAuthorizationStore, + OpenIddictApplication, + OpenIddictToken, TKey> + where TKey : IEquatable + { + public OpenIddictAuthorizationStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the authorizations stored in a database. + /// + /// The type of the Authorization entity. + /// The type of the Application entity. + /// The type of the Token entity. + /// The type of the entity primary keys. + public class OpenIddictAuthorizationStore : IOpenIddictAuthorizationStore + where TAuthorization : OpenIddictAuthorization + where TApplication : OpenIddictApplication + where TToken : OpenIddictToken + where TKey : IEquatable + { + public OpenIddictAuthorizationStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + { + Cache = cache; + Context = context; + Options = options; + } + + /// + /// Gets the memory cache associated with the current store. + /// + protected IMemoryCache Cache { get; } + + /// + /// Gets the database context associated with the current store. + /// + protected IOpenIddictNHibernateContext Context { get; } + + /// + /// Gets the options associated with the current store. + /// + protected IOptionsMonitor Options { get; } + + /// + /// Determines the number of authorizations that exist in the database. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of authorizations in the database. + /// + public virtual async Task CountAsync(CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + return await session.Query().LongCountAsync(cancellationToken); + } + + /// + /// Determines the number of authorizations that match the specified query. + /// + /// The result type. + /// The query to execute. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of authorizations that match the specified query. + /// + public virtual async Task CountAsync([NotNull] Func, IQueryable> query, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await query(session.Query()).LongCountAsync(cancellationToken); + } + + /// + /// Creates a new authorization. + /// + /// The authorization to create. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task CreateAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + await session.SaveAsync(authorization, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + /// + /// Removes an existing authorization. + /// + /// The authorization to delete. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task DeleteAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + // Delete all the tokens associated with the authorization. + await (from token in session.Query() + where token.Authorization.Id.Equals(authorization.Id) + select token).DeleteAsync(cancellationToken); + + await session.DeleteAsync(authorization, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The authorization was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the authorization from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Retrieves the authorizations corresponding to the specified + /// subject and associated with the application identifier. + /// + /// The subject associated with the authorization. + /// The client associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the subject/client. + /// + public virtual async Task> FindAsync( + [NotNull] string subject, [NotNull] string client, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client cannot be null or empty.", nameof(client)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(client); + + return ImmutableArray.CreateRange( + await (from authorization in session.Query().Fetch(authorization => authorization.Application) + where authorization.Application != null && + authorization.Application.Id.Equals(key) && + authorization.Subject == subject + select authorization).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the authorizations matching the specified parameters. + /// + /// The subject associated with the authorization. + /// The client associated with the authorization. + /// The authorization status. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the criteria. + /// + public virtual async Task> FindAsync( + [NotNull] string subject, [NotNull] string client, + [NotNull] string status, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client cannot be null or empty.", nameof(client)); + } + + if (string.IsNullOrEmpty(status)) + { + throw new ArgumentException("The status cannot be null or empty.", nameof(status)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(client); + + return ImmutableArray.CreateRange( + await (from authorization in session.Query().Fetch(authorization => authorization.Application) + where authorization.Application != null && + authorization.Application.Id.Equals(key) && + authorization.Subject == subject && + authorization.Status == status + select authorization).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the authorizations matching the specified parameters. + /// + /// The subject associated with the authorization. + /// The client associated with the authorization. + /// The authorization status. + /// The authorization type. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the criteria. + /// + public virtual async Task> FindAsync( + [NotNull] string subject, [NotNull] string client, + [NotNull] string status, [NotNull] string type, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); + } + + if (string.IsNullOrEmpty(status)) + { + throw new ArgumentException("The status cannot be null or empty.", nameof(status)); + } + + if (string.IsNullOrEmpty(type)) + { + throw new ArgumentException("The type cannot be null or empty.", nameof(type)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(client); + + return ImmutableArray.CreateRange( + await (from authorization in session.Query().Fetch(authorization => authorization.Application) + where authorization.Application != null && + authorization.Application.Id.Equals(key) && + authorization.Subject == subject && + authorization.Status == status && + authorization.Type == type + select authorization).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the authorizations matching the specified parameters. + /// + /// The subject associated with the authorization. + /// The client associated with the authorization. + /// The authorization status. + /// The authorization type. + /// The minimal scopes associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the criteria. + /// + public virtual async Task> FindAsync( + [NotNull] string subject, [NotNull] string client, + [NotNull] string status, [NotNull] string type, + ImmutableArray scopes, CancellationToken cancellationToken) + { + var authorizations = await FindAsync(subject, client, status, type, cancellationToken); + if (authorizations.IsEmpty) + { + return ImmutableArray.Create(); + } + + var builder = ImmutableArray.CreateBuilder(authorizations.Length); + + foreach (var authorization in authorizations) + { + async Task HasScopesAsync() + => (await GetScopesAsync(authorization, cancellationToken)) + .ToImmutableHashSet(StringComparer.Ordinal) + .IsSupersetOf(scopes); + + if (await HasScopesAsync()) + { + builder.Add(authorization); + } + } + + return builder.Count == builder.Capacity ? + builder.MoveToImmutable() : + builder.ToImmutable(); + } + + /// + /// Retrieves the list of authorizations corresponding to the specified application identifier. + /// + /// The application identifier associated with the authorizations. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the specified application. + /// + public virtual async Task> FindByApplicationIdAsync( + [NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(identifier); + + return ImmutableArray.CreateRange( + await (from authorization in session.Query().Fetch(authorization => authorization.Application) + where authorization.Application != null && + authorization.Application.Id.Equals(key) + select authorization).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves an authorization using its unique identifier. + /// + /// The unique identifier associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorization corresponding to the identifier. + /// + public virtual async Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await session.GetAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + /// + /// Retrieves all the authorizations corresponding to the specified subject. + /// + /// The subject associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorizations corresponding to the specified subject. + /// + public virtual async Task> FindBySubjectAsync( + [NotNull] string subject, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return ImmutableArray.CreateRange( + await (from authorization in session.Query().Fetch(authorization => authorization.Application) + where authorization.Subject == subject + select authorization).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the optional application identifier associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the application identifier associated with the authorization. + /// + public virtual ValueTask GetApplicationIdAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + if (authorization.Application == null) + { + return new ValueTask(result: null); + } + + return new ValueTask(ConvertIdentifierToString(authorization.Application.Id)); + } + + /// + /// Executes the specified query and returns the first element. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the first element returned when executing the query. + /// + public virtual async Task GetAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return await query(session.Query() + .Fetch(authorization => authorization.Application), state).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves the unique identifier associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the unique identifier associated with the authorization. + /// + public virtual ValueTask GetIdAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + return new ValueTask(ConvertIdentifierToString(authorization.Id)); + } + + /// + /// Retrieves the additional properties associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the additional properties associated with the authorization. + /// + public virtual ValueTask GetPropertiesAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + if (string.IsNullOrEmpty(authorization.Properties)) + { + return new ValueTask(new JObject()); + } + + return new ValueTask(JObject.Parse(authorization.Properties)); + } + + /// + /// Retrieves the scopes associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scopes associated with the specified authorization. + /// + public virtual ValueTask> GetScopesAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + if (string.IsNullOrEmpty(authorization.Scopes)) + { + return new ValueTask>(ImmutableArray.Create()); + } + + return new ValueTask>(JArray.Parse(authorization.Scopes).Select(element => (string) element).ToImmutableArray()); + } + + /// + /// Retrieves the status associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the status associated with the specified authorization. + /// + public virtual ValueTask GetStatusAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + return new ValueTask(authorization.Status); + } + + /// + /// Retrieves the subject associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the subject associated with the specified authorization. + /// + public virtual ValueTask GetSubjectAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + return new ValueTask(authorization.Subject); + } + + /// + /// Retrieves the type associated with an authorization. + /// + /// The authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the type associated with the specified authorization. + /// + public virtual ValueTask GetTypeAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + return new ValueTask(authorization.Type); + } + + /// + /// Instantiates a new authorization. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the instantiated authorization, that can be persisted in the database. + /// + public virtual ValueTask InstantiateAsync(CancellationToken cancellationToken) + { + try + { + return new ValueTask(Activator.CreateInstance()); + } + + catch (MemberAccessException exception) + { + return new ValueTask(Task.FromException( + new InvalidOperationException(new StringBuilder() + .AppendLine("An error occurred while trying to create a new authorization instance.") + .Append("Make sure that the authorization entity is not abstract and has a public parameterless constructor ") + .Append("or create a custom authorization store that overrides 'InstantiateAsync()' to use a custom factory.") + .ToString(), exception))); + } + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The number of results to return. + /// The number of results to skip. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [CanBeNull] int? count, [CanBeNull] int? offset, CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + var query = session.Query() + .Fetch(authorization => authorization.Application) + .OrderBy(authorization => authorization.Id) + .AsQueryable(); + + if (offset.HasValue) + { + query = query.Skip(offset.Value); + } + + if (count.HasValue) + { + query = query.Take(count.Value); + } + + return ImmutableArray.CreateRange(await query.ToListAsync(cancellationToken)); + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return ImmutableArray.CreateRange(await query( + session.Query().Fetch(authorization => authorization.Application), state).ToListAsync(cancellationToken)); + } + + /// + /// Removes the authorizations that are marked as invalid and the ad-hoc ones that have no valid/nonexpired token attached. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task PruneAsync(CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + + await (from token in session.Query() + where token.Status != OpenIddictConstants.Statuses.Valid || + token.ExpirationDate < DateTimeOffset.UtcNow + select token).DeleteAsync(cancellationToken); + + await (from authorization in session.Query() + where authorization.Status != OpenIddictConstants.Statuses.Valid || + (authorization.Type == OpenIddictConstants.AuthorizationTypes.AdHoc && !authorization.Tokens.Any()) + select authorization).DeleteAsync(cancellationToken); + + await session.FlushAsync(cancellationToken); + } + + /// + /// Sets the application identifier associated with an authorization. + /// + /// The authorization. + /// The unique identifier associated with the client application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task SetApplicationIdAsync([NotNull] TAuthorization authorization, + [CanBeNull] string identifier, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + if (!string.IsNullOrEmpty(identifier)) + { + authorization.Application = await session.LoadAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + else + { + authorization.Application = null; + } + } + + /// + /// Sets the additional properties associated with an authorization. + /// + /// The authorization. + /// The additional properties associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPropertiesAsync([NotNull] TAuthorization authorization, [CanBeNull] JObject properties, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + if (properties == null) + { + authorization.Properties = null; + + return Task.CompletedTask; + } + + authorization.Properties = properties.ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the scopes associated with an authorization. + /// + /// The authorization. + /// The scopes associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetScopesAsync([NotNull] TAuthorization authorization, + ImmutableArray scopes, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + if (scopes.IsDefaultOrEmpty) + { + authorization.Scopes = null; + + return Task.CompletedTask; + } + + authorization.Scopes = new JArray(scopes.ToArray()).ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the status associated with an authorization. + /// + /// The authorization. + /// The status associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetStatusAsync([NotNull] TAuthorization authorization, + [CanBeNull] string status, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + authorization.Status = status; + + return Task.CompletedTask; + } + + /// + /// Sets the subject associated with an authorization. + /// + /// The authorization. + /// The subject associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetSubjectAsync([NotNull] TAuthorization authorization, + [CanBeNull] string subject, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + authorization.Subject = subject; + + return Task.CompletedTask; + } + + /// + /// Sets the type associated with an authorization. + /// + /// The authorization. + /// The type associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetTypeAsync([NotNull] TAuthorization authorization, + [CanBeNull] string type, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + authorization.Type = type; + + return Task.CompletedTask; + } + + /// + /// Updates an existing authorization. + /// + /// The authorization to update. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task UpdateAsync([NotNull] TAuthorization authorization, CancellationToken cancellationToken) + { + if (authorization == null) + { + throw new ArgumentNullException(nameof(authorization)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + await session.UpdateAsync(authorization, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The authorization was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the authorization from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Converts the provided identifier to a strongly typed key object. + /// + /// The identifier to convert. + /// An instance of representing the provided identifier. + public virtual TKey ConvertIdentifierFromString([CanBeNull] string identifier) + { + if (string.IsNullOrEmpty(identifier)) + { + return default; + } + + return (TKey) TypeDescriptor.GetConverter(typeof(TKey)).ConvertFromInvariantString(identifier); + } + + /// + /// Converts the provided identifier to its string representation. + /// + /// The identifier to convert. + /// A representation of the provided identifier. + public virtual string ConvertIdentifierToString([CanBeNull] TKey identifier) + { + if (Equals(identifier, default(TKey))) + { + return null; + } + + return TypeDescriptor.GetConverter(typeof(TKey)).ConvertToInvariantString(identifier); + } + } +} \ No newline at end of file diff --git a/src/OpenIddict.NHibernate/Stores/OpenIddictScopeStore.cs b/src/OpenIddict.NHibernate/Stores/OpenIddictScopeStore.cs new file mode 100644 index 00000000..0562f8ff --- /dev/null +++ b/src/OpenIddict.NHibernate/Stores/OpenIddictScopeStore.cs @@ -0,0 +1,712 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Immutable; +using System.ComponentModel; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using NHibernate; +using NHibernate.Linq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Provides methods allowing to manage the scopes stored in a database. + /// + public class OpenIddictScopeStore : OpenIddictScopeStore + { + public OpenIddictScopeStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the scopes stored in a database. + /// + /// The type of the entity primary keys. + public class OpenIddictScopeStore : OpenIddictScopeStore, TKey> + where TKey : IEquatable + { + public OpenIddictScopeStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the scopes stored in a database. + /// + /// The type of the Scope entity. + /// The type of the entity primary keys. + public class OpenIddictScopeStore : IOpenIddictScopeStore + where TScope : OpenIddictScope + where TKey : IEquatable + { + public OpenIddictScopeStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + { + Cache = cache; + Context = context; + Options = options; + } + + /// + /// Gets the memory cache associated with the current store. + /// + protected IMemoryCache Cache { get; } + + /// + /// Gets the database context associated with the current store. + /// + protected IOpenIddictNHibernateContext Context { get; } + + /// + /// Gets the options associated with the current store. + /// + protected IOptionsMonitor Options { get; } + + /// + /// Determines the number of scopes that exist in the database. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of scopes in the database. + /// + public virtual async Task CountAsync(CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + return await session.Query().LongCountAsync(cancellationToken); + } + + /// + /// Determines the number of scopes that match the specified query. + /// + /// The result type. + /// The query to execute. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of scopes that match the specified query. + /// + public virtual async Task CountAsync([NotNull] Func, IQueryable> query, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await query(session.Query()).LongCountAsync(cancellationToken); + } + + /// + /// Creates a new scope. + /// + /// The scope to create. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task CreateAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + await session.SaveAsync(scope, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + /// + /// Removes an existing scope. + /// + /// The scope to delete. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task DeleteAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + await session.DeleteAsync(scope, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The scope was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the scope from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Retrieves a scope using its unique identifier. + /// + /// The unique identifier associated with the scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scope corresponding to the identifier. + /// + public virtual async Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await session.GetAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + /// + /// Retrieves a scope using its name. + /// + /// The name associated with the scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scope corresponding to the specified name. + /// + public virtual async Task FindByNameAsync([NotNull] string name, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentException("The scope name cannot be null or empty.", nameof(name)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return await (from scope in session.Query() + where scope.Name == name + select scope).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves a list of scopes using their name. + /// + /// The names associated with the scopes. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scopes corresponding to the specified names. + /// + public virtual async Task> FindByNamesAsync( + ImmutableArray names, CancellationToken cancellationToken) + { + if (names.Any(name => string.IsNullOrEmpty(name))) + { + throw new ArgumentException("Scope names cannot be null or empty.", nameof(names)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return ImmutableArray.CreateRange( + await (from scope in session.Query() + where names.Contains(scope.Name) + select scope).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves all the scopes that contain the specified resource. + /// + /// The resource associated with the scopes. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the scopes associated with the specified resource. + /// + public virtual async Task> FindByResourceAsync( + [NotNull] string resource, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(resource)) + { + throw new ArgumentException("The resource cannot be null or empty.", nameof(resource)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + // To optimize the efficiency of the query a bit, only scopes whose stringified + // Resources column contains the specified resource are returned. Once the scopes + // are retrieved, a second pass is made to ensure only valid elements are returned. + // Implementers that use this method in a hot path may want to override this method + // to use SQL Server 2016 functions like JSON_VALUE to make the query more efficient. + var scopes = await (from scope in session.Query() + where scope.Resources.Contains(resource) + select scope).ToListAsync(cancellationToken); + + var builder = ImmutableArray.CreateBuilder(); + + foreach (var scope in scopes) + { + var resources = await GetResourcesAsync(scope, cancellationToken); + if (resources.Contains(resource, StringComparer.OrdinalIgnoreCase)) + { + builder.Add(scope); + } + } + + return builder.ToImmutable(); + } + + /// + /// Executes the specified query and returns the first element. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the first element returned when executing the query. + /// + public virtual async Task GetAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await query(session.Query(), state).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves the description associated with a scope. + /// + /// The scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the description associated with the specified scope. + /// + public virtual ValueTask GetDescriptionAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + return new ValueTask(scope.Description); + } + + /// + /// Retrieves the display name associated with a scope. + /// + /// The scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the display name associated with the scope. + /// + public virtual ValueTask GetDisplayNameAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + return new ValueTask(scope.DisplayName); + } + + /// + /// Retrieves the unique identifier associated with a scope. + /// + /// The scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the unique identifier associated with the scope. + /// + public virtual ValueTask GetIdAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + return new ValueTask(ConvertIdentifierToString(scope.Id)); + } + + /// + /// Retrieves the name associated with a scope. + /// + /// The scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the name associated with the specified scope. + /// + public virtual ValueTask GetNameAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + return new ValueTask(scope.Name); + } + + /// + /// Retrieves the additional properties associated with a scope. + /// + /// The scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the additional properties associated with the scope. + /// + public virtual ValueTask GetPropertiesAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + if (string.IsNullOrEmpty(scope.Properties)) + { + return new ValueTask(new JObject()); + } + + return new ValueTask(JObject.Parse(scope.Properties)); + } + + /// + /// Retrieves the resources associated with a scope. + /// + /// The scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the resources associated with the scope. + /// + public virtual ValueTask> GetResourcesAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + if (string.IsNullOrEmpty(scope.Resources)) + { + return new ValueTask>(ImmutableArray.Create()); + } + + // Note: parsing the stringified resources is an expensive operation. + // To mitigate that, the resulting array is stored in the memory cache. + var key = string.Concat("b6148250-aede-4fb9-a621-07c9bcf238c3", "\x1e", scope.Resources); + var resources = Cache.GetOrCreate(key, entry => + { + entry.SetPriority(CacheItemPriority.High) + .SetSlidingExpiration(TimeSpan.FromMinutes(1)); + + return JArray.Parse(scope.Resources) + .Select(element => (string) element) + .ToImmutableArray(); + }); + + return new ValueTask>(resources); + } + + /// + /// Instantiates a new scope. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the instantiated scope, that can be persisted in the database. + /// + public virtual ValueTask InstantiateAsync(CancellationToken cancellationToken) + { + try + { + return new ValueTask(Activator.CreateInstance()); + } + + catch (MemberAccessException exception) + { + return new ValueTask(Task.FromException( + new InvalidOperationException(new StringBuilder() + .AppendLine("An error occurred while trying to create a new scope instance.") + .Append("Make sure that the scope entity is not abstract and has a public parameterless constructor ") + .Append("or create a custom scope store that overrides 'InstantiateAsync()' to use a custom factory.") + .ToString(), exception))); + } + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The number of results to return. + /// The number of results to skip. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [CanBeNull] int? count, [CanBeNull] int? offset, CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + var query = session.Query() + .OrderBy(scope => scope.Id) + .AsQueryable(); + + if (offset.HasValue) + { + query = query.Skip(offset.Value); + } + + if (count.HasValue) + { + query = query.Take(count.Value); + } + + return ImmutableArray.CreateRange(await query.ToListAsync(cancellationToken)); + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return ImmutableArray.CreateRange(await query(session.Query(), state).ToListAsync(cancellationToken)); + } + + /// + /// Sets the description associated with a scope. + /// + /// The scope. + /// The description associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetDescriptionAsync([NotNull] TScope scope, [CanBeNull] string description, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + scope.Description = description; + + return Task.CompletedTask; + } + + /// + /// Sets the display name associated with a scope. + /// + /// The scope. + /// The display name associated with the scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetDisplayNameAsync([NotNull] TScope scope, [CanBeNull] string name, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + scope.DisplayName = name; + + return Task.CompletedTask; + } + + /// + /// Sets the name associated with a scope. + /// + /// The scope. + /// The name associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetNameAsync([NotNull] TScope scope, [CanBeNull] string name, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + scope.Name = name; + + return Task.CompletedTask; + } + + /// + /// Sets the additional properties associated with a scope. + /// + /// The scope. + /// The additional properties associated with the scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPropertiesAsync([NotNull] TScope scope, [CanBeNull] JObject properties, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + if (properties == null) + { + scope.Properties = null; + + return Task.CompletedTask; + } + + scope.Properties = properties.ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the resources associated with a scope. + /// + /// The scope. + /// The resources associated with the scope. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetResourcesAsync([NotNull] TScope scope, ImmutableArray resources, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + if (resources.IsDefaultOrEmpty) + { + scope.Resources = null; + + return Task.CompletedTask; + } + + scope.Resources = new JArray(resources.ToArray()).ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Updates an existing scope. + /// + /// The scope to update. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task UpdateAsync([NotNull] TScope scope, CancellationToken cancellationToken) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + await session.UpdateAsync(scope, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The scope was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the scope from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Converts the provided identifier to a strongly typed key object. + /// + /// The identifier to convert. + /// An instance of representing the provided identifier. + public virtual TKey ConvertIdentifierFromString([CanBeNull] string identifier) + { + if (string.IsNullOrEmpty(identifier)) + { + return default; + } + + return (TKey) TypeDescriptor.GetConverter(typeof(TKey)).ConvertFromInvariantString(identifier); + } + + /// + /// Converts the provided identifier to its string representation. + /// + /// The identifier to convert. + /// A representation of the provided identifier. + public virtual string ConvertIdentifierToString([CanBeNull] TKey identifier) + { + if (Equals(identifier, default(TKey))) + { + return null; + } + + return TypeDescriptor.GetConverter(typeof(TKey)).ConvertToInvariantString(identifier); + } + } +} \ No newline at end of file diff --git a/src/OpenIddict.NHibernate/Stores/OpenIddictTokenStore.cs b/src/OpenIddict.NHibernate/Stores/OpenIddictTokenStore.cs new file mode 100644 index 00000000..7cf481e8 --- /dev/null +++ b/src/OpenIddict.NHibernate/Stores/OpenIddictTokenStore.cs @@ -0,0 +1,1111 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Collections.Immutable; +using System.ComponentModel; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using JetBrains.Annotations; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using NHibernate; +using NHibernate.Linq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; + +namespace OpenIddict.NHibernate +{ + /// + /// Provides methods allowing to manage the tokens stored in a database. + /// + public class OpenIddictTokenStore : OpenIddictTokenStore + { + public OpenIddictTokenStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the tokens stored in a database. + /// + /// The type of the entity primary keys. + public class OpenIddictTokenStore : OpenIddictTokenStore, + OpenIddictApplication, + OpenIddictAuthorization, TKey> + where TKey : IEquatable + { + public OpenIddictTokenStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + : base(cache, context, options) + { + } + } + + /// + /// Provides methods allowing to manage the tokens stored in a database. + /// + /// The type of the Token entity. + /// The type of the Application entity. + /// The type of the Authorization entity. + /// The type of the entity primary keys. + public class OpenIddictTokenStore : IOpenIddictTokenStore + where TToken : OpenIddictToken + where TApplication : OpenIddictApplication + where TAuthorization : OpenIddictAuthorization + where TKey : IEquatable + { + public OpenIddictTokenStore( + [NotNull] IMemoryCache cache, + [NotNull] IOpenIddictNHibernateContext context, + [NotNull] IOptionsMonitor options) + { + Cache = cache; + Context = context; + Options = options; + } + + /// + /// Gets the memory cache associated with the current store. + /// + protected IMemoryCache Cache { get; } + + /// + /// Gets the database context associated with the current store. + /// + protected IOpenIddictNHibernateContext Context { get; } + + /// + /// Gets the options associated with the current store. + /// + protected IOptionsMonitor Options { get; } + + /// + /// Determines the number of tokens that exist in the database. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of applications in the database. + /// + public virtual async Task CountAsync(CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + return await session.Query().LongCountAsync(cancellationToken); + } + + /// + /// Determines the number of tokens that match the specified query. + /// + /// The result type. + /// The query to execute. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the number of tokens that match the specified query. + /// + public virtual async Task CountAsync([NotNull] Func, IQueryable> query, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await query(session.Query()).LongCountAsync(cancellationToken); + } + + /// + /// Creates a new token. + /// + /// The token to create. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task CreateAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + await session.SaveAsync(token, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + /// + /// Removes a token. + /// + /// The token to delete. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task DeleteAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + await session.DeleteAsync(token, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The token was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the token from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Retrieves the tokens corresponding to the specified + /// subject and associated with the application identifier. + /// + /// The subject associated with the token. + /// The client associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the subject/client. + /// + public virtual async Task> FindAsync([NotNull] string subject, + [NotNull] string client, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client cannot be null or empty.", nameof(client)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(client); + + return ImmutableArray.CreateRange( + await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.Application != null && + token.Application.Id.Equals(key) && + token.Subject == subject + select token).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the tokens matching the specified parameters. + /// + /// The subject associated with the token. + /// The client associated with the token. + /// The token status. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the criteria. + /// + public virtual async Task> FindAsync( + [NotNull] string subject, [NotNull] string client, + [NotNull] string status, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client cannot be null or empty.", nameof(client)); + } + + if (string.IsNullOrEmpty(status)) + { + throw new ArgumentException("The status cannot be null or empty.", nameof(status)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(client); + + return ImmutableArray.CreateRange( + await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.Application != null && + token.Application.Id.Equals(key) && + token.Subject == subject && + token.Status == status + select token).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the tokens matching the specified parameters. + /// + /// The subject associated with the token. + /// The client associated with the token. + /// The token status. + /// The token type. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the criteria. + /// + public virtual async Task> FindAsync( + [NotNull] string subject, [NotNull] string client, + [NotNull] string status, [NotNull] string type, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + if (string.IsNullOrEmpty(client)) + { + throw new ArgumentException("The client identifier cannot be null or empty.", nameof(client)); + } + + if (string.IsNullOrEmpty(status)) + { + throw new ArgumentException("The status cannot be null or empty.", nameof(status)); + } + + if (string.IsNullOrEmpty(type)) + { + throw new ArgumentException("The type cannot be null or empty.", nameof(type)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(client); + + return ImmutableArray.CreateRange( + await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.Application != null && + token.Application.Id.Equals(key) && + token.Subject == subject && + token.Status == status && + token.Type == type + select token).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the list of tokens corresponding to the specified application identifier. + /// + /// The application identifier associated with the tokens. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the specified application. + /// + public virtual async Task> FindByApplicationIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(identifier); + + return ImmutableArray.CreateRange( + await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.Application != null && + token.Application.Id.Equals(key) + select token).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the list of tokens corresponding to the specified authorization identifier. + /// + /// The authorization identifier associated with the tokens. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the specified authorization. + /// + public virtual async Task> FindByAuthorizationIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + var key = ConvertIdentifierFromString(identifier); + + return ImmutableArray.CreateRange( + await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.Authorization != null && + token.Authorization.Id.Equals(key) + select token).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves a token using its unique identifier. + /// + /// The unique identifier associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the token corresponding to the unique identifier. + /// + public virtual async Task FindByIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + return await session.GetAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + /// + /// Retrieves the list of tokens corresponding to the specified reference identifier. + /// Note: the reference identifier may be hashed or encrypted for security reasons. + /// + /// The reference identifier associated with the tokens. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the specified reference identifier. + /// + public virtual async Task FindByReferenceIdAsync([NotNull] string identifier, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(identifier)) + { + throw new ArgumentException("The identifier cannot be null or empty.", nameof(identifier)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.ReferenceId == identifier + select token).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves the list of tokens corresponding to the specified subject. + /// + /// The subject associated with the tokens. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the tokens corresponding to the specified subject. + /// + public virtual async Task> FindBySubjectAsync([NotNull] string subject, CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(subject)) + { + throw new ArgumentException("The subject cannot be null or empty.", nameof(subject)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return ImmutableArray.CreateRange( + await (from token in session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + where token.Subject == subject + select token).ToListAsync(cancellationToken)); + } + + /// + /// Retrieves the optional application identifier associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the application identifier associated with the token. + /// + public virtual ValueTask GetApplicationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + if (token.Application == null) + { + return new ValueTask(result: null); + } + + return new ValueTask(ConvertIdentifierToString(token.Application.Id)); + } + + /// + /// Executes the specified query and returns the first element. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the first element returned when executing the query. + /// + public virtual async Task GetAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return await query( + session.Query().Fetch(token => token.Application) + .Fetch(token => token.Authorization), state).FirstOrDefaultAsync(cancellationToken); + } + + /// + /// Retrieves the optional authorization identifier associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the authorization identifier associated with the token. + /// + public virtual ValueTask GetAuthorizationIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + if (token.Authorization == null) + { + return new ValueTask(result: null); + } + + return new ValueTask(ConvertIdentifierToString(token.Authorization.Id)); + } + + /// + /// Retrieves the creation date associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the creation date associated with the specified token. + /// + public virtual ValueTask GetCreationDateAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.CreationDate); + } + + /// + /// Retrieves the expiration date associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the expiration date associated with the specified token. + /// + public virtual ValueTask GetExpirationDateAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.ExpirationDate); + } + + /// + /// Retrieves the unique identifier associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the unique identifier associated with the token. + /// + public virtual ValueTask GetIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(ConvertIdentifierToString(token.Id)); + } + + /// + /// Retrieves the payload associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the payload associated with the specified token. + /// + public virtual ValueTask GetPayloadAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.Payload); + } + + /// + /// Retrieves the additional properties associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the additional properties associated with the token. + /// + public virtual ValueTask GetPropertiesAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + if (string.IsNullOrEmpty(token.Properties)) + { + return new ValueTask(new JObject()); + } + + return new ValueTask(JObject.Parse(token.Properties)); + } + + /// + /// Retrieves the reference identifier associated with a token. + /// Note: depending on the manager used to create the token, + /// the reference identifier may be hashed for security reasons. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the reference identifier associated with the specified token. + /// + public virtual ValueTask GetReferenceIdAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.ReferenceId); + } + + /// + /// Retrieves the status associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the status associated with the specified token. + /// + public virtual ValueTask GetStatusAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.Status); + } + + /// + /// Retrieves the subject associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the subject associated with the specified token. + /// + public virtual ValueTask GetSubjectAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.Subject); + } + + /// + /// Retrieves the token type associated with a token. + /// + /// The token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the token type associated with the specified token. + /// + public virtual ValueTask GetTypeAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + return new ValueTask(token.Type); + } + + /// + /// Instantiates a new token. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns the instantiated token, that can be persisted in the database. + /// + public virtual ValueTask InstantiateAsync(CancellationToken cancellationToken) + { + try + { + return new ValueTask(Activator.CreateInstance()); + } + + catch (MemberAccessException exception) + { + return new ValueTask(Task.FromException( + new InvalidOperationException(new StringBuilder() + .AppendLine("An error occurred while trying to create a new token instance.") + .Append("Make sure that the token entity is not abstract and has a public parameterless constructor ") + .Append("or create a custom token store that overrides 'InstantiateAsync()' to use a custom factory.") + .ToString(), exception))); + } + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The number of results to return. + /// The number of results to skip. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [CanBeNull] int? count, [CanBeNull] int? offset, CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + var query = session.Query() + .Fetch(token => token.Application) + .Fetch(token => token.Authorization) + .OrderBy(token => token.Id) + .AsQueryable(); + + if (offset.HasValue) + { + query = query.Skip(offset.Value); + } + + if (count.HasValue) + { + query = query.Take(count.Value); + } + + return ImmutableArray.CreateRange(await query.ToListAsync(cancellationToken)); + } + + /// + /// Executes the specified query and returns all the corresponding elements. + /// + /// The state type. + /// The result type. + /// The query to execute. + /// The optional state. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation, + /// whose result returns all the elements returned when executing the specified query. + /// + public virtual async Task> ListAsync( + [NotNull] Func, TState, IQueryable> query, + [CanBeNull] TState state, CancellationToken cancellationToken) + { + if (query == null) + { + throw new ArgumentNullException(nameof(query)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + return ImmutableArray.CreateRange(await query( + session.Query().Fetch(token => token.Application) + .Fetch(token => token.Authorization), state).ToListAsync(cancellationToken)); + } + + /// + /// Removes the tokens that are marked as expired or invalid. + /// + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task PruneAsync(CancellationToken cancellationToken) + { + var session = await Context.GetSessionAsync(cancellationToken); + + await (from token in session.Query() + where token.Status != OpenIddictConstants.Statuses.Valid || + token.ExpirationDate < DateTimeOffset.UtcNow + select token).DeleteAsync(cancellationToken); + + await session.FlushAsync(cancellationToken); + } + + /// + /// Sets the application identifier associated with a token. + /// + /// The token. + /// The unique identifier associated with the client application. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task SetApplicationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + if (!string.IsNullOrEmpty(identifier)) + { + token.Application = await session.LoadAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + else + { + token.Application = null; + } + } + + /// + /// Sets the authorization identifier associated with a token. + /// + /// The token. + /// The unique identifier associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task SetAuthorizationIdAsync([NotNull] TToken token, + [CanBeNull] string identifier, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + if (!string.IsNullOrEmpty(identifier)) + { + token.Authorization = await session.LoadAsync(ConvertIdentifierFromString(identifier), cancellationToken); + } + + else + { + token.Authorization = null; + } + } + + /// + /// Sets the creation date associated with a token. + /// + /// The token. + /// The creation date. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetCreationDateAsync([NotNull] TToken token, + [CanBeNull] DateTimeOffset? date, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.CreationDate = date; + + return Task.CompletedTask; + } + + /// + /// Sets the expiration date associated with a token. + /// + /// The token. + /// The expiration date. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetExpirationDateAsync([NotNull] TToken token, + [CanBeNull] DateTimeOffset? date, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.ExpirationDate = date; + + return Task.CompletedTask; + } + + /// + /// Sets the payload associated with a token. + /// + /// The token. + /// The payload associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPayloadAsync([NotNull] TToken token, [CanBeNull] string payload, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.Payload = payload; + + return Task.CompletedTask; + } + + /// + /// Sets the additional properties associated with a token. + /// + /// The token. + /// The additional properties associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetPropertiesAsync([NotNull] TToken token, [CanBeNull] JObject properties, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + if (properties == null) + { + token.Properties = null; + + return Task.CompletedTask; + } + + token.Properties = properties.ToString(Formatting.None); + + return Task.CompletedTask; + } + + /// + /// Sets the reference identifier associated with a token. + /// Note: depending on the manager used to create the token, + /// the reference identifier may be hashed for security reasons. + /// + /// The token. + /// The reference identifier associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetReferenceIdAsync([NotNull] TToken token, [CanBeNull] string identifier, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.ReferenceId = identifier; + + return Task.CompletedTask; + } + + /// + /// Sets the status associated with a token. + /// + /// The token. + /// The status associated with the authorization. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetStatusAsync([NotNull] TToken token, [CanBeNull] string status, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.Status = status; + + return Task.CompletedTask; + } + + /// + /// Sets the subject associated with a token. + /// + /// The token. + /// The subject associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetSubjectAsync([NotNull] TToken token, [CanBeNull] string subject, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.Subject = subject; + + return Task.CompletedTask; + } + + /// + /// Sets the token type associated with a token. + /// + /// The token. + /// The token type associated with the token. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual Task SetTypeAsync([NotNull] TToken token, [CanBeNull] string type, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + token.Type = type; + + return Task.CompletedTask; + } + + /// + /// Updates an existing token. + /// + /// The token to update. + /// The that can be used to abort the operation. + /// + /// A that can be used to monitor the asynchronous operation. + /// + public virtual async Task UpdateAsync([NotNull] TToken token, CancellationToken cancellationToken) + { + if (token == null) + { + throw new ArgumentNullException(nameof(token)); + } + + var session = await Context.GetSessionAsync(cancellationToken); + + try + { + await session.UpdateAsync(token, cancellationToken); + await session.FlushAsync(cancellationToken); + } + + catch (StaleObjectStateException exception) + { + throw new OpenIddictExceptions.ConcurrencyException(new StringBuilder() + .AppendLine("The token was concurrently updated and cannot be persisted in its current state.") + .Append("Reload the token from the database and retry the operation.") + .ToString(), exception); + } + } + + /// + /// Converts the provided identifier to a strongly typed key object. + /// + /// The identifier to convert. + /// An instance of representing the provided identifier. + public virtual TKey ConvertIdentifierFromString([CanBeNull] string identifier) + { + if (string.IsNullOrEmpty(identifier)) + { + return default; + } + + return (TKey) TypeDescriptor.GetConverter(typeof(TKey)).ConvertFromInvariantString(identifier); + } + + /// + /// Converts the provided identifier to its string representation. + /// + /// The identifier to convert. + /// A representation of the provided identifier. + public virtual string ConvertIdentifierToString([CanBeNull] TKey identifier) + { + if (Equals(identifier, default(TKey))) + { + return null; + } + + return TypeDescriptor.GetConverter(typeof(TKey)).ConvertToInvariantString(identifier); + } + } +} \ No newline at end of file diff --git a/test/OpenIddict.NHibernate.Tests/OpenIddict.NHibernate.Tests.csproj b/test/OpenIddict.NHibernate.Tests/OpenIddict.NHibernate.Tests.csproj new file mode 100644 index 00000000..755404e1 --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/OpenIddict.NHibernate.Tests.csproj @@ -0,0 +1,26 @@ + + + + + + netcoreapp2.0;net461 + netcoreapp2.0 + + + + + + + + + + + + + + + + + + + diff --git a/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateBuilderTests.cs b/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateBuilderTests.cs new file mode 100644 index 00000000..2ff20afd --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateBuilderTests.cs @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using NHibernate; +using OpenIddict.Core; +using OpenIddict.NHibernate.Models; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictNHibernateBuilderTests + { + [Fact] + public void Constructor_ThrowsAnExceptionForNullServices() + { + // Arrange + var services = (IServiceCollection) null; + + // Act and assert + var exception = Assert.Throws(() => new OpenIddictNHibernateBuilder(services)); + + Assert.Equal("services", exception.ParamName); + } + + [Fact] + public void ReplaceDefaultEntities_EntitiesAreCorrectlyReplaced() + { + // Arrange + var services = CreateServices(); + var builder = CreateBuilder(services); + + // Act + builder.ReplaceDefaultEntities(); + + // Assert + var provider = services.BuildServiceProvider(); + var options = provider.GetRequiredService>().CurrentValue; + + Assert.Equal(typeof(CustomApplication), options.DefaultApplicationType); + Assert.Equal(typeof(CustomAuthorization), options.DefaultAuthorizationType); + Assert.Equal(typeof(CustomScope), options.DefaultScopeType); + Assert.Equal(typeof(CustomToken), options.DefaultTokenType); + } + + [Fact] + public void ReplaceDefaultEntities_AllowsSpecifyingCustomKeyType() + { + // Arrange + var services = CreateServices(); + var builder = CreateBuilder(services); + + // Act + builder.ReplaceDefaultEntities(); + + // Assert + var provider = services.BuildServiceProvider(); + var options = provider.GetRequiredService>().CurrentValue; + + Assert.Equal(typeof(OpenIddictApplication), options.DefaultApplicationType); + Assert.Equal(typeof(OpenIddictAuthorization), options.DefaultAuthorizationType); + Assert.Equal(typeof(OpenIddictScope), options.DefaultScopeType); + Assert.Equal(typeof(OpenIddictToken), options.DefaultTokenType); + } + + [Fact] + public void UseSessionFactory_ThrowsAnExceptionForNullFactory() + { + // Arrange + var services = CreateServices(); + var builder = CreateBuilder(services); + + // Act and assert + var exception = Assert.Throws(delegate + { + return builder.UseSessionFactory(factory: null); + }); + + Assert.Equal("factory", exception.ParamName); + } + + [Fact] + public void UseSessionFactory_SetsDbContextTypeInOptions() + { + // Arrange + var services = CreateServices(); + var builder = CreateBuilder(services); + var factory = Mock.Of(); + + // Act + builder.UseSessionFactory(factory); + + // Assert + var provider = services.BuildServiceProvider(); + var options = provider.GetRequiredService>().CurrentValue; + + Assert.Same(factory, options.SessionFactory); + } + + private static OpenIddictNHibernateBuilder CreateBuilder(IServiceCollection services) + => services.AddOpenIddict().AddCore().UseNHibernate(); + + private static IServiceCollection CreateServices() + { + var services = new ServiceCollection(); + services.AddOptions(); + + return services; + } + + public class CustomApplication : OpenIddictApplication { } + public class CustomAuthorization : OpenIddictAuthorization { } + public class CustomScope : OpenIddictScope { } + public class CustomToken : OpenIddictToken { } + } +} diff --git a/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateContextTests.cs b/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateContextTests.cs new file mode 100644 index 00000000..293f8054 --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateContextTests.cs @@ -0,0 +1,240 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using NHibernate; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictNHibernateContextTests + { + [Fact] + public async Task GetSessionAsync_ThrowsAnExceptionForCanceledToken() + { + // Arrange + var services = new ServiceCollection(); + var provider = services.BuildServiceProvider(); + + var options = Mock.Of>(); + var token = new CancellationToken(canceled: true); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + var exception = await Assert.ThrowsAsync(async delegate + { + await context.GetSessionAsync(token); + }); + + Assert.Equal(token, exception.CancellationToken); + } + + [Fact] + public async Task GetSessionAsync_UsesSessionRegisteredInDependencyInjectionContainer() + { + // Arrange + var services = new ServiceCollection(); + + var session = new Mock(); + var factory = new Mock(); + + services.AddSingleton(session.Object); + services.AddSingleton(factory.Object); + + var provider = services.BuildServiceProvider(); + + var options = Mock.Of>( + mock => mock.CurrentValue == new OpenIddictNHibernateOptions + { + SessionFactory = null + }); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + Assert.Same(session.Object, await context.GetSessionAsync(CancellationToken.None)); + factory.Verify(mock => mock.OpenSession(), Times.Never()); + } + + [Theory] + [InlineData(FlushMode.Always)] + [InlineData(FlushMode.Auto)] + [InlineData(FlushMode.Commit)] + [InlineData(FlushMode.Unspecified)] + public async Task GetSessionAsync_CreatesSubSessionWhenFlushModeIsNotManual(FlushMode mode) + { + // Arrange + var services = new ServiceCollection(); + + var session = new Mock(); + session.SetupProperty(mock => mock.FlushMode, mode); + + var builder = new Mock(); + builder.Setup(mock => mock.AutoClose()) + .Returns(builder.Object); + builder.Setup(mock => mock.AutoJoinTransaction()) + .Returns(builder.Object); + builder.Setup(mock => mock.Connection()) + .Returns(builder.Object); + builder.Setup(mock => mock.ConnectionReleaseMode()) + .Returns(builder.Object); + builder.Setup(mock => mock.FlushMode(FlushMode.Manual)) + .Returns(builder.Object); + builder.Setup(mock => mock.Interceptor()) + .Returns(builder.Object); + builder.Setup(mock => mock.OpenSession()) + .Returns(session.Object); + + session.Setup(mock => mock.SessionWithOptions()) + .Returns(builder.Object); + + var factory = new Mock(); + + services.AddSingleton(session.Object); + services.AddSingleton(factory.Object); + + var provider = services.BuildServiceProvider(); + + var options = Mock.Of>( + mock => mock.CurrentValue == new OpenIddictNHibernateOptions + { + SessionFactory = null + }); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + Assert.Same(session.Object, await context.GetSessionAsync(CancellationToken.None)); + builder.Verify(mock => mock.AutoClose(), Times.Once()); + builder.Verify(mock => mock.AutoJoinTransaction(), Times.Once()); + builder.Verify(mock => mock.Connection(), Times.Once()); + builder.Verify(mock => mock.ConnectionReleaseMode(), Times.Once()); + builder.Verify(mock => mock.FlushMode(FlushMode.Manual), Times.Once()); + builder.Verify(mock => mock.Interceptor(), Times.Once()); + builder.Verify(mock => mock.OpenSession(), Times.Once()); + factory.Verify(mock => mock.OpenSession(), Times.Never()); + } + + [Fact] + public async Task GetSessionAsync_UsesSessionFactoryRegisteredInDependencyInjectionContainer() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of()); + + var session = new Mock(); + var factory = new Mock(); + factory.Setup(mock => mock.OpenSession()) + .Returns(session.Object); + + var provider = services.BuildServiceProvider(); + + var options = Mock.Of>( + mock => mock.CurrentValue == new OpenIddictNHibernateOptions + { + SessionFactory = factory.Object + }); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + Assert.Same(session.Object, await context.GetSessionAsync(CancellationToken.None)); + factory.Verify(mock => mock.OpenSession(), Times.Once()); + session.VerifySet(mock => mock.FlushMode = FlushMode.Manual, Times.Once()); + } + + [Fact] + public async Task GetSessionAsync_ThrowsAnExceptionWhenSessionFactoryCannotBeFound() + { + // Arrange + var services = new ServiceCollection(); + var provider = services.BuildServiceProvider(); + + var options = Mock.Of>( + mock => mock.CurrentValue == new OpenIddictNHibernateOptions + { + SessionFactory = null + }); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + var exception = await Assert.ThrowsAsync(async delegate + { + await context.GetSessionAsync(CancellationToken.None); + }); + + Assert.Equal(new StringBuilder() + .AppendLine("No suitable NHibernate session or session factory can be found.") + .Append("To configure the OpenIddict NHibernate stores to use a specific factory, use ") + .Append("'services.AddOpenIddict().AddCore().UseNHibernate().UseSessionFactory()' or register an ") + .Append("'ISession'/'ISessionFactory' in the dependency injection container in 'ConfigureServices()'.") + .ToString(), exception.Message); + } + + [Fact] + public async Task GetSessionAsync_PrefersSessionFactoryRegisteredInOptionsToSessionRegisteredInDependencyInjectionContainer() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of()); + + var session = new Mock(); + var factory = new Mock(); + factory.Setup(mock => mock.OpenSession()) + .Returns(session.Object); + + var provider = services.BuildServiceProvider(); + + var options = Mock.Of>( + mock => mock.CurrentValue == new OpenIddictNHibernateOptions + { + SessionFactory = factory.Object + }); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + Assert.Same(session.Object, await context.GetSessionAsync(CancellationToken.None)); + factory.Verify(mock => mock.OpenSession(), Times.Once()); + session.VerifySet(mock => mock.FlushMode = FlushMode.Manual, Times.Once()); + } + + [Fact] + public async Task GetSessionAsync_ReturnsCachedSession() + { + // Arrange + var services = new ServiceCollection(); + var provider = services.BuildServiceProvider(); + + var factory = new Mock(); + factory.Setup(mock => mock.OpenSession()) + .Returns(() => Mock.Of()); + + var options = Mock.Of>( + mock => mock.CurrentValue == new OpenIddictNHibernateOptions + { + SessionFactory = factory.Object + }); + + var context = new OpenIddictNHibernateContext(options, provider); + + // Act and assert + Assert.Same( + await context.GetSessionAsync(CancellationToken.None), + await context.GetSessionAsync(CancellationToken.None)); + + factory.Verify(mock => mock.OpenSession(), Times.Once()); + } + } +} diff --git a/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateExtensionsTests.cs b/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateExtensionsTests.cs new file mode 100644 index 00000000..50d06a32 --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/OpenIddictNHibernateExtensionsTests.cs @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using OpenIddict.Abstractions; +using OpenIddict.Core; +using OpenIddict.NHibernate.Models; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictNHibernateExtensionsTests + { + [Fact] + public void UseNHibernate_ThrowsAnExceptionForNullBuilder() + { + // Arrange + var builder = (OpenIddictCoreBuilder) null; + + // Act and assert + var exception = Assert.Throws(() => builder.UseNHibernate()); + + Assert.Equal("builder", exception.ParamName); + } + + [Fact] + public void UseNHibernate_ThrowsAnExceptionForNullConfiguration() + { + // Arrange + var services = new ServiceCollection(); + var builder = new OpenIddictCoreBuilder(services); + + // Act and assert + var exception = Assert.Throws(() => builder.UseNHibernate(configuration: null)); + + Assert.Equal("configuration", exception.ParamName); + } + + [Fact] + public void UseNHibernate_RegistersDefaultEntities() + { + // Arrange + var services = new ServiceCollection().AddOptions(); + var builder = new OpenIddictCoreBuilder(services); + + // Act + builder.UseNHibernate(); + + // Assert + var provider = services.BuildServiceProvider(); + var options = provider.GetRequiredService>().CurrentValue; + + Assert.Equal(typeof(OpenIddictApplication), options.DefaultApplicationType); + Assert.Equal(typeof(OpenIddictAuthorization), options.DefaultAuthorizationType); + Assert.Equal(typeof(OpenIddictScope), options.DefaultScopeType); + Assert.Equal(typeof(OpenIddictToken), options.DefaultTokenType); + } + + [Theory] + [InlineData(typeof(IOpenIddictApplicationStoreResolver), typeof(OpenIddictApplicationStoreResolver))] + [InlineData(typeof(IOpenIddictAuthorizationStoreResolver), typeof(OpenIddictAuthorizationStoreResolver))] + [InlineData(typeof(IOpenIddictScopeStoreResolver), typeof(OpenIddictScopeStoreResolver))] + [InlineData(typeof(IOpenIddictTokenStoreResolver), typeof(OpenIddictTokenStoreResolver))] + public void UseNHibernate_RegistersNHibernateStoreResolvers(Type serviceType, Type implementationType) + { + // Arrange + var services = new ServiceCollection(); + var builder = new OpenIddictCoreBuilder(services); + + // Act + builder.UseNHibernate(); + + // Assert + Assert.Contains(services, service => service.ServiceType == serviceType && + service.ImplementationType == implementationType); + } + + [Theory] + [InlineData(typeof(OpenIddictApplicationStore<,,,>))] + [InlineData(typeof(OpenIddictAuthorizationStore<,,,>))] + [InlineData(typeof(OpenIddictScopeStore<,>))] + [InlineData(typeof(OpenIddictTokenStore<,,,>))] + public void UseNHibernate_RegistersNHibernateStore(Type type) + { + // Arrange + var services = new ServiceCollection(); + var builder = new OpenIddictCoreBuilder(services); + + // Act + builder.UseNHibernate(); + + // Assert + Assert.Contains(services, service => service.ServiceType == type && service.ImplementationType == type); + } + } +} diff --git a/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictApplicationStoreResolverTests.cs b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictApplicationStoreResolverTests.cs new file mode 100644 index 00000000..0372021a --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictApplicationStoreResolverTests.cs @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Text; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictApplicationStoreResolverTests + { + [Fact] + public void Get_ReturnsCustomStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictApplicationStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + [Fact] + public void Get_ThrowsAnExceptionForInvalidEntityType() + { + // Arrange + var services = new ServiceCollection(); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictApplicationStoreResolver(provider); + + // Act and assert + var exception = Assert.Throws(() => resolver.Get()); + + Assert.Equal(new StringBuilder() + .AppendLine("The specified application type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictApplication' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictApplication' entity.") + .ToString(), exception.Message); + } + + [Fact] + public void Get_ReturnsDefaultStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + services.AddSingleton(CreateStore()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictApplicationStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + private static OpenIddictApplicationStore CreateStore() + => new Mock>( + Mock.Of(), + Mock.Of(), + Mock.Of>()).Object; + + public class CustomApplication { } + + public class MyApplication : OpenIddictApplication { } + public class MyAuthorization : OpenIddictAuthorization { } + public class MyScope : OpenIddictScope { } + public class MyToken : OpenIddictToken { } + } +} diff --git a/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictAuthorizationStoreResolverTests.cs b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictAuthorizationStoreResolverTests.cs new file mode 100644 index 00000000..1eb9b931 --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictAuthorizationStoreResolverTests.cs @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Text; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictAuthorizationStoreResolverTests + { + [Fact] + public void Get_ReturnsCustomStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictAuthorizationStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + [Fact] + public void Get_ThrowsAnExceptionForInvalidEntityType() + { + // Arrange + var services = new ServiceCollection(); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictAuthorizationStoreResolver(provider); + + // Act and assert + var exception = Assert.Throws(() => resolver.Get()); + + Assert.Equal(new StringBuilder() + .AppendLine("The specified authorization type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictAuthorization' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictAuthorization' entity.") + .ToString(), exception.Message); + } + + [Fact] + public void Get_ReturnsDefaultStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + services.AddSingleton(CreateStore()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictAuthorizationStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + private static OpenIddictAuthorizationStore CreateStore() + => new Mock>( + Mock.Of(), + Mock.Of(), + Mock.Of>()).Object; + + public class CustomAuthorization { } + + public class MyApplication : OpenIddictApplication { } + public class MyAuthorization : OpenIddictAuthorization { } + public class MyScope : OpenIddictScope { } + public class MyToken : OpenIddictToken { } + } +} diff --git a/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictScopeStoreResolverTests.cs b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictScopeStoreResolverTests.cs new file mode 100644 index 00000000..812dd97e --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictScopeStoreResolverTests.cs @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Text; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictScopeStoreResolverTests + { + [Fact] + public void Get_ReturnsCustomStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictScopeStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + [Fact] + public void Get_ThrowsAnExceptionForInvalidEntityType() + { + // Arrange + var services = new ServiceCollection(); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictScopeStoreResolver(provider); + + // Act and assert + var exception = Assert.Throws(() => resolver.Get()); + + Assert.Equal(new StringBuilder() + .AppendLine("The specified scope type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictScope' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictScope' entity.") + .ToString(), exception.Message); + } + + [Fact] + public void Get_ReturnsDefaultStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + services.AddSingleton(CreateStore()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictScopeStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + private static OpenIddictScopeStore CreateStore() + => new Mock>( + Mock.Of(), + Mock.Of(), + Mock.Of>()).Object; + + public class CustomScope { } + + public class MyApplication : OpenIddictApplication { } + public class MyAuthorization : OpenIddictAuthorization { } + public class MyScope : OpenIddictScope { } + public class MyToken : OpenIddictToken { } + } +} diff --git a/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictTokenStoreResolverTests.cs b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictTokenStoreResolverTests.cs new file mode 100644 index 00000000..f69f495c --- /dev/null +++ b/test/OpenIddict.NHibernate.Tests/Resolvers/OpenIddictTokenStoreResolverTests.cs @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + * See https://github.com/openiddict/openiddict-core for more information concerning + * the license and the contributors participating to this project. + */ + +using System; +using System.Text; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; +using OpenIddict.Abstractions; +using OpenIddict.NHibernate.Models; +using Xunit; + +namespace OpenIddict.NHibernate.Tests +{ + public class OpenIddictTokenStoreResolverTests + { + [Fact] + public void Get_ReturnsCustomStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictTokenStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + [Fact] + public void Get_ThrowsAnExceptionForInvalidEntityType() + { + // Arrange + var services = new ServiceCollection(); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictTokenStoreResolver(provider); + + // Act and assert + var exception = Assert.Throws(() => resolver.Get()); + + Assert.Equal(new StringBuilder() + .AppendLine("The specified token type is not compatible with the NHibernate stores.") + .Append("When enabling the NHibernate stores, make sure you use the built-in ") + .Append("'OpenIddictToken' entity (from the 'OpenIddict.NHibernate.Models' package) ") + .Append("or a custom entity that inherits from the generic 'OpenIddictToken' entity.") + .ToString(), exception.Message); + } + + [Fact] + public void Get_ReturnsDefaultStoreCorrespondingToTheSpecifiedTypeWhenAvailable() + { + // Arrange + var services = new ServiceCollection(); + services.AddSingleton(Mock.Of>()); + services.AddSingleton(CreateStore()); + + var provider = services.BuildServiceProvider(); + var resolver = new OpenIddictTokenStoreResolver(provider); + + // Act and assert + Assert.NotNull(resolver.Get()); + } + + private static OpenIddictTokenStore CreateStore() + => new Mock>( + Mock.Of(), + Mock.Of(), + Mock.Of>()).Object; + + public class CustomToken { } + + public class MyApplication : OpenIddictApplication { } + public class MyAuthorization : OpenIddictAuthorization { } + public class MyScope : OpenIddictScope { } + public class MyToken : OpenIddictToken { } + } +}