diff --git a/framework/src/Volo.Abp.AspNetCore.SignalR/Volo/Abp/AspNetCore/SignalR/Authentication/AbpAuthenticationHubFilter.cs b/framework/src/Volo.Abp.AspNetCore.SignalR/Volo/Abp/AspNetCore/SignalR/Authentication/AbpAuthenticationHubFilter.cs index baac1b538f..c7b8c029cf 100644 --- a/framework/src/Volo.Abp.AspNetCore.SignalR/Volo/Abp/AspNetCore/SignalR/Authentication/AbpAuthenticationHubFilter.cs +++ b/framework/src/Volo.Abp.AspNetCore.SignalR/Volo/Abp/AspNetCore/SignalR/Authentication/AbpAuthenticationHubFilter.cs @@ -13,7 +13,9 @@ public class AbpAuthenticationHubFilter : IHubFilter public virtual async ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) { var currentPrincipalAccessor = invocationContext.ServiceProvider.GetRequiredService(); - using (currentPrincipalAccessor.Change((await GetDynamicClaimsPrincipalAsync(invocationContext.Context.User, invocationContext.ServiceProvider))!)) + var claimsPrincipal = invocationContext.Context.User; + await HandleDynamicClaimsPrincipalAsync(claimsPrincipal, invocationContext.ServiceProvider, invocationContext.Context); + using (currentPrincipalAccessor.Change(claimsPrincipal!)) { return await next(invocationContext); } @@ -22,7 +24,9 @@ public class AbpAuthenticationHubFilter : IHubFilter public virtual async Task OnConnectedAsync(HubLifetimeContext context, Func next) { var currentPrincipalAccessor = context.ServiceProvider.GetRequiredService(); - using (currentPrincipalAccessor.Change((await GetDynamicClaimsPrincipalAsync(context.Context.User, context.ServiceProvider))!)) + var claimsPrincipal = context.Context.User; + await HandleDynamicClaimsPrincipalAsync(claimsPrincipal, context.ServiceProvider, context.Context); + using (currentPrincipalAccessor.Change(claimsPrincipal!)) { await next(context); } @@ -31,27 +35,29 @@ public class AbpAuthenticationHubFilter : IHubFilter public virtual async Task OnDisconnectedAsync(HubLifetimeContext context, Exception? exception, Func next) { var currentPrincipalAccessor = context.ServiceProvider.GetRequiredService(); - using (currentPrincipalAccessor.Change((await GetDynamicClaimsPrincipalAsync(context.Context.User, context.ServiceProvider))!)) + var claimsPrincipal = context.Context.User; + await HandleDynamicClaimsPrincipalAsync(claimsPrincipal, context.ServiceProvider, context.Context); + using (currentPrincipalAccessor.Change(claimsPrincipal!)) { await next(context, exception); } } - protected virtual async Task GetDynamicClaimsPrincipalAsync(ClaimsPrincipal? claimsPrincipal, IServiceProvider serviceProvider) + protected virtual async Task HandleDynamicClaimsPrincipalAsync(ClaimsPrincipal? claimsPrincipal, IServiceProvider serviceProvider, HubCallerContext hubCallerContext) { - if (claimsPrincipal == null) - { - return claimsPrincipal; - } - - if (claimsPrincipal.Identity != null && + if (claimsPrincipal?.Identity != null && claimsPrincipal.Identity.IsAuthenticated && serviceProvider.GetRequiredService>().Value.IsDynamicClaimsEnabled) { - var abpClaimsPrincipalFactory = serviceProvider.GetRequiredService(); - claimsPrincipal = await abpClaimsPrincipalFactory.CreateDynamicAsync(claimsPrincipal); - } + claimsPrincipal = claimsPrincipal.Identity is ClaimsIdentity identity + ? new ClaimsPrincipal(new ClaimsIdentity(claimsPrincipal.Claims, claimsPrincipal.Identity.AuthenticationType, identity.NameClaimType, identity.RoleClaimType)) + : new ClaimsPrincipal(new ClaimsIdentity(claimsPrincipal.Claims, claimsPrincipal.Identity.AuthenticationType)); - return claimsPrincipal; + claimsPrincipal = await serviceProvider.GetRequiredService().CreateDynamicAsync(claimsPrincipal); + if (claimsPrincipal.Identity?.IsAuthenticated == false) + { + hubCallerContext.Abort(); + } + } } }