// ========================================================================== // Squidex Headless CMS // ========================================================================== // Copyright (c) Squidex UG (haftungsbeschränkt) // All rights reserved. Licensed under the MIT license. // ========================================================================== using System; using System.Collections.Generic; using System.Linq; using System.Security.Claims; using System.Threading.Tasks; using Microsoft.AspNetCore.Identity; using Squidex.Infrastructure; namespace Squidex.Domain.Users { public static class UserManagerExtensions { public static async Task GetUserWithClaimsAsync(this UserManager userManager, ClaimsPrincipal principal) { if (principal == null) { return null; } var user = await userManager.FindByIdWithClaimsAsync(userManager.GetUserId(principal)); return user; } public static async Task ResolveUserAsync(this UserManager userManager, IdentityUser user) { if (user == null) { return null; } var claims = await userManager.GetClaimsAsync(user); return new UserWithClaims(user, claims); } public static async Task FindByIdWithClaimsAsync(this UserManager userManager, string id) { if (id == null) { return null; } var user = await userManager.FindByIdAsync(id); return await userManager.ResolveUserAsync(user); } public static async Task FindByEmailWithClaimsAsyncAsync(this UserManager userManager, string email) { if (email == null) { return null; } var user = await userManager.FindByEmailAsync(email); return await userManager.ResolveUserAsync(user); } public static Task CountByEmailAsync(this UserManager userManager, string email = null) { var count = QueryUsers(userManager, email).LongCount(); return Task.FromResult(count); } public static async Task> QueryByEmailAsync(this UserManager userManager, string email = null, int take = 10, int skip = 0) { var users = QueryUsers(userManager, email).Skip(skip).Take(take).ToList(); var result = await userManager.ResolveUsersAsync(users); return result.ToList(); } public static Task ResolveUsersAsync(this UserManager userManager, IEnumerable users) { return Task.WhenAll(users.Select(async user => { return await userManager.ResolveUserAsync(user); })); } public static IQueryable QueryUsers(UserManager userManager, string email = null) { var result = userManager.Users; if (!string.IsNullOrWhiteSpace(email)) { var normalizedEmail = userManager.NormalizeKey(email); result = result.Where(x => x.NormalizedEmail.Contains(normalizedEmail)); } return result; } public static async Task CreateAsync(this UserManager userManager, IUserFactory factory, UserValues values) { var user = factory.Create(values.Email); try { await DoChecked(() => userManager.CreateAsync(user), "Cannot create user."); var claims = values.ToClaims().ToList(); if (claims.Count > 0) { await DoChecked(() => userManager.AddClaimsAsync(user, claims), "Cannot add user."); } if (!string.IsNullOrWhiteSpace(values.Password)) { await DoChecked(() => userManager.AddPasswordAsync(user, values.Password), "Cannot create user."); } } catch { await userManager.DeleteAsync(user); throw; } return user; } public static async Task UpdateAsync(this UserManager userManager, string id, UserValues values) { var user = await userManager.FindByIdAsync(id); if (user == null) { throw new DomainObjectNotFoundException(id, typeof(IdentityUser)); } await UpdateAsync(userManager, user, values); } public static async Task UpdateSafeAsync(this UserManager userManager, IdentityUser user, UserValues values) { try { await userManager.UpdateAsync(user, values); return IdentityResult.Success; } catch (ValidationException ex) { return IdentityResult.Failed(ex.Errors.Select(x => new IdentityError { Description = x.Message }).ToArray()); } } public static async Task UpdateAsync(this UserManager userManager, IdentityUser user, UserValues values) { if (user == null) { throw new DomainObjectNotFoundException("Id", typeof(IdentityUser)); } if (!string.IsNullOrWhiteSpace(values.Email) && values.Email != user.Email) { await DoChecked(() => userManager.SetEmailAsync(user, values.Email), "Cannot update email."); await DoChecked(() => userManager.SetUserNameAsync(user, values.Email), "Cannot update email."); } await DoChecked(() => userManager.SyncClaimsAsync(user, values.ToClaims().ToList()), "Cannot update user."); if (!string.IsNullOrWhiteSpace(values.Password)) { await DoChecked(() => userManager.RemovePasswordAsync(user), "Cannot replace password."); await DoChecked(() => userManager.AddPasswordAsync(user, values.Password), "Cannot replace password."); } } public static async Task LockAsync(this UserManager userManager, string id) { var user = await userManager.FindByIdAsync(id); if (user == null) { throw new DomainObjectNotFoundException(id, typeof(IdentityUser)); } await DoChecked(() => userManager.SetLockoutEndDateAsync(user, DateTimeOffset.UtcNow.AddYears(100)), "Cannot lock user."); } public static async Task UnlockAsync(this UserManager userManager, string id) { var user = await userManager.FindByIdAsync(id); if (user == null) { throw new DomainObjectNotFoundException(id, typeof(IdentityUser)); } await DoChecked(() => userManager.SetLockoutEndDateAsync(user, null), "Cannot unlock user."); } private static async Task DoChecked(Func> action, string message) { var result = await action(); if (!result.Succeeded) { throw new ValidationException(message, result.Errors.Select(x => new ValidationError(x.Description)).ToArray()); } } public static async Task SyncClaimsAsync(this UserManager userManager, IdentityUser user, IEnumerable claims) { if (claims.Any()) { var oldClaims = await userManager.GetClaimsAsync(user); var oldClaimsToRemove = new List(); foreach (var oldClaim in oldClaims) { if (claims.Any(x => x.Type == oldClaim.Type)) { oldClaimsToRemove.Add(oldClaim); } } if (oldClaimsToRemove.Count > 0) { var result = await userManager.RemoveClaimsAsync(user, oldClaimsToRemove); if (!result.Succeeded) { return result; } } return await userManager.AddClaimsAsync(user, claims); } return IdentityResult.Success; } } }