// ========================================================================== // Squidex Headless CMS // ========================================================================== // Copyright (c) Squidex UG (haftungsbeschränkt) // All rights reserved. Licensed under the MIT license. // ========================================================================== using System; using System.Collections.Generic; using System.Linq; namespace Squidex.Infrastructure { public static class CollectionExtensions { public static bool SetEquals(this IReadOnlyCollection source, IReadOnlyCollection other) { return source.Count == other.Count && source.Intersect(other).Count() == other.Count; } public static bool SetEquals(this IReadOnlyCollection source, IReadOnlyCollection other, IEqualityComparer comparer) { return source.Count == other.Count && source.Intersect(other, comparer).Count() == other.Count; } public static IResultList SortSet(this IResultList input, Func idProvider, IReadOnlyList ids) where T : class { return ResultList.Create(input.Total, SortList(input, idProvider, ids)); } public static IEnumerable SortList(this IEnumerable input, Func idProvider, IReadOnlyList ids) where T : class { return ids.Select(id => input.FirstOrDefault(x => Equals(idProvider(x), id))).NotNull(); } public static IEnumerable Duplicates(this IEnumerable input) { return input.GroupBy(x => x).Where(x => x.Count() > 1).Select(x => x.Key); } public static int IndexOf(this IEnumerable input, Func predicate) { var i = 0; foreach (var item in input) { if (predicate(item)) { return i; } i++; } return -1; } public static IEnumerable Duplicates(this IEnumerable input, Func selector) { return input.GroupBy(selector).Where(x => x.Count() > 1).Select(x => x.Key); } public static void AddRange(this ICollection target, IEnumerable source) { foreach (var value in source) { target.Add(value); } } public static IEnumerable Shuffle(this IEnumerable enumerable) { var random = new Random(); return enumerable.OrderBy(x => random.Next()).ToList(); } public static IEnumerable OrEmpty(this IEnumerable? source) { return source ?? Enumerable.Empty(); } public static IEnumerable NotNull(this IEnumerable source) where T : class { return source.Where(x => x != null)!; } public static IEnumerable Concat(this IEnumerable source, T value) { return source.Concat(Enumerable.Repeat(value, 1)); } public static TResult[] Map(this T[] value, Func convert) { var result = new TResult[value.Length]; for (var i = 0; i < value.Length; i++) { result[i] = convert(value[i]); } return result; } public static int SequentialHashCode(this IEnumerable collection) { return collection.SequentialHashCode(EqualityComparer.Default); } public static int SequentialHashCode(this IEnumerable collection, IEqualityComparer comparer) { var hashCode = 17; foreach (var item in collection) { if (!Equals(item, null)) { hashCode = (hashCode * 23) + comparer.GetHashCode(item); } } return hashCode; } public static int OrderedHashCode(this IEnumerable collection) where T : notnull { return collection.OrderedHashCode(EqualityComparer.Default); } public static int OrderedHashCode(this IEnumerable collection, IEqualityComparer comparer) where T : notnull { Guard.NotNull(comparer, nameof(comparer)); var hashCodes = collection.Where(x => !Equals(x, null)).Select(x => x.GetHashCode()).OrderBy(x => x).ToArray(); var hashCode = 17; foreach (var code in hashCodes) { hashCode = (hashCode * 23) + code; } return hashCode; } public static int DictionaryHashCode(this IReadOnlyDictionary dictionary) where TKey : notnull { return DictionaryHashCode(dictionary, EqualityComparer.Default, EqualityComparer.Default); } public static int DictionaryHashCode(this IReadOnlyDictionary dictionary, IEqualityComparer keyComparer, IEqualityComparer valueComparer) where TKey : notnull { var hashCode = 17; foreach (var (key, value) in dictionary.OrderBy(x => x.Key)) { hashCode = (hashCode * 23) + keyComparer.GetHashCode(key); if (!Equals(value, null)) { hashCode = (hashCode * 23) + valueComparer.GetHashCode(value); } } return hashCode; } public static bool EqualsDictionary(this IReadOnlyDictionary dictionary, IReadOnlyDictionary other) where TKey : notnull { return EqualsDictionary(dictionary, other, EqualityComparer.Default, EqualityComparer.Default); } public static bool EqualsDictionary(this IReadOnlyDictionary dictionary, IReadOnlyDictionary other, IEqualityComparer keyComparer, IEqualityComparer valueComparer) where TKey : notnull { if (other == null) { return false; } if (dictionary.Count != other.Count) { return false; } var comparer = new KeyValuePairComparer(keyComparer, valueComparer); return !dictionary.Except(other, comparer).Any(); } public static Dictionary ToDictionary(this IReadOnlyDictionary dictionary) where TKey : notnull { return dictionary.ToDictionary(x => x.Key, x => x.Value); } public static TValue GetOrDefault(this IReadOnlyDictionary dictionary, TKey key) where TKey : notnull { return dictionary.GetOrCreate(key, _ => default!); } public static TValue GetOrAddDefault(this IDictionary dictionary, TKey key) where TKey : notnull { return dictionary.GetOrAdd(key, _ => default!); } public static TValue GetOrNew(this IReadOnlyDictionary dictionary, TKey key) where TKey : notnull where TValue : class, new() { return dictionary.GetOrCreate(key, _ => new TValue()); } public static TValue GetOrAddNew(this IDictionary dictionary, TKey key) where TKey : notnull where TValue : class, new() { return dictionary.GetOrAdd(key, _ => new TValue()); } public static TValue GetOrCreate(this IReadOnlyDictionary dictionary, TKey key, Func creator) where TKey : notnull { if (!dictionary.TryGetValue(key, out var result)) { result = creator(key); } return result; } public static TValue GetOrAdd(this IDictionary dictionary, TKey key, TValue fallback) where TKey : notnull { if (!dictionary.TryGetValue(key, out var result)) { result = fallback; dictionary.Add(key, result); } return result; } public static TValue GetOrAdd(this IDictionary dictionary, TKey key, Func creator) where TKey : notnull { if (!dictionary.TryGetValue(key, out var result)) { result = creator(key); dictionary.Add(key, result); } return result; } public static TValue GetOrAdd(this IDictionary dictionary, TKey key, TContext context, Func creator) where TKey : notnull { if (!dictionary.TryGetValue(key, out var result)) { result = creator(key, context); dictionary.Add(key, result); } return result; } public static void Foreach(this IEnumerable collection, Action action) { var index = 0; foreach (var item in collection) { action(item, index); index++; } } public sealed class KeyValuePairComparer : IEqualityComparer> where TKey : notnull { private readonly IEqualityComparer keyComparer; private readonly IEqualityComparer valueComparer; public KeyValuePairComparer(IEqualityComparer keyComparer, IEqualityComparer valueComparer) { this.keyComparer = keyComparer; this.valueComparer = valueComparer; } public bool Equals(KeyValuePair x, KeyValuePair y) { return keyComparer.Equals(x.Key, y.Key) && valueComparer.Equals(x.Value, y.Value); } public int GetHashCode(KeyValuePair obj) { return keyComparer.GetHashCode(obj.Key) ^ ValueHashCode(obj); } private int ValueHashCode(KeyValuePair obj) { if (Equals(obj.Value, null)) { return 0; } return valueComparer.GetHashCode(obj.Value); } } } }