diff --git a/Avalonia.Desktop.slnf b/Avalonia.Desktop.slnf index 35eeab33ac..b32d6e3d70 100644 --- a/Avalonia.Desktop.slnf +++ b/Avalonia.Desktop.slnf @@ -30,15 +30,16 @@ "src\\Avalonia.MicroCom\\Avalonia.MicroCom.csproj", "src\\Avalonia.Native\\Avalonia.Native.csproj", "src\\Avalonia.OpenGL\\Avalonia.OpenGL.csproj", - "src\\Avalonia.Vulkan\\Avalonia.Vulkan.csproj", "src\\Avalonia.ReactiveUI\\Avalonia.ReactiveUI.csproj", "src\\Avalonia.Remote.Protocol\\Avalonia.Remote.Protocol.csproj", "src\\Avalonia.Themes.Fluent\\Avalonia.Themes.Fluent.csproj", "src\\Avalonia.Themes.Simple\\Avalonia.Themes.Simple.csproj", + "src\\Avalonia.Vulkan\\Avalonia.Vulkan.csproj", "src\\Avalonia.X11\\Avalonia.X11.csproj", "src\\Headless\\Avalonia.Headless.Vnc\\Avalonia.Headless.Vnc.csproj", "src\\Headless\\Avalonia.Headless\\Avalonia.Headless.csproj", "src\\Linux\\Avalonia.LinuxFramebuffer\\Avalonia.LinuxFramebuffer.csproj", + "src\\Linux\\Tmds.DBus.Protocol\\Tmds.DBus.Protocol.csproj", "src\\Linux\\Tmds.DBus.SourceGenerator\\Tmds.DBus.SourceGenerator.csproj", "src\\Markup\\Avalonia.Markup.Xaml.Loader\\Avalonia.Markup.Xaml.Loader.csproj", "src\\Markup\\Avalonia.Markup.Xaml\\Avalonia.Markup.Xaml.csproj", diff --git a/Avalonia.sln b/Avalonia.sln index 395a6fe559..5e98d26cc6 100644 --- a/Avalonia.sln +++ b/Avalonia.sln @@ -304,6 +304,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Avalonia.RenderTests.WpfCom EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tmds.DBus.SourceGenerator", "src\Linux\Tmds.DBus.SourceGenerator\Tmds.DBus.SourceGenerator.csproj", "{12AE6CBC-C0A1-4BEF-AED6-81E566AAE7EB}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tmds.DBus.Protocol", "src\Linux\Tmds.DBus.Protocol\Tmds.DBus.Protocol.csproj", "{9A7672D7-77B9-4D50-AF22-6C5049C4712B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -708,6 +710,10 @@ Global {12AE6CBC-C0A1-4BEF-AED6-81E566AAE7EB}.Debug|Any CPU.Build.0 = Debug|Any CPU {12AE6CBC-C0A1-4BEF-AED6-81E566AAE7EB}.Release|Any CPU.ActiveCfg = Release|Any CPU {12AE6CBC-C0A1-4BEF-AED6-81E566AAE7EB}.Release|Any CPU.Build.0 = Release|Any CPU + {9A7672D7-77B9-4D50-AF22-6C5049C4712B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9A7672D7-77B9-4D50-AF22-6C5049C4712B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9A7672D7-77B9-4D50-AF22-6C5049C4712B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9A7672D7-77B9-4D50-AF22-6C5049C4712B}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -796,6 +802,7 @@ Global {DA5F1FF9-4259-4C54-B443-85CFA226EE6A} = {9CCA131B-DE95-4D44-8788-C3CAE28574CD} {9AE1B827-21AC-4063-AB22-C8804B7F931E} = {C5A00AC3-B34C-4564-9BDD-2DA473EF4D8B} {12AE6CBC-C0A1-4BEF-AED6-81E566AAE7EB} = {86C53C40-57AA-45B8-AD42-FAE0EFDF0F2B} + {9A7672D7-77B9-4D50-AF22-6C5049C4712B} = {86C53C40-57AA-45B8-AD42-FAE0EFDF0F2B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {87366D66-1391-4D90-8999-95A620AD786A} diff --git a/src/Avalonia.FreeDesktop/Avalonia.FreeDesktop.csproj b/src/Avalonia.FreeDesktop/Avalonia.FreeDesktop.csproj index 536d9238a2..e45b11fec1 100644 --- a/src/Avalonia.FreeDesktop/Avalonia.FreeDesktop.csproj +++ b/src/Avalonia.FreeDesktop/Avalonia.FreeDesktop.csproj @@ -13,12 +13,12 @@ - + @@ -51,8 +51,6 @@ - - diff --git a/src/Linux/Tmds.DBus.Protocol/ActionException.cs b/src/Linux/Tmds.DBus.Protocol/ActionException.cs new file mode 100644 index 0000000000..48b9225306 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ActionException.cs @@ -0,0 +1,17 @@ +namespace Tmds.DBus.Protocol; + +public static class ActionException +{ + // Exception used when the IDisposable returned by AddMatchAsync gets disposed. + public static bool IsObserverDisposed(Exception exception) + => object.ReferenceEquals(exception, DBusConnection.ObserverDisposedException); + + // Exception used when the Connection gets disposed. + public static bool IsConnectionDisposed(Exception exception) + // note: Connection.DisposedException is only ever used as an InnerException of DisconnectedException, + // so we directly check for that. + => object.ReferenceEquals(exception?.InnerException, Connection.DisposedException); + + public static bool IsDisposed(Exception exception) + => IsObserverDisposed(exception) || IsConnectionDisposed(exception); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Address.cs b/src/Linux/Tmds.DBus.Protocol/Address.cs new file mode 100644 index 0000000000..c43b7ca38e --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Address.cs @@ -0,0 +1,220 @@ +using System.IO.MemoryMappedFiles; + +namespace Tmds.DBus.Protocol; + +public static class Address +{ + private static bool _systemAddressResolved = false; + private static string? _systemAddress = null; + private static bool _sessionAddressResolved = false; + private static string? _sessionAddress = null; + + public static string? System + { + get + { + if (_systemAddressResolved) + { + return _systemAddress; + } + + _systemAddress = Environment.GetEnvironmentVariable("DBUS_SYSTEM_BUS_ADDRESS"); + + if (string.IsNullOrEmpty(_systemAddress) && !PlatformDetection.IsWindows()) + { + _systemAddress = "unix:path=/var/run/dbus/system_bus_socket"; + } + + _systemAddressResolved = true; + return _systemAddress; + } + } + + public static string? Session + { + get + { + if (_sessionAddressResolved) + { + return _sessionAddress; + } + + _sessionAddress = Environment.GetEnvironmentVariable("DBUS_SESSION_BUS_ADDRESS"); + + if (string.IsNullOrEmpty(_sessionAddress)) + { + if (PlatformDetection.IsWindows()) + { + _sessionAddress = GetSessionBusAddressFromSharedMemory(); + } + else + { + _sessionAddress = GetSessionBusAddressFromX11(); + } + } + + _sessionAddressResolved = true; + return _sessionAddress; + } + } + + private static string? GetSessionBusAddressFromX11() + { + if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("DISPLAY"))) + { + var display = XOpenDisplay(null); + if (display == IntPtr.Zero) + { + return null; + } + string username; + unsafe + { + const int BufLen = 1024; + byte* stackBuf = stackalloc byte[BufLen]; + Passwd passwd; + IntPtr result; + getpwuid_r(getuid(), out passwd, stackBuf, BufLen, out result); + if (result != IntPtr.Zero) + { + username = Marshal.PtrToStringAnsi(passwd.Name)!; + } + else + { + return null; + } + } + var machineId = DBusEnvironment.MachineId.Replace("-", string.Empty); + var selectionName = $"_DBUS_SESSION_BUS_SELECTION_{username}_{machineId}"; + var selectionAtom = XInternAtom(display, selectionName, false); + if (selectionAtom == IntPtr.Zero) + { + return null; + } + var owner = XGetSelectionOwner(display, selectionAtom); + if (owner == IntPtr.Zero) + { + return null; + } + var addressAtom = XInternAtom(display, "_DBUS_SESSION_BUS_ADDRESS", false); + if (addressAtom == IntPtr.Zero) + { + return null; + } + + IntPtr actualReturnType; + IntPtr actualReturnFormat; + IntPtr nrItemsReturned; + IntPtr bytesAfterReturn; + IntPtr propReturn; + + int rv = XGetWindowProperty(display, owner, addressAtom, 0, 1024, false, (IntPtr)31 /* XA_STRING */, + out actualReturnType, out actualReturnFormat, out nrItemsReturned, out bytesAfterReturn, out propReturn); + string? address = rv == 0 ? Marshal.PtrToStringAnsi(propReturn) : null; + if (propReturn != IntPtr.Zero) + { + XFree(propReturn); + } + + XCloseDisplay(display); + + return address; + } + else + { + return null; + } + } + + private static string? GetSessionBusAddressFromSharedMemory() + { + string? result = ReadSharedMemoryString("DBusDaemonAddressInfo", 255); + if (string.IsNullOrEmpty(result)) + { + result = ReadSharedMemoryString("DBusDaemonAddressInfoDebug", 255); + } + return result; + } + + private static string? ReadSharedMemoryString(string id, long maxlen = -1) + { + if (!PlatformDetection.IsWindows()) + { + return null; + } + MemoryMappedFile? shmem; + try + { + shmem = MemoryMappedFile.OpenExisting(id); + } + catch + { + shmem = null; + } + if (shmem == null) + { + return null; + } + + MemoryMappedViewStream s = shmem.CreateViewStream(); + long len = s.Length; + if (maxlen >= 0 && len > maxlen) + { + len = maxlen; + } + if (len == 0) + { + return string.Empty; + } + if (len > int.MaxValue) + { + len = int.MaxValue; + } + byte[] bytes = new byte[len]; + int count = s.Read(bytes, 0, (int)len); + if (count <= 0) + { + return string.Empty; + } + + count = 0; + while (count < len && bytes[count] != 0) + { + count++; + } + + return Encoding.UTF8.GetString(bytes, 0, count); + } + + struct Passwd + { + public IntPtr Name; + public IntPtr Password; + public uint UserID; + public uint GroupID; + public IntPtr UserInfo; + public IntPtr HomeDir; + public IntPtr Shell; + } + + [DllImport("libc")] + private static extern unsafe int getpwuid_r(uint uid, out Passwd pwd, byte* buf, int bufLen, out IntPtr result); + [DllImport("libc")] + private static extern uint getuid(); + + [DllImport("libX11")] + private static extern IntPtr XOpenDisplay(string? name); + [DllImport("libX11")] + private static extern int XCloseDisplay(IntPtr display); + [DllImport("libX11")] + private static extern IntPtr XInternAtom(IntPtr display, string atom_name, bool only_if_exists); + [DllImport("libX11")] + private static extern int XGetWindowProperty(IntPtr display, IntPtr w, IntPtr property, + int long_offset, int long_length, bool delete, IntPtr req_type, + out IntPtr actual_type_return, out IntPtr actual_format_return, + out IntPtr nitems_return, out IntPtr bytes_after_return, out IntPtr prop_return); + [DllImport("libX11")] + private static extern int XFree(IntPtr data); + [DllImport("libX11")] + private static extern IntPtr XGetSelectionOwner(IntPtr display, IntPtr Atom); +} diff --git a/src/Linux/Tmds.DBus.Protocol/AddressReader.cs b/src/Linux/Tmds.DBus.Protocol/AddressReader.cs new file mode 100644 index 0000000000..7ccf3fa168 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/AddressReader.cs @@ -0,0 +1,191 @@ +namespace Tmds.DBus.Protocol; + +static class AddressParser +{ + public struct AddressEntry + { + internal string String { get; } + internal int Offset { get; } + internal int Count { get; } + + internal AddressEntry(string s, int offset, int count) => + (String, Offset, Count) = (s, offset, count); + + internal ReadOnlySpan AsSpan() => String.AsSpan(Offset, Count); + + public override string ToString() => AsSpan().AsString(); + } + + public static bool TryGetNextEntry(string addresses, ref AddressEntry address) + { + int offset = address.String is null ? 0 : address.Offset + address.Count + 1; + if (offset >= addresses.Length - 1) + { + return false; + } + ReadOnlySpan span = addresses.AsSpan().Slice(offset); + int length = span.IndexOf(';'); + if (length == -1) + { + length = span.Length; + } + address = new AddressEntry(addresses, offset, length); + return true; + } + + public static bool IsType(AddressEntry address, string type) + { + ReadOnlySpan span = address.AsSpan(); + return span.Length > type.Length && span[type.Length] == ':' && span.StartsWith(type.AsSpan()); + } + + public static void ParseTcpProperties(AddressEntry address, out string host, out int? port, out Guid guid) + { + host = null!; + port = null; + guid = default; + ReadOnlySpan properties = GetProperties(address); + while (TryParseProperty(ref properties, out ReadOnlySpan key, out ReadOnlySpan value)) + { + if (key.SequenceEqual("host".AsSpan())) + { + host = Unescape(value); + } + else if (key.SequenceEqual("port".AsSpan())) + { + port = int.Parse(Unescape(value)); + } + else if (key.SequenceEqual("guid".AsSpan())) + { + guid = Guid.ParseExact(Unescape(value), "N"); + } + } + if (host is null) + { + host = "localhost"; + } + } + + public static void ParseUnixProperties(AddressEntry address, out string path, out Guid guid) + { + path = null!; + bool isAbstract = false; + guid = default; + ReadOnlySpan properties = GetProperties(address); + while (TryParseProperty(ref properties, out ReadOnlySpan key, out ReadOnlySpan value)) + { + if (key.SequenceEqual("path".AsSpan())) + { + path = Unescape(value); + } + else if (key.SequenceEqual("abstract".AsSpan())) + { + isAbstract = true; + path = Unescape(value); + } + else if (key.SequenceEqual("guid".AsSpan())) + { + guid = Guid.ParseExact(Unescape(value), "N"); + } + } + if (string.IsNullOrEmpty(path)) + { + throw new ArgumentException("path"); + } + if (isAbstract) + { + path = (char)'\0' + path; + } + } + + private static ReadOnlySpan GetProperties(AddressEntry address) + { + ReadOnlySpan span = address.AsSpan(); + int colonPos = span.IndexOf(':'); + if (colonPos == -1) + { + throw new FormatException("No colon found."); + } + return span.Slice(colonPos + 1); + } + + public static bool TryParseProperty(ref ReadOnlySpan properties, out ReadOnlySpan key, out ReadOnlySpan value) + { + if (properties.Length == 0) + { + key = default; + value = default; + return false; + } + int end = properties.IndexOf(','); + ReadOnlySpan property; + if (end == -1) + { + property = properties; + properties = default; + } + else + { + property = properties.Slice(0, end); + properties = properties.Slice(end + 1); + } + int equalPos = property.IndexOf('='); + if (equalPos == -1) + { + throw new FormatException("No equals sign found."); + } + key = property.Slice(0, equalPos); + value = property.Slice(equalPos + 1); + return true; + } + + private static string Unescape(ReadOnlySpan value) + { + if (!value.Contains("%".AsSpan(), StringComparison.Ordinal)) + { + return value.AsString(); + } + Span unescaped = stackalloc char[Constants.StackAllocCharThreshold]; + int pos = 0; + for (int i = 0; i < value.Length;) + { + char c = value[i++]; + if (c != '%') + { + unescaped[pos++] = c; + } + else if (i + 2 < value.Length) + { + int a = FromHexChar(value[i++]); + int b = FromHexChar(value[i++]); + if (a == -1 || b == -1) + { + throw new FormatException("Invalid hex char."); + } + unescaped[pos++] = (char)((a << 4) + b); + } + else + { + throw new FormatException("Escape sequence is too short."); + } + } + return unescaped.Slice(0, pos).AsString(); + + static int FromHexChar(char c) + { + if (c >= '0' && c <= '9') + { + return c - '0'; + } + if (c >= 'A' && c <= 'F') + { + return c - 'A' + 10; + } + if (c >= 'a' && c <= 'f') + { + return c - 'a' + 10; + } + return -1; + } + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Array.cs b/src/Linux/Tmds.DBus.Protocol/Array.cs new file mode 100644 index 0000000000..42ab3d5fcf --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Array.cs @@ -0,0 +1,88 @@ +using System.Collections; + +namespace Tmds.DBus.Protocol; + +// Using obsolete generic write members +#pragma warning disable CS0618 + +public sealed class Array : IDBusWritable, IList + where T : notnull +{ + private readonly List _values; + + public Array() : + this(new List()) + { } + + public Array(int capacity) : + this(new List(capacity)) + { } + + public Array(IEnumerable collection) : + this(new List(collection)) + { } + + private Array(List values) + { + TypeModel.EnsureSupportedVariantType(); + _values = values; + } + + public void Add(T item) + => _values.Add(item); + + public void Clear() + => _values.Clear(); + + public int Count => _values.Count; + + bool ICollection.IsReadOnly + => false; + + public T this[int index] + { + get => _values[index]; + set => _values[index] = value; + } + + IEnumerator IEnumerable.GetEnumerator() + => _values.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => _values.GetEnumerator(); + + public int IndexOf(T item) + => _values.IndexOf(item); + + public void Insert(int index, T item) + => _values.Insert(index, item); + + public void RemoveAt(int index) + => _values.RemoveAt(index); + + public bool Contains(T item) + => _values.Contains(item); + + public void CopyTo(T[] array, int arrayIndex) + => _values.CopyTo(array, arrayIndex); + + public bool Remove(T item) + => _values.Remove(item); + + public Variant AsVariant() + => Variant.FromArray(this); + + public static implicit operator Variant(Array value) + => value.AsVariant(); + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + { +#if NET5_0_OR_GREATER + Span span = CollectionsMarshal.AsSpan(_values); + writer.WriteArray(span); +#else + writer.WriteArray(_values); +#endif + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/ClientConnectionOptions.cs b/src/Linux/Tmds.DBus.Protocol/ClientConnectionOptions.cs new file mode 100644 index 0000000000..129947db9d --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ClientConnectionOptions.cs @@ -0,0 +1,36 @@ +namespace Tmds.DBus.Protocol; + +public class ClientConnectionOptions : ConnectionOptions +{ + private string _address; + + public ClientConnectionOptions(string address) + { + if (address == null) + throw new ArgumentNullException(nameof(address)); + _address = address; + } + + protected ClientConnectionOptions() + { + _address = string.Empty; + } + + public bool AutoConnect { get; set; } + + internal bool IsShared { get; set; } + + protected internal virtual ValueTask SetupAsync(CancellationToken cancellationToken) + { + return new ValueTask( + new ClientSetupResult(_address) + { + SupportsFdPassing = true, + UserId = DBusEnvironment.UserId, + MachineId = DBusEnvironment.MachineId + }); + } + + protected internal virtual void Teardown(object? token) + { } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/ClientSetupResult.cs b/src/Linux/Tmds.DBus.Protocol/ClientSetupResult.cs new file mode 100644 index 0000000000..d375d7d66e --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ClientSetupResult.cs @@ -0,0 +1,19 @@ +namespace Tmds.DBus.Protocol; + +public class ClientSetupResult +{ + public ClientSetupResult(string address) + { + ConnectionAddress = address ?? throw new ArgumentNullException(nameof(address)); + } + + public string ConnectionAddress { get; } + + public object? TeardownToken { get; set; } + + public string? UserId { get; set; } + + public string? MachineId { get; set; } + + public bool SupportsFdPassing { get; set; } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/CloseSafeHandle.cs b/src/Linux/Tmds.DBus.Protocol/CloseSafeHandle.cs new file mode 100644 index 0000000000..1e92132ada --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/CloseSafeHandle.cs @@ -0,0 +1,17 @@ +namespace Tmds.DBus.Protocol; + +sealed class CloseSafeHandle : SafeHandle +{ + public CloseSafeHandle() : + base(new IntPtr(-1), ownsHandle: true) + { } + + public override bool IsInvalid + => handle == new IntPtr(-1); + + protected override bool ReleaseHandle() + => close(handle.ToInt32()) == 0; + + [DllImport("libc", SetLastError = true)] + internal static extern int close(int fd); +} diff --git a/src/Linux/Tmds.DBus.Protocol/ConnectException.cs b/src/Linux/Tmds.DBus.Protocol/ConnectException.cs new file mode 100644 index 0000000000..f035990402 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ConnectException.cs @@ -0,0 +1,10 @@ +namespace Tmds.DBus.Protocol; + +public class ConnectException : Exception +{ + public ConnectException(string message) : base(message) + { } + + public ConnectException(string message, Exception innerException) : base(message, innerException) + { } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Connection.DBus.cs b/src/Linux/Tmds.DBus.Protocol/Connection.DBus.cs new file mode 100644 index 0000000000..4ceceb5c70 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Connection.DBus.cs @@ -0,0 +1,112 @@ +using System.Threading.Channels; + +namespace Tmds.DBus.Protocol; + +public partial class Connection +{ + public const string DBusObjectPath = "/org/freedesktop/DBus"; + public const string DBusServiceName = "org.freedesktop.DBus"; + public const string DBusInterface = "org.freedesktop.DBus"; + + public Task ListServicesAsync() + { + return CallMethodAsync(CreateMessage(), (Message m, object? s) => m.GetBodyReader().ReadArrayOfString()); + MessageBuffer CreateMessage() + { + using var writer = GetMessageWriter(); + writer.WriteMethodCallHeader( + destination: DBusServiceName, + path: DBusObjectPath, + @interface: DBusInterface, + member: "ListNames"); + return writer.CreateMessage(); + } + } + + public Task ListActivatableServicesAsync() + { + return CallMethodAsync(CreateMessage(), (Message m, object? s) => m.GetBodyReader().ReadArrayOfString()); + MessageBuffer CreateMessage() + { + using var writer = GetMessageWriter(); + writer.WriteMethodCallHeader( + destination: DBusServiceName, + path: DBusObjectPath, + @interface: DBusInterface, + member: "ListActivatableNames"); + return writer.CreateMessage(); + } + } + + public async Task BecomeMonitorAsync(Action handler, IEnumerable? rules = null) + { + if (_connectionOptions.IsShared) + { + throw new InvalidOperationException("Cannot become monitor on a shared connection."); + } + + DBusConnection connection = await ConnectCoreAsync().ConfigureAwait(false); + await connection.BecomeMonitorAsync(handler, rules).ConfigureAwait(false); + } + + public static async IAsyncEnumerable MonitorBusAsync(string address, IEnumerable? rules = null, [EnumeratorCancellation]CancellationToken ct = default) + { + ct.ThrowIfCancellationRequested(); + + var channel = Channel.CreateUnbounded( + new UnboundedChannelOptions() + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = true, + } + ); + + using var connection = new Connection(address); + using CancellationTokenRegistration ctr = +#if NETCOREAPP3_1_OR_GREATER + ct.UnsafeRegister(c => ((Connection)c!).Dispose(), connection); +#else + ct.Register(c => ((Connection)c!).Dispose(), connection); +#endif + try + { + await connection.ConnectAsync().ConfigureAwait(false); + + await connection.BecomeMonitorAsync( + (Exception? ex, DisposableMessage message) => + { + if (ex is not null) + { + if (ct.IsCancellationRequested) + { + ex = new OperationCanceledException(ct); + } + channel.Writer.TryComplete(ex); + return; + } + + if (!channel.Writer.TryWrite(message)) + { + message.Dispose(); + } + }, + rules + ).ConfigureAwait(false); + } + catch + { + ct.ThrowIfCancellationRequested(); + + throw; + } + + while (await channel.Reader.WaitToReadAsync().ConfigureAwait(false)) + { + if (channel.Reader.TryRead(out DisposableMessage msg)) + { + yield return msg; + } + } + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Connection.cs b/src/Linux/Tmds.DBus.Protocol/Connection.cs new file mode 100644 index 0000000000..973c40738e --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Connection.cs @@ -0,0 +1,335 @@ +namespace Tmds.DBus.Protocol; + +public delegate T MessageValueReader(Message message, object? state); + +public partial class Connection : IDisposable +{ + internal static readonly Exception DisposedException = new ObjectDisposedException(typeof(Connection).FullName); + private static Connection? s_systemConnection; + private static Connection? s_sessionConnection; + + public static Connection System => s_systemConnection ?? CreateConnection(ref s_systemConnection, Address.System); + public static Connection Session => s_sessionConnection ?? CreateConnection(ref s_sessionConnection, Address.Session); + + public string? UniqueName => GetConnection().UniqueName; + + enum ConnectionState + { + Created, + Connecting, + Connected, + Disconnected + } + + private readonly object _gate = new object(); + private readonly ClientConnectionOptions _connectionOptions; + private DBusConnection? _connection; + private CancellationTokenSource? _connectCts; + private Task? _connectingTask; + private ClientSetupResult? _setupResult; + private ConnectionState _state; + private bool _disposed; + private int _nextSerial; + + public Connection(string address) : + this(new ClientConnectionOptions(address)) + { } + + public Connection(ConnectionOptions connectionOptions) + { + if (connectionOptions == null) + throw new ArgumentNullException(nameof(connectionOptions)); + + _connectionOptions = (ClientConnectionOptions)connectionOptions; + } + + // For tests. + internal void Connect(IMessageStream stream) + { + _connection = new DBusConnection(this, DBusEnvironment.MachineId); + _connection.Connect(stream); + _state = ConnectionState.Connected; + } + + public async ValueTask ConnectAsync() + { + await ConnectCoreAsync(explicitConnect: true).ConfigureAwait(false); + } + + private ValueTask ConnectCoreAsync(bool explicitConnect = false) + { + lock (_gate) + { + ThrowHelper.ThrowIfDisposed(_disposed, this); + + ConnectionState state = _state; + + if (state == ConnectionState.Connected) + { + return new ValueTask(_connection!); + } + + if (!_connectionOptions.AutoConnect) + { + DBusConnection? connection = _connection; + if (!explicitConnect && _state == ConnectionState.Disconnected && connection is not null) + { + throw new DisconnectedException(connection.DisconnectReason); + } + + if (!explicitConnect || _state != ConnectionState.Created) + { + throw new InvalidOperationException("Can only connect once using an explicit call."); + } + } + + if (state == ConnectionState.Connecting) + { + return new ValueTask(_connectingTask!); + } + + _state = ConnectionState.Connecting; + _connectingTask = DoConnectAsync(); + + return new ValueTask(_connectingTask); + } + } + + private async Task DoConnectAsync() + { + Debug.Assert(Monitor.IsEntered(_gate)); + + DBusConnection? connection = null; + try + { + _connectCts = new(); + _setupResult = await _connectionOptions.SetupAsync(_connectCts.Token).ConfigureAwait(false); + connection = _connection = new DBusConnection(this, _setupResult.MachineId ?? DBusEnvironment.MachineId); + + await connection.ConnectAsync(_setupResult.ConnectionAddress, _setupResult.UserId, _setupResult.SupportsFdPassing, _connectCts.Token).ConfigureAwait(false); + + lock (_gate) + { + ThrowHelper.ThrowIfDisposed(_disposed, this); + + if (_connection == connection && _state == ConnectionState.Connecting) + { + _connectingTask = null; + _connectCts = null; + _state = ConnectionState.Connected; + } + else + { + throw new DisconnectedException(connection.DisconnectReason); + } + } + + return connection; + } + catch (Exception exception) + { + Disconnect(exception, connection); + + // Prefer throwing ObjectDisposedException. + ThrowHelper.ThrowIfDisposed(_disposed, this); + + // Throw DisconnectedException or ConnectException. + if (exception is DisconnectedException || exception is ConnectException) + { + throw; + } + else + { + throw new ConnectException(exception.Message, exception); + } + } + } + + public void Dispose() + { + lock (_gate) + { + if (_disposed) + { + return; + } + _disposed = true; + } + + Disconnect(DisposedException); + } + + internal void Disconnect(Exception disconnectReason, DBusConnection? trigger = null) + { + DBusConnection? connection; + ClientSetupResult? setupResult; + CancellationTokenSource? connectCts; + lock (_gate) + { + if (trigger is not null && trigger != _connection) + { + // Already disconnected from this stream. + return; + } + + ConnectionState state = _state; + if (state == ConnectionState.Disconnected) + { + return; + } + + _state = ConnectionState.Disconnected; + + connection = _connection; + setupResult = _setupResult; + connectCts = _connectCts; + + _connectingTask = null; + _setupResult = null; + _connectCts = null; + + if (connection is not null) + { + connection.DisconnectReason = disconnectReason; + } + } + + connectCts?.Cancel(); + connection?.Dispose(); + if (setupResult != null) + { + _connectionOptions.Teardown(setupResult.TeardownToken); + } + } + + public async Task CallMethodAsync(MessageBuffer message) + { + DBusConnection connection; + try + { + connection = await ConnectCoreAsync().ConfigureAwait(false); + } + catch + { + message.ReturnToPool(); + throw; + } + await connection.CallMethodAsync(message).ConfigureAwait(false); + } + + public async Task CallMethodAsync(MessageBuffer message, MessageValueReader reader, object? readerState = null) + { + DBusConnection connection; + try + { + connection = await ConnectCoreAsync().ConfigureAwait(false); + } + catch + { + message.ReturnToPool(); + throw; + } + return await connection.CallMethodAsync(message, reader, readerState).ConfigureAwait(false); + } + + [Obsolete("Use an overload that accepts ObserverFlags.")] + public ValueTask AddMatchAsync(MatchRule rule, MessageValueReader reader, Action handler, object? readerState = null, object? handlerState = null, bool emitOnCapturedContext = true, bool subscribe = true) + => AddMatchAsync(rule, reader, handler, readerState, handlerState, emitOnCapturedContext, ObserverFlags.EmitOnDispose | (!subscribe ? ObserverFlags.NoSubscribe : default)); + + public ValueTask AddMatchAsync(MatchRule rule, MessageValueReader reader, Action handler, ObserverFlags flags, object? readerState = null, object? handlerState = null, bool emitOnCapturedContext = true) + => AddMatchAsync(rule, reader, handler, readerState, handlerState, emitOnCapturedContext, flags); + + public ValueTask AddMatchAsync(MatchRule rule, MessageValueReader reader, Action handler, object? readerState, object? handlerState, bool emitOnCapturedContext, ObserverFlags flags) + => AddMatchAsync(rule, reader, handler, readerState, handlerState, emitOnCapturedContext ? SynchronizationContext.Current : null, flags); + + public async ValueTask AddMatchAsync(MatchRule rule, MessageValueReader reader, Action handler, object? readerState , object? handlerState, SynchronizationContext? synchronizationContext, ObserverFlags flags) + { + DBusConnection connection = await ConnectCoreAsync().ConfigureAwait(false); + return await connection.AddMatchAsync(synchronizationContext, rule, reader, handler, readerState, handlerState, flags).ConfigureAwait(false); + } + + public void AddMethodHandler(IMethodHandler methodHandler) + => UpdateMethodHandlers((dictionary, handler) => dictionary.AddMethodHandler(handler), methodHandler); + + public void AddMethodHandlers(IReadOnlyList methodHandlers) + => UpdateMethodHandlers((dictionary, handlers) => dictionary.AddMethodHandlers(handlers), methodHandlers); + + public void RemoveMethodHandler(string path) + => UpdateMethodHandlers((dictionary, path) => dictionary.RemoveMethodHandler(path), path); + + public void RemoveMethodHandlers(IEnumerable paths) + => UpdateMethodHandlers((dictionary, paths) => dictionary.RemoveMethodHandlers(paths), paths); + + private void UpdateMethodHandlers(Action update, T state) + => GetConnection().UpdateMethodHandlers(update, state); + + private static Connection CreateConnection(ref Connection? field, string? address) + { + address = address ?? "unix:"; + var connection = Volatile.Read(ref field); + if (connection is not null) + { + return connection; + } + var newConnection = new Connection(new ClientConnectionOptions(address) { AutoConnect = true, IsShared = true }); + connection = Interlocked.CompareExchange(ref field, newConnection, null); + if (connection != null) + { + newConnection.Dispose(); + return connection; + } + return newConnection; + } + + public MessageWriter GetMessageWriter() => new MessageWriter(MessageBufferPool.Shared, GetNextSerial()); + + public bool TrySendMessage(MessageBuffer message) + { + DBusConnection? connection = GetConnection(ifConnected: true); + if (connection is null) + { + message.ReturnToPool(); + return false; + } + connection.SendMessage(message); + return true; + } + + public Task DisconnectedAsync() + { + DBusConnection connection = GetConnection(); + return connection.DisconnectedAsync(); + } + + private DBusConnection GetConnection() => GetConnection(ifConnected: false)!; + + private DBusConnection? GetConnection(bool ifConnected) + { + lock (_gate) + { + ThrowHelper.ThrowIfDisposed(_disposed, this); + + if (_connectionOptions.AutoConnect) + { + throw new InvalidOperationException("Method cannot be used on autoconnect connections."); + } + + ConnectionState state = _state; + + if (state == ConnectionState.Created || + state == ConnectionState.Connecting) + { + throw new InvalidOperationException("Connect before using this method."); + } + + if (ifConnected && state != ConnectionState.Connected) + { + return null; + } + + return _connection; + } + } + + internal uint GetNextSerial() => (uint)Interlocked.Increment(ref _nextSerial); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/ConnectionOptions.cs b/src/Linux/Tmds.DBus.Protocol/ConnectionOptions.cs new file mode 100644 index 0000000000..40c6e7a9f0 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ConnectionOptions.cs @@ -0,0 +1,7 @@ +namespace Tmds.DBus.Protocol; + +public abstract class ConnectionOptions +{ + internal ConnectionOptions() + { } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Constants.cs b/src/Linux/Tmds.DBus.Protocol/Constants.cs new file mode 100644 index 0000000000..9382ef13d3 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Constants.cs @@ -0,0 +1,7 @@ +namespace Tmds.DBus.Protocol; + +static class Constants +{ + public const int StackAllocByteThreshold = 512; + public const int StackAllocCharThreshold = StackAllocByteThreshold / 2; +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/DBusConnection.cs b/src/Linux/Tmds.DBus.Protocol/DBusConnection.cs new file mode 100644 index 0000000000..3ddcbb5a9d --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/DBusConnection.cs @@ -0,0 +1,1290 @@ +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks.Sources; + +#pragma warning disable VSTHRD100 // Avoid "async void" methods + +namespace Tmds.DBus.Protocol; + +class DBusConnection : IDisposable +{ + private delegate void MessageReceivedHandler(Exception? exception, Message message, object? state); + + sealed class MyValueTaskSource : IValueTaskSource, IValueTaskSource + { + private ManualResetValueTaskSourceCore _core; + private volatile bool _continuationSet; + + public void SetResult(T result) + { + // Ensure we complete the Task from the read loop. + SpinWait wait = new(); + while (!_continuationSet) + { + wait.SpinOnce(); + } + _core.SetResult(result); + } + + public void SetException(Exception exception) => _core.SetException(exception); + + public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token); + + public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) + { + _core.OnCompleted(continuation, state, token, flags); + _continuationSet = true; + } + + T IValueTaskSource.GetResult(short token) => _core.GetResult(token); + + void IValueTaskSource.GetResult(short token) => _core.GetResult(token); + } + + enum ConnectionState + { + Created, + Connecting, + Connected, + Disconnected + } + + delegate void MessageHandlerDelegate(Exception? exception, Message message, object? state1, object? state2, object? state3); + + readonly struct MessageHandler + { + public MessageHandler(MessageHandlerDelegate handler, object? state1 = null, object? state2 = null, object? state3 = null) + { + _delegate = handler; + _state1 = state1; + _state2 = state2; + _state3 = state3; + } + + public void Invoke(Exception? exception, Message message) + { + _delegate(exception, message, _state1, _state2, _state3); + } + + public bool HasValue => _delegate is not null; + + private readonly MessageHandlerDelegate _delegate; + private readonly object? _state1; + private readonly object? _state2; + private readonly object? _state3; + } + + delegate void MessageHandlerDelegate4(Exception? exception, Message message, object? state1, object? state2, object? state3, object? state4); + + readonly struct MessageHandler4 + { + public MessageHandler4(MessageHandlerDelegate4 handler, object? state1 = null, object? state2 = null, object? state3 = null, object? state4 = null) + { + _delegate = handler; + _state1 = state1; + _state2 = state2; + _state3 = state3; + _state4 = state4; + } + + public void Invoke(Exception? exception, Message message) + { + _delegate(exception, message, _state1, _state2, _state3, _state4); + } + + public bool HasValue => _delegate is not null; + + private readonly MessageHandlerDelegate4 _delegate; + private readonly object? _state1; + private readonly object? _state2; + private readonly object? _state3; + private readonly object? _state4; + } + + private readonly object _gate = new object(); + private readonly Connection _parentConnection; + private readonly Dictionary _pendingCalls; + private readonly CancellationTokenSource _connectCts; + private readonly Dictionary _matchMakers; + private readonly List _matchedObservers; + private readonly PathNodeDictionary _pathNodes; + private readonly string _machineId; + + private IMessageStream? _messageStream; + private ConnectionState _state; + private Exception? _disconnectReason; + private string? _localName; + private Message? _currentMessage; + private Observer? _currentObserver; + private SynchronizationContext? _currentSynchronizationContext; + private TaskCompletionSource? _disconnectedTcs; + private CancellationTokenSource _abortedCts; + private bool _isMonitor; + private Action? _monitorHandler; + + public string? UniqueName => _localName; + + public Exception DisconnectReason + { + get => _disconnectReason ?? new ObjectDisposedException(GetType().FullName); + set => Interlocked.CompareExchange(ref _disconnectReason, value, null); + } + + public bool RemoteIsBus => _localName is not null; + + public DBusConnection(Connection parent, string machineId) + { + _parentConnection = parent; + _connectCts = new(); + _pendingCalls = new(); + _matchMakers = new(); + _matchedObservers = new(); + _pathNodes = new(); + _machineId = machineId; + _abortedCts = new(); + } + + // For tests. + internal void Connect(IMessageStream stream) + { + _messageStream = stream; + + stream.ReceiveMessages( + static (Exception? exception, Message message, DBusConnection connection) => + connection.HandleMessages(exception, message), this); + + _state = ConnectionState.Connected; + } + + public async ValueTask ConnectAsync(string address, string? userId, bool supportsFdPassing, CancellationToken cancellationToken) + { + _state = ConnectionState.Connecting; + Exception? firstException = null; + + AddressParser.AddressEntry addr = default; + while (AddressParser.TryGetNextEntry(address, ref addr)) + { + Socket? socket = null; + EndPoint? endpoint = null; + Guid guid = default; + + if (AddressParser.IsType(addr, "unix")) + { + AddressParser.ParseUnixProperties(addr, out string path, out guid); + socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + endpoint = new UnixDomainSocketEndPoint(path); + } + else if (AddressParser.IsType(addr, "tcp")) + { + AddressParser.ParseTcpProperties(addr, out string host, out int? port, out guid); + if (!port.HasValue) + { + throw new ArgumentException("port"); + } + socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + endpoint = new DnsEndPoint(host, port.Value); + } + + if (socket is null) + { + continue; + } + + try + { + await socket.ConnectAsync(endpoint!, cancellationToken).ConfigureAwait(false); + + MessageStream stream; + lock (_gate) + { + if (_state != ConnectionState.Connecting) + { + throw new DisconnectedException(DisconnectReason); + } + _messageStream = stream = new MessageStream(socket); + } + + await stream.DoClientAuthAsync(guid, userId, supportsFdPassing).ConfigureAwait(false); + + stream.ReceiveMessages( + static (Exception? exception, Message message, DBusConnection connection) => + connection.HandleMessages(exception, message), this); + + lock (_gate) + { + if (_state != ConnectionState.Connecting) + { + throw new DisconnectedException(DisconnectReason); + } + _state = ConnectionState.Connected; + } + + _localName = await GetLocalNameAsync().ConfigureAwait(false); + + return; + } + catch (Exception exception) + { + socket.Dispose(); + firstException ??= exception; + } + } + + if (firstException is not null) + { + throw firstException; + } + + throw new ArgumentException("No addresses were found", nameof(address)); + } + + private async Task GetLocalNameAsync() + { + MyValueTaskSource vts = new(); + + await CallMethodAsync( + message: CreateHelloMessage(), + static (Exception? exception, Message message, object? state) => + { + var vtsState = (MyValueTaskSource)state!; + + if (exception is not null) + { + vtsState.SetException(exception); + } + else if (message.MessageType == MessageType.MethodReturn) + { + vtsState.SetResult(message.GetBodyReader().ReadString().ToString()); + } + else + { + vtsState.SetResult(null); + } + }, vts).ConfigureAwait(false); + + return await new ValueTask(vts, token: 0).ConfigureAwait(false); + + MessageBuffer CreateHelloMessage() + { + using var writer = GetMessageWriter(); + + writer.WriteMethodCallHeader( + destination: "org.freedesktop.DBus", + path: "/org/freedesktop/DBus", + @interface: "org.freedesktop.DBus", + member: "Hello"); + + return writer.CreateMessage(); + } + } + + private async void HandleMessages(Exception? exception, Message message) + { + if (exception is not null) + { + _parentConnection.Disconnect(exception, this); + } + else + { + try + { + bool returnMessageToPool = true; + MessageHandler pendingCall = default; + IMethodHandler? methodHandler = null; + Action? monitor = null; + bool isMethodCall = message.MessageType == MessageType.MethodCall; + MethodContext? methodContext = null; + + lock (_gate) + { + if (_state == ConnectionState.Disconnected) + { + return; + } + + monitor = _monitorHandler; + + if (monitor is null) + { + if (message.ReplySerial.HasValue) + { + _pendingCalls.Remove(message.ReplySerial.Value, out pendingCall); + } + + foreach (var matchMaker in _matchMakers.Values) + { + if (matchMaker.Matches(message)) + { + _matchedObservers.AddRange(matchMaker.Observers); + } + } + + if (isMethodCall) + { + methodContext = new MethodContext(_parentConnection, message, _abortedCts.Token); // TODO: pool. + + if (message.PathIsSet) + { + if (_pathNodes.TryGetValue(message.PathAsString!, out PathNode? node)) + { + methodHandler = node.MethodHandler; + + bool isDBusIntrospect = message.Member.SequenceEqual("Introspect"u8) && + message.Interface.SequenceEqual("org.freedesktop.DBus.Introspectable"u8); + methodContext.IsDBusIntrospectRequest = isDBusIntrospect; + if (isDBusIntrospect) + { + node.CopyChildNamesTo(methodContext); + } + } + } + } + } + } + + if (monitor is not null) + { + lock (monitor) + { + if (_monitorHandler is not null) + { + returnMessageToPool = false; + monitor(null, new DisposableMessage(message)); + } + } + } + else + { + if (_matchedObservers.Count != 0) + { + foreach (var observer in _matchedObservers) + { + observer.Emit(message); + } + _matchedObservers.Clear(); + } + + if (pendingCall.HasValue) + { + pendingCall.Invoke(null, message); + } + + if (isMethodCall) + { + Debug.Assert(methodContext is not null); + if (methodHandler is not null) + { +// Suppress methodContext nullability warnings. +#if NETSTANDARD2_0 +#pragma warning disable CS8604 +#endif + bool runHandlerSynchronously = methodHandler.RunMethodHandlerSynchronously(message); + if (runHandlerSynchronously) + { + await methodHandler.HandleMethodAsync(methodContext).ConfigureAwait(false); + HandleNoReplySent(methodContext); + } + else + { + returnMessageToPool = false; + RunMethodHandler(methodHandler, methodContext); + } + } + else + { + HandleNoReplySent(methodContext); + } +#if NETSTANDARD2_0 +#pragma warning restore CS8604 +#endif + } + } + + if (returnMessageToPool) + { + message.ReturnToPool(); + } + } + catch (Exception ex) + { + _parentConnection.Disconnect(ex, this); + } + } + } + + private void HandleNoReplySent(MethodContext context) + { + if (context.ReplySent || context.NoReplyExpected) + { + return; + } + + if (context.IsDBusIntrospectRequest) + { + context.ReplyIntrospectXml(interfaceXmls: []); + return; + } + + var request = context.Request; + + if (request.Interface.SequenceEqual("org.freedesktop.DBus.Peer"u8)) + { + if (request.Member.SequenceEqual("Ping"u8)) + { + using var writer = context.CreateReplyWriter(null); + context.Reply(writer.CreateMessage()); + return; + } + else if (request.Member.SequenceEqual("GetMachineId"u8)) + { + using var writer = context.CreateReplyWriter("s"); + writer.WriteString(_machineId); + context.Reply(writer.CreateMessage()); + return; + } + } + + context.ReplyError("org.freedesktop.DBus.Error.UnknownMethod", + $"Method \"{request.MemberAsString}\" with signature \"{request.SignatureAsString}\" on interface \"{request.InterfaceAsString}\" doesn't exist"); + } + + private async void RunMethodHandler(IMethodHandler methodHandler, MethodContext context) + { + try + { + await methodHandler.HandleMethodAsync(context).ConfigureAwait(false); + HandleNoReplySent(context); + context.Request.ReturnToPool(); + } + catch (Exception ex) + { + _parentConnection.Disconnect(ex, this); + } + } + + private void EmitOnSynchronizationContextHelper(Observer observer, SynchronizationContext synchronizationContext, Message message) + { + _currentMessage = message; + _currentObserver = observer; + _currentSynchronizationContext = synchronizationContext; + +#pragma warning disable VSTHRD001 // Await JoinableTaskFactory.SwitchToMainThreadAsync() to switch to the UI thread instead of APIs that can deadlock or require specifying a priority. + // note: Send blocks the current thread until the SynchronizationContext ran the delegate. + synchronizationContext.Send(static o => + { + SynchronizationContext? previousContext = SynchronizationContext.Current; + try + { + DBusConnection conn = (DBusConnection)o!; + SynchronizationContext.SetSynchronizationContext(conn._currentSynchronizationContext); + conn._currentObserver!.InvokeHandler(conn._currentMessage!); + } + finally + { + SynchronizationContext.SetSynchronizationContext(previousContext); + } + }, this); + + _currentMessage = null; + _currentObserver = null; + _currentSynchronizationContext = null; + } + + public void UpdateMethodHandlers(Action update, T state) + { + lock (_gate) + { + update(_pathNodes, state); + } + } + + public void Dispose() + { + Action? monitor = null; + + lock (_gate) + { + if (_state == ConnectionState.Disconnected) + { + return; + } + _state = ConnectionState.Disconnected; + monitor = _monitorHandler; + } + + Exception disconnectReason = DisconnectReason; + + _messageStream?.Close(disconnectReason); + + _abortedCts.Cancel(); + + if (_pendingCalls is not null) + { + foreach (var pendingCall in _pendingCalls.Values) + { + pendingCall.Invoke(new DisconnectedException(disconnectReason), null!); + } + _pendingCalls.Clear(); + } + + foreach (var matchMaker in _matchMakers.Values) + { + foreach (var observer in matchMaker.Observers) + { + bool emitException = !object.ReferenceEquals(disconnectReason, Connection.DisposedException) || + observer.EmitOnConnectionDispose; + Exception? exception = emitException ? new DisconnectedException(disconnectReason) : null; + observer.Dispose(exception, removeObserver: false); + } + } + _matchMakers.Clear(); + + if (monitor is not null) + { + lock (monitor) + { + _monitorHandler = null; + monitor(new DisconnectedException(disconnectReason), new DisposableMessage(null)); + } + } + + _disconnectedTcs?.SetResult(GetWaitForDisconnectException()); + } + + private ValueTask CallMethodAsync(MessageBuffer message, MessageReceivedHandler returnHandler, object? state) + { + MessageHandlerDelegate fn = static (Exception? exception, Message message, object? state1, object? state2, object? state3) => + { + ((MessageReceivedHandler)state1!)(exception, message, state2); + }; + MessageHandler handler = new(fn, returnHandler, state); + + return CallMethodAsync(message, handler); + } + + private async ValueTask CallMethodAsync(MessageBuffer message, MessageHandler handler) + { + bool messageSent = false; + try + { + lock (_gate) + { + if (_state != ConnectionState.Connected) + { + throw new DisconnectedException(DisconnectReason!); + } + if (_isMonitor) + { + throw new InvalidOperationException("Cannot send messages on monitor connection."); + } + if ((message.MessageFlags & MessageFlags.NoReplyExpected) == 0) + { + _pendingCalls.Add(message.Serial, handler); + } + } + + messageSent = await _messageStream!.TrySendMessageAsync(message).ConfigureAwait(false); + } + finally + { + if (!messageSent) + { + message.ReturnToPool(); + } + } + } + + public async Task CallMethodAsync(MessageBuffer message, MessageValueReader valueReader, object? state = null) + { + MessageHandlerDelegate fn = static (Exception? exception, Message message, object? state1, object? state2, object? state3) => + { + var valueReaderState = (MessageValueReader)state1!; + var vtsState = (MyValueTaskSource)state2!; + + if (exception is not null) + { + vtsState.SetException(exception); + } + else if (message.MessageType == MessageType.MethodReturn) + { + try + { + vtsState.SetResult(valueReaderState(message, state3)); + } + catch (Exception ex) + { + vtsState.SetException(ex); + } + } + else if (message.MessageType == MessageType.Error) + { + vtsState.SetException(CreateDBusExceptionForErrorMessage(message)); + } + else + { + vtsState.SetException(new ProtocolException($"Unexpected reply type: {message.MessageType}.")); + } + }; + + MyValueTaskSource vts = new(); + MessageHandler handler = new(fn, valueReader, vts, state); + + await CallMethodAsync(message, handler).ConfigureAwait(false); + + return await new ValueTask(vts, 0).ConfigureAwait(false); + } + + public async Task CallMethodAsync(MessageBuffer message) + { + MyValueTaskSource vts = new(); + + await CallMethodAsync(message, + static (Exception? exception, Message message, object? state) => CompleteCallValueTaskSource(exception, message, state), vts).ConfigureAwait(false); + + await new ValueTask(vts, 0).ConfigureAwait(false); + } + + private static void CompleteCallValueTaskSource(Exception? exception, Message message, object? vts) + { + var vtsState = (MyValueTaskSource)vts!; + + if (exception is not null) + { + vtsState.SetException(exception); + } + else if (message.MessageType == MessageType.MethodReturn) + { + vtsState.SetResult(null); + } + else if (message.MessageType == MessageType.Error) + { + vtsState.SetException(CreateDBusExceptionForErrorMessage(message)); + } + else + { + vtsState.SetException(new ProtocolException($"Unexpected reply type: {message.MessageType}.")); + } + } + + private static DBusException CreateDBusExceptionForErrorMessage(Message message) + { + string errorName = message.ErrorNameAsString ?? "<>."; + string errMessage = errorName; + if (message.SignatureIsSet && message.Signature.Length > 0 && (DBusType)message.Signature[0] == DBusType.String) + { + errMessage = message.GetBodyReader().ReadString(); + } + return new DBusException(errorName, errMessage); + } + + public async Task BecomeMonitorAsync(Action handler, IEnumerable? rules) + { + Task reply; + + lock (_gate) + { + if (_state != ConnectionState.Connected) + { + throw new DisconnectedException(DisconnectReason!); + } + if (!RemoteIsBus) + { + throw new InvalidOperationException("The remote is not a bus."); + } + if (_matchMakers.Count != 0) + { + throw new InvalidOperationException("The connection has observers."); + } + if (_pendingCalls.Count != 0) + { + throw new InvalidOperationException("The connection has pending method calls."); + } + + HashSet? ruleStrings = null; + if (rules is not null) + { + ruleStrings = new(); + foreach (var rule in rules) + { + ruleStrings.Add(rule.ToString()); + } + } + + reply = CallMethodAsync(CreateMessage(ruleStrings)); + _isMonitor = true; + } + + try + { + await reply.ConfigureAwait(false); + lock (_gate) + { + _messageStream!.BecomeMonitor(); + _monitorHandler = handler; + } + } + catch + { + lock (_gate) + { + _isMonitor = false; + } + + throw; + } + + MessageBuffer CreateMessage(IEnumerable? rules) + { + using var writer = GetMessageWriter(); + writer.WriteMethodCallHeader( + destination: Connection.DBusServiceName, + path: Connection.DBusObjectPath, + @interface: "org.freedesktop.DBus.Monitoring", + signature: "asu", + member: "BecomeMonitor"); + writer.WriteArray(rules ?? Array.Empty()); + writer.WriteUInt32(0); + return writer.CreateMessage(); + } + } + + public ValueTask AddMatchAsync(SynchronizationContext? synchronizationContext, MatchRule rule, MessageValueReader valueReader, Action valueHandler, object? readerState, object? handlerState, ObserverFlags flags) + { + MessageHandlerDelegate4 fn = static (Exception? exception, Message message, object? reader, object? handler, object? rs, object? hs) => + { + var valueHandlerState = (Action)handler!; + if (exception is not null) + { + valueHandlerState(exception, default(T)!, rs, hs); + } + else + { + var valueReaderState = (MessageValueReader)reader!; + T value = valueReaderState(message, rs); + valueHandlerState(null, value, rs, hs); + } + }; + + return AddMatchAsync(synchronizationContext, rule, new(fn, valueReader, valueHandler, readerState, handlerState), flags); + } + + private async ValueTask AddMatchAsync(SynchronizationContext? synchronizationContext, MatchRule rule, MessageHandler4 handler, ObserverFlags flags) + { + MatchRuleData data = rule.Data; + MatchMaker? matchMaker; + string ruleString; + Observer observer; + MessageBuffer? addMatchMessage = null; + bool subscribe; + + lock (_gate) + { + if (_state != ConnectionState.Connected) + { + throw new DisconnectedException(DisconnectReason!); + } + if (!RemoteIsBus) + { + flags |= ObserverFlags.NoSubscribe; + } + if (_isMonitor) + { + throw new InvalidOperationException("Cannot add subscriptions on a monitor connection."); + } + + ruleString = data.GetRuleString(); + + if (!_matchMakers.TryGetValue(ruleString, out matchMaker)) + { + matchMaker = new MatchMaker(this, ruleString, data); + _matchMakers.Add(ruleString, matchMaker); + } + + observer = new Observer(synchronizationContext, matchMaker, handler, flags); + matchMaker.Observers.Add(observer); + + subscribe = observer.Subscribes; + bool sendMessage = subscribe && matchMaker.AddMatchTcs is null; + if (sendMessage) + { + addMatchMessage = CreateAddMatchMessage(matchMaker.RuleString); + matchMaker.AddMatchTcs = new(); + + MessageHandlerDelegate fn = static (Exception? exception, Message message, object? state1, object? state2, object? state3) => + { + var mm = (MatchMaker)state1!; + if (message.MessageType == MessageType.MethodReturn) + { + mm.HasSubscribed = true; + } + CompleteCallValueTaskSource(exception, message, mm.AddMatchTcs!); + }; + + _pendingCalls.Add(addMatchMessage.Serial, new(fn, matchMaker)); + } + } + + if (subscribe) + { + if (addMatchMessage is not null) + { + if (!await _messageStream!.TrySendMessageAsync(addMatchMessage).ConfigureAwait(false)) + { + addMatchMessage.ReturnToPool(); + } + } + + try + { + await matchMaker.AddMatchTask!.ConfigureAwait(false); + } + catch + { + observer.Dispose(exception: null); + + throw; + } + } + + return observer; + + MessageBuffer CreateAddMatchMessage(string ruleString) + { + using var writer = GetMessageWriter(); + + writer.WriteMethodCallHeader( + destination: "org.freedesktop.DBus", + path: "/org/freedesktop/DBus", + @interface: "org.freedesktop.DBus", + member: "AddMatch", + signature: "s"); + + writer.WriteString(ruleString); + + return writer.CreateMessage(); + } + } + + internal static readonly ObjectDisposedException ObserverDisposedException = new ObjectDisposedException(typeof(Observer).FullName); + + sealed class Observer : IDisposable + { + private readonly object _gate = new object(); + private readonly SynchronizationContext? _synchronizationContext; + private readonly MatchMaker _matchMaker; + private readonly MessageHandler4 _messageHandler; + private readonly ObserverFlags _flags; + private bool _disposed; + + public bool Subscribes => (_flags & ObserverFlags.NoSubscribe) == 0; + public bool EmitOnConnectionDispose => (_flags & ObserverFlags.EmitOnConnectionDispose) != 0; + public bool EmitOnObserverDispose => (_flags & ObserverFlags.EmitOnObserverDispose) != 0; + + public Observer(SynchronizationContext? synchronizationContext, MatchMaker matchMaker, in MessageHandler4 messageHandler, ObserverFlags flags) + { + _synchronizationContext = synchronizationContext; + _matchMaker = matchMaker; + _messageHandler = messageHandler; + _flags = flags; + } + + public void Dispose() => + Dispose(EmitOnObserverDispose ? ObserverDisposedException : null); + + public void Dispose(Exception? exception, bool removeObserver = true) + { + lock (_gate) + { + if (_disposed) + { + return; + } + _disposed = true; + } + + if (exception is not null) + { + Emit(exception); + } + + if (removeObserver) + { + _matchMaker.Connection.RemoveObserver(_matchMaker, this); + } + } + + public void Emit(Message message) + { + if (_synchronizationContext is null) + { + InvokeHandler(message); + } + else + { + _matchMaker.Connection.EmitOnSynchronizationContextHelper(this, _synchronizationContext, message); + } + } + + private void Emit(Exception exception) + { + if (_synchronizationContext is null || + SynchronizationContext.Current == _synchronizationContext) + { + _messageHandler.Invoke(exception, null!); + } + else + { + _synchronizationContext.Send( + delegate + { + SynchronizationContext? previousContext = SynchronizationContext.Current; + try + { + SynchronizationContext.SetSynchronizationContext(_synchronizationContext); + _messageHandler.Invoke(exception, null!); + } + finally + { + SynchronizationContext.SetSynchronizationContext(previousContext); + } + }, null); + } + } + + internal void InvokeHandler(Message message) + { + if (Subscribes && !_matchMaker.HasSubscribed) + { + return; + } + + lock (_gate) + { + if (_disposed) + { + return; + } + + _messageHandler.Invoke(null, message); + } + } + } + + private async void RemoveObserver(MatchMaker matchMaker, Observer observer) + { + string ruleString = matchMaker.RuleString; + bool sendMessage = false; + + lock (_gate) + { + if (_state == ConnectionState.Disconnected) + { + return; + } + + if (_matchMakers.TryGetValue(ruleString, out _)) + { + matchMaker.Observers.Remove(observer); + sendMessage = matchMaker.AddMatchTcs is not null && matchMaker.HasSubscribers; + if (sendMessage) + { + _matchMakers.Remove(ruleString); + } + } + } + + if (sendMessage) + { + var message = CreateRemoveMatchMessage(); + if (!await _messageStream!.TrySendMessageAsync(message).ConfigureAwait(false)) + { + message.ReturnToPool(); + } + } + + MessageBuffer CreateRemoveMatchMessage() + { + using var writer = GetMessageWriter(); + + writer.WriteMethodCallHeader( + destination: "org.freedesktop.DBus", + path: "/org/freedesktop/DBus", + @interface: "org.freedesktop.DBus", + member: "RemoveMatch", + signature: "s", + flags: MessageFlags.NoReplyExpected); + + writer.WriteString(ruleString); + + return writer.CreateMessage(); + } + } + + sealed class MatchMaker + { + private readonly MessageType? _type; + private readonly byte[]? _sender; + private readonly byte[]? _interface; + private readonly byte[]? _member; + private readonly byte[]? _path; + private readonly byte[]? _pathNamespace; + private readonly byte[]? _destination; + private readonly byte[]? _arg0; + private readonly byte[]? _arg0Path; + private readonly byte[]? _arg0Namespace; + private readonly string _rule; + + private MyValueTaskSource? _vts; + + public List Observers { get; } = new(); + + public MyValueTaskSource? AddMatchTcs + { + get => _vts; + set + { + _vts = value; + if (value != null) + { + AddMatchTask = new ValueTask(value, token: 0).AsTask(); + } + } + } + + public Task? AddMatchTask { get; private set; } + + public bool HasSubscribed { get; set; } + + public DBusConnection Connection { get; } + + public MatchMaker(DBusConnection connection, string rule, in MatchRuleData data) + { + Connection = connection; + _rule = rule; + + _type = data.MessageType; + + if (data.Sender is not null && data.Sender.StartsWith(":")) + { + _sender = Encoding.UTF8.GetBytes(data.Sender); + } + if (data.Interface is not null) + { + _interface = Encoding.UTF8.GetBytes(data.Interface); + } + if (data.Member is not null) + { + _member = Encoding.UTF8.GetBytes(data.Member); + } + if (data.Path is not null) + { + _path = Encoding.UTF8.GetBytes(data.Path); + } + if (data.PathNamespace is not null) + { + _pathNamespace = Encoding.UTF8.GetBytes(data.PathNamespace); + } + if (data.Destination is not null) + { + _destination = Encoding.UTF8.GetBytes(data.Destination); + } + if (data.Arg0 is not null) + { + _arg0 = Encoding.UTF8.GetBytes(data.Arg0); + } + if (data.Arg0Path is not null) + { + _arg0Path = Encoding.UTF8.GetBytes(data.Arg0Path); + } + if (data.Arg0Namespace is not null) + { + _arg0Namespace = Encoding.UTF8.GetBytes(data.Arg0Namespace); + } + } + + public string RuleString => _rule; + + public bool HasSubscribers + { + get + { + if (Observers.Count == 0) + { + return false; + } + foreach (var observer in Observers) + { + if (observer.Subscribes) + { + return true; + } + } + return false; + } + } + + public override string ToString() => _rule; + + internal bool Matches(Message message) + { + if (_type.HasValue && _type != message.MessageType) + { + return false; + } + + if (_sender is not null && !IsEqual(_sender, message.Sender)) + { + return false; + } + + if (_interface is not null && !IsEqual(_interface, message.Interface)) + { + return false; + } + + if (_member is not null && !IsEqual(_member, message.Member)) + { + return false; + } + + if (_path is not null && !IsEqual(_path, message.Path)) + { + return false; + } + + if (_destination is not null && !IsEqual(_destination, message.Destination)) + { + return false; + } + + if (_pathNamespace is not null && (!message.PathIsSet || !IsEqualOrChildOfPath(message.Path, _pathNamespace))) + { + return false; + } + + if (_arg0Namespace is not null || + _arg0 is not null || + _arg0Path is not null) + { + if (message.Signature.Length == 0) + { + return false; + } + + DBusType arg0Type = (DBusType)message.Signature![0]; + + if (arg0Type != DBusType.String && + arg0Type != DBusType.ObjectPath) + { + return false; + } + + ReadOnlySpan arg0 = message.GetBodyReader().ReadStringAsSpan(); + + if (_arg0Path is not null && !IsEqualParentOrChildOfPath(arg0, _arg0Path)) + { + return false; + } + + if (arg0Type != DBusType.String) + { + return false; + } + + if (_arg0 is not null && !IsEqual(_arg0, arg0)) + { + return false; + } + + if (_arg0Namespace is not null && !IsEqualOrChildOfName(arg0, _arg0Namespace)) + { + return false; + } + } + + return true; + } + + private static bool IsEqualOrChildOfName(ReadOnlySpan lhs, ReadOnlySpan rhs) + { + return lhs.StartsWith(rhs) && (lhs.Length == rhs.Length || lhs[rhs.Length] == '.'); + } + + private static bool IsEqualOrChildOfPath(ReadOnlySpan lhs, ReadOnlySpan rhs) + { + return lhs.StartsWith(rhs) && (lhs.Length == rhs.Length || lhs[rhs.Length] == '/'); + } + + private static bool IsEqualParentOrChildOfPath(ReadOnlySpan lhs, ReadOnlySpan rhs) + { + if (rhs.Length < lhs.Length) + { + return rhs[rhs.Length - 1] == '/' && lhs.StartsWith(rhs); + } + else if (lhs.Length < rhs.Length) + { + return lhs[lhs.Length - 1] == '/' && rhs.StartsWith(lhs); + } + else + { + return IsEqual(lhs, rhs); + } + } + + private static bool IsEqual(ReadOnlySpan lhs, ReadOnlySpan rhs) + { + return lhs.SequenceEqual(rhs); + } + } + + public MessageWriter GetMessageWriter() => _parentConnection.GetMessageWriter(); + + public async void SendMessage(MessageBuffer message) + { + bool messageSent = await _messageStream!.TrySendMessageAsync(message).ConfigureAwait(false); + if (!messageSent) + { + message.ReturnToPool(); + } + } + + public Task DisconnectedAsync() + { + lock (_gate) + { + if (_disconnectedTcs is null) + { + if (_state == ConnectionState.Disconnected) + { + return Task.FromResult(GetWaitForDisconnectException()); + } + else + { + _disconnectedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + return _disconnectedTcs.Task; + } + } + + private Exception? GetWaitForDisconnectException() + => _disconnectReason is ObjectDisposedException ? null : _disconnectReason; + + private void SendErrorReplyMessage(Message methodCall, string errorName, string errorMsg) + { + SendMessage(CreateErrorMessage(methodCall, errorName, errorMsg)); + + MessageBuffer CreateErrorMessage(Message methodCall, string errorName, string errorMsg) + { + using var writer = GetMessageWriter(); + + writer.WriteError( + replySerial: methodCall.Serial, + destination: methodCall.Sender, + errorName: errorName, + errorMsg: errorMsg); + + return writer.CreateMessage(); + } + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/DBusEnvironment.cs b/src/Linux/Tmds.DBus.Protocol/DBusEnvironment.cs new file mode 100644 index 0000000000..122bcb21e0 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/DBusEnvironment.cs @@ -0,0 +1,50 @@ +namespace Tmds.DBus.Protocol; + +static class DBusEnvironment +{ + public static string? UserId + { + get + { + if (PlatformDetection.IsWindows()) + { +#if NET6_0_OR_GREATER + return System.Security.Principal.WindowsIdentity.GetCurrent().User?.Value; +#else + throw new NotSupportedException("Cannot determine Windows UserId. You must manually assign it."); +#endif + } + else + { + return geteuid().ToString(); + } + } + } + + private static string? _machineId; + + public static string MachineId + { + get + { + if (_machineId == null) + { + const string MachineUuidPath = @"/var/lib/dbus/machine-id"; + + if (File.Exists(MachineUuidPath)) + { + _machineId = Guid.Parse(File.ReadAllText(MachineUuidPath).Substring(0, 32)).ToString("N"); + } + else + { + _machineId = Guid.Empty.ToString("N"); + } + } + + return _machineId; + } + } + + [DllImport("libc")] + internal static extern uint geteuid(); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/DBusException.cs b/src/Linux/Tmds.DBus.Protocol/DBusException.cs new file mode 100644 index 0000000000..a6dceb4345 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/DBusException.cs @@ -0,0 +1,15 @@ +namespace Tmds.DBus.Protocol; + +public class DBusException : Exception +{ + public DBusException(string errorName, string errorMessage) : + base($"{errorName}: {errorMessage}") + { + ErrorName = errorName; + ErrorMessage = errorMessage; + } + + public string ErrorName { get; } + + public string ErrorMessage { get; } +} diff --git a/src/Linux/Tmds.DBus.Protocol/DBusType.cs b/src/Linux/Tmds.DBus.Protocol/DBusType.cs new file mode 100644 index 0000000000..6527edf838 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/DBusType.cs @@ -0,0 +1,23 @@ +namespace Tmds.DBus.Protocol; + +public enum DBusType : byte +{ + Invalid = 0, + Byte = (byte)'y', + Bool = (byte)'b', + Int16 = (byte)'n', + UInt16 = (byte)'q', + Int32 = (byte)'i', + UInt32 = (byte)'u', + Int64 = (byte)'x', + UInt64 = (byte)'t', + Double = (byte)'d', + String = (byte)'s', + ObjectPath = (byte)'o', + Signature = (byte)'g', + Array = (byte)'a', + Struct = (byte)'(', + Variant = (byte)'v', + DictEntry = (byte)'{', + UnixFd = (byte)'h', +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Dict.cs b/src/Linux/Tmds.DBus.Protocol/Dict.cs new file mode 100644 index 0000000000..3d6472ceae --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Dict.cs @@ -0,0 +1,89 @@ +using System.Collections; + +namespace Tmds.DBus.Protocol; + +// Using obsolete generic write members +#pragma warning disable CS0618 + +public sealed class Dict : IDBusWritable, IDictionary + where TKey : notnull + where TValue : notnull +{ + private readonly Dictionary _dict; + + public Dict() : + this(new Dictionary()) + { } + + public Dict(IDictionary dictionary) : + this(new Dictionary(dictionary)) + { } + + private Dict(Dictionary value) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + _dict = value; + } + + public Variant AsVariant() => Variant.FromDict(this); + + public static implicit operator Variant(Dict value) + => value.AsVariant(); + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteDictionary(_dict); + + + ICollection IDictionary.Keys => _dict.Keys; + + ICollection IDictionary.Values => _dict.Values; + + public int Count => _dict.Count; + + public TValue this[TKey key] + { + get => _dict[key]; + set => _dict[key] = value; + } + + public void Add(TKey key, TValue value) + => _dict.Add(key, value); + + public bool ContainsKey(TKey key) + => _dict.ContainsKey(key); + + public bool Remove(TKey key) + => _dict.Remove(key); + + public bool TryGetValue(TKey key, +#if NET + [MaybeNullWhen(false)] +#endif + out TValue value) + => _dict.TryGetValue(key, out value); + + public void Clear() + => _dict.Clear(); + + void ICollection>.Add(KeyValuePair item) + => ((ICollection>)_dict).Add(item); + + bool ICollection>.Contains(KeyValuePair item) + => ((ICollection>)_dict).Contains(item); + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + => ((ICollection>)_dict).CopyTo(array, arrayIndex); + + bool ICollection>.Remove(KeyValuePair item) + => ((ICollection>)_dict).Remove(item); + + bool ICollection>.IsReadOnly => false; + + IEnumerator> IEnumerable>.GetEnumerator() + => _dict.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + => _dict.GetEnumerator(); +} diff --git a/src/Linux/Tmds.DBus.Protocol/DisconnectedException.cs b/src/Linux/Tmds.DBus.Protocol/DisconnectedException.cs new file mode 100644 index 0000000000..0f51c66db2 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/DisconnectedException.cs @@ -0,0 +1,6 @@ +namespace Tmds.DBus.Protocol; +public class DisconnectedException : Exception +{ + internal DisconnectedException(Exception innerException) : base(innerException.Message, innerException) + { } +} diff --git a/src/Linux/Tmds.DBus.Protocol/DisposableMessage.cs b/src/Linux/Tmds.DBus.Protocol/DisposableMessage.cs new file mode 100644 index 0000000000..86c6a346ca --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/DisposableMessage.cs @@ -0,0 +1,18 @@ +namespace Tmds.DBus.Protocol; + +public struct DisposableMessage : IDisposable +{ + private Message? _message; + + internal DisposableMessage(Message? message) + => _message = message; + + public Message Message + => _message ?? throw new ObjectDisposedException(typeof(Message).FullName); + + public void Dispose() + { + _message?.ReturnToPool(); + _message = null; + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Feature.cs b/src/Linux/Tmds.DBus.Protocol/Feature.cs new file mode 100644 index 0000000000..00597dc6a3 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Feature.cs @@ -0,0 +1,14 @@ +namespace Tmds.DBus.Protocol; + +static class Feature +{ + public static bool IsDynamicCodeEnabled => +#if NETSTANDARD2_0 + true +#else + System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported +#endif + && EnableDynamicCode; + + private static readonly bool EnableDynamicCode = Environment.GetEnvironmentVariable("TMDS_DBUS_PROTOCOL_DYNAMIC_CODE") != "0"; +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/GlobalUsings.cs b/src/Linux/Tmds.DBus.Protocol/GlobalUsings.cs new file mode 100644 index 0000000000..7df10d47b6 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/GlobalUsings.cs @@ -0,0 +1,14 @@ +global using System; +global using System.Buffers; +global using System.Buffers.Binary; +global using System.Collections.Generic; +global using System.Diagnostics; +global using System.Diagnostics.CodeAnalysis; +global using System.IO; +global using System.Linq; +global using System.Runtime.CompilerServices; +global using System.Runtime.InteropServices; +global using System.Text; +global using System.Threading; +global using System.Threading.Tasks; +global using Nerdbank.Streams; \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/IDBusWritable.cs b/src/Linux/Tmds.DBus.Protocol/IDBusWritable.cs new file mode 100644 index 0000000000..14f2f6f0ef --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/IDBusWritable.cs @@ -0,0 +1,6 @@ +namespace Tmds.DBus.Protocol; + +public interface IDBusWritable +{ + void WriteTo(ref MessageWriter writer); +} diff --git a/src/Linux/Tmds.DBus.Protocol/IMessageStream.cs b/src/Linux/Tmds.DBus.Protocol/IMessageStream.cs new file mode 100644 index 0000000000..5337180eda --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/IMessageStream.cs @@ -0,0 +1,14 @@ +namespace Tmds.DBus.Protocol; + +interface IMessageStream +{ + public delegate void MessageReceivedHandler(Exception? closeReason, Message message, T state); + + void ReceiveMessages(MessageReceivedHandler handler, T state); + + ValueTask TrySendMessageAsync(MessageBuffer message); + + void BecomeMonitor(); + + void Close(Exception closeReason); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/IMethodHandler.cs b/src/Linux/Tmds.DBus.Protocol/IMethodHandler.cs new file mode 100644 index 0000000000..5d45a986d2 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/IMethodHandler.cs @@ -0,0 +1,13 @@ +namespace Tmds.DBus.Protocol; + +public interface IMethodHandler +{ + // Path that is handled by this method handler. + string Path { get; } + + // The message argument is only valid during the call. It must not be stored to extend its lifetime. + ValueTask HandleMethodAsync(MethodContext context); + + // Controls whether to wait for the handler method to finish executing before reading more messages. + bool RunMethodHandlerSynchronously(Message message); +} diff --git a/src/Linux/Tmds.DBus.Protocol/IMethodHandlerDictionary.cs b/src/Linux/Tmds.DBus.Protocol/IMethodHandlerDictionary.cs new file mode 100644 index 0000000000..00bf17e4b8 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/IMethodHandlerDictionary.cs @@ -0,0 +1,9 @@ +namespace Tmds.DBus.Protocol; + +interface IMethodHandlerDictionary +{ + void AddMethodHandlers(IReadOnlyList methodHandlers); + void AddMethodHandler(IMethodHandler methodHandler); + void RemoveMethodHandler(string path); + void RemoveMethodHandlers(IEnumerable paths); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/IntrospectionXml.cs b/src/Linux/Tmds.DBus.Protocol/IntrospectionXml.cs new file mode 100644 index 0000000000..6256963d21 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/IntrospectionXml.cs @@ -0,0 +1,51 @@ +namespace Tmds.DBus.Protocol; + +public static class IntrospectionXml +{ + public static ReadOnlyMemory DBusProperties { get; } = + """ + + + + + + + + + + + + + + + + + + + + + + + """u8.ToArray(); + + public static ReadOnlyMemory DBusIntrospectable { get; } = + """ + + + + + + + """u8.ToArray(); + + public static ReadOnlyMemory DBusPeer { get; } = + """ + + + + + + + + """u8.ToArray(); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/MatchRule.cs b/src/Linux/Tmds.DBus.Protocol/MatchRule.cs new file mode 100644 index 0000000000..ce21c9a941 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MatchRule.cs @@ -0,0 +1,120 @@ +namespace Tmds.DBus.Protocol; + +struct MatchRuleData +{ + public MessageType? MessageType { get; set; } + + public string? Sender { get; set; } + + public string? Interface { get; set; } + + public string? Member { get; set; } + + public string? Path { get; set; } + + public string? PathNamespace { get; set; } + + public string? Destination { get; set; } + + public string? Arg0 { get; set; } + + public string? Arg0Path { get; set; } + + public string? Arg0Namespace { get; set; } + + public string GetRuleString() + { + var sb = new StringBuilder(); // TODO (perf): pool + + if (MessageType.HasValue) + { + string? typeMatch = MessageType switch + { + Protocol.MessageType.MethodCall => "type=method_call", + Protocol.MessageType.MethodReturn => "type=method_return", + Protocol.MessageType.Error => "type=error", + Protocol.MessageType.Signal => "type=signal", + _ => null + }; + + if (typeMatch is not null) + { + sb.Append(typeMatch); + } + } + + Append(sb, "sender", Sender); + Append(sb, "interface", Interface); + Append(sb, "member", Member); + Append(sb, "path", Path); + Append(sb, "pathNamespace", PathNamespace); + Append(sb, "destination", Destination); + Append(sb, "arg0", Arg0); + Append(sb, "arg0Path", Arg0Path); + Append(sb, "arg0Namespace", Arg0Namespace); + + return sb.ToString(); + + static void Append(StringBuilder sb, string key, string? value) + { + if (value is null) + { + return; + } + + sb.Append($"{(sb.Length > 0 ? ',' : "")}{key}="); + + bool quoting = false; + + ReadOnlySpan span = value.AsSpan(); + while (!span.IsEmpty) + { + int specialPos = span.IndexOfAny((ReadOnlySpan)new char[] { ',', '\'' }); + if (specialPos == -1) + { + sb.Append(span); + break; + } + bool isComma = span[specialPos] == ','; + if (isComma && !quoting || + !isComma && quoting) + { + sb.Append("'"); + quoting = !quoting; + } + sb.Append(span.Slice(0, specialPos + (isComma ? 1 : 0))); + if (!isComma) + { + sb.Append("\\'"); + } + span = span.Slice(specialPos + 1); + } + + if (quoting) + { + sb.Append("'"); + quoting = false; + } + } + } +} + +public sealed class MatchRule +{ + private MatchRuleData _data; + + internal MatchRuleData Data => _data; + + public MessageType? Type { get => _data.MessageType; set => _data.MessageType = value; } + public string? Sender { get => _data.Sender; set => _data.Sender = value; } + public string? Interface { get => _data.Interface; set => _data.Interface = value; } + public string? Member { get => _data.Member; set => _data.Member = value; } + public string? Path { get => _data.Path; set => _data.Path = value; } + public string? PathNamespace { get => _data.PathNamespace; set => _data.PathNamespace = value; } + public string? Destination { get => _data.Destination; set => _data.Destination = value; } + public string? Arg0 { get => _data.Arg0; set => _data.Arg0 = value; } + public string? Arg0Path { get => _data.Arg0Path; set => _data.Arg0Path = value; } + public string? Arg0Namespace { get => _data.Arg0Namespace; set => _data.Arg0Namespace = value; } + + public override string ToString() => _data.GetRuleString(); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Message.cs b/src/Linux/Tmds.DBus.Protocol/Message.cs new file mode 100644 index 0000000000..c9c9f2fd2c --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Message.cs @@ -0,0 +1,242 @@ +namespace Tmds.DBus.Protocol; + +public sealed class Message +{ + private const int HeaderFieldsLengthOffset = 12; + + private readonly MessagePool _pool; + private readonly Sequence _data; + + private UnixFdCollection? _handles; + private ReadOnlySequence _body; + + public bool IsBigEndian { get; private set; } + public uint Serial { get; private set; } + public MessageFlags MessageFlags { get; private set; } + public MessageType MessageType { get; private set; } + + public uint? ReplySerial { get; private set; } + public int UnixFdCount { get; private set; } + + private HeaderBuffer _path; + private HeaderBuffer _interface; + private HeaderBuffer _member; + private HeaderBuffer _errorName; + private HeaderBuffer _destination; + private HeaderBuffer _sender; + private HeaderBuffer _signature; + + public string? PathAsString => _path.ToString(); + public string? InterfaceAsString => _interface.ToString(); + public string? MemberAsString => _member.ToString(); + public string? ErrorNameAsString => _errorName.ToString(); + public string? DestinationAsString => _destination.ToString(); + public string? SenderAsString => _sender.ToString(); + public string? SignatureAsString => _signature.ToString(); + + public ReadOnlySpan Path => _path.Span; + public ReadOnlySpan Interface => _interface.Span; + public ReadOnlySpan Member => _member.Span; + public ReadOnlySpan ErrorName => _errorName.Span; + public ReadOnlySpan Destination => _destination.Span; + public ReadOnlySpan Sender => _sender.Span; + public ReadOnlySpan Signature => _signature.Span; + + public bool PathIsSet => _path.IsSet; + public bool InterfaceIsSet => _interface.IsSet; + public bool MemberIsSet => _member.IsSet; + public bool ErrorNameIsSet => _errorName.IsSet; + public bool DestinationIsSet => _destination.IsSet; + public bool SenderIsSet => _sender.IsSet; + public bool SignatureIsSet => _signature.IsSet; + + struct HeaderBuffer + { + private byte[] _buffer; + private int _length; + private string? _string; + + public Span Span => new Span(_buffer, 0, Math.Max(_length, 0)); + + public void Set(ReadOnlySpan data) + { + _string = null; + if (_buffer is null || data.Length > _buffer.Length) + { + _buffer = new byte[data.Length]; + } + data.CopyTo(_buffer); + _length = data.Length; + } + + public void Clear() + { + _length = -1; + _string = null; + } + + public override string? ToString() + { + return _length == -1 ? null : _string ??= Encoding.UTF8.GetString(Span); + } + + public bool IsSet => _length != -1; + } + + public Reader GetBodyReader() => new Reader(IsBigEndian, _body, _handles, UnixFdCount); + + internal Message(MessagePool messagePool, Sequence sequence) + { + _pool = messagePool; + _data = sequence; + ClearHeaders(); + } + + internal void ReturnToPool() + { + _data.Reset(); + ClearHeaders(); + _handles?.DisposeHandles(); + _pool.Return(this); + } + + private void ClearHeaders() + { + ReplySerial = null; + UnixFdCount = 0; + + _path.Clear(); + _interface.Clear(); + _member.Clear(); + _errorName.Clear(); + _destination.Clear(); + _sender.Clear(); + _signature.Clear(); + } + + internal static Message? TryReadMessage(MessagePool messagePool, ref ReadOnlySequence sequence, UnixFdCollection? handles = null, bool isMonitor = false) + { + SequenceReader seqReader = new(sequence); + if (!seqReader.TryRead(out byte endianness) || + !seqReader.TryRead(out byte msgType) || + !seqReader.TryRead(out byte flags) || + !seqReader.TryRead(out byte version)) + { + return null; + } + + if (version != 1) + { + throw new NotSupportedException(); + } + + bool isBigEndian = endianness == 'B'; + + if (!TryReadUInt32(ref seqReader, isBigEndian, out uint bodyLength) || + !TryReadUInt32(ref seqReader, isBigEndian, out uint serial) || + !TryReadUInt32(ref seqReader, isBigEndian, out uint headerFieldLength)) + { + return null; + } + + headerFieldLength = (uint)ProtocolConstants.Align((int)headerFieldLength, DBusType.Struct); + + long totalLength = seqReader.Consumed + headerFieldLength + bodyLength; + + if (sequence.Length < totalLength) + { + return null; + } + + // Copy data so it has a lifetime independent of the source sequence. + var message = messagePool.Rent(); + Sequence dst = message._data; + do + { + ReadOnlySpan srcSpan = sequence.First.Span; + int length = (int)Math.Min(totalLength, srcSpan.Length); + Span dstSpan = dst.GetSpan(0); + length = Math.Min(length, dstSpan.Length); + srcSpan.Slice(0, length).CopyTo(dstSpan); + dst.Advance(length); + sequence = sequence.Slice(length); + totalLength -= length; + } while (totalLength > 0); + + message.IsBigEndian = isBigEndian; + message.Serial = serial; + message.MessageType = (MessageType)msgType; + message.MessageFlags = (MessageFlags)flags; + message.ParseHeader(handles, isMonitor); + + return message; + + static bool TryReadUInt32(ref SequenceReader seqReader, bool isBigEndian, out uint value) + { + int v; + bool rv = (isBigEndian && seqReader.TryReadBigEndian(out v) || seqReader.TryReadLittleEndian(out v)); + value = (uint)v; + return rv; + } + } + + private void ParseHeader(UnixFdCollection? handles, bool isMonitor) + { + var reader = new Reader(IsBigEndian, _data.AsReadOnlySequence); + reader.Advance(HeaderFieldsLengthOffset); + + ArrayEnd headersEnd = reader.ReadArrayStart(DBusType.Struct); + while (reader.HasNext(headersEnd)) + { + MessageHeader hdrType = (MessageHeader)reader.ReadByte(); + ReadOnlySpan sig = reader.ReadSignature(); + switch (hdrType) + { + case MessageHeader.Path: + _path.Set(reader.ReadObjectPathAsSpan()); + break; + case MessageHeader.Interface: + _interface.Set(reader.ReadStringAsSpan()); + break; + case MessageHeader.Member: + _member.Set(reader.ReadStringAsSpan()); + break; + case MessageHeader.ErrorName: + _errorName.Set(reader.ReadStringAsSpan()); + break; + case MessageHeader.ReplySerial: + ReplySerial = reader.ReadUInt32(); + break; + case MessageHeader.Destination: + _destination.Set(reader.ReadStringAsSpan()); + break; + case MessageHeader.Sender: + _sender.Set(reader.ReadStringAsSpan()); + break; + case MessageHeader.Signature: + _signature.Set(reader.ReadSignature()); + break; + case MessageHeader.UnixFds: + UnixFdCount = (int)reader.ReadUInt32(); + if (UnixFdCount > 0 && !isMonitor) + { + if (handles is null || UnixFdCount > handles.Count) + { + throw new ProtocolException("Received less handles than UNIX_FDS."); + } + if (_handles is null) + { + _handles = new(handles.IsRawHandleCollection); + } + handles.MoveTo(_handles, UnixFdCount); + } + break; + default: + throw new NotSupportedException(); + } + } + reader.AlignStruct(); + + _body = reader.UnreadSequence; + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/MessageBuffer.cs b/src/Linux/Tmds.DBus.Protocol/MessageBuffer.cs new file mode 100644 index 0000000000..dc8d5a755e --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageBuffer.cs @@ -0,0 +1,43 @@ +namespace Tmds.DBus.Protocol; + +public sealed class MessageBuffer : IDisposable +{ + private readonly MessageBufferPool _messagePool; + + private readonly Sequence _data; + + internal uint Serial { get; private set; } + + internal MessageFlags MessageFlags { get; private set; } + + internal UnixFdCollection? Handles { get; private set; } + + internal MessageBuffer(MessageBufferPool messagePool, Sequence sequence) + { + _messagePool = messagePool; + _data = sequence; + } + + internal void Init(uint serial, MessageFlags flags, UnixFdCollection? handles) + { + Serial = serial; + MessageFlags = flags; + Handles = handles; + } + + public void Dispose() => ReturnToPool(); + + internal void ReturnToPool() + { + _data.Reset(); + Handles?.DisposeHandles(); + Handles = null; + _messagePool.Return(this); + } + + // For writing data. + internal Sequence Sequence => _data; + + // For reading data. + internal ReadOnlySequence AsReadOnlySequence() => _data.AsReadOnlySequence; +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/MessageBufferPool.cs b/src/Linux/Tmds.DBus.Protocol/MessageBufferPool.cs new file mode 100644 index 0000000000..fd80f977d2 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageBufferPool.cs @@ -0,0 +1,42 @@ +namespace Tmds.DBus.Protocol; + +class MessageBufferPool +{ + private const int MinimumSpanLength = 512; + + public static readonly MessageBufferPool Shared = new MessageBufferPool(Environment.ProcessorCount * 2); + + private readonly int _maxSize; + private readonly Stack _pool = new Stack(); + + internal MessageBufferPool(int maxSize) + { + _maxSize = maxSize; + } + + public MessageBuffer Rent() + { + lock (_pool) + { + if (_pool.Count > 0) + { + return _pool.Pop(); + } + } + + var sequence = new Sequence(ArrayPool.Shared) { MinimumSpanLength = MinimumSpanLength }; + + return new MessageBuffer(this, sequence); + } + + internal void Return(MessageBuffer value) + { + lock (_pool) + { + if (_pool.Count < _maxSize) + { + _pool.Push(value); + } + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageFlags.cs b/src/Linux/Tmds.DBus.Protocol/MessageFlags.cs new file mode 100644 index 0000000000..6a597e1238 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageFlags.cs @@ -0,0 +1,10 @@ +namespace Tmds.DBus.Protocol; + +[Flags] +public enum MessageFlags : byte +{ + None = 0, + NoReplyExpected = 1, + NoAutoStart = 2, + AllowInteractiveAuthorization = 4 +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/MessageHeader.cs b/src/Linux/Tmds.DBus.Protocol/MessageHeader.cs new file mode 100644 index 0000000000..3a7b7b011f --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageHeader.cs @@ -0,0 +1,14 @@ +namespace Tmds.DBus.Protocol; + +internal enum MessageHeader : byte +{ + Path = 1, + Interface = 2, + Member = 3, + ErrorName = 4, + ReplySerial = 5, + Destination = 6, + Sender = 7, + Signature = 8, + UnixFds = 9 +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/MessagePool.cs b/src/Linux/Tmds.DBus.Protocol/MessagePool.cs new file mode 100644 index 0000000000..92e82741d9 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessagePool.cs @@ -0,0 +1,27 @@ +namespace Tmds.DBus.Protocol; + +class MessagePool +{ + private const int MinimumSpanLength = 512; + + private Message? _pooled = null; // Pool a single message. + + public Message Rent() + { + Message? rent = Interlocked.Exchange(ref _pooled, null); + + if (rent is not null) + { + return rent; + } + + var sequence = new Sequence(ArrayPool.Shared) { MinimumSpanLength = MinimumSpanLength }; + + return new Message(this, sequence); + } + + internal void Return(Message value) + { + Volatile.Write(ref _pooled, value); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageStream.cs b/src/Linux/Tmds.DBus.Protocol/MessageStream.cs new file mode 100644 index 0000000000..284f76c4e3 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageStream.cs @@ -0,0 +1,403 @@ +using System.IO.Pipelines; +using System.Net.Sockets; +using System.Threading.Channels; + +namespace Tmds.DBus.Protocol; + +#pragma warning disable VSTHRD100 // Avoid "async void" methods + +class MessageStream : IMessageStream +{ + private static readonly ReadOnlyMemory OneByteArray = new[] { (byte)0 }; + private readonly Socket _socket; + private UnixFdCollection? _fdCollection; + private bool _supportsFdPassing; + private readonly MessagePool _messagePool; + + // Messages going out. + private readonly ChannelReader _messageReader; + private readonly ChannelWriter _messageWriter; + + // Bytes coming in. + private readonly PipeWriter _pipeWriter; + private readonly PipeReader _pipeReader; + + private Exception? _completionException; + private bool _isMonitor; + + public MessageStream(Socket socket) + { + _socket = socket; + Channel channel = Channel.CreateUnbounded(new UnboundedChannelOptions + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false + }); + _messageReader = channel.Reader; + _messageWriter = channel.Writer; + var pipe = new Pipe(new PipeOptions(useSynchronizationContext: false)); + _pipeReader = pipe.Reader; + _pipeWriter = pipe.Writer; + _messagePool = new(); + } + + public void BecomeMonitor() + { + _isMonitor = true; + } + + private async void ReadFromSocketIntoPipe() + { + var writer = _pipeWriter; + Exception? exception = null; + try + { + while (true) + { + Memory memory = writer.GetMemory(1024); + int bytesRead = await _socket.ReceiveAsync(memory, _fdCollection).ConfigureAwait(false); + if (bytesRead == 0) + { + throw new IOException("Connection closed by peer"); + } + writer.Advance(bytesRead); + + await writer.FlushAsync().ConfigureAwait(false); + } + } + catch (Exception e) + { + exception = e; + } + writer.Complete(exception); + } + + private async void ReadMessagesIntoSocket() + { + while (true) + { + if (!await _messageReader.WaitToReadAsync().ConfigureAwait(false)) + { + // No more messages will be coming. + return; + } + var message = await _messageReader.ReadAsync().ConfigureAwait(false); + try + { + IReadOnlyList? handles = _supportsFdPassing ? message.Handles : null; + var buffer = message.AsReadOnlySequence(); + if (buffer.IsSingleSegment) + { + await _socket.SendAsync(buffer.First, handles).ConfigureAwait(false); + } + else + { + SequencePosition position = buffer.Start; + while (buffer.TryGet(ref position, out ReadOnlyMemory memory)) + { + await _socket.SendAsync(memory, handles).ConfigureAwait(false); + handles = null; + } + } + } + catch (Exception exception) + { + Close(exception); + return; + } + finally + { + message.ReturnToPool(); + } + } + } + + public async void ReceiveMessages(IMessageStream.MessageReceivedHandler handler, T state) + { + var reader = _pipeReader; + try + { + while (true) + { + ReadResult result = await reader.ReadAsync().ConfigureAwait(false); + ReadOnlySequence buffer = result.Buffer; + + ReadMessages(ref buffer, handler, state); + + reader.AdvanceTo(buffer.Start, buffer.End); + } + } + catch (Exception exception) + { + exception = CloseCore(exception); + OnException(exception, handler, state); + } + finally + { + _fdCollection?.Dispose(); + } + + void ReadMessages(ref ReadOnlySequence buffer, IMessageStream.MessageReceivedHandler handler, TState state) + { + Message? message; + while ((message = Message.TryReadMessage(_messagePool, ref buffer, _fdCollection, _isMonitor)) != null) + { + handler(closeReason: null, message, state); + } + } + + static void OnException(Exception exception, IMessageStream.MessageReceivedHandler handler, T state) + { + handler(exception, message: null!, state); + } + } + + private struct AuthenticationResult + { + public bool IsAuthenticated; + public bool SupportsFdPassing; + public Guid Guid; + } + + public async ValueTask DoClientAuthAsync(Guid guid, string? userId, bool supportsFdPassing) + { + ReadFromSocketIntoPipe(); + + // send 1 byte + await _socket.SendAsync(OneByteArray, SocketFlags.None).ConfigureAwait(false); + // auth + var authenticationResult = await SendAuthCommandsAsync(userId, supportsFdPassing).ConfigureAwait(false); + _supportsFdPassing = authenticationResult.SupportsFdPassing; + if (_supportsFdPassing) + { + _fdCollection = new(); + } + if (guid != Guid.Empty) + { + if (guid != authenticationResult.Guid) + { + throw new ConnectException("Authentication failure: Unexpected GUID"); + } + } + + ReadMessagesIntoSocket(); + } + + private async ValueTask SendAuthCommandsAsync(string? userId, bool supportsFdPassing) + { + AuthenticationResult result; + if (userId is not null) + { + string command = CreateAuthExternalCommand(userId); + + result = await SendAuthCommandAsync(command, supportsFdPassing).ConfigureAwait(false); + + if (result.IsAuthenticated) + { + return result; + } + } + + result = await SendAuthCommandAsync("AUTH ANONYMOUS\r\n", supportsFdPassing).ConfigureAwait(false); + if (result.IsAuthenticated) + { + return result; + } + + throw new ConnectException("Authentication failure"); + } + + private static string CreateAuthExternalCommand(string userId) + { + const string AuthExternal = "AUTH EXTERNAL "; + const string hexchars = "0123456789abcdef"; +#if NETSTANDARD2_0 + StringBuilder sb = new(); + sb.Append(AuthExternal); + for (int i = 0; i < userId.Length; i++) + { + byte b = (byte)userId[i]; + sb.Append(hexchars[(int)(b >> 4)]); + sb.Append(hexchars[(int)(b & 0xF)]); + } + sb.Append("\r\n"); + return sb.ToString(); +#else + return string.Create( + length: AuthExternal.Length + userId.Length * 2 + 2, userId, + static (Span span, string userId) => + { + AuthExternal.AsSpan().CopyTo(span); + span = span.Slice(AuthExternal.Length); + + for (int i = 0; i < userId.Length; i++) + { + byte b = (byte)userId[i]; + span[i * 2] = hexchars[(int)(b >> 4)]; + span[i * 2 + 1] = hexchars[(int)(b & 0xF)]; + } + span = span.Slice(userId.Length * 2); + + span[0] = '\r'; + span[1] = '\n'; + }); +#endif + } + + private async ValueTask SendAuthCommandAsync(string command, bool supportsFdPassing) + { + byte[] lineBuffer = ArrayPool.Shared.Rent(512); + try + { + AuthenticationResult result = default(AuthenticationResult); + await WriteAsync(command, lineBuffer).ConfigureAwait(false); + int lineLength = await ReadLineAsync(lineBuffer).ConfigureAwait(false); + + if (StartsWithAscii(lineBuffer, lineLength, "OK")) + { + result.IsAuthenticated = true; + result.Guid = ParseGuid(lineBuffer, lineLength); + + if (supportsFdPassing) + { + await WriteAsync("NEGOTIATE_UNIX_FD\r\n", lineBuffer).ConfigureAwait(false); + + lineLength = await ReadLineAsync(lineBuffer).ConfigureAwait(false); + + result.SupportsFdPassing = StartsWithAscii(lineBuffer, lineLength, "AGREE_UNIX_FD"); + } + + await WriteAsync("BEGIN\r\n", lineBuffer).ConfigureAwait(false); + return result; + } + else if (StartsWithAscii(lineBuffer, lineLength, "REJECTED")) + { + return result; + } + else + { + await WriteAsync("ERROR\r\n", lineBuffer).ConfigureAwait(false); + return result; + } + } + finally + { + ArrayPool.Shared.Return(lineBuffer); + } + + static bool StartsWithAscii(byte[] line, int length, string expected) + { + if (length < expected.Length) + { + return false; + } + for (int i = 0; i < expected.Length; i++) + { + if (line[i] != expected[i]) + { + return false; + } + } + return true; + } + + static Guid ParseGuid(byte[] line, int length) + { + Span span = new Span(line, 0, length); + int spaceIndex = span.IndexOf((byte)' '); + if (spaceIndex == -1) + { + return Guid.Empty; + } + span = span.Slice(spaceIndex + 1); + spaceIndex = span.IndexOf((byte)' '); + if (spaceIndex != -1) + { + span = span.Slice(0, spaceIndex); + } + Span charBuffer = stackalloc char[span.Length]; // TODO (low prio): check length + for (int i = 0; i < span.Length; i++) + { + // TODO (low prio): validate char + charBuffer[i] = (char)span[i]; + } +#if NETSTANDARD2_0 + return Guid.ParseExact(charBuffer.AsString(), "N"); +#else + return Guid.ParseExact(charBuffer, "N"); +#endif + } + } + + private async ValueTask WriteAsync(string message, Memory lineBuffer) + { + int length = Encoding.ASCII.GetBytes(message.AsSpan(), lineBuffer.Span); + lineBuffer = lineBuffer.Slice(0, length); + await _socket.SendAsync(lineBuffer, SocketFlags.None).ConfigureAwait(false); + } + + private async ValueTask ReadLineAsync(Memory lineBuffer) + { + var reader = _pipeReader; + while (true) + { + ReadResult result = await reader.ReadAsync().ConfigureAwait(false); + ReadOnlySequence buffer = result.Buffer; + + // TODO (low prio): check length. + + SequencePosition? position = buffer.PositionOf((byte)'\n'); + + if (!position.HasValue) + { + reader.AdvanceTo(buffer.Start, buffer.End); + continue; + } + + int length = CopyBuffer(buffer.Slice(0, position.Value), lineBuffer); + reader.AdvanceTo(buffer.GetPosition(1, position.Value)); + return length; + } + + int CopyBuffer(ReadOnlySequence src, Memory dst) + { + Span span = dst.Span; + src.CopyTo(span); + span = span.Slice(0, (int)src.Length); + if (!span.EndsWith((ReadOnlySpan)new byte[] { (byte)'\r' })) + { + throw new ProtocolException("Authentication messages from server must end with '\\r\\n'."); + } + if (span.Length == 1) + { + throw new ProtocolException("Received empty authentication message from server."); + } + return span.Length - 1; + } + } + + public async ValueTask TrySendMessageAsync(MessageBuffer message) + { + while (await _messageWriter.WaitToWriteAsync().ConfigureAwait(false)) + { + if (_messageWriter.TryWrite(message)) + return true; + } + + return false; + } + + public void Close(Exception closeReason) => CloseCore(closeReason); + + private Exception CloseCore(Exception closeReason) + { + Exception? previous = Interlocked.CompareExchange(ref _completionException, closeReason, null); + if (previous is null) + { + _socket?.Dispose(); + _messageWriter.Complete(); + } + return previous ?? closeReason; + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageType.cs b/src/Linux/Tmds.DBus.Protocol/MessageType.cs new file mode 100644 index 0000000000..ded2857270 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageType.cs @@ -0,0 +1,9 @@ +namespace Tmds.DBus.Protocol; + +public enum MessageType : byte +{ + MethodCall = 1, + MethodReturn = 2, + Error = 3, + Signal = 4 +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Array.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Array.cs new file mode 100644 index 0000000000..3739a25793 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Array.cs @@ -0,0 +1,271 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + public void WriteArray(byte[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(short[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(ushort[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(int[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(uint[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(long[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(ulong[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(double[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfNumeric(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(string[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfT(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(Signature[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfT(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(ObjectPath[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfT(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(Variant[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfT(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + public void WriteArray(SafeHandle[] value) + => WriteArray(value.AsSpan()); + + public void WriteArray(ReadOnlySpan value) + => WriteArrayOfT(value); + + public void WriteArray(IEnumerable value) + => WriteArrayOfT(value); + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteArray)] + [Obsolete(Strings.UseNonGenericWriteArrayObsolete)] + public void WriteArray(IEnumerable value) + where T : notnull + { + ArrayStart arrayStart = WriteArrayStart(TypeModel.GetTypeAlignment()); + foreach (var item in value) + { + Write(item); + } + WriteArrayEnd(arrayStart); + } + + internal void WriteArray(ReadOnlySpan value) + where T : notnull + { +#if NET || NETSTANDARD2_1_OR_GREATER + if (typeof(T) == typeof(byte)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(short)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(ushort)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(int)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(uint)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(long)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(ulong)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else if (typeof(T) == typeof(double)) + { + ReadOnlySpan span = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref MemoryMarshal.GetReference(value)), value.Length); + WriteArray(span); + } + else +#endif + { + WriteArrayOfT(value); + } + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteArray)] + [Obsolete(Strings.UseNonGenericWriteArrayObsolete)] + public void WriteArray(T[] value) + where T : notnull + { + if (typeof(T) == typeof(byte)) + { + WriteArray((byte[])(object)value); + } + else if (typeof(T) == typeof(short)) + { + WriteArray((short[])(object)value); + } + else if (typeof(T) == typeof(ushort)) + { + WriteArray((ushort[])(object)value); + } + else if (typeof(T) == typeof(int)) + { + WriteArray((int[])(object)value); + } + else if (typeof(T) == typeof(uint)) + { + WriteArray((uint[])(object)value); + } + else if (typeof(T) == typeof(long)) + { + WriteArray((long[])(object)value); + } + else if (typeof(T) == typeof(ulong)) + { + WriteArray((ulong[])(object)value); + } + else if (typeof(T) == typeof(double)) + { + WriteArray((double[])(object)value); + } + else + { + WriteArrayOfT(value.AsSpan()); + } + } + + private unsafe void WriteArrayOfNumeric(ReadOnlySpan value) where T : unmanaged + { + WriteInt32(value.Length * sizeof(T)); + if (sizeof(T) > 4) + { + WritePadding(sizeof(T)); + } + WriteRaw(MemoryMarshal.AsBytes(value)); + } + + private void WriteArrayOfT(ReadOnlySpan value) + where T : notnull + { + ArrayStart arrayStart = WriteArrayStart(TypeModel.GetTypeAlignment()); + foreach (var item in value) + { + Write(item); + } + WriteArrayEnd(arrayStart); + } + + private void WriteArrayOfT(IEnumerable value) + where T : notnull + { + if (value is T[] array) + { + WriteArrayOfT(array.AsSpan()); + return; + } + ArrayStart arrayStart = WriteArrayStart(TypeModel.GetTypeAlignment()); + foreach (var item in value) + { + Write(item); + } + WriteArrayEnd(arrayStart); + } + + private static void WriteArraySignature(ref MessageWriter writer) where T : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Basic.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Basic.cs new file mode 100644 index 0000000000..2a6598ae3f --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Basic.cs @@ -0,0 +1,234 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + private const int MaxSizeHint = 4096; + + public void WriteBool(bool value) => WriteUInt32(value ? 1u : 0u); + + public void WriteByte(byte value) => WritePrimitiveCore(value, DBusType.Byte); + + public void WriteInt16(short value) => WritePrimitiveCore(value, DBusType.Int16); + + public void WriteUInt16(ushort value) => WritePrimitiveCore(value, DBusType.UInt16); + + public void WriteInt32(int value) => WritePrimitiveCore(value, DBusType.Int32); + + public void WriteUInt32(uint value) => WritePrimitiveCore(value, DBusType.UInt32); + + public void WriteInt64(long value) => WritePrimitiveCore(value, DBusType.Int64); + + public void WriteUInt64(ulong value) => WritePrimitiveCore(value, DBusType.UInt64); + + public void WriteDouble(double value) => WritePrimitiveCore(value, DBusType.Double); + + public void WriteString(Utf8Span value) => WriteStringCore(value); + + public void WriteString(string value) => WriteStringCore(value); + + public void WriteSignature(Utf8Span value) + { + ReadOnlySpan span = value; + int length = span.Length; + WriteByte((byte)length); + var dst = GetSpan(length); + span.CopyTo(dst); + Advance(length); + WriteByte((byte)0); + } + + public void WriteSignature(string s) + { + Span lengthSpan = GetSpan(1); + Advance(1); + int bytesWritten = WriteRaw(s); + lengthSpan[0] = (byte)bytesWritten; + WriteByte(0); + } + + public void WriteObjectPath(Utf8Span value) => WriteStringCore(value); + + public void WriteObjectPath(string value) => WriteStringCore(value); + + public void WriteVariantBool(bool value) + { + WriteSignature(ProtocolConstants.BooleanSignature); + WriteBool(value); + } + + public void WriteVariantByte(byte value) + { + WriteSignature(ProtocolConstants.ByteSignature); + WriteByte(value); + } + + public void WriteVariantInt16(short value) + { + WriteSignature(ProtocolConstants.Int16Signature); + WriteInt16(value); + } + + public void WriteVariantUInt16(ushort value) + { + WriteSignature(ProtocolConstants.UInt16Signature); + WriteUInt16(value); + } + + public void WriteVariantInt32(int value) + { + WriteSignature(ProtocolConstants.Int32Signature); + WriteInt32(value); + } + + public void WriteVariantUInt32(uint value) + { + WriteSignature(ProtocolConstants.UInt32Signature); + WriteUInt32(value); + } + + public void WriteVariantInt64(long value) + { + WriteSignature(ProtocolConstants.Int64Signature); + WriteInt64(value); + } + + public void WriteVariantUInt64(ulong value) + { + WriteSignature(ProtocolConstants.UInt64Signature); + WriteUInt64(value); + } + + public void WriteVariantDouble(double value) + { + WriteSignature(ProtocolConstants.DoubleSignature); + WriteDouble(value); + } + + public void WriteVariantString(Utf8Span value) + { + WriteSignature(ProtocolConstants.StringSignature); + WriteString(value); + } + + public void WriteVariantSignature(Utf8Span value) + { + WriteSignature(ProtocolConstants.SignatureSignature); + WriteSignature(value); + } + + public void WriteVariantObjectPath(Utf8Span value) + { + WriteSignature(ProtocolConstants.ObjectPathSignature); + WriteObjectPath(value); + } + + public void WriteVariantString(string value) + { + WriteSignature(ProtocolConstants.StringSignature); + WriteString(value); + } + + public void WriteVariantSignature(string value) + { + WriteSignature(ProtocolConstants.SignatureSignature); + WriteSignature(value); + } + + public void WriteVariantObjectPath(string value) + { + WriteSignature(ProtocolConstants.ObjectPathSignature); + WriteObjectPath(value); + } + + private void WriteStringCore(ReadOnlySpan span) + { + int length = span.Length; + WriteUInt32((uint)length); + var dst = GetSpan(length); + span.CopyTo(dst); + Advance(length); + WriteByte((byte)0); + } + + private void WriteStringCore(string s) + { + WritePadding(DBusType.UInt32); + Span lengthSpan = GetSpan(4); + Advance(4); + int bytesWritten = WriteRaw(s); + Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(lengthSpan), (uint)bytesWritten); + WriteByte(0); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void WritePrimitiveCore(T value, DBusType type) + { + WritePadding(type); + int length = ProtocolConstants.GetFixedTypeLength(type); + var span = GetSpan(length); + Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(span), value); + Advance(length); + } + + private int WriteRaw(ReadOnlySpan data) + { + int totalLength = data.Length; + if (totalLength <= MaxSizeHint) + { + var dst = GetSpan(totalLength); + data.CopyTo(dst); + Advance(totalLength); + return totalLength; + } + else + { + while (!data.IsEmpty) + { + var dst = GetSpan(1); + int length = Math.Min(data.Length, dst.Length); + data.Slice(0, length).CopyTo(dst); + Advance(length); + data = data.Slice(length); + } + return totalLength; + } + } + + private int WriteRaw(string data) + { + const int MaxUtf8BytesPerChar = 3; + + if (data.Length <= MaxSizeHint / MaxUtf8BytesPerChar) + { + ReadOnlySpan chars = data.AsSpan(); + int byteCount = Encoding.UTF8.GetByteCount(chars); + var dst = GetSpan(byteCount); + byteCount = Encoding.UTF8.GetBytes(data.AsSpan(), dst); + Advance(byteCount); + return byteCount; + } + else + { + ReadOnlySpan chars = data.AsSpan(); + Encoder encoder = Encoding.UTF8.GetEncoder(); + int totalLength = 0; + do + { + Debug.Assert(!chars.IsEmpty); + + var dst = GetSpan(MaxUtf8BytesPerChar); + encoder.Convert(chars, dst, flush: true, out int charsUsed, out int bytesUsed, out bool completed); + + Advance(bytesUsed); + totalLength += bytesUsed; + + if (completed) + { + return totalLength; + } + + chars = chars.Slice(charsUsed); + } while (true); + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Dictionary.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Dictionary.cs new file mode 100644 index 0000000000..34ca420c95 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Dictionary.cs @@ -0,0 +1,128 @@ +namespace Tmds.DBus.Protocol; + +// Using obsolete generic write members +#pragma warning disable CS0618 + +public ref partial struct MessageWriter +{ + public ArrayStart WriteDictionaryStart() + => WriteArrayStart(DBusType.Struct); + + public void WriteDictionaryEnd(ArrayStart start) + => WriteArrayEnd(start); + + public void WriteDictionaryEntryStart() + => WriteStructureStart(); + + // Write method for the common 'a{sv}' type. + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // It's safe to call WriteDictionary with these types. + public void WriteDictionary(IEnumerable> value) + => WriteDictionary(value); + + // Write method for the common 'a{sv}' type. + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // It's safe to call WriteDictionary with these types. + public void WriteDictionary(KeyValuePair[] value) + => WriteDictionary(value); + + // Write method for the common 'a{sv}' type. + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // It's safe to call WriteDictionary with these types. + public void WriteDictionary(Dictionary value) + => WriteDictionary(value); + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteDictionary)] + [Obsolete(Strings.UseNonGenericWriteDictionaryObsolete)] + public void WriteDictionary(IEnumerable> value) + where TKey : notnull + where TValue : notnull + { + ArrayStart arrayStart = WriteDictionaryStart(); + foreach (var item in value) + { + WriteDictionaryEntryStart(); + Write(item.Key); + Write(item.Value); + } + WriteDictionaryEnd(arrayStart); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteDictionary)] + [Obsolete(Strings.UseNonGenericWriteDictionaryObsolete)] + public void WriteDictionary(KeyValuePair[] value) + where TKey : notnull + where TValue : notnull + { + ArrayStart arrayStart = WriteDictionaryStart(); + foreach (var item in value) + { + WriteDictionaryEntryStart(); + Write(item.Key); + Write(item.Value); + } + WriteDictionaryEnd(arrayStart); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteDictionary)] + [Obsolete(Strings.UseNonGenericWriteDictionaryObsolete)] + public void WriteDictionary(Dictionary value) + where TKey : notnull + where TValue : notnull + { + ArrayStart arrayStart = WriteDictionaryStart(); + foreach (var item in value) + { + WriteDictionaryEntryStart(); + Write(item.Key); + Write(item.Value); + } + WriteDictionaryEnd(arrayStart); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteVariantDictionary)] + [Obsolete(Strings.UseNonGenericWriteVariantDictionaryObsolete)] + public void WriteVariantDictionary(IEnumerable> value) + where TKey : notnull + where TValue : notnull + { + WriteDictionarySignature(ref this); + WriteDictionary(value); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteVariantDictionary)] + [Obsolete(Strings.UseNonGenericWriteVariantDictionaryObsolete)] + public void WriteVariantDictionary(KeyValuePair[] value) + where TKey : notnull + where TValue : notnull + { + WriteDictionarySignature(ref this); + WriteDictionary(value); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteVariantDictionary)] + [Obsolete(Strings.UseNonGenericWriteVariantDictionaryObsolete)] + public void WriteVariantDictionary(Dictionary value) + where TKey : notnull + where TValue : notnull + { + WriteDictionarySignature(ref this); + WriteDictionary(value); + } + + // This method writes a Dictionary without using generics at the 'cost' of boxing. + // private void WriteDictionary(IDictionary value) + // { + // ArrayStart arrayStart = WriteDictionaryStart(); + // foreach (System.Collections.DictionaryEntry de in dictionary) + // { + // WriteDictionaryEntryStart(); + // Write(de.Key, asVariant: keyType == typeof(object)); + // Write(de.Value, asVariant: valueType == typeof(object)); + // } + // WriteDictionaryEnd(ref arrayStart); + // } + + private static void WriteDictionarySignature(ref MessageWriter writer) where TKey : notnull where TValue : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Handle.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Handle.cs new file mode 100644 index 0000000000..a473feadc5 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Handle.cs @@ -0,0 +1,28 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + public void WriteHandle(SafeHandle value) + { + int idx = HandleCount; + AddHandle(value); + WriteInt32(idx); + } + + public void WriteVariantHandle(SafeHandle value) + { + WriteSignature(ProtocolConstants.UnixFdSignature); + WriteHandle(value); + } + + private int HandleCount => _handles?.Count ?? 0; + + private void AddHandle(SafeHandle handle) + { + if (_handles is null) + { + _handles = new(isRawHandleCollection: false); + } + _handles.AddHandle(handle); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Header.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Header.cs new file mode 100644 index 0000000000..037588c95f --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Header.cs @@ -0,0 +1,215 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + public void WriteMethodCallHeader( + string? destination = null, + string? path = null, + string? @interface = null, + string? member = null, + string? signature = null, + MessageFlags flags = MessageFlags.None) + { + ArrayStart start = WriteHeaderStart(MessageType.MethodCall, flags); + + // Path. + if (path is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Path); + WriteVariantObjectPath(path); + } + + // Interface. + if (@interface is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Interface); + WriteVariantString(@interface); + } + + // Member. + if (member is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Member); + WriteVariantString(member); + } + + // Destination. + if (destination is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Destination); + WriteVariantString(destination); + } + + // Signature. + if (signature is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Signature); + WriteVariantSignature(signature); + } + + WriteHeaderEnd(start); + } + + public void WriteMethodReturnHeader( + uint replySerial, + Utf8Span destination = default, + string? signature = null) + { + ArrayStart start = WriteHeaderStart(MessageType.MethodReturn, MessageFlags.None); + + // ReplySerial + WriteStructureStart(); + WriteByte((byte)MessageHeader.ReplySerial); + WriteVariantUInt32(replySerial); + + // Destination. + if (!destination.IsEmpty) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Destination); + WriteVariantString(destination); + } + + // Signature. + if (signature is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Signature); + WriteVariantSignature(signature); + } + + WriteHeaderEnd(start); + } + + public void WriteError( + uint replySerial, + ReadOnlySpan destination = default, + string? errorName = null, + string? errorMsg = null) + { + ArrayStart start = WriteHeaderStart(MessageType.Error, MessageFlags.None); + + // ReplySerial + WriteStructureStart(); + WriteByte((byte)MessageHeader.ReplySerial); + WriteVariantUInt32(replySerial); + + // Destination. + if (!destination.IsEmpty) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Destination); + WriteVariantString(destination); + } + + // Error. + if (errorName is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.ErrorName); + WriteVariantString(errorName); + } + + // Signature. + if (errorMsg is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Signature); + WriteVariantSignature(ProtocolConstants.StringSignature); + } + + WriteHeaderEnd(start); + + if (errorMsg is not null) + { + WriteString(errorMsg); + } + } + + public void WriteSignalHeader( + string? destination = null, + string? path = null, + string? @interface = null, + string? member = null, + string? signature = null) + { + ArrayStart start = WriteHeaderStart(MessageType.Signal, MessageFlags.None); + + // Path. + if (path is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Path); + WriteVariantObjectPath(path); + } + + // Interface. + if (@interface is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Interface); + WriteVariantString(@interface); + } + + // Member. + if (member is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Member); + WriteVariantString(member); + } + + // Destination. + if (destination is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Destination); + WriteVariantString(destination); + } + + // Signature. + if (signature is not null) + { + WriteStructureStart(); + WriteByte((byte)MessageHeader.Signature); + WriteVariantSignature(signature); + } + + WriteHeaderEnd(start); + } + + private void WriteHeaderEnd(ArrayStart start) + { + WriteArrayEnd(start); + WritePadding(DBusType.Struct); + } + + private ArrayStart WriteHeaderStart(MessageType type, MessageFlags flags) + { + _flags = flags; + + WriteByte(BitConverter.IsLittleEndian ? (byte)'l' : (byte)'B'); // endianness + WriteByte((byte)type); + WriteByte((byte)flags); + WriteByte((byte)1); // version + WriteUInt32((uint)0); // length placeholder + Debug.Assert(_offset == LengthOffset + 4); + WriteUInt32(_serial); + Debug.Assert(_offset == SerialOffset + 4); + + // headers + ArrayStart start = WriteArrayStart(DBusType.Struct); + + // UnixFds + WriteStructureStart(); + WriteByte((byte)MessageHeader.UnixFds); + WriteVariantUInt32(0); // unix fd length placeholder + Debug.Assert(_offset == UnixFdLengthOffset + 4); + return start; + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.IntrospectionXml.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.IntrospectionXml.cs new file mode 100644 index 0000000000..d343592e2f --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.IntrospectionXml.cs @@ -0,0 +1,75 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + private ReadOnlySpan IntrospectionHeader => + """ + + + + """u8; + + private ReadOnlySpan IntrospectionFooter => + """ + + + """u8; + + private ReadOnlySpan NodeNameStart => + """ + NodeNameEnd => + """ + "/> + + """u8; + + public void WriteIntrospectionXml(scoped ReadOnlySpan> interfaceXmls, IEnumerable childNames) + => WriteIntrospectionXml(interfaceXmls, baseInterfaceXmls: default, childNames: default, + childNamesEnumerable: childNames ?? throw new ArgumentNullException(nameof(childNames))); + + internal void WriteIntrospectionXml( + scoped ReadOnlySpan> interfaceXmls, + scoped ReadOnlySpan> baseInterfaceXmls, + scoped ReadOnlySpan childNames, + IEnumerable? childNamesEnumerable) + { + WritePadding(DBusType.UInt32); + Span lengthSpan = GetSpan(4); + Advance(4); + + int bytesWritten = 0; + bytesWritten += WriteRaw(IntrospectionHeader); + foreach (var xml in interfaceXmls) + { + bytesWritten += WriteRaw(xml.Span); + } + foreach (var xml in baseInterfaceXmls) + { + bytesWritten += WriteRaw(xml.Span); + } + // D-Bus names only consist of '[A-Z][a-z][0-9]_'. + // We don't need to any escaping for use as an XML attribute value. + foreach (var childName in childNames) + { + bytesWritten += WriteRaw(NodeNameStart); + bytesWritten += WriteRaw(childName); + bytesWritten += WriteRaw(NodeNameEnd); + } + if (childNamesEnumerable is not null) + { + foreach (var childName in childNamesEnumerable) + { + bytesWritten += WriteRaw(NodeNameStart); + bytesWritten += WriteRaw(childName); + bytesWritten += WriteRaw(NodeNameEnd); + } + } + bytesWritten += WriteRaw(IntrospectionFooter); + + Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(lengthSpan), (uint)bytesWritten); + WriteByte(0); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Struct.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Struct.cs new file mode 100644 index 0000000000..436fda1743 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Struct.cs @@ -0,0 +1,299 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct(ValueTuple value) + where T1 : notnull + { + WriteStructureStart(); + Write(value.Item1); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2) value) + where T1 : notnull + where T2 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4, T5) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + Write(value.Item5); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4, T5, T6) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + Write(value.Item5); + Write(value.Item6); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4, T5, T6, T7) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + Write(value.Item5); + Write(value.Item6); + Write(value.Item7); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4, T5, T6, T7, T8) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + Write(value.Item5); + Write(value.Item6); + Write(value.Item7); + Write(value.Rest.Item1); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4, T5, T6, T7, T8, T9) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + Write(value.Item5); + Write(value.Item6); + Write(value.Item7); + Write(value.Rest.Item1); + Write(value.Rest.Item2); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericWriteStruct)] + [Obsolete(Strings.UseNonGenericWriteStructObsolete)] + public void WriteStruct((T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + WriteStructureStart(); + Write(value.Item1); + Write(value.Item2); + Write(value.Item3); + Write(value.Item4); + Write(value.Item5); + Write(value.Item6); + Write(value.Item7); + Write(value.Rest.Item1); + Write(value.Rest.Item2); + Write(value.Rest.Item3); + } + + private static void WriteStructSignature(ref MessageWriter writer) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + writer.WriteSignature(TypeModel.GetSignature>(buffer)); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Variant.Dynamic.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Variant.Dynamic.cs new file mode 100644 index 0000000000..c7fa3ec6be --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Variant.Dynamic.cs @@ -0,0 +1,77 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + [RequiresUnreferencedCode(Strings.UseNonObjectWriteVariant)] + [Obsolete(Strings.UseNonObjectWriteVariantObsolete)] + public void WriteVariant(object value) + { + Type type = value.GetType(); + + if (type == typeof(byte)) + { + WriteVariantByte((byte)value); + return; + } + else if (type == typeof(bool)) + { + WriteVariantBool((bool)value); + return; + } + else if (type == typeof(short)) + { + WriteVariantInt16((short)value); + return; + } + else if (type == typeof(ushort)) + { + WriteVariantUInt16((ushort)value); + return; + } + else if (type == typeof(int)) + { + WriteVariantInt32((int)value); + return; + } + else if (type == typeof(uint)) + { + WriteVariantUInt32((uint)value); + return; + } + else if (type == typeof(long)) + { + WriteVariantInt64((long)value); + return; + } + else if (type == typeof(ulong)) + { + WriteVariantUInt64((ulong)value); + return; + } + else if (type == typeof(double)) + { + WriteVariantDouble((double)value); + return; + } + else if (type == typeof(string)) + { + WriteVariantString((string)value); + return; + } + else if (type == typeof(ObjectPath)) + { + WriteVariantObjectPath(value.ToString()!); + return; + } + else if (type == typeof(Signature)) + { + WriteVariantSignature(value.ToString()!); + return; + } + else + { + var typeWriter = TypeWriters.GetTypeWriter(type); + typeWriter.WriteVariant(ref this, value); + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.Variant.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Variant.cs new file mode 100644 index 0000000000..99fecbe884 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.Variant.cs @@ -0,0 +1,9 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + public void WriteVariant(Variant value) + { + value.WriteTo(ref this); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.WriteT.Dynamic.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.WriteT.Dynamic.cs new file mode 100644 index 0000000000..d6c5113eaa --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.WriteT.Dynamic.cs @@ -0,0 +1,994 @@ +namespace Tmds.DBus.Protocol; + +// Code in this file is not trimmer friendly. +#pragma warning disable IL3050 +#pragma warning disable IL2026 +// Using obsolete generic write members +#pragma warning disable CS0618 + +public ref partial struct MessageWriter +{ + interface ITypeWriter + { + void WriteVariant(ref MessageWriter writer, object value); + } + + interface ITypeWriter : ITypeWriter + { + void Write(ref MessageWriter writer, T value); + } + + private void WriteDynamic(T value) where T : notnull + { + if (typeof(T) == typeof(object)) + { + WriteVariant((object)value); + return; + } + + var typeWriter = (ITypeWriter)TypeWriters.GetTypeWriter(typeof(T)); + typeWriter.Write(ref this, value); + } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddArrayTypeWriter() + where T : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddDictionaryTypeWriter() + where TKey : notnull + where TValue : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddValueTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { } + + [Obsolete(Strings.AddTypeWriterMethodObsolete)] + public static void AddTupleTypeWriter() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { } + + static class TypeWriters + { + private static readonly Dictionary _typeWriters = new(); + + public static ITypeWriter GetTypeWriter(Type type) + { + lock (_typeWriters) + { + if (_typeWriters.TryGetValue(type, out ITypeWriter? writer)) + { + return writer; + } + writer = CreateWriterForType(type); + _typeWriters.Add(type, writer); + return writer; + } + } + + private static ITypeWriter CreateWriterForType(Type type) + { + // Struct (ValueTuple) + if (type.IsGenericType && type.FullName!.StartsWith("System.ValueTuple")) + { + switch (type.GenericTypeArguments.Length) + { + case 1: return CreateValueTupleTypeWriter(type.GenericTypeArguments[0]); + case 2: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1]); + case 3: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2]); + case 4: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3]); + case 5: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4]); + + case 6: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5]); + case 7: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6]); + case 8: + switch (type.GenericTypeArguments[7].GenericTypeArguments.Length) + { + case 1: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0]); + case 2: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1]); + case 3: + return CreateValueTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1], + type.GenericTypeArguments[7].GenericTypeArguments[2]); + } + break; + } + } + // Struct (ValueTuple) + if (type.IsGenericType && type.FullName!.StartsWith("System.Tuple")) + { + switch (type.GenericTypeArguments.Length) + { + case 1: return CreateTupleTypeWriter(type.GenericTypeArguments[0]); + case 2: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1]); + case 3: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2]); + case 4: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3]); + case 5: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4]); + case 6: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5]); + case 7: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6]); + case 8: + switch (type.GenericTypeArguments[7].GenericTypeArguments.Length) + { + case 1: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0]); + case 2: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1]); + case 3: + return CreateTupleTypeWriter(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1], + type.GenericTypeArguments[7].GenericTypeArguments[2]); + } + break; + } + } + + // Array/Dictionary type (IEnumerable<>/IEnumerable>) + Type? extractedType = TypeModel.ExtractGenericInterface(type, typeof(IEnumerable<>)); + if (extractedType != null) + { + if (_typeWriters.TryGetValue(extractedType, out ITypeWriter? writer)) + { + return writer; + } + + Type elementType = extractedType.GenericTypeArguments[0]; + if (elementType.IsGenericType && elementType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>)) + { + Type keyType = elementType.GenericTypeArguments[0]; + Type valueType = elementType.GenericTypeArguments[1]; + writer = CreateDictionaryTypeWriter(keyType, valueType); + } + else + { + writer = CreateArrayTypeWriter(elementType); + } + + if (type != extractedType) + { + _typeWriters.Add(extractedType, writer); + } + + return writer; + } + + ThrowNotSupportedType(type); + return default!; + } + + sealed class ArrayTypeWriter : ITypeWriter> + where T : notnull + { + public void Write(ref MessageWriter writer, IEnumerable value) + { + writer.WriteArray(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteArraySignature(ref writer); + writer.WriteArray((IEnumerable)value); + } + } + + private static ITypeWriter CreateArrayTypeWriter(Type elementType) + { + Type writerType = typeof(ArrayTypeWriter<>).MakeGenericType(new[] { elementType }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class DictionaryTypeWriter : ITypeWriter>> + where TKey : notnull + where TValue : notnull + { + public void Write(ref MessageWriter writer, IEnumerable> value) + { + writer.WriteDictionary(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteDictionarySignature(ref writer); + writer.WriteDictionary((IEnumerable>)value); + } + } + + private static ITypeWriter CreateDictionaryTypeWriter(Type keyType, Type valueType) + { + Type writerType = typeof(DictionaryTypeWriter<,>).MakeGenericType(new[] { keyType, valueType }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(new ValueTuple(value.Item1)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct(new ValueTuple(value.Item1)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1) + { + Type writerType = typeof(ValueTupleTypeWriter<>).MakeGenericType(new[] { type1 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1) + { + Type writerType = typeof(TupleTypeWriter<>).MakeGenericType(new[] { type1 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct((value.Item1, value.Item2)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2) + { + Type writerType = typeof(ValueTupleTypeWriter<,>).MakeGenericType(new[] { type1, type2 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2) + { + Type writerType = typeof(TupleTypeWriter<,>).MakeGenericType(new[] { type1, type2 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3) + { + Type writerType = typeof(ValueTupleTypeWriter<,,>).MakeGenericType(new[] { type1, type2, type3 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3) + { + Type writerType = typeof(TupleTypeWriter<,,>).MakeGenericType(new[] { type1, type2, type3 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,>).MakeGenericType(new[] { type1, type2, type3, type4 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4) + { + Type writerType = typeof(TupleTypeWriter<,,,>).MakeGenericType(new[] { type1, type2, type3, type4 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4, value.Item5)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5) + { + Type writerType = typeof(TupleTypeWriter<,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6) + { + Type writerType = typeof(TupleTypeWriter<,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + public void Write(ref MessageWriter writer, ValueTuple value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + public void Write(ref MessageWriter writer, Tuple value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6, value.Item7)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7) + { + Type writerType = typeof(TupleTypeWriter<,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter>> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + public void Write(ref MessageWriter writer, (T1, T2, T3, T4, T5, T6, T7, T8) value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple>)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter>> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + public void Write(ref MessageWriter writer, Tuple> value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6, value.Item7, value.Rest.Item1)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple>)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8) + { + Type writerType = typeof(TupleTypeWriter<,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter>> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + public void Write(ref MessageWriter writer, (T1, T2, T3, T4, T5, T6, T7, T8, T9) value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple>)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter>> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + public void Write(ref MessageWriter writer, Tuple> value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6, value.Item7, value.Rest.Item1, value.Rest.Item2)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple>)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9) + { + Type writerType = typeof(TupleTypeWriter<,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type9 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + sealed class ValueTupleTypeWriter : ITypeWriter>> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + public void Write(ref MessageWriter writer, (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) value) + { + writer.WriteStruct(value); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (ValueTuple>)value); + } + } + + sealed class TupleTypeWriter : ITypeWriter>> + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + public void Write(ref MessageWriter writer, Tuple> value) + { + writer.WriteStruct((value.Item1, value.Item2, value.Item3, value.Item4, value.Item5, value.Item6, value.Item7, value.Rest.Item1, value.Rest.Item2, value.Rest.Item3)); + } + + public void WriteVariant(ref MessageWriter writer, object value) + { + WriteStructSignature(ref writer); + Write(ref writer, (Tuple>)value); + } + } + + private static ITypeWriter CreateValueTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9, Type type10) + { + Type writerType = typeof(ValueTupleTypeWriter<,,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type10 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + + private static ITypeWriter CreateTupleTypeWriter(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9, Type type10) + { + Type writerType = typeof(TupleTypeWriter<,,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type9, type10 }); + return (ITypeWriter)Activator.CreateInstance(writerType)!; + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.WriteT.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.WriteT.cs new file mode 100644 index 0000000000..a7f2dee8ac --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.WriteT.cs @@ -0,0 +1,82 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Write(T value) where T : notnull + { + if (typeof(T) == typeof(byte)) + { + WriteByte((byte)(object)value); + } + else if (typeof(T) == typeof(bool)) + { + WriteBool((bool)(object)value); + } + else if (typeof(T) == typeof(short)) + { + WriteInt16((short)(object)value); + } + else if (typeof(T) == typeof(ushort)) + { + WriteUInt16((ushort)(object)value); + } + else if (typeof(T) == typeof(int)) + { + WriteInt32((int)(object)value); + } + else if (typeof(T) == typeof(uint)) + { + WriteUInt32((uint)(object)value); + } + else if (typeof(T) == typeof(long)) + { + WriteInt64((long)(object)value); + } + else if (typeof(T) == typeof(ulong)) + { + WriteUInt64((ulong)(object)value); + } + else if (typeof(T) == typeof(double)) + { + WriteDouble((double)(object)value); + } + else if (typeof(T) == typeof(string)) + { + WriteString((string)(object)value); + } + else if (typeof(T) == typeof(ObjectPath)) + { + WriteString(((ObjectPath)(object)value).ToString()); + } + else if (typeof(T) == typeof(Signature)) + { + WriteSignature(((Signature)(object)value).ToString()); + } + else if (typeof(T) == typeof(Variant)) + { + ((Variant)(object)value).WriteTo(ref this); + } + else if (typeof(T).IsAssignableTo(typeof(SafeHandle))) + { + WriteHandle((SafeHandle)(object)value); + } + else if (typeof(T).IsAssignableTo(typeof(IDBusWritable))) + { + (value as IDBusWritable)!.WriteTo(ref this); + } + else if (Feature.IsDynamicCodeEnabled) + { + WriteDynamic(value); + } + else + { + ThrowNotSupportedType(typeof(T)); + } + } + + private static void ThrowNotSupportedType(Type type) + { + throw new NotSupportedException($"Cannot write type {type.FullName}"); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MessageWriter.cs b/src/Linux/Tmds.DBus.Protocol/MessageWriter.cs new file mode 100644 index 0000000000..950c4b669f --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MessageWriter.cs @@ -0,0 +1,193 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct MessageWriter +{ + private const int LengthOffset = 4; + private const int SerialOffset = 8; + private const int HeaderFieldsLengthOffset = 12; + private const int UnixFdLengthOffset = 20; + + private MessageBuffer _message; + private Sequence _data; + private UnixFdCollection? _handles; + private readonly uint _serial; + private MessageFlags _flags; + private Span _firstSpan; + private Span _span; + private int _offset; + private int _buffered; + + public MessageBuffer CreateMessage() + { + Flush(); + + Span span = _firstSpan; + + // Length + uint headerFieldsLength = Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(span.Slice(HeaderFieldsLengthOffset))); + uint pad = headerFieldsLength % 8; + if (pad != 0) + { + headerFieldsLength += (8 - pad); + } + uint length = (uint)_data.Length // Total length + - headerFieldsLength // Header fields + - 4 // Header fields length + - (uint)HeaderFieldsLengthOffset; // Preceeding header fields + Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(span.Slice(LengthOffset)), length); + + // UnixFdLength + Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(span.Slice(UnixFdLengthOffset)), (uint)HandleCount); + + uint serial = _serial; + MessageFlags flags = _flags; + ReadOnlySequence data = _data; + UnixFdCollection? handles = _handles; + var message = _message; + + _message = null!; + _handles = null; + _data = null!; + + message.Init(serial, flags, handles); + + return message; + } + + internal MessageWriter(MessageBufferPool messagePool, uint serial) + { + _message = messagePool.Rent(); + _data = _message.Sequence; + _handles = null; + _flags = default; + _offset = 0; + _buffered = 0; + _serial = serial; + _firstSpan = _span = _data.GetSpan(sizeHint: 0); + } + + public ArrayStart WriteArrayStart(DBusType elementType) + { + // Array length. + WritePadding(DBusType.UInt32); + Span lengthSpan = GetSpan(4); + Advance(4); + + WritePadding(elementType); + + return new ArrayStart(lengthSpan, _offset); + } + + public void WriteArrayEnd(ArrayStart start) + { + start.WriteLength(_offset); + } + + public void WriteStructureStart() + { + WritePadding(DBusType.Struct); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void Advance(int count) + { + _buffered += count; + _offset += count; + _span = _span.Slice(count); + } + + private void WritePadding(DBusType type) + { + int pad = ProtocolConstants.GetPadding(_offset, type); + if (pad != 0) + { + GetSpan(pad).Slice(0, pad).Fill(0); + Advance(pad); + } + } + + private void WritePadding(int alignment) + { + int pad = ProtocolConstants.GetPadding(_offset, alignment); + if (pad != 0) + { + GetSpan(pad).Slice(0, pad).Fill(0); + Advance(pad); + } + } + + private Span GetSpan(int sizeHint) + { + Ensure(sizeHint); + return _span; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void Ensure(int count = 1) + { + if (_span.Length < count) + { + EnsureMore(count); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void EnsureMore(int count = 0) + { + if (_buffered > 0) + { + Flush(); + } + + _span = _data.GetSpan(count); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void Flush() + { + var buffered = _buffered; + if (buffered > 0) + { + _buffered = 0; + _data.Advance(buffered); + _span = default; + } + } + + public void Dispose() + { + _message?.ReturnToPool(); + _handles?.Dispose(); + + _message = null!; + _data = null!; + _handles = null!; + } + + // For Tests. + internal ReadOnlySequence AsReadOnlySequence() + { + Flush(); + return _data.AsReadOnlySequence; + } + // For Tests. + internal UnixFdCollection? Handles => _handles; +} + +public ref struct ArrayStart +{ + private Span _span; + private int _offset; + + internal ArrayStart(Span lengthSpan, int offset) + { + _span = lengthSpan; + _offset = offset; + } + + internal void WriteLength(int offset) + { + uint length = (uint)(offset - _offset); + Unsafe.WriteUnaligned(ref MemoryMarshal.GetReference(_span), length); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/MethodContext.cs b/src/Linux/Tmds.DBus.Protocol/MethodContext.cs new file mode 100644 index 0000000000..db8380b046 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/MethodContext.cs @@ -0,0 +1,97 @@ +namespace Tmds.DBus.Protocol; + +public class MethodContext +{ + internal MethodContext(Connection connection, Message request, CancellationToken requestAborted) + { + Connection = connection; + Request = request; + RequestAborted = requestAborted; + } + + public Message Request { get; } + public Connection Connection { get; } + public CancellationToken RequestAborted { get; } + + public bool ReplySent { get; private set; } + + public bool NoReplyExpected => (Request.MessageFlags & MessageFlags.NoReplyExpected) != 0; + + public bool IsDBusIntrospectRequest { get; internal set; } + + internal List? IntrospectChildNameList { get; set; } + + public MessageWriter CreateReplyWriter(string? signature) + { + var writer = Connection.GetMessageWriter(); + writer.WriteMethodReturnHeader( + replySerial: Request.Serial, + destination: Request.Sender, + signature: signature + ); + return writer; + } + + public void Reply(MessageBuffer message) + { + if (ReplySent || NoReplyExpected) + { + message.Dispose(); + if (ReplySent) + { + throw new InvalidOperationException("A reply has already been sent."); + } + } + + ReplySent = true; + Connection.TrySendMessage(message); + } + + public void ReplyError(string? errorName = null, + string? errorMsg = null) + { + using var writer = Connection.GetMessageWriter(); + writer.WriteError( + replySerial: Request.Serial, + destination: Request.Sender, + errorName: errorName, + errorMsg: errorMsg + ); + Reply(writer.CreateMessage()); + } + + public void ReplyIntrospectXml(ReadOnlySpan> interfaceXmls) + { + if (!IsDBusIntrospectRequest) + { + throw new InvalidOperationException($"Can not reply with introspection XML when {nameof(IsDBusIntrospectRequest)} is false."); + } + + using var writer = Connection.GetMessageWriter(); + writer.WriteMethodReturnHeader( + replySerial: Request.Serial, + destination: Request.Sender, + signature: "s" + ); + + // Add the Peer and Introspectable interfaces. + // Tools like D-Feet will list the paths separately as soon as there is an interface. + // We add the base interfaces only for the paths that we want to show up. + // Those are paths that have other interfaces, paths that are leaves. + bool includeBaseInterfaces = !interfaceXmls.IsEmpty || IntrospectChildNameList is null || IntrospectChildNameList.Count == 0; + ReadOnlySpan> baseInterfaceXmls = includeBaseInterfaces ? [ IntrospectionXml.DBusIntrospectable, IntrospectionXml.DBusPeer ] : [ ]; + + // Add the child names. +#if NET5_0_OR_GREATER + ReadOnlySpan childNames = CollectionsMarshal.AsSpan(IntrospectChildNameList); + IEnumerable? childNamesEnumerable = null; +#else + ReadOnlySpan childNames = default; + IEnumerable? childNamesEnumerable = IntrospectChildNameList; +#endif + + writer.WriteIntrospectionXml(interfaceXmls, baseInterfaceXmls, childNames, childNamesEnumerable); + + Reply(writer.CreateMessage()); + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Netstandard2_0Extensions.cs b/src/Linux/Tmds.DBus.Protocol/Netstandard2_0Extensions.cs new file mode 100644 index 0000000000..ea65bc332a --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Netstandard2_0Extensions.cs @@ -0,0 +1,214 @@ +using System.Net; +using System.Net.Sockets; + +namespace Tmds.DBus.Protocol; + +#if NETSTANDARD2_0 +static partial class NetstandardExtensions +{ + public static bool Remove(this Dictionary dictionary, TKey key, out TValue value) + { + if (dictionary.TryGetValue(key, out value)) + { + dictionary.Remove(key); + return true; + } + return false; + } + + public static unsafe int GetBytes(this Encoding encoding, ReadOnlySpan chars, Span bytes) + { + fixed (char* pChars = &GetNonNullPinnableReference(chars)) + fixed (byte* pBytes = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetBytes(pChars, chars.Length, pBytes, bytes.Length); + } + } + + public static unsafe int GetChars(this Encoding encoding, ReadOnlySpan bytes, Span chars) + { + fixed (char* pChars = &GetNonNullPinnableReference(chars)) + fixed (byte* pBytes = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetChars(pBytes, bytes.Length, pChars, chars.Length); + } + } + + public static unsafe string GetString(this Encoding encoding, ReadOnlySpan bytes) + { + fixed (byte* pBytes = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetString(pBytes, bytes.Length); + } + } + + public static unsafe int GetCharCount(this Encoding encoding, ReadOnlySpan bytes) + { + fixed (byte* pBytes = &GetNonNullPinnableReference(bytes)) + { + return encoding.GetCharCount(pBytes, bytes.Length); + } + } + + public static unsafe int GetByteCount(this Encoding encoding, ReadOnlySpan chars) + { + fixed (char* pChars = &GetNonNullPinnableReference(chars)) + { + return encoding.GetByteCount(pChars, chars.Length); + } + } + + public static unsafe int GetByteCount(this Encoder encoder, ReadOnlySpan chars, bool flush) + { + fixed (char* pChars = &GetNonNullPinnableReference(chars)) + { + return encoder.GetByteCount(pChars, chars.Length, flush); + } + } + + public static unsafe void Convert(this Encoder encoder, ReadOnlySpan chars, Span bytes, bool flush, out int charsUsed, out int bytesUsed, out bool completed) + { + fixed (char* pChars = &GetNonNullPinnableReference(chars)) + fixed (byte* pBytes = &GetNonNullPinnableReference(bytes)) + { + encoder.Convert(pChars, chars.Length, pBytes, bytes.Length, flush, out charsUsed, out bytesUsed, out completed); + } + } + + public static unsafe void Append(this StringBuilder sb, ReadOnlySpan value) + { + fixed (char* ptr = value) + { + sb.Append(ptr, value.Length); + } + } + + public static unsafe string AsString(this ReadOnlySpan chars) + { + fixed (char* ptr = chars) + { + return new string(ptr, 0, chars.Length); + } + } + + public static unsafe string AsString(this Span chars) + => AsString((ReadOnlySpan)chars); + + public static async ValueTask ReceiveAsync(this Socket socket, Memory buffer, SocketFlags socketFlags) + { + if (MemoryMarshal.TryGetArray((ReadOnlyMemory)buffer, out var segment)) + return await SocketTaskExtensions.ReceiveAsync(socket, segment, socketFlags).ConfigureAwait(false); + + throw new NotSupportedException(); + } + + public static async ValueTask SendAsync(this Socket socket, ReadOnlyMemory buffer, SocketFlags socketFlags) + { + if (MemoryMarshal.TryGetArray(buffer, out var segment)) + return await SocketTaskExtensions.SendAsync(socket, segment, socketFlags).ConfigureAwait(false); + + throw new NotSupportedException(); + } + + /// + /// Returns a reference to the 0th element of the Span. If the Span is empty, returns a reference to fake non-null pointer. Such a reference can be used + /// for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe ref T GetNonNullPinnableReference(Span span) => ref (span.Length != 0) ? ref MemoryMarshal.GetReference(span) : ref Unsafe.AsRef((void*)1); + + /// + /// Returns a reference to the 0th element of the ReadOnlySpan. If the ReadOnlySpan is empty, returns a reference to fake non-null pointer. Such a reference + /// can be used for pinning but must never be dereferenced. This is useful for interop with methods that do not accept null pointers for zero-sized buffers. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe ref T GetNonNullPinnableReference(ReadOnlySpan span) => ref (span.Length != 0) ? ref MemoryMarshal.GetReference(span) : ref Unsafe.AsRef((void*)1); +} + +internal sealed class UnixDomainSocketEndPoint : EndPoint +{ + private const AddressFamily EndPointAddressFamily = AddressFamily.Unix; + + private static readonly Encoding s_pathEncoding = Encoding.UTF8; + private const int s_nativePathOffset = 2; + + private readonly string _path; + private readonly byte[] _encodedPath; + + public UnixDomainSocketEndPoint(string path) + { + if (path == null) + { + throw new ArgumentNullException(nameof(path)); + } + + _path = path; + _encodedPath = s_pathEncoding.GetBytes(_path); + + if (path.Length == 0) + { + throw new ArgumentOutOfRangeException( + nameof(path), path, + string.Format("The path '{0}' is of an invalid length for use with domain sockets on this platform. The length must be at least 1 characters.", path)); + } + } + + internal UnixDomainSocketEndPoint(SocketAddress socketAddress) + { + if (socketAddress == null) + { + throw new ArgumentNullException(nameof(socketAddress)); + } + + if (socketAddress.Family != EndPointAddressFamily) + { + throw new ArgumentOutOfRangeException(nameof(socketAddress)); + } + + if (socketAddress.Size > s_nativePathOffset) + { + _encodedPath = new byte[socketAddress.Size - s_nativePathOffset]; + for (int i = 0; i < _encodedPath.Length; i++) + { + _encodedPath[i] = socketAddress[s_nativePathOffset + i]; + } + + _path = s_pathEncoding.GetString(_encodedPath, 0, _encodedPath.Length); + } + else + { + _encodedPath = Array.Empty(); + _path = string.Empty; + } + } + + public override SocketAddress Serialize() + { + var result = new SocketAddress(AddressFamily.Unix, _encodedPath.Length + s_nativePathOffset); + + for (int index = 0; index < _encodedPath.Length; index++) + { + result[s_nativePathOffset + index] = _encodedPath[index]; + } + + return result; + } + + public override EndPoint Create(SocketAddress socketAddress) => new UnixDomainSocketEndPoint(socketAddress); + + public override AddressFamily AddressFamily => EndPointAddressFamily; + + public string Path => _path; + + public override string ToString() => _path; +} +#else +static partial class NetstandardExtensions +{ + public static string AsString(this ReadOnlySpan chars) + => new string(chars); + + public static unsafe string AsString(this Span chars) + => AsString((ReadOnlySpan)chars); +} +#endif diff --git a/src/Linux/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs b/src/Linux/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs new file mode 100644 index 0000000000..70694d0362 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs @@ -0,0 +1,53 @@ +using System.Net; +using System.Net.Sockets; +using System.Reflection; + +namespace Tmds.DBus.Protocol; + +#if NETSTANDARD2_0 || NETSTANDARD2_1 +static partial class NetstandardExtensions +{ + + private static PropertyInfo s_safehandleProperty = typeof(Socket).GetTypeInfo().GetDeclaredProperty("SafeHandle"); + + private const int MaxInputElementsPerIteration = 1 * 1024 * 1024; + + public static bool IsAssignableTo(this Type type, Type? targetType) + => targetType?.IsAssignableFrom(type) ?? false; + + public static SafeHandle GetSafeHandle(this Socket socket) + { + if (s_safehandleProperty != null) + { + return (SafeHandle)s_safehandleProperty.GetValue(socket, null); + } + ThrowHelper.ThrowNotSupportedException(); + return null!; + } + + public static async Task ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken) + { + using var ctr = cancellationToken.Register(state => ((Socket)state!).Dispose(), socket, useSynchronizationContext: false); + try + { + await Task.Factory.FromAsync( + (targetEndPoint, callback, state) => ((Socket)state).BeginConnect(targetEndPoint, callback, state), + asyncResult => ((Socket)asyncResult.AsyncState).EndConnect(asyncResult), + remoteEP, + state: socket).ConfigureAwait(false); + } + catch (ObjectDisposedException) + { + cancellationToken.ThrowIfCancellationRequested(); + + throw; + } + } +} +#else +static partial class NetstandardExtensions +{ + public static SafeHandle GetSafeHandle(this Socket socket) + => socket.SafeHandle; +} +#endif diff --git a/src/Linux/Tmds.DBus.Protocol/ObjectPath.cs b/src/Linux/Tmds.DBus.Protocol/ObjectPath.cs new file mode 100644 index 0000000000..51753ea190 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ObjectPath.cs @@ -0,0 +1,16 @@ +namespace Tmds.DBus.Protocol; + +public struct ObjectPath +{ + private string _value; + + public ObjectPath(string value) => _value = value; + + public override string ToString() => _value ?? ""; + + public static implicit operator string(ObjectPath value) => value._value; + + public static implicit operator ObjectPath(string value) => new ObjectPath(value); + + public Variant AsVariant() => new Variant(this); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/ObserverFlags.cs b/src/Linux/Tmds.DBus.Protocol/ObserverFlags.cs new file mode 100644 index 0000000000..d14295957a --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ObserverFlags.cs @@ -0,0 +1,12 @@ +namespace Tmds.DBus.Protocol; + +[Flags] +public enum ObserverFlags +{ + None = 0, + EmitOnConnectionDispose = 1, + EmitOnObserverDispose = 2, + NoSubscribe = 4, + + EmitOnDispose = EmitOnConnectionDispose | EmitOnObserverDispose, +} diff --git a/src/Linux/Tmds.DBus.Protocol/PathNodeDictionary.cs b/src/Linux/Tmds.DBus.Protocol/PathNodeDictionary.cs new file mode 100644 index 0000000000..443a042104 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/PathNodeDictionary.cs @@ -0,0 +1,311 @@ +namespace Tmds.DBus.Protocol; + +sealed class PathNode +{ + // _childNames is null when there are no child names + // a string if there is a single child name + // a List List.Count child names + private object? _childNames; + public IMethodHandler? MethodHandler; + public PathNode? Parent { get; set; } + + public int ChildNameCount => + _childNames switch + { + null => 0, + string => 1, + var list => ((List)list).Count + }; + + public void ClearChildNames() + { + Debug.Assert(ChildNameCount == 1, "Method isn't expected to be called unless there is 1 child name."); + if (_childNames is List list) + { + list.Clear(); + } + else + { + _childNames = null; + } + } + + public void RemoveChildName(string name) + { + Debug.Assert(ChildNameCount > 1, "Caller is expected to call ClearChildNames instead."); + var list = (List)_childNames!; + list.Remove(name); + } + + public void AddChildName(string value) + { + if (_childNames is null) + { + _childNames = value; + } + else if (_childNames is string first) + { + _childNames = new List() { first, value }; + } + else + { + ((List)_childNames).Add(value); + } + } + + public void CopyChildNamesTo(MethodContext methodContext) + { + Debug.Assert(methodContext.IntrospectChildNameList is null || methodContext.IntrospectChildNameList.Count == 0); + + if (_childNames is null) + { + return; + } + + methodContext.IntrospectChildNameList ??= new(); + if (_childNames is string s) + { + methodContext.IntrospectChildNameList.Add(s); + } + else + { + methodContext.IntrospectChildNameList.AddRange((List)_childNames); + } + } +} + +sealed class PathNodeDictionary : IMethodHandlerDictionary +{ + private readonly Dictionary _dictionary = new(); + + public bool TryGetValue(string path, [NotNullWhen(true)]out PathNode? pathNode) + => _dictionary.TryGetValue(path, out pathNode); + + // For tests: + public PathNode this[string path] + => _dictionary[path]; + public int Count + => _dictionary.Count; + + public void AddMethodHandlers(IReadOnlyList methodHandlers) + { + if (methodHandlers is null) + { + throw new ArgumentNullException(nameof(methodHandlers)); + } + + int registeredCount = 0; + try + { + for (int i = 0; i < methodHandlers.Count; i++) + { + IMethodHandler methodHandler = methodHandlers[i] ?? throw new ArgumentNullException("methodHandler"); + + AddMethodHandler(methodHandler); + + registeredCount++; + } + } + catch + { + RemoveMethodHandlers(methodHandlers, registeredCount); + + throw; + } + } + + + private PathNode GetOrCreateNode(string path) + { +#if NET6_0_OR_GREATER + ref PathNode? node = ref CollectionsMarshal.GetValueRefOrAddDefault(_dictionary, path, out bool exists); + if (exists) + { + return node!; + } + PathNode newNode = new PathNode(); + node = newNode; +#else + if (_dictionary.TryGetValue(path, out PathNode? node)) + { + return node; + } + PathNode newNode = new PathNode(); + _dictionary.Add(path, newNode); +#endif + string? parentPath = GetParentPath(path); + if (parentPath is not null) + { + PathNode parent = GetOrCreateNode(parentPath); + newNode.Parent = parent; + parent.AddChildName(GetChildName(path)); + } + + return newNode; + } + + private static string? GetParentPath(string path) + { + if (path.Length == 1) + { + return null; + } + + int index = path.LastIndexOf('/'); + Debug.Assert(index != -1); + + // When index == 0, return '/'. + index = Math.Max(index, 1); + + return path.Substring(0, index); + } + + private static string GetChildName(string path) + { + int index = path.LastIndexOf('/'); + return path.Substring(index + 1); + } + + private void RemoveMethodHandlers(IReadOnlyList methodHandlers, int count) + { + // We start by (optimistically) removing all nodes (assuming they form a tree that is pruned). + // If there are nodes that are still needed to serve as parent nodes, we'll add them back at the end. + (string Path, PathNode Node)[] nodes = new (string, PathNode)[count]; + int j = 0; + for (int i = 0; i < count; i++) + { + string path = methodHandlers[i].Path; + if (_dictionary.Remove(path, out PathNode? node)) + { + nodes[j++] = (path, node); + node.MethodHandler = null; + } + } + count = j; j = 0; + + // Reverse sort by path length to remove leaves before parents. + Array.Sort(nodes, 0, count, RemoveKeyComparerInstance); + for (int i = 0; i < count; i++) + { + var node = nodes[i]; + if (node.Node.ChildNameCount == 0) + { + RemoveFromParent(node.Path, node.Node); + } + else + { + nodes[j++] = node; + } + } + count = j; j = 0; + + // Add back the nodes that serve as parent nodes. + for (int i = 0; i < count; i++) + { + var node = nodes[i]; + _dictionary[node.Path] = node.Node; + } + } + + private void RemoveFromParent(string path, PathNode node) + { + PathNode? parent = node.Parent; + if (parent is null) + { + return; + } + Debug.Assert(parent.ChildNameCount >= 1, "node is expected to be a known child"); + if (parent.ChildNameCount == 1) // We're the only child. + { + if (parent.MethodHandler is not null) + { + // Parent is still needed for the MethodHandler. + parent.ClearChildNames(); + } + else + { +// Suppress netstandard2.0 nullability warnings around NetstandardExtensions.Remove. +#if NETSTANDARD2_0 +#pragma warning disable CS8620 +#pragma warning disable CS8604 +#endif + // Parent is no longer needed. + string parentPath = GetParentPath(path)!; + Debug.Assert(parentPath is not null); + _dictionary.Remove(parentPath, out PathNode? parentNode); + Debug.Assert(parentNode is not null); + RemoveFromParent(parentPath, parentNode); +#if NETSTANDARD2_0 +#pragma warning restore CS8620 +#pragma warning restore CS8604 +#endif + } + } + else + { + string childName = GetChildName(path); + parent.RemoveChildName(childName); + } + } + + public void AddMethodHandler(IMethodHandler methodHandler) + { + string path = methodHandler.Path ?? throw new ArgumentNullException(nameof(methodHandler.Path)); + + // Validate the path starts with '/' and has no empty sections. + // GetParentPath relies on this. + if (path[0] != '/' || path.IndexOf("//", StringComparison.Ordinal) != -1) + { + throw new FormatException($"The path '{path}' is not valid."); + } + + PathNode node = GetOrCreateNode(path); + + if (node.MethodHandler is not null) + { + throw new InvalidOperationException($"A method handler is already registered for the path '{path}'."); + } + node.MethodHandler = methodHandler; + } + + public void RemoveMethodHandler(string path) + { + if (path is null) + { + throw new ArgumentNullException(nameof(path)); + } + if (_dictionary.Remove(path, out PathNode? node)) + { + if (node.ChildNameCount > 0) + { + // Node is still needed for its children. + node.MethodHandler = null; + _dictionary.Add(path, node); + } + else + { + RemoveFromParent(path, node); + } + } + } + + public void RemoveMethodHandlers(IEnumerable paths) + { + if (paths is null) + { + throw new ArgumentNullException(nameof(paths)); + } + foreach (var path in paths) + { + RemoveMethodHandler(path); + } + } + + private static readonly RemoveKeyComparer RemoveKeyComparerInstance = new(); + + sealed class RemoveKeyComparer : IComparer<(string Path, PathNode Node)> + { + public int Compare((string Path, PathNode Node) x, (string Path, PathNode Node) y) + => x.Path.Length - y.Path.Length; + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/PlatformDetection.cs b/src/Linux/Tmds.DBus.Protocol/PlatformDetection.cs new file mode 100644 index 0000000000..5f272b84e1 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/PlatformDetection.cs @@ -0,0 +1,19 @@ +namespace Tmds.DBus.Protocol; +#if NET6_0_OR_GREATER +using System.Runtime.Versioning; +#endif + +static class PlatformDetection +{ +#if NET6_0_OR_GREATER + [SupportedOSPlatformGuard("windows")] +#endif + public static bool IsWindows() => +#if NET6_0_OR_GREATER + // IsWindows is marked with the NonVersionable attribute. + // This allows R2R to inline it and eliminate platform-specific branches. + OperatingSystem.IsWindows(); +#else + RuntimeInformation.IsOSPlatform(OSPlatform.Windows); +#endif +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/DynamicallyAccessedMemberTypes.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/DynamicallyAccessedMemberTypes.cs new file mode 100644 index 0000000000..52dfc0d0e9 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/DynamicallyAccessedMemberTypes.cs @@ -0,0 +1,25 @@ +#if !NET5_0_OR_GREATER + +namespace System.Diagnostics.CodeAnalysis; + +[Flags] +enum DynamicallyAccessedMemberTypes +{ + None = 0, + PublicParameterlessConstructor = 0x0001, + PublicConstructors = 0x0002 | PublicParameterlessConstructor, + NonPublicConstructors = 0x0004, + PublicMethods = 0x0008, + NonPublicMethods = 0x0010, + PublicFields = 0x0020, + NonPublicFields = 0x0040, + PublicNestedTypes = 0x0080, + NonPublicNestedTypes = 0x0100, + PublicProperties = 0x0200, + NonPublicProperties = 0x0400, + PublicEvents = 0x0800, + NonPublicEvents = 0x1000, + Interfaces = 0x2000, + All = ~None +} +#endif \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/DynamicallyAccessedMembersAttribute.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/DynamicallyAccessedMembersAttribute.cs new file mode 100644 index 0000000000..8d4586996b --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/DynamicallyAccessedMembersAttribute.cs @@ -0,0 +1,30 @@ +#if !NET5_0_OR_GREATER + +namespace System.Diagnostics.CodeAnalysis; + +using Targets = AttributeTargets; + +[ExcludeFromCodeCoverage] +[DebuggerNonUserCode] +[AttributeUsage( + validOn: Targets.Class | + Targets.Field | + Targets.GenericParameter | + Targets.Interface | + Targets.Method | + Targets.Parameter | + Targets.Property | + Targets.ReturnValue | + Targets.Struct, + Inherited = false)] + +sealed class DynamicallyAccessedMembersAttribute : + Attribute +{ + public DynamicallyAccessedMembersAttribute(DynamicallyAccessedMemberTypes memberTypes) => + MemberTypes = memberTypes; + + public DynamicallyAccessedMemberTypes MemberTypes { get; } +} + +#endif \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/Nerdbank.Streams.Sequence.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/Nerdbank.Streams.Sequence.cs new file mode 100644 index 0000000000..20757e1421 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/Nerdbank.Streams.Sequence.cs @@ -0,0 +1,530 @@ +// Copyright (c) Andrew Arnott. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Nerdbank.Streams +{ + using System; + using System.Buffers; + using System.Collections.Generic; + using System.ComponentModel; + using System.Diagnostics; + using System.Reflection; + using System.Runtime.CompilerServices; + + /// + /// Manages a sequence of elements, readily castable as a . + /// + /// The type of element stored by the sequence. + /// + /// Instance members are not thread-safe. + /// + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + internal class Sequence : IBufferWriter, IDisposable + { + private const int MaximumAutoGrowSize = 32 * 1024; + + private static readonly int DefaultLengthFromArrayPool = 1 + (4095 / Unsafe.SizeOf()); + + private static readonly ReadOnlySequence Empty = new ReadOnlySequence(SequenceSegment.Empty, 0, SequenceSegment.Empty, 0); + + private readonly Stack segmentPool = new Stack(); + + private readonly MemoryPool? memoryPool; + + private readonly ArrayPool? arrayPool; + + private SequenceSegment? first; + + private SequenceSegment? last; + + /// + /// Initializes a new instance of the class + /// that uses a private for recycling arrays. + /// + public Sequence() + : this(ArrayPool.Create()) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The pool to use for recycling backing arrays. + public Sequence(MemoryPool memoryPool) + { + this.memoryPool = memoryPool ?? throw new ArgumentNullException(nameof(memoryPool)); + } + + /// + /// Initializes a new instance of the class. + /// + /// The pool to use for recycling backing arrays. + public Sequence(ArrayPool arrayPool) + { + this.arrayPool = arrayPool ?? throw new ArgumentNullException(nameof(arrayPool)); + } + + /// + /// Gets or sets the minimum length for any array allocated as a segment in the sequence. + /// Any non-positive value allows the pool to determine the length of the array. + /// + /// The default value is 0. + /// + /// + /// Each time or is called, + /// previously allocated memory is used if it is large enough to satisfy the length demand. + /// If new memory must be allocated, the argument to one of these methods typically dictate + /// the length of array to allocate. When the caller uses very small values (just enough for its immediate need) + /// but the high level scenario can predict that a large amount of memory will be ultimately required, + /// it can be advisable to set this property to a value such that just a few larger arrays are allocated + /// instead of many small ones. + /// + /// + /// The in use may itself have a minimum array length as well, + /// in which case the higher of the two minimums dictate the minimum array size that will be allocated. + /// + /// + /// If is , this value may be automatically increased as the length of a sequence grows. + /// + /// + public int MinimumSpanLength { get; set; } = 0; + + /// + /// Gets or sets a value indicating whether the should be + /// intelligently increased as the length of the sequence grows. + /// + /// + /// This can help prevent long sequences made up of many very small arrays. + /// + public bool AutoIncreaseMinimumSpanLength { get; set; } = true; + + /// + /// Gets this sequence expressed as a . + /// + /// A read only sequence representing the data in this object. + public ReadOnlySequence AsReadOnlySequence => this; + + /// + /// Gets the length of the sequence. + /// + public long Length => this.AsReadOnlySequence.Length; + + /// + /// Gets the value to display in a debugger datatip. + /// + private string DebuggerDisplay => $"Length: {this.AsReadOnlySequence.Length}"; + + /// + /// Expresses this sequence as a . + /// + /// The sequence to convert. + public static implicit operator ReadOnlySequence(Sequence sequence) + { + return sequence.first is { } first && sequence.last is { } last + ? new ReadOnlySequence(first, first.Start, last, last!.End) + : Empty; + } + + /// + /// Removes all elements from the sequence from its beginning to the specified position, + /// considering that data to have been fully processed. + /// + /// + /// The position of the first element that has not yet been processed. + /// This is typically after reading all elements from that instance. + /// + public void AdvanceTo(SequencePosition position) + { + var firstSegment = (SequenceSegment?)position.GetObject(); + if (firstSegment == null) + { + // Emulate PipeReader behavior which is to just return for default(SequencePosition) + return; + } + + if (ReferenceEquals(firstSegment, SequenceSegment.Empty) && this.Length == 0) + { + // We were called with our own empty buffer segment. + return; + } + + int firstIndex = position.GetInteger(); + + // Before making any mutations, confirm that the block specified belongs to this sequence. + Sequence.SequenceSegment? current = this.first; + while (current != firstSegment && current != null) + { + current = current.Next; + } + + if (current == null) + throw new ArgumentException("Position does not represent a valid position in this sequence.", + nameof(position)); + + // Also confirm that the position is not a prior position in the block. + if (firstIndex < current.Start) + throw new ArgumentException("Position must not be earlier than current position.", nameof(position)); + + // Now repeat the loop, performing the mutations. + current = this.first; + while (current != firstSegment) + { + current = this.RecycleAndGetNext(current!); + } + + firstSegment.AdvanceTo(firstIndex); + + this.first = firstSegment.Length == 0 ? this.RecycleAndGetNext(firstSegment) : firstSegment; + + if (this.first == null) + { + this.last = null; + } + } + + /// + /// Advances the sequence to include the specified number of elements initialized into memory + /// returned by a prior call to . + /// + /// The number of elements written into memory. + public void Advance(int count) + { + SequenceSegment? last = this.last; + if(last==null) + throw new InvalidOperationException("Cannot advance before acquiring memory."); + last.Advance(count); + this.ConsiderMinimumSizeIncrease(); + } + + /// + /// Gets writable memory that can be initialized and added to the sequence via a subsequent call to . + /// + /// The size of the memory required, or 0 to just get a convenient (non-empty) buffer. + /// The requested memory. + public Memory GetMemory(int sizeHint) => this.GetSegment(sizeHint).RemainingMemory; + + /// + /// Gets writable memory that can be initialized and added to the sequence via a subsequent call to . + /// + /// The size of the memory required, or 0 to just get a convenient (non-empty) buffer. + /// The requested memory. + public Span GetSpan(int sizeHint) => this.GetSegment(sizeHint).RemainingSpan; + + /// + /// Adds an existing memory location to this sequence without copying. + /// + /// The memory to add. + /// + /// This *may* leave significant slack space in a previously allocated block if calls to + /// follow calls to or . + /// + public void Append(ReadOnlyMemory memory) + { + if (memory.Length > 0) + { + Sequence.SequenceSegment? segment = this.segmentPool.Count > 0 ? this.segmentPool.Pop() : new SequenceSegment(); + segment.AssignForeign(memory); + this.Append(segment); + } + } + + /// + /// Clears the entire sequence, recycles associated memory into pools, + /// and resets this instance for reuse. + /// This invalidates any previously produced by this instance. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + public void Dispose() => this.Reset(); + + /// + /// Clears the entire sequence and recycles associated memory into pools. + /// This invalidates any previously produced by this instance. + /// + public void Reset() + { + Sequence.SequenceSegment? current = this.first; + while (current != null) + { + current = this.RecycleAndGetNext(current); + } + + this.first = this.last = null; + } + + private SequenceSegment GetSegment(int sizeHint) + { + if (sizeHint < 0) + throw new ArgumentOutOfRangeException(nameof(sizeHint)); + int? minBufferSize = null; + if (sizeHint == 0) + { + if (this.last == null || this.last.WritableBytes == 0) + { + // We're going to need more memory. Take whatever size the pool wants to give us. + minBufferSize = -1; + } + } + else + { + if (this.last == null || this.last.WritableBytes < sizeHint) + { + minBufferSize = Math.Max(this.MinimumSpanLength, sizeHint); + } + } + + if (minBufferSize.HasValue) + { + Sequence.SequenceSegment? segment = this.segmentPool.Count > 0 ? this.segmentPool.Pop() : new SequenceSegment(); + if (this.arrayPool != null) + { + segment.Assign(this.arrayPool.Rent(minBufferSize.Value == -1 ? DefaultLengthFromArrayPool : minBufferSize.Value)); + } + else + { + segment.Assign(this.memoryPool!.Rent(minBufferSize.Value)); + } + + this.Append(segment); + } + + return this.last!; + } + + private void Append(SequenceSegment segment) + { + if (this.last == null) + { + this.first = this.last = segment; + } + else + { + if (this.last.Length > 0) + { + // Add a new block. + this.last.SetNext(segment); + } + else + { + // The last block is completely unused. Replace it instead of appending to it. + Sequence.SequenceSegment? current = this.first; + if (this.first != this.last) + { + while (current!.Next != this.last) + { + current = current.Next; + } + } + else + { + this.first = segment; + } + + current!.SetNext(segment); + this.RecycleAndGetNext(this.last); + } + + this.last = segment; + } + } + + private SequenceSegment? RecycleAndGetNext(SequenceSegment segment) + { + Sequence.SequenceSegment? recycledSegment = segment; + Sequence.SequenceSegment? nextSegment = segment.Next; + recycledSegment.ResetMemory(this.arrayPool); + this.segmentPool.Push(recycledSegment); + return nextSegment; + } + + private void ConsiderMinimumSizeIncrease() + { + if (this.AutoIncreaseMinimumSpanLength && this.MinimumSpanLength < MaximumAutoGrowSize) + { + int autoSize = Math.Min(MaximumAutoGrowSize, (int)Math.Min(int.MaxValue, this.Length / 2)); + if (this.MinimumSpanLength < autoSize) + { + this.MinimumSpanLength = autoSize; + } + } + } + + private class SequenceSegment : ReadOnlySequenceSegment + { + internal static readonly SequenceSegment Empty = new SequenceSegment(); + + /// + /// A value indicating whether the element may contain references (and thus must be cleared). + /// + private static readonly bool MayContainReferences = !typeof(T).GetTypeInfo().IsPrimitive; + +#pragma warning disable SA1011 // Closing square brackets should be spaced correctly + /// + /// Gets the backing array, when using an instead of a . + /// + private T[]? array; +#pragma warning restore SA1011 // Closing square brackets should be spaced correctly + + /// + /// Gets the position within where the data starts. + /// + /// This may be nonzero as a result of calling . + internal int Start { get; private set; } + + /// + /// Gets the position within where the data ends. + /// + internal int End { get; private set; } + + /// + /// Gets the tail of memory that has not yet been committed. + /// + internal Memory RemainingMemory => this.AvailableMemory.Slice(this.End); + + /// + /// Gets the tail of memory that has not yet been committed. + /// + internal Span RemainingSpan => this.AvailableMemory.Span.Slice(this.End); + + /// + /// Gets the tracker for the underlying array for this segment, which can be used to recycle the array when we're disposed of. + /// Will be if using an array pool, in which case the memory is held by . + /// + internal IMemoryOwner? MemoryOwner { get; private set; } + + /// + /// Gets the full memory owned by the . + /// + internal Memory AvailableMemory => this.array ?? this.MemoryOwner?.Memory ?? default; + + /// + /// Gets the number of elements that are committed in this segment. + /// + internal int Length => this.End - this.Start; + + /// + /// Gets the amount of writable bytes in this segment. + /// It is the amount of bytes between and . + /// + internal int WritableBytes => this.AvailableMemory.Length - this.End; + + /// + /// Gets or sets the next segment in the singly linked list of segments. + /// + internal new SequenceSegment? Next + { + get => (SequenceSegment?)base.Next; + set => base.Next = value; + } + + /// + /// Gets a value indicating whether this segment refers to memory that came from outside and that we cannot write to nor recycle. + /// + internal bool IsForeignMemory => this.array == null && this.MemoryOwner == null; + + /// + /// Assigns this (recyclable) segment a new area in memory. + /// + /// The memory and a means to recycle it. + internal void Assign(IMemoryOwner memoryOwner) + { + this.MemoryOwner = memoryOwner; + this.Memory = memoryOwner.Memory; + } + + /// + /// Assigns this (recyclable) segment a new area in memory. + /// + /// An array drawn from an . + internal void Assign(T[] array) + { + this.array = array; + this.Memory = array; + } + + /// + /// Assigns this (recyclable) segment a new area in memory. + /// + /// A memory block obtained from outside, that we do not own and should not recycle. + internal void AssignForeign(ReadOnlyMemory memory) + { + this.Memory = memory; + this.End = memory.Length; + } + + /// + /// Clears all fields in preparation to recycle this instance. + /// + internal void ResetMemory(ArrayPool? arrayPool) + { + this.ClearReferences(this.Start, this.End - this.Start); + this.Memory = default; + this.Next = null; + this.RunningIndex = 0; + this.Start = 0; + this.End = 0; + if (this.array != null) + { + arrayPool!.Return(this.array); + this.array = null; + } + else + { + this.MemoryOwner?.Dispose(); + this.MemoryOwner = null; + } + } + + /// + /// Adds a new segment after this one. + /// + /// The next segment in the linked list. + internal void SetNext(SequenceSegment segment) + { + this.Next = segment; + segment.RunningIndex = this.RunningIndex + this.Start + this.Length; + + // Trim any slack on this segment. + if (!this.IsForeignMemory) + { + // When setting Memory, we start with index 0 instead of this.Start because + // the first segment has an explicit index set anyway, + // and we don't want to double-count it here. + this.Memory = this.AvailableMemory.Slice(0, this.Start + this.Length); + } + } + + /// + /// Commits more elements as written in this segment. + /// + /// The number of elements written. + internal void Advance(int count) + { + if (!(count >= 0 && this.End + count <= this.Memory.Length)) + throw new ArgumentOutOfRangeException(nameof(count)); + + this.End += count; + } + + /// + /// Removes some elements from the start of this segment. + /// + /// The number of elements to ignore from the start of the underlying array. + internal void AdvanceTo(int offset) + { + Debug.Assert(offset >= this.Start, "Trying to rewind."); + this.ClearReferences(this.Start, offset - this.Start); + this.Start = offset; + } + + private void ClearReferences(int startIndex, int length) + { + // Clear the array to allow the objects to be GC'd. + // Reference types need to be cleared. Value types can be structs with reference type members too, so clear everything. + if (MayContainReferences) + { + this.AvailableMemory.Span.Slice(startIndex, length).Clear(); + } + } + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/NullableAttributes.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/NullableAttributes.cs new file mode 100644 index 0000000000..3cd6ed4efc --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/NullableAttributes.cs @@ -0,0 +1,144 @@ +namespace System.Diagnostics.CodeAnalysis +{ +#if NETSTANDARD2_0 + + /// Specifies that null is allowed as an input even if the corresponding type disallows it. + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property, Inherited = false)] + internal sealed class AllowNullAttribute : Attribute { } + + /// Specifies that null is disallowed as an input even if the corresponding type allows it. + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property, Inherited = false)] + internal sealed class DisallowNullAttribute : Attribute { } + + /// Specifies that an output may be null even if the corresponding type disallows it. + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, Inherited = false)] + internal sealed class MaybeNullAttribute : Attribute { } + + /// Specifies that an output will not be null even if the corresponding type allows it. + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, Inherited = false)] + internal sealed class NotNullAttribute : Attribute { } + + /// Specifies that when a method returns , the parameter may be null even if the corresponding type disallows it. + [AttributeUsage(AttributeTargets.Parameter, Inherited = false)] + internal sealed class MaybeNullWhenAttribute : Attribute + { + /// Initializes the attribute with the specified return value condition. + /// + /// The return value condition. If the method returns this value, the associated parameter may be null. + /// + public MaybeNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + + /// Gets the return value condition. + public bool ReturnValue { get; } + } + + /// Specifies that when a method returns , the parameter will not be null even if the corresponding type allows it. + [AttributeUsage(AttributeTargets.Parameter, Inherited = false)] + internal sealed class NotNullWhenAttribute : Attribute + { + /// Initializes the attribute with the specified return value condition. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + /// + public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + + /// Gets the return value condition. + public bool ReturnValue { get; } + } + + /// Specifies that the output will be non-null if the named parameter is non-null. + [AttributeUsage(AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, AllowMultiple = true, Inherited = false)] + internal sealed class NotNullIfNotNullAttribute : Attribute + { + /// Initializes the attribute with the associated parameter name. + /// + /// The associated parameter name. The output will be non-null if the argument to the parameter specified is non-null. + /// + public NotNullIfNotNullAttribute(string parameterName) => ParameterName = parameterName; + + /// Gets the associated parameter name. + public string ParameterName { get; } + } + + /// Applied to a method that will never return under any circumstance. + [AttributeUsage(AttributeTargets.Method, Inherited = false)] + internal sealed class DoesNotReturnAttribute : Attribute { } + + /// Specifies that the method will not return if the associated Boolean parameter is passed the specified value. + [AttributeUsage(AttributeTargets.Parameter, Inherited = false)] + internal sealed class DoesNotReturnIfAttribute : Attribute + { + /// Initializes the attribute with the specified parameter value. + /// + /// The condition parameter value. Code after the method will be considered unreachable by diagnostics if the argument to + /// the associated parameter matches this value. + /// + public DoesNotReturnIfAttribute(bool parameterValue) => ParameterValue = parameterValue; + + /// Gets the condition parameter value. + public bool ParameterValue { get; } + } + +#endif + +#if !NETCOREAPP || NETCOREAPP3_1 + + /// Specifies that the method or property will ensure that the listed field and property members have not-null values. + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] + internal sealed class MemberNotNullAttribute : Attribute + { + /// Initializes the attribute with a field or property member. + /// + /// The field or property member that is promised to be not-null. + /// + public MemberNotNullAttribute(string member) => Members = new[] { member }; + + /// Initializes the attribute with the list of field and property members. + /// + /// The list of field and property members that are promised to be not-null. + /// + public MemberNotNullAttribute(params string[] members) => Members = members; + + /// Gets field or property member names. + public string[] Members { get; } + } + + /// Specifies that the method or property will ensure that the listed field and property members have not-null values when returning with the specified return value condition. + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] + internal sealed class MemberNotNullWhenAttribute : Attribute + { + /// Initializes the attribute with the specified return value condition and a field or property member. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + /// + /// + /// The field or property member that is promised to be not-null. + /// + public MemberNotNullWhenAttribute(bool returnValue, string member) + { + ReturnValue = returnValue; + Members = new[] { member }; + } + + /// Initializes the attribute with the specified return value condition and list of field and property members. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + /// + /// + /// The list of field and property members that are promised to be not-null. + /// + public MemberNotNullWhenAttribute(bool returnValue, params string[] members) + { + ReturnValue = returnValue; + Members = members; + } + + /// Gets the return value condition. + public bool ReturnValue { get; } + + /// Gets field or property member names. + public string[] Members { get; } + } + +#endif +} diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/RequiresUnreferencedCodeAttribute.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/RequiresUnreferencedCodeAttribute.cs new file mode 100644 index 0000000000..678d9266ed --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/RequiresUnreferencedCodeAttribute.cs @@ -0,0 +1,23 @@ +#if !NET5_0_OR_GREATER + +namespace System.Diagnostics.CodeAnalysis; + +[System.AttributeUsage( + System.AttributeTargets.Method | + System.AttributeTargets.Constructor | + System.AttributeTargets.Class, Inherited = false)] +[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +[System.Diagnostics.Conditional("MULTI_TARGETING_SUPPORT_ATTRIBUTES")] +internal sealed class RequiresUnreferencedCodeAttribute : System.Attribute +{ + public RequiresUnreferencedCodeAttribute(string message) + { + Message = message; + } + + public string Message { get; } + + public string? Url { get; set; } +} + +#endif \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/SequenceReader.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/SequenceReader.cs new file mode 100644 index 0000000000..16da4a4b87 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/SequenceReader.cs @@ -0,0 +1,417 @@ +// // Copied from https://github.com/dotnet/runtime/raw/cf5b231fcbea483df3b081939b422adfb6fd486a/src/libraries/System.Memory/src/System/Buffers/SequenceReader.cs +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +#if NETSTANDARD2_0 + +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace System.Buffers +{ + /// + /// Provides methods for reading binary and text data out of a with a focus on performance and minimal or zero heap allocations. + /// + /// The type of element stored by the . + internal ref partial struct SequenceReader where T : unmanaged, IEquatable + { + private SequencePosition _currentPosition; + private SequencePosition _nextPosition; + private bool _moreData; + private readonly long _length; + + /// + /// Create a over the given . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public SequenceReader(ReadOnlySequence sequence) + { + CurrentSpanIndex = 0; + Consumed = 0; + Sequence = sequence; + _currentPosition = sequence.Start; + _length = -1; + + var first = sequence.First.Span; + _nextPosition = sequence.GetPosition(first.Length); + CurrentSpan = first; + + _moreData = first.Length > 0; + + if (!_moreData && !sequence.IsSingleSegment) + { + _moreData = true; + GetNextSpan(); + } + } + + /// + /// True when there is no more data in the . + /// + public readonly bool End => !_moreData; + + /// + /// The underlying for the reader. + /// + public readonly ReadOnlySequence Sequence { get; } + + /// + /// The current position in the . + /// + public readonly SequencePosition Position + => Sequence.GetPosition(CurrentSpanIndex, _currentPosition); + + /// + /// The current segment in the as a span. + /// + public ReadOnlySpan CurrentSpan { readonly get; private set; } + + /// + /// The index in the . + /// + public int CurrentSpanIndex { readonly get; private set; } + + /// + /// The unread portion of the . + /// + public readonly ReadOnlySpan UnreadSpan + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => CurrentSpan.Slice(CurrentSpanIndex); + } + + /// + /// The total number of 's processed by the reader. + /// + public long Consumed { readonly get; private set; } + + /// + /// Remaining 's in the reader's . + /// + public readonly long Remaining => Length - Consumed; + + /// + /// Count of in the reader's . + /// + public readonly long Length + { + get + { + if (_length < 0) + { + // Cast-away readonly to initialize lazy field + Unsafe.AsRef(_length) = Sequence.Length; + } + return _length; + } + } + + /// + /// Peeks at the next value without advancing the reader. + /// + /// The next value or default if at the end. + /// False if at the end of the reader. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly bool TryPeek(out T value) + { + if (_moreData) + { + value = CurrentSpan[CurrentSpanIndex]; + return true; + } + else + { + value = default; + return false; + } + } + + /// + /// Read the next value and advance the reader. + /// + /// The next value or default if at the end. + /// False if at the end of the reader. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryRead(out T value) + { + if (End) + { + value = default; + return false; + } + + value = CurrentSpan[CurrentSpanIndex]; + CurrentSpanIndex++; + Consumed++; + + if (CurrentSpanIndex >= CurrentSpan.Length) + { + GetNextSpan(); + } + + return true; + } + + /// + /// Move the reader back the specified number of items. + /// + /// + /// Thrown if trying to rewind a negative amount or more than . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Rewind(long count) + { + if ((ulong)count > (ulong)Consumed) + { + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count)); + } + + Consumed -= count; + + if (CurrentSpanIndex >= count) + { + CurrentSpanIndex -= (int)count; + _moreData = true; + } + else + { + // Current segment doesn't have enough data, scan backward through segments + RetreatToPreviousSpan(Consumed); + } + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void RetreatToPreviousSpan(long consumed) + { + ResetReader(); + Advance(consumed); + } + + private void ResetReader() + { + CurrentSpanIndex = 0; + Consumed = 0; + _currentPosition = Sequence.Start; + _nextPosition = _currentPosition; + + if (Sequence.TryGet(ref _nextPosition, out ReadOnlyMemory memory, advance: true)) + { + _moreData = true; + + if (memory.Length == 0) + { + CurrentSpan = default; + // No data in the first span, move to one with data + GetNextSpan(); + } + else + { + CurrentSpan = memory.Span; + } + } + else + { + // No data in any spans and at end of sequence + _moreData = false; + CurrentSpan = default; + } + } + + /// + /// Get the next segment with available data, if any. + /// + private void GetNextSpan() + { + if (!Sequence.IsSingleSegment) + { + SequencePosition previousNextPosition = _nextPosition; + while (Sequence.TryGet(ref _nextPosition, out ReadOnlyMemory memory, advance: true)) + { + _currentPosition = previousNextPosition; + if (memory.Length > 0) + { + CurrentSpan = memory.Span; + CurrentSpanIndex = 0; + return; + } + else + { + CurrentSpan = default; + CurrentSpanIndex = 0; + previousNextPosition = _nextPosition; + } + } + } + _moreData = false; + } + + /// + /// Move the reader ahead the specified number of items. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Advance(long count) + { + const long TooBigOrNegative = unchecked((long)0xFFFFFFFF80000000); + if ((count & TooBigOrNegative) == 0 && CurrentSpan.Length - CurrentSpanIndex > (int)count) + { + CurrentSpanIndex += (int)count; + Consumed += count; + } + else + { + // Can't satisfy from the current span + AdvanceToNextSpan(count); + } + } + + /// + /// Unchecked helper to avoid unnecessary checks where you know count is valid. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void AdvanceCurrentSpan(long count) + { + Debug.Assert(count >= 0); + + Consumed += count; + CurrentSpanIndex += (int)count; + if (CurrentSpanIndex >= CurrentSpan.Length) + GetNextSpan(); + } + + /// + /// Only call this helper if you know that you are advancing in the current span + /// with valid count and there is no need to fetch the next one. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void AdvanceWithinSpan(long count) + { + Debug.Assert(count >= 0); + + Consumed += count; + CurrentSpanIndex += (int)count; + + Debug.Assert(CurrentSpanIndex < CurrentSpan.Length); + } + + private void AdvanceToNextSpan(long count) + { + if (count < 0) + { + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count)); + } + + Consumed += count; + while (_moreData) + { + int remaining = CurrentSpan.Length - CurrentSpanIndex; + + if (remaining > count) + { + CurrentSpanIndex += (int)count; + count = 0; + break; + } + + // As there may not be any further segments we need to + // push the current index to the end of the span. + CurrentSpanIndex += remaining; + count -= remaining; + Debug.Assert(count >= 0); + + GetNextSpan(); + + if (count == 0) + { + break; + } + } + + if (count != 0) + { + // Not enough data left- adjust for where we actually ended and throw + Consumed -= count; + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count)); + } + } + + /// + /// Copies data from the current to the given span if there + /// is enough data to fill it. + /// + /// + /// This API is used to copy a fixed amount of data out of the sequence if possible. It does not advance + /// the reader. To look ahead for a specific stream of data can be used. + /// + /// Destination span to copy to. + /// True if there is enough data to completely fill the span. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly bool TryCopyTo(Span destination) + { + // This API doesn't advance to facilitate conditional advancement based on the data returned. + // We don't provide an advance option to allow easier utilizing of stack allocated destination spans. + // (Because we can make this method readonly we can guarantee that we won't capture the span.) + + ReadOnlySpan firstSpan = UnreadSpan; + if (firstSpan.Length >= destination.Length) + { + firstSpan.Slice(0, destination.Length).CopyTo(destination); + return true; + } + + // Not enough in the current span to satisfy the request, fall through to the slow path + return TryCopyMultisegment(destination); + } + + internal readonly bool TryCopyMultisegment(Span destination) + { + // If we don't have enough to fill the requested buffer, return false + if (Remaining < destination.Length) + return false; + + ReadOnlySpan firstSpan = UnreadSpan; + Debug.Assert(firstSpan.Length < destination.Length); + firstSpan.CopyTo(destination); + int copied = firstSpan.Length; + + SequencePosition next = _nextPosition; + while (Sequence.TryGet(ref next, out ReadOnlyMemory nextSegment, true)) + { + if (nextSegment.Length > 0) + { + ReadOnlySpan nextSpan = nextSegment.Span; + int toCopy = Math.Min(nextSpan.Length, destination.Length - copied); + nextSpan.Slice(0, toCopy).CopyTo(destination.Slice(copied)); + copied += toCopy; + if (copied >= destination.Length) + { + break; + } + } + } + + return true; + } + + static class ThrowHelper + { + public static void ThrowArgumentOutOfRangeException(string name) => throw new ArgumentOutOfRangeException(name); + } + } +} + +#else + +using System.Buffers; +using System.Runtime.CompilerServices; + +#pragma warning disable RS0026 +#pragma warning disable RS0016 +#pragma warning disable RS0041 +[assembly: TypeForwardedTo(typeof(SequenceReader<>))] +#pragma warning restore RS0041 +#pragma warning restore RS0016 +#pragma warning restore RS0026 + +#endif diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/SequenceReaderExtensions.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/SequenceReaderExtensions.cs new file mode 100644 index 0000000000..1e7d3e8f35 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/SequenceReaderExtensions.cs @@ -0,0 +1,194 @@ +// // Copied from https://raw.githubusercontent.com/dotnet/runtime/cf5b231fcbea483df3b081939b422adfb6fd486a/src/libraries/System.Memory/src/System/Buffers/SequenceReaderExtensions.Binary.cs +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +#if NETSTANDARD2_0 + +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace System.Buffers +{ + /// + /// Provides extended functionality for the class that allows reading of endian specific numeric values from binary data. + /// + internal static partial class SequenceReaderExtensions + { + /// + /// Try to read the given type out of the buffer if possible. Warning: this is dangerous to use with arbitrary + /// structs- see remarks for full details. + /// + /// + /// IMPORTANT: The read is a straight copy of bits. If a struct depends on specific state of it's members to + /// behave correctly this can lead to exceptions, etc. If reading endian specific integers, use the explicit + /// overloads such as + /// + /// + /// True if successful. will be default if failed (due to lack of space). + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static unsafe bool TryRead(ref this SequenceReader reader, out T value) where T : unmanaged + { + ReadOnlySpan span = reader.UnreadSpan; + if (span.Length < sizeof(T)) + return TryReadMultisegment(ref reader, out value); + + value = Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(span)); + reader.Advance(sizeof(T)); + return true; + } + + private static unsafe bool TryReadMultisegment(ref SequenceReader reader, out T value) where T : unmanaged + { + Debug.Assert(reader.UnreadSpan.Length < sizeof(T)); + + // Not enough data in the current segment, try to peek for the data we need. + T buffer = default; + Span tempSpan = new Span(&buffer, sizeof(T)); + + if (!reader.TryCopyTo(tempSpan)) + { + value = default; + return false; + } + + value = Unsafe.ReadUnaligned(ref MemoryMarshal.GetReference(tempSpan)); + reader.Advance(sizeof(T)); + return true; + } + + /// + /// Reads an as little endian. + /// + /// False if there wasn't enough data for an . + public static bool TryReadLittleEndian(ref this SequenceReader reader, out short value) + { + if (BitConverter.IsLittleEndian) + { + return reader.TryRead(out value); + } + + return TryReadReverseEndianness(ref reader, out value); + } + + /// + /// Reads an as big endian. + /// + /// False if there wasn't enough data for an . + public static bool TryReadBigEndian(ref this SequenceReader reader, out short value) + { + if (!BitConverter.IsLittleEndian) + { + return reader.TryRead(out value); + } + + return TryReadReverseEndianness(ref reader, out value); + } + + private static bool TryReadReverseEndianness(ref SequenceReader reader, out short value) + { + if (reader.TryRead(out value)) + { + value = BinaryPrimitives.ReverseEndianness(value); + return true; + } + + return false; + } + + /// + /// Reads an as little endian. + /// + /// False if there wasn't enough data for an . + public static bool TryReadLittleEndian(ref this SequenceReader reader, out int value) + { + if (BitConverter.IsLittleEndian) + { + return reader.TryRead(out value); + } + + return TryReadReverseEndianness(ref reader, out value); + } + + /// + /// Reads an as big endian. + /// + /// False if there wasn't enough data for an . + public static bool TryReadBigEndian(ref this SequenceReader reader, out int value) + { + if (!BitConverter.IsLittleEndian) + { + return reader.TryRead(out value); + } + + return TryReadReverseEndianness(ref reader, out value); + } + + private static bool TryReadReverseEndianness(ref SequenceReader reader, out int value) + { + if (reader.TryRead(out value)) + { + value = BinaryPrimitives.ReverseEndianness(value); + return true; + } + + return false; + } + + /// + /// Reads an as little endian. + /// + /// False if there wasn't enough data for an . + public static bool TryReadLittleEndian(ref this SequenceReader reader, out long value) + { + if (BitConverter.IsLittleEndian) + { + return reader.TryRead(out value); + } + + return TryReadReverseEndianness(ref reader, out value); + } + + /// + /// Reads an as big endian. + /// + /// False if there wasn't enough data for an . + public static bool TryReadBigEndian(ref this SequenceReader reader, out long value) + { + if (!BitConverter.IsLittleEndian) + { + return reader.TryRead(out value); + } + + return TryReadReverseEndianness(ref reader, out value); + } + + private static bool TryReadReverseEndianness(ref SequenceReader reader, out long value) + { + if (reader.TryRead(out value)) + { + value = BinaryPrimitives.ReverseEndianness(value); + return true; + } + + return false; + } + } +} + +#else + +using System.Buffers; +using System.Runtime.CompilerServices; + +#pragma warning disable RS0016 +#pragma warning disable RS0041 +[assembly: TypeForwardedTo(typeof(SequenceReaderExtensions))] +#pragma warning restore RS0041 +#pragma warning restore RS0016 + +#endif diff --git a/src/Linux/Tmds.DBus.Protocol/Polyfill/UnconditionalSuppressMessageAttribute.cs b/src/Linux/Tmds.DBus.Protocol/Polyfill/UnconditionalSuppressMessageAttribute.cs new file mode 100644 index 0000000000..e3a50e2756 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Polyfill/UnconditionalSuppressMessageAttribute.cs @@ -0,0 +1,32 @@ +#if !NET5_0_OR_GREATER + +namespace System.Diagnostics.CodeAnalysis; + +[ExcludeFromCodeCoverage] +[DebuggerNonUserCode] +[AttributeUsage( + validOn: AttributeTargets.All, + Inherited = false, + AllowMultiple = true)] +sealed class UnconditionalSuppressMessageAttribute : + Attribute +{ + public UnconditionalSuppressMessageAttribute(string category, string checkId) + { + Category = category; + CheckId = checkId; + } + + public string Category { get; } + + public string CheckId { get; } + + public string? Scope { get; set; } + + public string? Target { get; set; } + + public string? MessageId { get; set; } + + public string? Justification { get; set; } +} +#endif \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/ProtocolConstants.cs b/src/Linux/Tmds.DBus.Protocol/ProtocolConstants.cs new file mode 100644 index 0000000000..ab935f6bcd --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ProtocolConstants.cs @@ -0,0 +1,86 @@ +namespace Tmds.DBus.Protocol; + +static class ProtocolConstants +{ + public const int MaxSignatureLength = 256; + + // note: C# compiler treats these as static data. + public static ReadOnlySpan ByteSignature => new byte[] { (byte)'y' }; + public static ReadOnlySpan BooleanSignature => new byte[] { (byte)'b' }; + public static ReadOnlySpan Int16Signature => new byte[] { (byte)'n' }; + public static ReadOnlySpan UInt16Signature => new byte[] { (byte)'q' }; + public static ReadOnlySpan Int32Signature => new byte[] { (byte)'i' }; + public static ReadOnlySpan UInt32Signature => new byte[] { (byte)'u' }; + public static ReadOnlySpan Int64Signature => new byte[] { (byte)'x' }; + public static ReadOnlySpan UInt64Signature => new byte[] { (byte)'t' }; + public static ReadOnlySpan DoubleSignature => new byte[] { (byte)'d' }; + public static ReadOnlySpan UnixFdSignature => new byte[] { (byte)'h' }; + public static ReadOnlySpan StringSignature => new byte[] { (byte)'s' }; + public static ReadOnlySpan ObjectPathSignature => new byte[] { (byte)'o' }; + public static ReadOnlySpan SignatureSignature => new byte[] { (byte)'g' }; + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetTypeAlignment(DBusType type) + { + switch (type) + { + case DBusType.Byte: return 1; + case DBusType.Bool: return 4; + case DBusType.Int16: return 2; + case DBusType.UInt16: return 2; + case DBusType.Int32: return 4; + case DBusType.UInt32: return 4; + case DBusType.Int64: return 8; + case DBusType.UInt64: return 8; + case DBusType.Double: return 8; + case DBusType.String: return 4; + case DBusType.ObjectPath: return 4; + case DBusType.Signature: return 4; + case DBusType.Array: return 4; + case DBusType.Struct: return 8; + case DBusType.Variant: return 1; + case DBusType.DictEntry: return 8; + case DBusType.UnixFd: return 4; + default: return 1; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetFixedTypeLength(DBusType type) + { + switch (type) + { + case DBusType.Byte: return 1; + case DBusType.Bool: return 4; + case DBusType.Int16: return 2; + case DBusType.UInt16: return 2; + case DBusType.Int32: return 4; + case DBusType.UInt32: return 4; + case DBusType.Int64: return 8; + case DBusType.UInt64: return 8; + case DBusType.Double: return 8; + case DBusType.UnixFd: return 4; + default: return 0; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Align(int offset, DBusType type) + { + return offset + GetPadding(offset, type); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetPadding(int offset, DBusType type) + { + int alignment = GetTypeAlignment(type); + return GetPadding(offset ,alignment); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetPadding(int offset, int alignment) + { + return (~offset + 1) & (alignment - 1); + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/ProtocolException.cs b/src/Linux/Tmds.DBus.Protocol/ProtocolException.cs new file mode 100644 index 0000000000..745bd8bda2 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ProtocolException.cs @@ -0,0 +1,7 @@ +namespace Tmds.DBus.Protocol; + +public class ProtocolException : Exception +{ + public ProtocolException(string message) : base(message) + { } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Array.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Array.cs new file mode 100644 index 0000000000..f1fffe3b57 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Array.cs @@ -0,0 +1,166 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + public byte[] ReadArrayOfByte() + => ReadArrayOfNumeric(); + + public bool[] ReadArrayOfBool() + => ReadArrayOfT(); + + public short[] ReadArrayOfInt16() + => ReadArrayOfNumeric(); + + public ushort[] ReadArrayOfUInt16() + => ReadArrayOfNumeric(); + + public int[] ReadArrayOfInt32() + => ReadArrayOfNumeric(); + + public uint[] ReadArrayOfUInt32() + => ReadArrayOfNumeric(); + + public long[] ReadArrayOfInt64() + => ReadArrayOfNumeric(); + + public ulong[] ReadArrayOfUInt64() + => ReadArrayOfNumeric(); + + public double[] ReadArrayOfDouble() + => ReadArrayOfNumeric(); + + public string[] ReadArrayOfString() + => ReadArrayOfT(); + + public ObjectPath[] ReadArrayOfObjectPath() + => ReadArrayOfT(); + + public Signature[] ReadArrayOfSignature() + => ReadArrayOfT(); + + public VariantValue[] ReadArrayOfVariantValue() + => ReadArrayOfT(); + + public T[] ReadArrayOfHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() where T : SafeHandle + => ReadArrayOfT(); + + [RequiresUnreferencedCode(Strings.UseNonGenericReadArray)] + [Obsolete(Strings.UseNonGenericReadArrayObsolete)] + public T[] ReadArray<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + { + if (typeof(T) == typeof(byte)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(short)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(ushort)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(int)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(uint)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(long)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(ulong)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else if (typeof(T) == typeof(double)) + { + return (T[])(object)ReadArrayOfNumeric(); + } + else + { + return ReadArrayOfT(); + } + } + + private T[] ReadArrayOfT<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + { + List items = new(); + ArrayEnd arrayEnd = ReadArrayStart(TypeModel.GetTypeAlignment()); + while (HasNext(arrayEnd)) + { + items.Add(Read()); + } + return items.ToArray(); + } + + private unsafe T[] ReadArrayOfNumeric() where T : unmanaged + { + int length = ReadInt32(); + if (sizeof(T) > 4) + { + AlignReader(sizeof(T)); + } + T[] array = new T[length / sizeof(T)]; + bool dataRead = _reader.TryCopyTo(MemoryMarshal.AsBytes(array.AsSpan())); + if (!dataRead) + { + ThrowHelper.ThrowIndexOutOfRange(); + } + _reader.Advance(sizeof(T) * array.Length); + if (sizeof(T) > 1 && ReverseEndianness) + { +#if NET8_0_OR_GREATER + if (sizeof(T) == 2) + { + var span = MemoryMarshal.Cast(array.AsSpan()); + BinaryPrimitives.ReverseEndianness(span, span); + } + else if (sizeof(T) == 4) + { + var span = MemoryMarshal.Cast(array.AsSpan()); + BinaryPrimitives.ReverseEndianness(span, span); + } + else if (sizeof(T) == 8) + { + Span span = MemoryMarshal.Cast(array.AsSpan()); + BinaryPrimitives.ReverseEndianness(span, span); + } +#else + Span span = array.AsSpan(); + for (int i = 0; i < span.Length; i++) + { + if (sizeof(T) == 2) + { + span[i] = (T)(object)BinaryPrimitives.ReverseEndianness((short)(object)span[i]); + } + else if (sizeof(T) == 4) + { + span[i] = (T)(object)BinaryPrimitives.ReverseEndianness((int)(object)span[i]); + } + else if (typeof(T) == typeof(double)) + { + span[i] = (T)(object)ReverseDoubleEndianness((double)(object)span[i]); + } + else if (sizeof(T) == 8) + { + span[i] = (T)(object)BinaryPrimitives.ReverseEndianness((long)(object)span[i]); + } + } +#endif + } + return array; + +#if !NET8_0_OR_GREATER + static double ReverseDoubleEndianness(double d) + { + long l = *(long*)&d; + l = BinaryPrimitives.ReverseEndianness(l); + return *(double*)&d; + } +#endif + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Basic.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Basic.cs new file mode 100644 index 0000000000..c6e018c0c8 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Basic.cs @@ -0,0 +1,132 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + public byte ReadByte() + { + if (!_reader.TryRead(out byte b)) + { + ThrowHelper.ThrowIndexOutOfRange(); + } + return b; + } + + public bool ReadBool() + { + return ReadInt32() != 0; + } + + public ushort ReadUInt16() + => (ushort)ReadInt16(); + + public short ReadInt16() + { + AlignReader(DBusType.Int16); + bool dataRead = _isBigEndian ? _reader.TryReadBigEndian(out short rv) : _reader.TryReadLittleEndian(out rv); + if (!dataRead) + { + ThrowHelper.ThrowIndexOutOfRange(); + } + return rv; + } + + public uint ReadUInt32() + => (uint)ReadInt32(); + + public int ReadInt32() + { + AlignReader(DBusType.Int32); + bool dataRead = _isBigEndian ? _reader.TryReadBigEndian(out int rv) : _reader.TryReadLittleEndian(out rv); + if (!dataRead) + { + ThrowHelper.ThrowIndexOutOfRange(); + } + return rv; + } + + public ulong ReadUInt64() + => (ulong)ReadInt64(); + + public long ReadInt64() + { + AlignReader(DBusType.Int64); + bool dataRead = _isBigEndian ? _reader.TryReadBigEndian(out long rv) : _reader.TryReadLittleEndian(out rv); + if (!dataRead) + { + ThrowHelper.ThrowIndexOutOfRange(); + } + return rv; + } + + public unsafe double ReadDouble() + { + double value; + *(long*)&value = ReadInt64(); + return value; + } + + public Utf8Span ReadSignature() + { + int length = ReadByte(); + return ReadSpan(length); + } + + public void ReadSignature(string expected) + { + ReadOnlySpan signature = ReadSignature().Span; + if (signature.Length != expected.Length) + { + ThrowHelper.ThrowUnexpectedSignature(signature, expected); + } + for (int i = 0; i < signature.Length; i++) + { + if (signature[i] != expected[i]) + { + ThrowHelper.ThrowUnexpectedSignature(signature, expected); + } + } + } + + public Utf8Span ReadObjectPathAsSpan() => ReadSpan(); + + public ObjectPath ReadObjectPath() => new ObjectPath(ReadString()); + + public ObjectPath ReadObjectPathAsString() => ReadString(); + + public Utf8Span ReadStringAsSpan() => ReadSpan(); + + public string ReadString() => Encoding.UTF8.GetString(ReadSpan()); + + public Signature ReadSignatureAsSignature() => new Signature(ReadSignature().ToString()); + + public string ReadSignatureAsString() => ReadSignature().ToString(); + + private ReadOnlySpan ReadSpan() + { + int length = (int)ReadUInt32(); + return ReadSpan(length); + } + + private ReadOnlySpan ReadSpan(int length) + { + var span = _reader.UnreadSpan; + if (span.Length >= length) + { + _reader.Advance(length + 1); + return span.Slice(0, length); + } + else + { + var buffer = new byte[length]; + if (!_reader.TryCopyTo(buffer)) + { + ThrowHelper.ThrowIndexOutOfRange(); + } + _reader.Advance(length + 1); + return new ReadOnlySpan(buffer); + } + } + + private bool ReverseEndianness + => BitConverter.IsLittleEndian != !_isBigEndian; +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Dictionary.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Dictionary.cs new file mode 100644 index 0000000000..c1f04d41b0 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Dictionary.cs @@ -0,0 +1,48 @@ +namespace Tmds.DBus.Protocol; + + +// Using obsolete generic read members +#pragma warning disable CS0618 + +public ref partial struct Reader +{ + public ArrayEnd ReadDictionaryStart() + => ReadArrayStart(DBusType.Struct); + + // Read method for the common 'a{sv}' type. + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // It's safe to call ReadDictionary with these types. + public Dictionary ReadDictionaryOfStringToVariantValue() + => ReadDictionary(); + + [RequiresUnreferencedCode(Strings.UseNonGenericReadDictionary)] + [Obsolete(Strings.UseNonGenericReadDictionaryObsolete)] + public Dictionary ReadDictionary + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]TKey, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]TValue + > + () + where TKey : notnull + where TValue : notnull + => ReadDictionary(new Dictionary()); + + internal Dictionary ReadDictionary + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]TKey, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]TValue + > + (Dictionary dictionary) + where TKey : notnull + where TValue : notnull + { + ArrayEnd dictEnd = ReadDictionaryStart(); + while (HasNext(dictEnd)) + { + var key = Read(); + var value = Read(); + // Use the indexer to avoid throwing if the key is present multiple times. + dictionary[key] = value; + } + return dictionary; + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Handle.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Handle.cs new file mode 100644 index 0000000000..b0a6912da2 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Handle.cs @@ -0,0 +1,36 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + public T? ReadHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() where T : SafeHandle + => ReadHandleGeneric(); + + internal T? ReadHandleGeneric<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + { + int idx = (int)ReadUInt32(); + if (idx >= _handleCount) + { + throw new IndexOutOfRangeException(); + } + if (_handles is not null) + { + return _handles.ReadHandleGeneric(idx); + } + return default(T); + } + + // note: The handle is still owned (i.e. Disposed) by the Message. + public IntPtr ReadHandleRaw() + { + int idx = (int)ReadUInt32(); + if (idx >= _handleCount) + { + throw new IndexOutOfRangeException(); + } + if (_handles is not null) + { + return _handles.ReadHandleRaw(idx); + } + return new IntPtr(-1); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.ReadT.Dynamic.cs b/src/Linux/Tmds.DBus.Protocol/Reader.ReadT.Dynamic.cs new file mode 100644 index 0000000000..16568e862e --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.ReadT.Dynamic.cs @@ -0,0 +1,1045 @@ +namespace Tmds.DBus.Protocol; + +// Code in this file is not trimmer friendly. +#pragma warning disable IL3050 +#pragma warning disable IL2055 +#pragma warning disable IL2091 +#pragma warning disable IL2026 +// Using obsolete generic read members +#pragma warning disable CS0618 + +public ref partial struct Reader +{ + interface ITypeReader + { } + + interface ITypeReader : ITypeReader + { + T Read(ref Reader reader); + } + + private T ReadDynamic() + { + Type type = typeof(T); + + if (type == typeof(object)) + { + Utf8Span signature = ReadSignature(); + type = DetermineVariantType(signature); + + if (type == typeof(byte)) + { + return (T)(object)ReadByte(); + } + else if (type == typeof(bool)) + { + return (T)(object)ReadBool(); + } + else if (type == typeof(short)) + { + return (T)(object)ReadInt16(); + } + else if (type == typeof(ushort)) + { + return (T)(object)ReadUInt16(); + } + else if (type == typeof(int)) + { + return (T)(object)ReadInt32(); + } + else if (type == typeof(uint)) + { + return (T)(object)ReadUInt32(); + } + else if (type == typeof(long)) + { + return (T)(object)ReadInt64(); + } + else if (type == typeof(ulong)) + { + return (T)(object)ReadUInt64(); + } + else if (type == typeof(double)) + { + return (T)(object)ReadDouble(); + } + else if (type == typeof(string)) + { + return (T)(object)ReadString(); + } + else if (type == typeof(ObjectPath)) + { + return (T)(object)ReadObjectPath(); + } + else if (type == typeof(Signature)) + { + return (T)(object)ReadSignatureAsSignature(); + } + else if (type == typeof(SafeHandle)) + { + return (T)(object)ReadHandle()!; + } + else if (type == typeof(VariantValue)) + { + return (T)(object)ReadVariantValue(); + } + } + + var typeReader = (ITypeReader)TypeReaders.GetTypeReader(type); + return typeReader.Read(ref this); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL3050")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2055")] + private static Type DetermineVariantType(Utf8Span signature) + { + Func map = (dbusType, innerTypes) => + { + switch (dbusType) + { + case DBusType.Byte: return typeof(byte); + case DBusType.Bool: return typeof(bool); + case DBusType.Int16: return typeof(short); + case DBusType.UInt16: return typeof(ushort); + case DBusType.Int32: return typeof(int); + case DBusType.UInt32: return typeof(uint); + case DBusType.Int64: return typeof(long); + case DBusType.UInt64: return typeof(ulong); + case DBusType.Double: return typeof(double); + case DBusType.String: return typeof(string); + case DBusType.ObjectPath: return typeof(ObjectPath); + case DBusType.Signature: return typeof(Signature); + case DBusType.UnixFd: return typeof(SafeHandle); + case DBusType.Array: return innerTypes[0].MakeArrayType(); + case DBusType.DictEntry: return typeof(Dictionary<,>).MakeGenericType(innerTypes); + case DBusType.Struct: + switch (innerTypes.Length) + { + case 1: return typeof(ValueTuple<>).MakeGenericType(innerTypes); + case 2: return typeof(ValueTuple<,>).MakeGenericType(innerTypes); + case 3: return typeof(ValueTuple<,,>).MakeGenericType(innerTypes); + case 4: return typeof(ValueTuple<,,,>).MakeGenericType(innerTypes); + case 5: return typeof(ValueTuple<,,,,>).MakeGenericType(innerTypes); + case 6: return typeof(ValueTuple<,,,,,>).MakeGenericType(innerTypes); + case 7: return typeof(ValueTuple<,,,,,,>).MakeGenericType(innerTypes); + case 8: + case 9: + case 10: + var types = new Type[8]; + innerTypes.AsSpan(0, 7).CopyTo(types); + types[7] = innerTypes.Length switch + { + 8 => typeof(ValueTuple<>).MakeGenericType(new[] { innerTypes[7] }), + 9 => typeof(ValueTuple<,>).MakeGenericType(new[] { innerTypes[7], innerTypes[8] }), + 10 => typeof(ValueTuple<,,>).MakeGenericType(new[] { innerTypes[7], innerTypes[8], innerTypes[9] }), + _ => null! + }; + return typeof(ValueTuple<,,,,,,,>).MakeGenericType(types); + } + break; + } + return typeof(object); + }; + + // TODO (perf): add caching. + return SignatureReader.Transform(signature, map); + } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddArrayTypeReader() + where T : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddKeyValueArrayTypeReader() + where TKey : notnull + where TValue : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddDictionaryTypeReader() + where TKey : notnull + where TValue : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddValueTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { } + + [Obsolete(Strings.AddTypeReaderMethodObsolete)] + public static void AddTupleTypeReader() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL3050")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2091")] + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] + static class TypeReaders + { + private static readonly Dictionary _typeReaders = new(); + + public static ITypeReader GetTypeReader(Type type) + { + lock (_typeReaders) + { + if (_typeReaders.TryGetValue(type, out ITypeReader? reader)) + { + return reader; + } + reader = CreateReaderForType(type); + _typeReaders.Add(type, reader); + return reader; + } + } + + private static ITypeReader CreateReaderForType(Type type) + { + // Array + if (type.IsArray) + { + return CreateArrayTypeReader(type.GetElementType()!); + } + + // Dictionary<.> + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Dictionary<,>)) + { + Type keyType = type.GenericTypeArguments[0]; + Type valueType = type.GenericTypeArguments[1]; + return CreateDictionaryTypeReader(keyType, valueType); + } + + // Struct (ValueTuple) + if (type.IsGenericType && type.FullName!.StartsWith("System.ValueTuple")) + { + switch (type.GenericTypeArguments.Length) + { + case 1: return CreateValueTupleTypeReader(type.GenericTypeArguments[0]); + case 2: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1]); + case 3: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2]); + case 4: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3]); + case 5: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4]); + + case 6: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5]); + case 7: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6]); + case 8: + switch (type.GenericTypeArguments[7].GenericTypeArguments.Length) + { + case 1: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0]); + case 2: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1]); + case 3: + return CreateValueTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1], + type.GenericTypeArguments[7].GenericTypeArguments[2]); + } + break; + } + } + // Struct (ValueTuple) + if (type.IsGenericType && type.FullName!.StartsWith("System.Tuple")) + { + switch (type.GenericTypeArguments.Length) + { + case 1: return CreateTupleTypeReader(type.GenericTypeArguments[0]); + case 2: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1]); + case 3: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2]); + case 4: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3]); + case 5: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4]); + case 6: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5]); + case 7: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6]); + case 8: + switch (type.GenericTypeArguments[7].GenericTypeArguments.Length) + { + case 1: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0]); + case 2: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1]); + case 3: + return CreateTupleTypeReader(type.GenericTypeArguments[0], + type.GenericTypeArguments[1], + type.GenericTypeArguments[2], + type.GenericTypeArguments[3], + type.GenericTypeArguments[4], + type.GenericTypeArguments[5], + type.GenericTypeArguments[6], + type.GenericTypeArguments[7].GenericTypeArguments[0], + type.GenericTypeArguments[7].GenericTypeArguments[1], + type.GenericTypeArguments[7].GenericTypeArguments[2]); + } + break; + } + } + + ThrowNotSupportedType(type); + return default!; + } + + sealed class KeyValueArrayTypeReader : ITypeReader[]>, ITypeReader + where TKey : notnull + where TValue : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public KeyValuePair[] Read(ref Reader reader) + { + List> items = new(); + ArrayEnd arrayEnd = reader.ReadArrayStart(DBusType.Struct); + while (reader.HasNext(arrayEnd)) + { + TKey key = reader.Read(); + TValue value = reader.Read(); + items.Add(new KeyValuePair(key, value)); + } + return items.ToArray(); + } + } + + sealed class ArrayTypeReader : ITypeReader, ITypeReader + where T : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public T[] Read(ref Reader reader) + { + return reader.ReadArray(); + } + } + + private static ITypeReader CreateArrayTypeReader(Type elementType) + { + Type readerType; + if (elementType.IsGenericType && elementType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>)) + { + Type keyType = elementType.GenericTypeArguments[0]; + Type valueType = elementType.GenericTypeArguments[1]; + readerType = typeof(KeyValueArrayTypeReader<,>).MakeGenericType(new[] { keyType, valueType }); + } + else + { + readerType = typeof(ArrayTypeReader<>).MakeGenericType(new[] { elementType }); + } + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class DictionaryTypeReader : ITypeReader>, ITypeReader + where TKey : notnull + where TValue : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Dictionary Read(ref Reader reader) + { + return reader.ReadDictionary(); + } + } + + private static ITypeReader CreateDictionaryTypeReader(Type keyType, Type valueType) + { + Type readerType = typeof(DictionaryTypeReader<,>).MakeGenericType(new[] { keyType, valueType }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1) + { + Type readerType = typeof(ValueTupleTypeReader<>).MakeGenericType(new[] { type1 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1) + { + Type readerType = typeof(TupleTypeReader<>).MakeGenericType(new[] { type1 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2) + { + Type readerType = typeof(ValueTupleTypeReader<,>).MakeGenericType(new[] { type1, type2 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2) + { + Type readerType = typeof(TupleTypeReader<,>).MakeGenericType(new[] { type1, type2 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3) + { + Type readerType = typeof(ValueTupleTypeReader<,,>).MakeGenericType(new[] { type1, type2, type3 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3) + { + Type readerType = typeof(TupleTypeReader<,,>).MakeGenericType(new[] { type1, type2, type3 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4) + { + Type readerType = typeof(ValueTupleTypeReader<,,,>).MakeGenericType(new[] { type1, type2, type3, type4 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4) + { + Type readerType = typeof(TupleTypeReader<,,,>).MakeGenericType(new[] { type1, type2, type3, type4 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5) + { + Type readerType = typeof(ValueTupleTypeReader<,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5) + { + Type readerType = typeof(TupleTypeReader<,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type5 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6) + { + Type readerType = typeof(ValueTupleTypeReader<,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6) + { + Type readerType = typeof(TupleTypeReader<,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7) + { + Type readerType = typeof(ValueTupleTypeReader<,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7) + { + Type readerType = typeof(TupleTypeReader<,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple> Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple> Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8) + { + Type readerType = typeof(ValueTupleTypeReader<,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8) + { + Type readerType = typeof(TupleTypeReader<,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple> Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple> Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9) + { + Type readerType = typeof(ValueTupleTypeReader<,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type9 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9) + { + Type readerType = typeof(TupleTypeReader<,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type9 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + sealed class ValueTupleTypeReader : ITypeReader>>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public ValueTuple> Read(ref Reader reader) + { + return reader.ReadStruct(); + } + } + + sealed class TupleTypeReader : ITypeReader>>, ITypeReader + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + object ITypeReader.Read(ref Reader reader) => Read(ref reader); + + public Tuple> Read(ref Reader reader) + { + return reader.ReadStructAsTuple(); + } + } + + private static ITypeReader CreateValueTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9, Type type10) + { + Type readerType = typeof(ValueTupleTypeReader<,,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type9, type10 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + + private static ITypeReader CreateTupleTypeReader(Type type1, Type type2, Type type3, Type type4, Type type5, Type type6, Type type7, Type type8, Type type9, Type type10) + { + Type readerType = typeof(TupleTypeReader<,,,,,,,,,>).MakeGenericType(new[] { type1, type2, type3, type4, type5, type6, type7, type8, type9, type10 }); + return (ITypeReader)Activator.CreateInstance(readerType)!; + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.ReadT.cs b/src/Linux/Tmds.DBus.Protocol/Reader.ReadT.cs new file mode 100644 index 0000000000..0ca8de4ffc --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.ReadT.cs @@ -0,0 +1,77 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal T Read<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + { + if (typeof(T) == typeof(byte)) + { + return (T)(object)ReadByte(); + } + else if (typeof(T) == typeof(bool)) + { + return (T)(object)ReadBool(); + } + else if (typeof(T) == typeof(short)) + { + return (T)(object)ReadInt16(); + } + else if (typeof(T) == typeof(ushort)) + { + return (T)(object)ReadUInt16(); + } + else if (typeof(T) == typeof(int)) + { + return (T)(object)ReadInt32(); + } + else if (typeof(T) == typeof(uint)) + { + return (T)(object)ReadUInt32(); + } + else if (typeof(T) == typeof(long)) + { + return (T)(object)ReadInt64(); + } + else if (typeof(T) == typeof(ulong)) + { + return (T)(object)ReadUInt64(); + } + else if (typeof(T) == typeof(double)) + { + return (T)(object)ReadDouble(); + } + else if (typeof(T) == typeof(string)) + { + return (T)(object)ReadString(); + } + else if (typeof(T) == typeof(ObjectPath)) + { + return (T)(object)ReadObjectPath(); + } + else if (typeof(T) == typeof(Signature)) + { + return (T)(object)ReadSignatureAsSignature(); + } + else if (typeof(T).IsAssignableTo(typeof(SafeHandle))) + { + return (T)(object)ReadHandleGeneric()!; + } + else if (typeof(T) == typeof(VariantValue)) + { + return (T)(object)ReadVariantValue(); + } + else if (Feature.IsDynamicCodeEnabled) + { + return ReadDynamic(); + } + + ThrowNotSupportedType(typeof(T)); + return default!; + } + + private static void ThrowNotSupportedType(Type type) + { + throw new NotSupportedException($"Cannot read type {type.FullName}"); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Struct.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Struct.cs new file mode 100644 index 0000000000..412bc00e09 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Struct.cs @@ -0,0 +1,309 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1 + >() + { + AlignStruct(); + return ValueTuple.Create(Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1 + >() + { + AlignStruct(); + return Tuple.Create(Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2 + >() + where T1 : notnull + where T2 : notnull + { + AlignStruct(); + return ValueTuple.Create(Read(), Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2 + >() + where T1 : notnull + where T2 : notnull + { + AlignStruct(); + return Tuple.Create(Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3 + >() + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + AlignStruct(); + return ValueTuple.Create(Read(), Read(), Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3 + >() + { + AlignStruct(); + return Tuple.Create(Read(), Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4 + >() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + AlignStruct(); + return ValueTuple.Create(Read(), Read(), Read(), Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4 + >() + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + AlignStruct(); + return Tuple.Create(Read(), Read(), Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5 + >() + { + AlignStruct(); + return ValueTuple.Create(Read(), Read(), Read(), Read(), Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5 + >() + { + AlignStruct(); + return Tuple.Create(Read(), Read(), Read(), Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6 + >() + { + AlignStruct(); + return ValueTuple.Create(Read(), Read(), Read(), Read(), Read(), Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6 + >() + { + AlignStruct(); + return Tuple.Create(Read(), Read(), Read(), Read(), Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7 + >() + { + AlignStruct(); + return ValueTuple.Create(Read(), Read(), Read(), Read(), Read(), Read(), Read()); + } + + private Tuple ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7 + >() + { + AlignStruct(); + return Tuple.Create(Read(), Read(), Read(), Read(), Read(), Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple> ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T8 + >() + { + AlignStruct(); + return ValueTuple.Create(Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read()); + } + + private Tuple> ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T8 + >() + { + AlignStruct(); + return Tuple.Create(Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read()); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple> ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T8, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T9 + >() + { + AlignStruct(); + return (Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read()); + } + + private Tuple> ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T8, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T9 + >() + { + AlignStruct(); + return new Tuple>(Read(), Read(), Read(), Read(), Read(), Read(), Read(), Tuple.Create(Read(), Read())); + } + + [RequiresUnreferencedCode(Strings.UseNonGenericReadStruct)] + [Obsolete(Strings.UseNonGenericReadStructObsolete)] + public ValueTuple> ReadStruct + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T8, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T9, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T10 + >() + { + AlignStruct(); + return (Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read(), Read()); + } + + private Tuple> ReadStructAsTuple + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T1, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T2, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T3, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T4, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T5, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T6, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T7, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T8, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T9, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T10 + >() + { + AlignStruct(); + return new Tuple>(Read(), Read(), Read(), Read(), Read(), Read(), Read(), Tuple.Create(Read(), Read(), Read())); + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Variant.Dynamic.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Variant.Dynamic.cs new file mode 100644 index 0000000000..c7be7ff044 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Variant.Dynamic.cs @@ -0,0 +1,8 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + [RequiresUnreferencedCode(Strings.UseNonObjectReadVariantValue)] + [Obsolete(Strings.UseNonObjectReadVariantValueObsolete)] + public object ReadVariant() => Read(); +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.Variant.cs b/src/Linux/Tmds.DBus.Protocol/Reader.Variant.cs new file mode 100644 index 0000000000..bb3e6ad151 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.Variant.cs @@ -0,0 +1,161 @@ +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + public VariantValue ReadVariantValue() + { + Utf8Span signature = ReadSignature(); + SignatureReader sigReader = new(signature); + if (!sigReader.TryRead(out DBusType type, out ReadOnlySpan innerSignature)) + { + ThrowInvalidSignature($"Invalid variant signature: {signature.ToString()}"); + } + return ReadTypeAsVariantValue(type, innerSignature); + } + + private VariantValue ReadTypeAsVariantValue(DBusType type, ReadOnlySpan innerSignature) + { + SignatureReader sigReader; + switch (type) + { + case DBusType.Byte: + return new VariantValue(ReadByte()); + case DBusType.Bool: + return new VariantValue(ReadBool()); + case DBusType.Int16: + return new VariantValue(ReadInt16()); + case DBusType.UInt16: + return new VariantValue(ReadUInt16()); + case DBusType.Int32: + return new VariantValue(ReadInt32()); + case DBusType.UInt32: + return new VariantValue(ReadUInt32()); + case DBusType.Int64: + return new VariantValue(ReadInt64()); + case DBusType.UInt64: + return new VariantValue(ReadUInt64()); + case DBusType.Double: + return new VariantValue(ReadDouble()); + case DBusType.String: + return new VariantValue(ReadString()); + case DBusType.ObjectPath: + return new VariantValue(ReadObjectPath()); + case DBusType.Signature: + return new VariantValue(ReadSignatureAsSignature()); + case DBusType.UnixFd: + int idx = (int)ReadUInt32(); + return new VariantValue(_handles, idx); + case DBusType.Variant: + return ReadVariantValue(); + case DBusType.Array: + sigReader = new(innerSignature); + if (!sigReader.TryRead(out type, out innerSignature)) + { + ThrowInvalidSignature("Signature is missing array item type."); + } + bool isDictionary = type == DBusType.DictEntry; + if (isDictionary) + { + sigReader = new(innerSignature); + DBusType valueType = default; + ReadOnlySpan valueInnerSignature = default; + if (!sigReader.TryRead(out DBusType keyType, out ReadOnlySpan keyInnerSignature) || + !sigReader.TryRead(out valueType, out valueInnerSignature)) + { + ThrowInvalidSignature("Signature is missing dict entry types."); + } + List> items = new(); + ArrayEnd arrayEnd = ReadArrayStart(type); + while (HasNext(arrayEnd)) + { + AlignStruct(); + VariantValue key = ReadTypeAsVariantValue(keyType, keyInnerSignature); + VariantValue value = ReadTypeAsVariantValue(valueType, valueInnerSignature); + items.Add(new KeyValuePair(key, value)); + } + return new VariantValue(ToVariantValueType(keyType), ToVariantValueType(valueType), items.ToArray()); + } + else + { + if (type == DBusType.Byte) + { + return new VariantValue(ReadArrayOfByte()); + } + else if (type == DBusType.Int16) + { + return new VariantValue(ReadArrayOfInt16()); + } + else if (type == DBusType.UInt16) + { + return new VariantValue(ReadArrayOfUInt16()); + } + else if (type == DBusType.Int32) + { + return new VariantValue(ReadArrayOfInt32()); + } + else if (type == DBusType.UInt32) + { + return new VariantValue(ReadArrayOfUInt32()); + } + else if (type == DBusType.Int64) + { + return new VariantValue(ReadArrayOfInt64()); + } + else if (type == DBusType.UInt64) + { + return new VariantValue(ReadArrayOfUInt64()); + } + else if (type == DBusType.Double) + { + return new VariantValue(ReadArrayOfDouble()); + } + else if (type == DBusType.String || + type == DBusType.ObjectPath) + { + return new VariantValue(ToVariantValueType(type), ReadArrayOfString()); + } + else + { + List items = new(); + ArrayEnd arrayEnd = ReadArrayStart(type); + while (HasNext(arrayEnd)) + { + VariantValue value = ReadTypeAsVariantValue(type, innerSignature); + items.Add(value); + } + return new VariantValue(ToVariantValueType(type), items.ToArray()); + } + } + case DBusType.Struct: + { + AlignStruct(); + sigReader = new(innerSignature); + List items = new(); + while (sigReader.TryRead(out type, out innerSignature)) + { + VariantValue value = ReadTypeAsVariantValue(type, innerSignature); + items.Add(value); + } + return new VariantValue(items.ToArray()); + } + case DBusType.DictEntry: // Already handled under DBusType.Array. + default: + // note: the SignatureReader maps all unknown types to DBusType.Invalid + // so we won't see the actual character that caused it to fail. + ThrowInvalidSignature($"Unexpected type in signature: {type}."); + return default; + } + } + + private void ThrowInvalidSignature(string message) + { + throw new ProtocolException(message); + } + + private static VariantValueType ToVariantValueType(DBusType type) + => type switch + { + DBusType.Variant => VariantValueType.VariantValue, + _ => (VariantValueType)type + }; +} diff --git a/src/Linux/Tmds.DBus.Protocol/Reader.cs b/src/Linux/Tmds.DBus.Protocol/Reader.cs new file mode 100644 index 0000000000..622f481b5a --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Reader.cs @@ -0,0 +1,90 @@ +[assembly: InternalsVisibleTo("Tmds.DBus.Protocol.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010071a8770f460cce31df0feb6f94b328aebd55bffeb5c69504593df097fdd9b29586dbd155419031834411c8919516cc565dee6b813c033676218496edcbe7939c0dd1f919f3d1a228ebe83b05a3bbdbae53ce11bcf4c04a42d8df1a83c2d06cb4ebb0b447e3963f48a1ca968996f3f0db8ab0e840a89d0a5d5a237e2f09189ed3")] + +namespace Tmds.DBus.Protocol; + +public ref partial struct Reader +{ + private delegate object ValueReader(ref Reader reader); + + private readonly bool _isBigEndian; + private readonly UnixFdCollection? _handles; + private readonly int _handleCount; + private SequenceReader _reader; + + internal ReadOnlySequence UnreadSequence => _reader.Sequence.Slice(_reader.Position); + + internal void Advance(long count) => _reader.Advance(count); + + internal Reader(bool isBigEndian, ReadOnlySequence sequence) : this(isBigEndian, sequence, handles: null, 0) { } + + internal Reader(bool isBigEndian, ReadOnlySequence sequence, UnixFdCollection? handles, int handleCount) + { + _reader = new(sequence); + + _isBigEndian = isBigEndian; + _handles = handles; + _handleCount = handleCount; + } + + public void AlignStruct() => AlignReader(DBusType.Struct); + + private void AlignReader(DBusType type) + { + long pad = ProtocolConstants.GetPadding((int)_reader.Consumed, type); + if (pad != 0) + { + _reader.Advance(pad); + } + } + + private void AlignReader(int alignment) + { + long pad = ProtocolConstants.GetPadding((int)_reader.Consumed, alignment); + if (pad != 0) + { + _reader.Advance(pad); + } + } + + public ArrayEnd ReadArrayStart(DBusType elementType) + { + uint arrayLength = ReadUInt32(); + AlignReader(elementType); + int endOfArray = (int)(_reader.Consumed + arrayLength); + return new ArrayEnd(elementType, endOfArray); + } + + public bool HasNext(ArrayEnd iterator) + { + int consumed = (int)_reader.Consumed; + int nextElement = ProtocolConstants.Align(consumed, iterator.Type); + if (nextElement >= iterator.EndOfArray) + { + return false; + } + int advance = nextElement - consumed; + if (advance != 0) + { + _reader.Advance(advance); + } + return true; + } + + public void SkipTo(ArrayEnd end) + { + int advance = end.EndOfArray - (int)_reader.Consumed; + _reader.Advance(advance); + } +} + +public ref struct ArrayEnd +{ + internal readonly DBusType Type; + internal readonly int EndOfArray; + + internal ArrayEnd(DBusType type, int endOfArray) + { + Type = type; + EndOfArray = endOfArray; + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Signature.cs b/src/Linux/Tmds.DBus.Protocol/Signature.cs new file mode 100644 index 0000000000..0ca248671b --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Signature.cs @@ -0,0 +1,12 @@ +namespace Tmds.DBus.Protocol; + +public struct Signature +{ + private string _value; + + public Signature(string value) => _value = value; + + public override string ToString() => _value ?? ""; + + public Variant AsVariant() => new Variant(this); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/SignatureReader.cs b/src/Linux/Tmds.DBus.Protocol/SignatureReader.cs new file mode 100644 index 0000000000..5300a277bf --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/SignatureReader.cs @@ -0,0 +1,271 @@ +[assembly: InternalsVisibleTo("dotnet-dbus, PublicKey=002400000480000094000000060200000024000052534131000400000100010071a8770f460cce31df0feb6f94b328aebd55bffeb5c69504593df097fdd9b29586dbd155419031834411c8919516cc565dee6b813c033676218496edcbe7939c0dd1f919f3d1a228ebe83b05a3bbdbae53ce11bcf4c04a42d8df1a83c2d06cb4ebb0b447e3963f48a1ca968996f3f0db8ab0e840a89d0a5d5a237e2f09189ed3")] + +namespace Tmds.DBus.Protocol; + +public ref struct SignatureReader +{ + private ReadOnlySpan _signature; + + public ReadOnlySpan Signature => _signature; + + public SignatureReader(ReadOnlySpan signature) + { + _signature = signature; + } + + public bool TryRead(out DBusType type, out ReadOnlySpan innerSignature) + { + innerSignature = default; + + if (_signature.IsEmpty) + { + type = DBusType.Invalid; + return false; + } + + type = ReadSingleType(_signature, out int length); + + if (length > 1) + { + switch (type) + { + case DBusType.Array: + innerSignature = _signature.Slice(1, length - 1); + break; + case DBusType.Struct: + case DBusType.DictEntry: + innerSignature = _signature.Slice(1, length - 2); + break; + } + } + + _signature = _signature.Slice(length); + + return true; + } + + private static DBusType ReadSingleType(ReadOnlySpan signature, out int length) + { + length = 0; + + if (signature.IsEmpty) + { + return DBusType.Invalid; + } + + DBusType type = (DBusType)signature[0]; + + if (IsBasicType(type)) + { + length = 1; + } + else if (type == DBusType.Variant) + { + length = 1; + } + else if (type == DBusType.Array) + { + if (ReadSingleType(signature.Slice(1), out int elementLength) != DBusType.Invalid) + { + type = DBusType.Array; + length = elementLength + 1; + } + else + { + type = DBusType.Invalid; + } + } + else if (type == DBusType.Struct) + { + length = DetermineLength(signature.Slice(1), (byte)'(', (byte)')'); + if (length == 0) + { + type = DBusType.Invalid; + } + } + else if (type == DBusType.DictEntry) + { + length = DetermineLength(signature.Slice(1), (byte)'{', (byte)'}'); + if (length < 4 || + !IsBasicType((DBusType)signature[1]) || + ReadSingleType(signature.Slice(2), out int valueTypeLength) == DBusType.Invalid || + length != valueTypeLength + 3) + { + type = DBusType.Invalid; + } + } + else + { + type = DBusType.Invalid; + } + + return type; + } + + private static int DetermineLength(ReadOnlySpan span, byte startChar, byte endChar) + { + int length = 1; + int count = 1; + do + { + int offset = span.IndexOfAny(startChar, endChar); + if (offset == -1) + { + return 0; + } + + if (span[offset] == startChar) + { + count++; + } + else + { + count--; + } + + length += offset + 1; + span = span.Slice(offset + 1); + + } while (count > 0); + + return length; + } + + private static bool IsBasicType(DBusType type) + { + return BasicTypes.IndexOf((byte)type) != -1; + } + + private static ReadOnlySpan BasicTypes => new byte[] { + (byte)DBusType.Byte, + (byte)DBusType.Bool, + (byte)DBusType.Int16, + (byte)DBusType.UInt16, + (byte)DBusType.Int32, + (byte)DBusType.UInt32, + (byte)DBusType.Int64, + (byte)DBusType.UInt64, + (byte)DBusType.Double, + (byte)DBusType.String, + (byte)DBusType.ObjectPath, + (byte)DBusType.Signature, + (byte)DBusType.UnixFd }; + + private static ReadOnlySpan ReadSingleType(ref ReadOnlySpan signature) + { + if (signature.Length == 0) + { + return default; + } + + int length; + DBusType type = (DBusType)signature[0]; + if (type == DBusType.Struct) + { + length = DetermineLength(signature.Slice(1), (byte)'(', (byte)')'); + } + else if (type == DBusType.DictEntry) + { + length = DetermineLength(signature.Slice(1), (byte)'{', (byte)'}'); + } + else if (type == DBusType.Array) + { + ReadOnlySpan remainder = signature.Slice(1); + length = 1 + ReadSingleType(ref remainder).Length; + } + else + { + length = 1; + } + + ReadOnlySpan rv = signature.Slice(0, length); + signature = signature.Slice(length); + return rv; + } + + internal static T Transform(ReadOnlySpan signature, Func map) + { + DBusType dbusType = signature.Length == 0 ? DBusType.Invalid : (DBusType)signature[0]; + + if (dbusType == DBusType.Array) + { + if ((DBusType)signature[1] == DBusType.DictEntry) + { + signature = signature.Slice(2); + ReadOnlySpan keySignature = ReadSingleType(ref signature); + ReadOnlySpan valueSignature = ReadSingleType(ref signature); + signature = signature.Slice(1); + T keyType = Transform(keySignature, map); + T valueType = Transform(valueSignature, map); + return map(DBusType.DictEntry, new[] { keyType, valueType }); + } + else + { + signature = signature.Slice(1); + T elementType = Transform(signature, map); + signature = signature.Slice(1); + return map(DBusType.Array, new[] { elementType }); + } + } + else if (dbusType == DBusType.Struct) + { + signature = signature.Slice(1, signature.Length - 2); + int typeCount = CountTypes(signature); + T[] innerTypes = new T[typeCount]; + for (int i = 0; i < innerTypes.Length; i++) + { + ReadOnlySpan innerTypeSignature = ReadSingleType(ref signature); + innerTypes[i] = Transform(innerTypeSignature, map); + } + return map(DBusType.Struct, innerTypes); + } + + return map(dbusType, Array.Empty()); + } + + // Counts the number of single types in a signature. + private static int CountTypes(ReadOnlySpan signature) + { + if (signature.Length == 0) + { + return 0; + } + + if (signature.Length == 1) + { + return 1; + } + + DBusType type = (DBusType)signature[0]; + signature = signature.Slice(1); + + if (type == DBusType.Struct) + { + ReadToEnd(ref signature, (byte)'(', (byte)')'); + } + else if (type == DBusType.DictEntry) + { + ReadToEnd(ref signature, (byte)'{', (byte)'}'); + } + + return (type == DBusType.Array ? 0 : 1) + CountTypes(signature); + + static void ReadToEnd(ref ReadOnlySpan span, byte startChar, byte endChar) + { + int count = 1; + do + { + int offset = span.IndexOfAny(startChar, endChar); + if (span[offset] == startChar) + { + count++; + } + else + { + count--; + } + span = span.Slice(offset + 1); + } while (count > 0); + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/SocketExtensions.cs b/src/Linux/Tmds.DBus.Protocol/SocketExtensions.cs new file mode 100644 index 0000000000..a786a5e1ac --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/SocketExtensions.cs @@ -0,0 +1,232 @@ +using System.Net.Sockets; + +namespace Tmds.DBus.Protocol; + +using SizeT = System.UIntPtr; +using SSizeT = System.IntPtr; + +static class SocketExtensions +{ + public static ValueTask ReceiveAsync(this Socket socket, Memory memory, UnixFdCollection? fdCollection) + { + if (fdCollection is null) + { + return socket.ReceiveAsync(memory, SocketFlags.None); + } + else + { + return socket.ReceiveWithHandlesAsync(memory, fdCollection); + } + } + + private async static ValueTask ReceiveWithHandlesAsync(this Socket socket, Memory memory, UnixFdCollection fdCollection) + { + while (true) + { + await socket.ReceiveAsync(new Memory(), SocketFlags.None).ConfigureAwait(false); + + int rv = recvmsg(socket, memory, fdCollection); + + if (rv >= 0) + { + return rv; + } + else + { + int errno = Marshal.GetLastWin32Error(); + if (errno == EAGAIN || errno == EINTR) + { + continue; + } + + throw new SocketException(errno); + } + } + } + + public static ValueTask SendAsync(this Socket socket, ReadOnlyMemory buffer, IReadOnlyList? handles) + { + if (handles is null || handles.Count == 0) + { + return SendAsync(socket, buffer); + } + else + { + return socket.SendAsyncWithHandlesAsync(buffer, handles); + } + } + + private static async ValueTask SendAsync(this Socket socket, ReadOnlyMemory buffer) + { + while (buffer.Length > 0) + { + int sent = await socket.SendAsync(buffer, SocketFlags.None).ConfigureAwait(false); + buffer = buffer.Slice(sent); + } + } + + private static ValueTask SendAsyncWithHandlesAsync(this Socket socket, ReadOnlyMemory buffer, IReadOnlyList handles) + { + socket.Blocking = false; + do + { + int rv = sendmsg(socket, buffer, handles); + if (rv > 0) + { + if (buffer.Length == rv) + { + return default; + } + return SendAsync(socket, buffer.Slice(rv)); + } + else + { + int errno = Marshal.GetLastWin32Error(); + if (errno == EINTR) + { + continue; + } + // TODO (low prio): handle EAGAIN. + return new ValueTask(Task.FromException(new SocketException(errno))); + } + } while (true); + } + + private static unsafe int sendmsg(Socket socket, ReadOnlyMemory buffer, IReadOnlyList handles) + { + fixed (byte* ptr = buffer.Span) + { + IOVector* iovs = stackalloc IOVector[1]; + iovs[0].Base = ptr; + iovs[0].Length = (SizeT)buffer.Length; + + Msghdr msg = new Msghdr(); + msg.msg_iov = iovs; + msg.msg_iovlen = (SizeT)1; + + var fdm = new cmsg_fd(); + int size = sizeof(Cmsghdr) + 4 * handles.Count; + msg.msg_control = &fdm; + msg.msg_controllen = (SizeT)size; + fdm.hdr.cmsg_len = (SizeT)size; + fdm.hdr.cmsg_level = SOL_SOCKET; + fdm.hdr.cmsg_type = SCM_RIGHTS; + + SafeHandle handle = socket.GetSafeHandle(); + int handleRefsAdded = 0; + bool refAdded = false; + try + { + handle.DangerousAddRef(ref refAdded); + for (int i = 0, j = 0; i < handles.Count; i++) + { + bool added = false; + SafeHandle h = handles[i]; + h.DangerousAddRef(ref added); + handleRefsAdded++; + fdm.fds[j++] = h.DangerousGetHandle().ToInt32(); + } + + return (int)sendmsg(handle.DangerousGetHandle().ToInt32(), new IntPtr(&msg), 0); + } + finally + { + for (int i = 0; i < handleRefsAdded; i++) + { + SafeHandle h = handles[i]; + h.DangerousRelease(); + } + + if (refAdded) + handle.DangerousRelease(); + } + } + } + + private static unsafe int recvmsg(Socket socket, Memory buffer, UnixFdCollection handles) + { + fixed (byte* buf = buffer.Span) + { + IOVector iov = new IOVector(); + iov.Base = buf; + iov.Length = (SizeT)buffer.Length; + + Msghdr msg = new Msghdr(); + msg.msg_iov = &iov; + msg.msg_iovlen = (SizeT)1; + + cmsg_fd cm = new cmsg_fd(); + msg.msg_control = &cm; + msg.msg_controllen = (SizeT)sizeof(cmsg_fd); + + var handle = socket.GetSafeHandle(); + bool refAdded = false; + try + { + handle.DangerousAddRef(ref refAdded); + + int rv = (int)recvmsg(handle.DangerousGetHandle().ToInt32(), new IntPtr(&msg), 0); + + if (rv >= 0) + { + if (cm.hdr.cmsg_level == SOL_SOCKET && cm.hdr.cmsg_type == SCM_RIGHTS) + { + int msgFdCount = ((int)cm.hdr.cmsg_len - sizeof(Cmsghdr)) / sizeof(int); + for (int i = 0; i < msgFdCount; i++) + { + handles.AddHandle(new IntPtr(cm.fds[i])); + } + } + } + return rv; + } + finally + { + if (refAdded) + handle.DangerousRelease(); + } + } + } + + const int SOL_SOCKET = 1; + const int EINTR = 4; + //const int EBADF = 9; + static readonly int EAGAIN = RuntimeInformation.IsOSPlatform(OSPlatform.OSX) ? 35 : 11; + const int SCM_RIGHTS = 1; + + private unsafe struct Msghdr + { + public IntPtr msg_name; //optional address + public uint msg_namelen; //size of address + public IOVector* msg_iov; //scatter/gather array + public SizeT msg_iovlen; //# elements in msg_iov + public void* msg_control; //ancillary data, see below + public SizeT msg_controllen; //ancillary data buffer len + public int msg_flags; //flags on received message + } + + private unsafe struct IOVector + { + public void* Base; + public SizeT Length; + } + + private struct Cmsghdr + { + public SizeT cmsg_len; //data byte count, including header + public int cmsg_level; //originating protocol + public int cmsg_type; //protocol-specific type + } + + private unsafe struct cmsg_fd + { + public Cmsghdr hdr; + public fixed int fds[64]; + } + + [DllImport("libc", SetLastError = true)] + public static extern SSizeT sendmsg(int sockfd, IntPtr msg, int flags); + + [DllImport("libc", SetLastError = true)] + public static extern SSizeT recvmsg(int sockfd, IntPtr msg, int flags); +} diff --git a/src/Linux/Tmds.DBus.Protocol/StringBuilderExtensions.cs b/src/Linux/Tmds.DBus.Protocol/StringBuilderExtensions.cs new file mode 100644 index 0000000000..e7a20e3fca --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/StringBuilderExtensions.cs @@ -0,0 +1,24 @@ +namespace Tmds.DBus.Protocol; + +static class StringBuilderExtensions +{ + public static void AppendUTF8(this StringBuilder sb, ReadOnlySpan value) + { + char[]? valueArray = null; + + int length = Encoding.UTF8.GetCharCount(value); + + Span charBuffer = length <= Constants.StackAllocCharThreshold ? + stackalloc char[length] : + (valueArray = ArrayPool.Shared.Rent(length)); + + int charsWritten = Encoding.UTF8.GetChars(value, charBuffer); + + sb.Append(charBuffer.Slice(0, charsWritten)); + + if (valueArray is not null) + { + ArrayPool.Shared.Return(valueArray); + } + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Strings.cs b/src/Linux/Tmds.DBus.Protocol/Strings.cs new file mode 100644 index 0000000000..94a1d5194d --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Strings.cs @@ -0,0 +1,29 @@ +namespace Tmds.DBus.Protocol; + +static class Strings +{ + public const string AddTypeReaderMethodObsolete = "AddTypeReader methods are obsolete. Remove the call to this method."; + public const string AddTypeWriterMethodObsolete = "AddTypeWriter methods are obsolete. Remove the call to this method."; + + public const string UseNonGenericWriteArray = $"Use a non-generic overload of '{nameof(MessageWriter.WriteArray)}' if it exists for the item type, and otherwise write out the elements separately surrounded by a call to '{nameof(MessageWriter.WriteArrayStart)}' and '{nameof(MessageWriter.WriteArrayEnd)}'."; + public const string UseNonGenericReadArray = $"Use a '{nameof(Reader.ReadArray)}Of*' method if it exists for the item type, and otherwise read out the elements in a while loop using '{nameof(Reader.ReadArrayStart)}' and '{nameof(Reader.HasNext)}'."; + public const string UseNonGenericReadDictionary = $"Read the dictionary by calling '{nameof(Reader.ReadDictionaryStart)} and reading the key-value pairs in a while loop using '{nameof(Reader.HasNext)}'."; + public const string UseNonGenericWriteDictionary = $"Write the dictionary by calling '{nameof(MessageWriter.WriteDictionaryStart)}', for each element call '{nameof(MessageWriter.WriteDictionaryEntryStart)}', write the key and value. Complete the dictionary writing by calling '{nameof(MessageWriter.WriteDictionaryEnd)}'."; + public const string UseNonGenericWriteVariantDictionary = $"Write the signature using '{nameof(MessageWriter.WriteSignature)}', then write the dictionary by calling '{nameof(MessageWriter.WriteDictionaryStart)}', for each element call '{nameof(MessageWriter.WriteDictionaryEntryStart)}', write the key and value. Complete the dictionary writing by calling '{nameof(MessageWriter.WriteDictionaryEnd)}'."; + public const string UseNonGenericReadStruct = $"Read the struct by calling '{nameof(Reader.AlignStruct)}' and then reading all the struct fields."; + public const string UseNonGenericWriteStruct = $"Write the struct by calling '{nameof(MessageWriter.WriteStructureStart)}' and then writing all the struct fields."; + public const string UseNonObjectWriteVariant = $"Use the overload of '{nameof(MessageWriter.WriteVariant)}' that accepts a '{nameof(Variant)}' instead."; + public const string UseNonObjectReadVariantValue = $"Use '{nameof(Reader.ReadVariantValue)}' instead."; + + private const string MethodIsNotCompatibleWithTrimmingNativeAot = "Method is not compatible with trimming/NativeAOT."; + + public const string UseNonGenericWriteArrayObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericWriteArray}"; + public const string UseNonGenericReadArrayObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericReadArray}"; + public const string UseNonGenericReadDictionaryObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericReadDictionary}"; + public const string UseNonGenericWriteDictionaryObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericWriteDictionary}"; + public const string UseNonGenericWriteVariantDictionaryObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericWriteVariantDictionary}"; + public const string UseNonGenericReadStructObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericReadStruct}"; + public const string UseNonGenericWriteStructObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonGenericWriteStruct}"; + public const string UseNonObjectWriteVariantObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonObjectWriteVariant}"; + public const string UseNonObjectReadVariantValueObsolete = $"{MethodIsNotCompatibleWithTrimmingNativeAot} {UseNonObjectReadVariantValue}"; +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Struct.cs b/src/Linux/Tmds.DBus.Protocol/Struct.cs new file mode 100644 index 0000000000..6c23afc8e1 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Struct.cs @@ -0,0 +1,469 @@ +namespace Tmds.DBus.Protocol; + +// Using obsolete generic write members +#pragma warning disable CS0618 + +public static class Struct +{ + public static Struct Create(T1 item1) + where T1 : notnull + => new Struct(item1); + + public static Struct Create(T1 item1, T2 item2) + where T1 : notnull + where T2 : notnull + => new Struct(item1, item2); + + public static Struct Create(T1 item1, T2 item2, T3 item3) + where T1 : notnull + where T2 : notnull + where T3 : notnull + => new Struct(item1, item2, item3); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + => new Struct(item1, item2, item3, item4); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + => new Struct(item1, item2, item3, item4, item5); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + => new Struct(item1, item2, item3, item4, item5, item6); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + => new Struct(item1, item2, item3, item4, item5, item6, item7); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7, T8 item8) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + => new Struct(item1, item2, item3, item4, item5, item6, item7, item8); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7, T8 item8, T9 item9) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + => new Struct(item1, item2, item3, item4, item5, item6, item7, item8, item9); + + public static Struct Create(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7, T8 item8, T9 item9, T10 item10) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + => new Struct(item1, item2, item3, item4, item5, item6, item7, item8, item9, item10); +} + +public sealed class Struct : IDBusWritable + where T1 : notnull +{ + public T1 Item1; + + public Struct(T1 item1) + { + TypeModel.EnsureSupportedVariantType(); + Item1 = item1; + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private ValueTuple ToValueTuple() + => new ValueTuple(Item1); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} + +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull +{ + public T1 Item1; + public T2 Item2; + + public Struct(T1 item1, T2 item2) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2) = (item1, item2); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2) ToValueTuple() + => (Item1, Item2); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + + public Struct(T1 item1, T2 item2, T3 item3) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3) = (item1, item2, item3); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3) ToValueTuple() + => (Item1, Item2, Item3); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4) = (item1, item2, item3, item4); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4) ToValueTuple() + => (Item1, Item2, Item3, Item4); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + public T5 Item5; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4, Item5) = (item1, item2, item3, item4, item5); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4, T5) ToValueTuple() + => (Item1, Item2, Item3, Item4, Item5); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + public T5 Item5; + public T6 Item6; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4, Item5, Item6) = (item1, item2, item3, item4, item5, item6); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4, T5, T6) ToValueTuple() + => (Item1, Item2, Item3, Item4, Item5, Item6); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + public T5 Item5; + public T6 Item6; + public T7 Item7; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4, Item5, Item6, Item7) = (item1, item2, item3, item4, item5, item6, item7); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4, T5, T6, T7) ToValueTuple() + => (Item1, Item2, Item3, Item4, Item5, Item6, Item7); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + public T5 Item5; + public T6 Item6; + public T7 Item7; + public T8 Item8; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7, T8 item8) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4, Item5, Item6, Item7, Item8) = (item1, item2, item3, item4, item5, item6, item7, item8); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4, T5, T6, T7, T8) ToValueTuple() + => (Item1, Item2, Item3, Item4, Item5, Item6, Item7, Item8); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + public T5 Item5; + public T6 Item6; + public T7 Item7; + public T8 Item8; + public T9 Item9; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7, T8 item8, T9 item9) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4, Item5, Item6, Item7, Item8, Item9) = (item1, item2, item3, item4, item5, item6, item7, item8, item9); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4, T5, T6, T7, T8, T9) ToValueTuple() + => (Item1, Item2, Item3, Item4, Item5, Item6, Item7, Item8, Item9); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} +public sealed class Struct : IDBusWritable + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull +{ + public T1 Item1; + public T2 Item2; + public T3 Item3; + public T4 Item4; + public T5 Item5; + public T6 Item6; + public T7 Item7; + public T8 Item8; + public T9 Item9; + public T10 Item10; + + public Struct(T1 item1, T2 item2, T3 item3, T4 item4, T5 item5, T6 item6, T7 item7, T8 item8, T9 item9, T10 item10) + { + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + TypeModel.EnsureSupportedVariantType(); + (Item1, Item2, Item3, Item4, Item5, Item6, Item7, Item8, Item9, Item10) = (item1, item2, item3, item4, item5, item6, item7, item8, item9, item10); + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026")] // this is a supported variant type. + void IDBusWritable.WriteTo(ref MessageWriter writer) + => writer.WriteStruct(ToValueTuple()); + + private (T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) ToValueTuple() + => (Item1, Item2, Item3, Item4, Item5, Item6, Item7, Item8, Item9, Item10); + + public Variant AsVariant() + => Variant.FromStruct(this); + + public static implicit operator Variant(Struct value) + => value.AsVariant(); +} diff --git a/src/Linux/Tmds.DBus.Protocol/ThrowHelper.cs b/src/Linux/Tmds.DBus.Protocol/ThrowHelper.cs new file mode 100644 index 0000000000..f3373f4b56 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/ThrowHelper.cs @@ -0,0 +1,32 @@ +namespace Tmds.DBus.Protocol; + +static class ThrowHelper +{ + public static void ThrowIfDisposed(bool condition, object instance) + { + if (condition) + { + ThrowObjectDisposedException(instance); + } + } + + private static void ThrowObjectDisposedException(object instance) + { + throw new ObjectDisposedException(instance?.GetType().FullName); + } + + public static void ThrowIndexOutOfRange() + { + throw new IndexOutOfRangeException(); + } + + public static void ThrowNotSupportedException() + { + throw new NotSupportedException(); + } + + internal static void ThrowUnexpectedSignature(ReadOnlySpan signature, string expected) + { + throw new ProtocolException($"Expected signature '{expected}' does not match actual signature '{Encoding.UTF8.GetString(signature)}'."); + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Tmds.DBus.Protocol.csproj b/src/Linux/Tmds.DBus.Protocol/Tmds.DBus.Protocol.csproj new file mode 100644 index 0000000000..5bd486a39d --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Tmds.DBus.Protocol.csproj @@ -0,0 +1,22 @@ + + + + $(AvsCurrentTargetFramework);$(AvsLegacyTargetFrameworks);netstandard2.0 + enable + true + + + + + + + + + + + + + + + + diff --git a/src/Linux/Tmds.DBus.Protocol/TypeModel.Dynamic.cs b/src/Linux/Tmds.DBus.Protocol/TypeModel.Dynamic.cs new file mode 100644 index 0000000000..5aeb2e0941 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/TypeModel.Dynamic.cs @@ -0,0 +1,134 @@ +namespace Tmds.DBus.Protocol; + +// Code in this file is not trimmer friendly. +#pragma warning disable IL3050 +#pragma warning disable IL2070 + +static partial class TypeModel +{ + private static DBusType GetTypeAlignmentDynamic() + { + if (typeof(T).IsArray) + { + return DBusType.Array; + } + else if (ExtractGenericInterface(typeof(T), typeof(System.Collections.Generic.IEnumerable<>)) != null) + { + return DBusType.Array; + } + else + { + return DBusType.Struct; + } + } + + private static int AppendTypeSignatureDynamic(Type type, Span signature) + { + Type? extractedType; + if (type == typeof(object)) + { + signature[0] = (byte)DBusType.Variant; + return 1; + } + else if (type.IsArray) + { + int bytesWritten = 0; + signature[bytesWritten++] = (byte)DBusType.Array; + bytesWritten += AppendTypeSignature(type.GetElementType()!, signature.Slice(bytesWritten)); + return bytesWritten; + } + else if (type.FullName!.StartsWith("System.ValueTuple")) + { + int bytesWritten = 0; + signature[bytesWritten++] = (byte)'('; + Type[] typeArguments = type.GenericTypeArguments; + do + { + for (int i = 0; i < typeArguments.Length; i++) + { + if (i == 7) + { + break; + } + bytesWritten += AppendTypeSignature(typeArguments[i], signature.Slice(bytesWritten)); + } + if (typeArguments.Length == 8) + { + typeArguments = typeArguments[7].GenericTypeArguments; + } + else + { + break; + } + } while (true); + signature[bytesWritten++] = (byte)')'; + return bytesWritten; + } + else if ((extractedType = TypeModel.ExtractGenericInterface(type, typeof(IDictionary<,>))) != null) + { + int bytesWritten = 0; + signature[bytesWritten++] = (byte)'a'; + signature[bytesWritten++] = (byte)'{'; + bytesWritten += AppendTypeSignature(extractedType.GenericTypeArguments[0], signature.Slice(bytesWritten)); + bytesWritten += AppendTypeSignature(extractedType.GenericTypeArguments[1], signature.Slice(bytesWritten)); + signature[bytesWritten++] = (byte)'}'; + return bytesWritten; + } + + ThrowNotSupportedType(type); + return 0; + } + + public static Type? ExtractGenericInterface(Type queryType, Type interfaceType) + { + if (IsGenericInstantiation(queryType, interfaceType)) + { + return queryType; + } + + return GetGenericInstantiation(queryType, interfaceType); + } + + private static bool IsGenericInstantiation(Type candidate, Type interfaceType) + { + return + candidate.IsGenericType && + candidate.GetGenericTypeDefinition() == interfaceType; + } + + [UnconditionalSuppressMessage("ReflectionAnalysis", "IL2070")] + private static Type? GetGenericInstantiation(Type queryType, Type interfaceType) + { + Type? bestMatch = null; + var interfaces = queryType.GetInterfaces(); + foreach (var @interface in interfaces) + { + if (IsGenericInstantiation(@interface, interfaceType)) + { + if (bestMatch == null) + { + bestMatch = @interface; + } + else if (StringComparer.Ordinal.Compare(@interface.FullName, bestMatch.FullName) < 0) + { + bestMatch = @interface; + } + } + } + + if (bestMatch != null) + { + return bestMatch; + } + + var baseType = queryType?.BaseType; + if (baseType == null) + { + return null; + } + else + { + return GetGenericInstantiation(baseType, interfaceType); + } + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/TypeModel.cs b/src/Linux/Tmds.DBus.Protocol/TypeModel.cs new file mode 100644 index 0000000000..6b842fb3db --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/TypeModel.cs @@ -0,0 +1,359 @@ +namespace Tmds.DBus.Protocol; + +static partial class TypeModel +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static DBusType GetTypeAlignment() + { + // TODO (perf): add caching. + if (typeof(T) == typeof(byte)) + { + return DBusType.Byte; + } + else if (typeof(T) == typeof(bool)) + { + return DBusType.Bool; + } + else if (typeof(T) == typeof(short)) + { + return DBusType.Int16; + } + else if (typeof(T) == typeof(ushort)) + { + return DBusType.UInt16; + } + else if (typeof(T) == typeof(int)) + { + return DBusType.Int32; + } + else if (typeof(T) == typeof(uint)) + { + return DBusType.UInt32; + } + else if (typeof(T) == typeof(long)) + { + return DBusType.Int64; + } + else if (typeof(T) == typeof(ulong)) + { + return DBusType.UInt64; + } + else if (typeof(T) == typeof(double)) + { + return DBusType.Double; + } + else if (typeof(T) == typeof(string)) + { + return DBusType.String; + } + else if (typeof(T) == typeof(ObjectPath)) + { + return DBusType.ObjectPath; + } + else if (typeof(T) == typeof(Signature)) + { + return DBusType.Signature; + } + else if (typeof(T) == typeof(Variant)) + { + return DBusType.Variant; + } + else if (typeof(T).IsConstructedGenericType) + { + Type type = typeof(T).GetGenericTypeDefinition(); + if (type == typeof(Dict<,>)) + { + return DBusType.Array; + } + else if (type == typeof(Array<>)) + { + return DBusType.Array; + } + else if (type == typeof(Struct<>) || + type == typeof(Struct<,>) || + type == typeof(Struct<,,>) || + type == typeof(Struct<,,,>) || + type == typeof(Struct<,,,,>) || + type == typeof(Struct<,,,,,>) || + type == typeof(Struct<,,,,,,>) || + type == typeof(Struct<,,,,,,,>) || + type == typeof(Struct<,,,,,,,,>) || + type == typeof(Struct<,,,,,,,,,>)) + { + return DBusType.Struct; + } + } + else if (typeof(T).IsAssignableTo(typeof(SafeHandle))) + { + return DBusType.UnixFd; + } + else if (Feature.IsDynamicCodeEnabled) + { + return GetTypeAlignmentDynamic(); + } + + ThrowNotSupportedType(typeof(T)); + return default; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void EnsureSupportedVariantType() + { + if (typeof(T) == typeof(byte)) + { } + else if (typeof(T) == typeof(bool)) + { } + else if (typeof(T) == typeof(short)) + { } + else if (typeof(T) == typeof(ushort)) + { } + else if (typeof(T) == typeof(int)) + { } + else if (typeof(T) == typeof(uint)) + { } + else if (typeof(T) == typeof(long)) + { } + else if (typeof(T) == typeof(ulong)) + { } + else if (typeof(T) == typeof(double)) + { } + else if (typeof(T) == typeof(string)) + { } + else if (typeof(T) == typeof(ObjectPath)) + { } + else if (typeof(T) == typeof(Signature)) + { } + else if (typeof(T) == typeof(Variant)) + { } + else if (typeof(T).IsConstructedGenericType) + { + Type type = typeof(T).GetGenericTypeDefinition(); + if (type == typeof(Dict<,>) || + type == typeof(Array<>) || + type == typeof(Struct<>) || + type == typeof(Struct<,>) || + type == typeof(Struct<,,>) || + type == typeof(Struct<,,,>) || + type == typeof(Struct<,,,,>) || + type == typeof(Struct<,,,,,>) || + type == typeof(Struct<,,,,,,>) || + type == typeof(Struct<,,,,,,,>) || + type == typeof(Struct<,,,,,,,,>) || + type == typeof(Struct<,,,,,,,,,>)) + { + foreach (var innerType in type.GenericTypeArguments) + { + EnsureSupportedVariantType(innerType); + } + } + else + { + ThrowNotSupportedType(typeof(T)); + } + } + else if (typeof(T).IsAssignableTo(typeof(SafeHandle))) + { } + else + { + ThrowNotSupportedType(typeof(T)); + } + } + + private static void EnsureSupportedVariantType(Type type) + { + if (type == typeof(byte)) + { } + else if (type == typeof(bool)) + { } + else if (type == typeof(short)) + { } + else if (type == typeof(ushort)) + { } + else if (type == typeof(int)) + { } + else if (type == typeof(uint)) + { } + else if (type == typeof(long)) + { } + else if (type == typeof(ulong)) + { } + else if (type == typeof(double)) + { } + else if (type == typeof(string)) + { } + else if (type == typeof(ObjectPath)) + { } + else if (type == typeof(Signature)) + { } + else if (type == typeof(Variant)) + { } + else if (type.IsConstructedGenericType) + { + Type typeDefinition = type.GetGenericTypeDefinition(); + if (typeDefinition == typeof(Dict<,>) || + typeDefinition == typeof(Array<>) || + typeDefinition == typeof(Struct<>) || + typeDefinition == typeof(Struct<,>) || + typeDefinition == typeof(Struct<,,>) || + typeDefinition == typeof(Struct<,,,>) || + typeDefinition == typeof(Struct<,,,,>) || + typeDefinition == typeof(Struct<,,,,,>) || + typeDefinition == typeof(Struct<,,,,,,>) || + typeDefinition == typeof(Struct<,,,,,,,>) || + typeDefinition == typeof(Struct<,,,,,,,,>) || + typeDefinition == typeof(Struct<,,,,,,,,,>)) + { + foreach (var innerType in typeDefinition.GenericTypeArguments) + { + EnsureSupportedVariantType(innerType); + } + } + else + { + ThrowNotSupportedType(type); + } + } + else if (type.IsAssignableTo(typeof(SafeHandle))) + { } + else + { + ThrowNotSupportedType(type); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Utf8Span GetSignature(scoped Span buffer) + { + Debug.Assert(buffer.Length >= ProtocolConstants.MaxSignatureLength); + + int bytesWritten = AppendTypeSignature(typeof(T), buffer); + return new Utf8Span(buffer.Slice(0, bytesWritten).ToArray()); + } + + private static int AppendTypeSignature(Type type, Span signature) + { + if (type == typeof(byte)) + { + signature[0] = (byte)DBusType.Byte; + return 1; + } + else if (type == typeof(bool)) + { + signature[0] = (byte)DBusType.Bool; + return 1; + } + else if (type == typeof(short)) + { + signature[0] = (byte)DBusType.Int16; + return 1; + } + else if (type == typeof(ushort)) + { + signature[0] = (byte)DBusType.UInt16; + return 1; + } + else if (type == typeof(int)) + { + signature[0] = (byte)DBusType.Int32; + return 1; + } + else if (type == typeof(uint)) + { + signature[0] = (byte)DBusType.UInt32; + return 1; + } + else if (type == typeof(long)) + { + signature[0] = (byte)DBusType.Int64; + return 1; + } + else if (type == typeof(ulong)) + { + signature[0] = (byte)DBusType.UInt64; + return 1; + } + else if (type == typeof(double)) + { + signature[0] = (byte)DBusType.Double; + return 1; + } + else if (type == typeof(string)) + { + signature[0] = (byte)DBusType.String; + return 1; + } + else if (type == typeof(ObjectPath)) + { + signature[0] = (byte)DBusType.ObjectPath; + return 1; + } + else if (type == typeof(Signature)) + { + signature[0] = (byte)DBusType.Signature; + return 1; + } + else if (type == typeof(Variant)) + { + signature[0] = (byte)DBusType.Variant; + return 1; + } + else if (type.IsConstructedGenericType) + { + Type genericTypeDefinition = type.GetGenericTypeDefinition(); + if (genericTypeDefinition == typeof(Dict<,>)) + { + int length = 0; + signature[length++] = (byte)'a'; + signature[length++] = (byte)'{'; + length += AppendTypeSignature(type.GenericTypeArguments[0], signature.Slice(length)); + length += AppendTypeSignature(type.GenericTypeArguments[1], signature.Slice(length)); + signature[length++] = (byte)'}'; + return length; + } + else if (genericTypeDefinition == typeof(Array<>)) + { + int length = 0; + signature[length++] = (byte)'a'; + length += AppendTypeSignature(type.GenericTypeArguments[0], signature.Slice(length)); + return length; + } + else if (genericTypeDefinition == typeof(Struct<>) || + genericTypeDefinition == typeof(Struct<,>) || + genericTypeDefinition == typeof(Struct<,,>) || + genericTypeDefinition == typeof(Struct<,,,>) || + genericTypeDefinition == typeof(Struct<,,,,>) || + genericTypeDefinition == typeof(Struct<,,,,,>) || + genericTypeDefinition == typeof(Struct<,,,,,,>) || + genericTypeDefinition == typeof(Struct<,,,,,,,>) || + genericTypeDefinition == typeof(Struct<,,,,,,,,>) || + genericTypeDefinition == typeof(Struct<,,,,,,,,,>)) + { + int length = 0; + signature[length++] = (byte)'('; + foreach (var innerType in type.GenericTypeArguments) + { + length += AppendTypeSignature(innerType, signature.Slice(length)); + } + signature[length++] = (byte)')'; + return length; + } + } + else if (type.IsAssignableTo(typeof(SafeHandle))) + { + signature[0] = (byte)DBusType.UnixFd; + return 1; + } + else if (Feature.IsDynamicCodeEnabled) + { + return AppendTypeSignatureDynamic(type, signature); + } + + ThrowNotSupportedType(type); + return 0; + } + + private static void ThrowNotSupportedType(Type type) + { + throw new NotSupportedException($"Unsupported type {type.FullName}"); + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/UnixFdCollection.cs b/src/Linux/Tmds.DBus.Protocol/UnixFdCollection.cs new file mode 100644 index 0000000000..a6ea755ba0 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/UnixFdCollection.cs @@ -0,0 +1,244 @@ +using System.Collections; + +namespace Tmds.DBus.Protocol; + +sealed class UnixFdCollection : IReadOnlyList, IDisposable +{ + private IntPtr InvalidRawHandle => new IntPtr(-1); + + private readonly List<(SafeHandle? Handle, bool CanRead)>? _handles; + private readonly List<(IntPtr RawHandle, bool CanRead)>? _rawHandles; + + // The gate guards someone removing handles while the UnixFdCollection gets disposed by the message. + // We don't need to lock it while adding handles, or reading them to send them. + private readonly object _gate; + private bool _disposed; + + internal bool IsRawHandleCollection => _rawHandles is not null; + + internal UnixFdCollection(bool isRawHandleCollection = true) + { + if (isRawHandleCollection) + { + _gate = _rawHandles = new(); + } + else + { + _gate = _handles = new(); + } + } + + internal int AddHandle(IntPtr handle) + { + _rawHandles!.Add((handle, true)); + return _rawHandles.Count - 1; + } + + internal void AddHandle(SafeHandle handle) + { + if (handle is null) + { + throw new ArgumentNullException(nameof(handle)); + } + _handles!.Add((handle, true)); + } + + public int Count => _rawHandles is not null ? _rawHandles.Count : _handles!.Count; + + // Used to get the file descriptors to send them over the socket. + public SafeHandle this[int index] => _handles![index].Handle!; + + // We remain responsible for disposing the handle. + public IntPtr ReadHandleRaw(int index) + { + lock (_gate) + { + if (_disposed) + { + ThrowDisposed(); + } + if (_rawHandles is not null) + { + (IntPtr rawHandle, bool CanRead) = _rawHandles[index]; + if (!CanRead) + { + ThrowHandleAlreadyRead(); + } + // Handle can no longer be read, but we are still responible for disposing it. + _rawHandles[index] = (rawHandle, false); + return rawHandle; + } + else + { + Debug.Assert(_handles is not null); + (SafeHandle? handle, bool CanRead) = _handles![index]; + if (!CanRead) + { + ThrowHandleAlreadyRead(); + } + // Handle can no longer be read, but we are still responible for disposing it. + _handles[index] = (handle, false); + return handle!.DangerousGetHandle(); + } + } + } + + private void ThrowHandleAlreadyRead() + { + throw new InvalidOperationException("The handle was already read."); + } + + private void ThrowDisposed() + { + throw new ObjectDisposedException(typeof(UnixFdCollection).FullName); + } + + public T? ReadHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>(int index) where T : SafeHandle + => ReadHandleGeneric(index); + + // The caller of this method owns the handle and is responsible for Disposing it. + internal T? ReadHandleGeneric<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>(int index) + { + lock (_gate) + { + if (_disposed) + { + ThrowDisposed(); + } + if (_rawHandles is not null) + { + (IntPtr rawHandle, bool CanRead) = _rawHandles[index]; + if (!CanRead) + { + ThrowHandleAlreadyRead(); + } + #if NET6_0_OR_GREATER + SafeHandle handle = (Activator.CreateInstance() as SafeHandle)!; + Marshal.InitHandle(handle, rawHandle); + #else + SafeHandle? handle = (SafeHandle?)Activator.CreateInstance(typeof(T), new object[] { rawHandle, true }); + #endif + _rawHandles[index] = (InvalidRawHandle, false); + return (T?)(object?)handle; + } + else + { + Debug.Assert(_handles is not null); + (SafeHandle? handle, bool CanRead) = _handles![index]; + if (!CanRead) + { + ThrowHandleAlreadyRead(); + } + if (handle is not T) + { + throw new ArgumentException($"Requested handle type {typeof(T).FullName} does not matched stored type {handle?.GetType().FullName}."); + } + _handles[index] = (null, false); + return (T)(object)handle; + } + } + } + + public IEnumerator GetEnumerator() + { + throw new NotSupportedException(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotSupportedException(); + } + + public void DisposeHandles(int count = -1) + { + if (count != 0) + { + DisposeHandles(true, count); + } + } + + public void Dispose() + { + lock (_gate) + { + if (_disposed) + { + return; + } + _disposed = true; + DisposeHandles(true); + } + } + + ~UnixFdCollection() + { + DisposeHandles(false); + } + + private void DisposeHandles(bool disposing, int count = -1) + { + if (count == -1) + { + count = Count; + } + + if (disposing) + { + if (_handles is not null) + { + for (int i = 0; i < count; i++) + { + var handle = _handles[i]; + if (handle.Handle is not null) + { + handle.Handle.Dispose(); + } + } + _handles.RemoveRange(0, count); + } + } + else + { + if (_rawHandles is not null) + { + for (int i = 0; i < count; i++) + { + var handle = _rawHandles[i]; + + if (handle.RawHandle != InvalidRawHandle) + { + close(handle.RawHandle.ToInt32()); + } + } + _rawHandles.RemoveRange(0, count); + } + } + } + + [DllImport("libc")] + private static extern void close(int fd); + + internal void MoveTo(UnixFdCollection handles, int count) + { + if (handles.IsRawHandleCollection != IsRawHandleCollection) + { + throw new ArgumentException("Handle collections are not compatible."); + } + if (handles.IsRawHandleCollection) + { + for (int i = 0; i < count; i++) + { + handles._rawHandles!.Add(_rawHandles![i]); + } + _rawHandles!.RemoveRange(0, count); + } + else + { + for (int i = 0; i < count; i++) + { + handles._handles!.Add(_handles![i]); + } + _handles!.RemoveRange(0, count); + } + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Utf8Span.cs b/src/Linux/Tmds.DBus.Protocol/Utf8Span.cs new file mode 100644 index 0000000000..02c73805e8 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Utf8Span.cs @@ -0,0 +1,20 @@ +namespace Tmds.DBus.Protocol; + +public ref struct Utf8Span +{ + private ReadOnlySpan _buffer; + + public ReadOnlySpan Span => _buffer; + + public bool IsEmpty => _buffer.IsEmpty; + + public Utf8Span(ReadOnlySpan value) => _buffer = value; + + public static implicit operator Utf8Span(ReadOnlySpan value) => new Utf8Span(value); + + public static implicit operator Utf8Span(Span value) => new Utf8Span(value); + + public static implicit operator ReadOnlySpan(Utf8Span value) => value._buffer; + + public override string ToString() => Encoding.UTF8.GetString(_buffer); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/Variant.cs b/src/Linux/Tmds.DBus.Protocol/Variant.cs new file mode 100644 index 0000000000..2c55c77ef6 --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/Variant.cs @@ -0,0 +1,451 @@ +namespace Tmds.DBus.Protocol; + +// This type is for writing so we don't need to add +// DynamicallyAccessedMemberTypes.PublicParameterlessConstructor. +#pragma warning disable IL2091 + +public readonly struct Variant +{ + private static readonly object Int64Type = DBusType.Int64; + private static readonly object UInt64Type = DBusType.UInt64; + private static readonly object DoubleType = DBusType.Double; + private readonly object? _o; + private readonly long _l; + + private const int TypeShift = 8 * 7; + //private const int SignatureFirstShift = 8 * 6; + private const long StripTypeMask = ~(0xffL << TypeShift); + + private DBusType Type + => DetermineType(); + + public Variant(byte value) + { + _l = value | ((long)DBusType.Byte << TypeShift); + _o = null; + } + public Variant(bool value) + { + _l = (value ? 1L : 0) | ((long)DBusType.Bool << TypeShift); + _o = null; + } + public Variant(short value) + { + _l = (ushort)value | ((long)DBusType.Int16 << TypeShift); + _o = null; + } + public Variant(ushort value) + { + _l = value | ((long)DBusType.UInt16 << TypeShift); + _o = null; + } + public Variant(int value) + { + _l = (uint)value | ((long)DBusType.Int32 << TypeShift); + _o = null; + } + public Variant(uint value) + { + _l = value | ((long)DBusType.UInt32 << TypeShift); + _o = null; + } + public Variant(long value) + { + _l = value; + _o = Int64Type; + } + public Variant(ulong value) + { + _l = (long)value; + _o = UInt64Type; + } + internal unsafe Variant(double value) + { + _l = *(long*)&value; + _o = DoubleType; + } + public Variant(string value) + { + _l = (long)DBusType.String << TypeShift; + _o = value ?? throw new ArgumentNullException(nameof(value)); + } + public Variant(ObjectPath value) + { + _l = (long)DBusType.ObjectPath << TypeShift; + string s = value.ToString(); + if (s.Length == 0) + { + throw new ArgumentException(nameof(value)); + } + _o = s; + } + public Variant(Signature value) + { + _l = (long)DBusType.Signature << TypeShift; + string s = value.ToString(); + if (s.Length == 0) + { + throw new ArgumentException(nameof(value)); + } + _o = s; + } + public Variant(SafeHandle value) + { + _l = (long)DBusType.UnixFd << TypeShift; + _o = value ?? throw new ArgumentNullException(nameof(value)); + } + + public static implicit operator Variant(byte value) + => new Variant(value); + public static implicit operator Variant(bool value) + => new Variant(value); + public static implicit operator Variant(short value) + => new Variant(value); + public static implicit operator Variant(ushort value) + => new Variant(value); + public static implicit operator Variant(int value) + => new Variant(value); + public static implicit operator Variant(uint value) + => new Variant(value); + public static implicit operator Variant(long value) + => new Variant(value); + public static implicit operator Variant(ulong value) + => new Variant(value); + public static implicit operator Variant(double value) + => new Variant(value); + public static implicit operator Variant(string value) + => new Variant(value); + public static implicit operator Variant(ObjectPath value) + => new Variant(value); + public static implicit operator Variant(Signature value) + => new Variant(value); + public static implicit operator Variant(SafeHandle value) + => new Variant(value); + + public static Variant FromArray(Array value) where T : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromDict(Dict value) + where TKey : notnull + where TValue : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + public static Variant FromStruct(Struct value) + where T1 : notnull + where T2 : notnull + where T3 : notnull + where T4 : notnull + where T5 : notnull + where T6 : notnull + where T7 : notnull + where T8 : notnull + where T9 : notnull + where T10 : notnull + { + Span buffer = stackalloc byte[ProtocolConstants.MaxSignatureLength]; + return new Variant(TypeModel.GetSignature>(buffer), value); + } + + // Dictionary, Struct, Array. + private unsafe Variant(Utf8Span signature, IDBusWritable value) + { + if (value is null) + { + throw new ArgumentNullException(nameof(value)); + } + // Store the signature in the long if it is large enough. + if (signature.Span.Length <= 8) + { + long l = 0; + Span span = new Span(&l, 8); + signature.Span.CopyTo(span); + if (BitConverter.IsLittleEndian) + { + l = BinaryPrimitives.ReverseEndianness(l); + } + + _l = l; + _o = value; + } + else + { + _l = (long)signature.Span[0] << TypeShift; + _o = new ValueTuple(signature.Span.ToArray(), value); + } + } + + private byte GetByte() + { + DebugAssertTypeIs(DBusType.Byte); + return (byte)(_l & StripTypeMask); + } + private bool GetBool() + { + DebugAssertTypeIs(DBusType.Bool); + return (_l & StripTypeMask) != 0; + } + private short GetInt16() + { + DebugAssertTypeIs(DBusType.Int16); + return (short)(_l & StripTypeMask); + } + private ushort GetUInt16() + { + DebugAssertTypeIs(DBusType.UInt16); + return (ushort)(_l & StripTypeMask); + } + private int GetInt32() + { + DebugAssertTypeIs(DBusType.Int32); + return (int)(_l & StripTypeMask); + } + private uint GetUInt32() + { + DebugAssertTypeIs(DBusType.UInt32); + return (uint)(_l & StripTypeMask); + } + private long GetInt64() + { + DebugAssertTypeIs(DBusType.Int64); + return _l; + } + private ulong GetUInt64() + { + DebugAssertTypeIs(DBusType.UInt64); + return (ulong)(_l); + } + private unsafe double GetDouble() + { + DebugAssertTypeIs(DBusType.Double); + double value; + *(long*)&value = _l; + return value; + } + private string GetString() + { + DebugAssertTypeIs(DBusType.String); + return (_o as string)!; + } + private string GetObjectPath() + { + DebugAssertTypeIs(DBusType.ObjectPath); + return (_o as string)!; + } + private string GetSignature() + { + DebugAssertTypeIs(DBusType.Signature); + return (_o as string)!; + } + private SafeHandle GetUnixFd() + { + DebugAssertTypeIs(DBusType.UnixFd); + return (_o as SafeHandle)!; + } + + private void DebugAssertTypeIs(DBusType expected) + { + Debug.Assert(Type == expected); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private DBusType DetermineType() + { + // For most types, we store the DBusType in the highest byte of the long. + // Except for some types, like Int64, for which we store the value allocation free + // in the long, and use the object field to store the type. + DBusType type = (DBusType)(_l >> TypeShift); + if (_o is not null) + { + if (_o.GetType() == typeof(DBusType)) + { + type = (DBusType)_o; + } + } + return type; + } + + internal unsafe void WriteTo(ref MessageWriter writer) + { + switch (Type) + { + case DBusType.Byte: + writer.WriteVariantByte(GetByte()); + break; + case DBusType.Bool: + writer.WriteVariantBool(GetBool()); + break; + case DBusType.Int16: + writer.WriteVariantInt16(GetInt16()); + break; + case DBusType.UInt16: + writer.WriteVariantUInt16(GetUInt16()); + break; + case DBusType.Int32: + writer.WriteVariantInt32(GetInt32()); + break; + case DBusType.UInt32: + writer.WriteVariantUInt32(GetUInt32()); + break; + case DBusType.Int64: + writer.WriteVariantInt64(GetInt64()); + break; + case DBusType.UInt64: + writer.WriteVariantUInt64(GetUInt64()); + break; + case DBusType.Double: + writer.WriteVariantDouble(GetDouble()); + break; + case DBusType.String: + writer.WriteVariantString(GetString()); + break; + case DBusType.ObjectPath: + writer.WriteVariantObjectPath(GetObjectPath()); + break; + case DBusType.Signature: + writer.WriteVariantSignature(GetSignature()); + break; + case DBusType.UnixFd: + writer.WriteVariantHandle(GetUnixFd()); + break; + + case DBusType.Array: + case DBusType.Struct: + Utf8Span signature; + IDBusWritable writable; + if ((_l << 8) == 0) + { + // The signature is stored in the object. + var o = (ValueTuple)_o!; + signature = new Utf8Span(o.Item1); + writable = o.Item2; + } + else + { + // The signature is stored in _l. + long l = _l; + if (BitConverter.IsLittleEndian) + { + l = BinaryPrimitives.ReverseEndianness(l); + } + Span span = new Span(&l, 8); + int length = span.IndexOf((byte)0); + if (length == -1) + { + length = 8; + } + signature = new Utf8Span(span.Slice(0, length)); + writable = (_o as IDBusWritable)!; + } + writer.WriteSignature(signature); + writable.WriteTo(ref writer); + break; + default: + throw new InvalidOperationException($"Cannot write Variant of type {Type}."); + } + } +} diff --git a/src/Linux/Tmds.DBus.Protocol/VariantExtensions.cs b/src/Linux/Tmds.DBus.Protocol/VariantExtensions.cs new file mode 100644 index 0000000000..da37491efa --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/VariantExtensions.cs @@ -0,0 +1,41 @@ +namespace Tmds.DBus.Protocol; + +// This type is for writing so we don't need to add +// DynamicallyAccessedMemberTypes.PublicParameterlessConstructor. +#pragma warning disable IL2091 + +public static class VariantExtensions +{ + public static Variant AsVariant(this byte value) + => new Variant(value); + + public static Variant AsVariant(this bool value) + => new Variant(value); + + public static Variant AsVariant(this short value) + => new Variant(value); + + public static Variant AsVariant(this ushort value) + => new Variant(value); + + public static Variant AsVariant(this int value) + => new Variant(value); + + public static Variant AsVariant(this uint value) + => new Variant(value); + + public static Variant AsVariant(this long value) + => new Variant(value); + + public static Variant AsVariant(this ulong value) + => new Variant(value); + + public static Variant AsVariant(this double value) + => new Variant(value); + + public static Variant AsVariant(this string value) + => new Variant(value); + + public static Variant AsVariant(this SafeHandle value) + => new Variant(value); +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/VariantValue.cs b/src/Linux/Tmds.DBus.Protocol/VariantValue.cs new file mode 100644 index 0000000000..57af9f87dc --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/VariantValue.cs @@ -0,0 +1,795 @@ +namespace Tmds.DBus.Protocol; + +public readonly struct VariantValue : IEquatable +{ + private static readonly object Int64Type = VariantValueType.Int64; + private static readonly object UInt64Type = VariantValueType.UInt64; + private static readonly object DoubleType = VariantValueType.Double; + private readonly object? _o; + private readonly long _l; + + private const int TypeShift = 8 * 7; + private const int ArrayItemTypeShift = 8 * 0; + private const int DictionaryKeyTypeShift = 8 * 0; + private const int DictionaryValueTypeShift = 8 * 1; + private const long StripTypeMask = ~(0xffL << TypeShift); + + private const long ArrayOfByte = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.Byte << ArrayItemTypeShift); + private const long ArrayOfInt16 = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.Int16 << ArrayItemTypeShift); + private const long ArrayOfUInt16 = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.UInt16 << ArrayItemTypeShift); + private const long ArrayOfInt32 = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.Int32 << ArrayItemTypeShift); + private const long ArrayOfUInt32 = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.UInt32 << ArrayItemTypeShift); + private const long ArrayOfInt64 = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.Int64 << ArrayItemTypeShift); + private const long ArrayOfUInt64 = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.UInt64 << ArrayItemTypeShift); + private const long ArrayOfDouble = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.Double << ArrayItemTypeShift); + private const long ArrayOfString = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.String << ArrayItemTypeShift); + private const long ArrayOfObjectPath = ((long)VariantValueType.Array << TypeShift) | ((long)VariantValueType.ObjectPath << ArrayItemTypeShift); + + public VariantValueType Type + => DetermineType(); + + internal VariantValue(byte value) + { + _l = value | ((long)VariantValueType.Byte << TypeShift); + _o = null; + } + internal VariantValue(bool value) + { + _l = (value ? 1L : 0) | ((long)VariantValueType.Bool << TypeShift); + _o = null; + } + internal VariantValue(short value) + { + _l = (ushort)value | ((long)VariantValueType.Int16 << TypeShift); + _o = null; + } + internal VariantValue(ushort value) + { + _l = value | ((long)VariantValueType.UInt16 << TypeShift); + _o = null; + } + internal VariantValue(int value) + { + _l = (uint)value | ((long)VariantValueType.Int32 << TypeShift); + _o = null; + } + internal VariantValue(uint value) + { + _l = value | ((long)VariantValueType.UInt32 << TypeShift); + _o = null; + } + internal VariantValue(long value) + { + _l = value; + _o = Int64Type; + } + internal VariantValue(ulong value) + { + _l = (long)value; + _o = UInt64Type; + } + internal unsafe VariantValue(double value) + { + _l = *(long*)&value; + _o = DoubleType; + } + internal VariantValue(string value) + { + _l = (long)VariantValueType.String << TypeShift; + _o = value ?? throw new ArgumentNullException(nameof(value)); + } + internal VariantValue(ObjectPath value) + { + _l = (long)VariantValueType.ObjectPath << TypeShift; + string s = value.ToString(); + if (s.Length == 0) + { + throw new ArgumentException(nameof(value)); + } + _o = s; + } + internal VariantValue(Signature value) + { + _l = (long)VariantValueType.Signature << TypeShift; + string s = value.ToString(); + if (s.Length == 0) + { + throw new ArgumentException(nameof(value)); + } + _o = s; + } + // Array + internal VariantValue(VariantValueType itemType, VariantValue[] items) + { + Debug.Assert( + itemType != VariantValueType.Byte && + itemType != VariantValueType.Int16 && + itemType != VariantValueType.UInt16 && + itemType != VariantValueType.Int32 && + itemType != VariantValueType.UInt32 && + itemType != VariantValueType.Int64 && + itemType != VariantValueType.UInt64 && + itemType != VariantValueType.Double + ); + _l = ((long)VariantValueType.Array << TypeShift) | + ((long)itemType << ArrayItemTypeShift); + _o = items; + } + internal VariantValue(VariantValueType itemType, string[] items) + { + Debug.Assert(itemType == VariantValueType.String || itemType == VariantValueType.ObjectPath); + _l = ((long)VariantValueType.Array << TypeShift) | + ((long)itemType << ArrayItemTypeShift); + _o = items; + } + internal VariantValue(byte[] items) + { + _l = ArrayOfByte; + _o = items; + } + internal VariantValue(short[] items) + { + _l = ArrayOfInt16; + _o = items; + } + internal VariantValue(ushort[] items) + { + _l = ArrayOfUInt16; + _o = items; + } + internal VariantValue(int[] items) + { + _l = ArrayOfInt32; + _o = items; + } + internal VariantValue(uint[] items) + { + _l = ArrayOfUInt32; + _o = items; + } + internal VariantValue(long[] items) + { + _l = ArrayOfInt64; + _o = items; + } + internal VariantValue(ulong[] items) + { + _l = ArrayOfUInt64; + _o = items; + } + internal VariantValue(double[] items) + { + _l = ArrayOfDouble; + _o = items; + } + // Dictionary + internal VariantValue(VariantValueType keyType, VariantValueType valueType, KeyValuePair[] pairs) + { + _l = ((long)VariantValueType.Dictionary << TypeShift) | + ((long)keyType << DictionaryKeyTypeShift) | + ((long)valueType << DictionaryValueTypeShift); + _o = pairs; + } + // Struct + internal VariantValue(VariantValue[] fields) + { + _l = ((long)VariantValueType.Struct << TypeShift); + _o = fields; + } + // UnixFd + internal VariantValue(UnixFdCollection? fdCollection, int index) + { + _l = (long)index | ((long)VariantValueType.UnixFd << TypeShift); + _o = fdCollection; + } + + public byte GetByte() + { + EnsureTypeIs(VariantValueType.Byte); + return UnsafeGetByte(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private byte UnsafeGetByte() + { + return (byte)(_l & StripTypeMask); + } + + public bool GetBool() + { + EnsureTypeIs(VariantValueType.Bool); + return UnsafeGetBool(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private bool UnsafeGetBool() + { + return (_l & StripTypeMask) != 0; + } + + public short GetInt16() + { + EnsureTypeIs(VariantValueType.Int16); + return UnsafeGetInt16(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private short UnsafeGetInt16() + { + return (short)(_l & StripTypeMask); + } + + public ushort GetUInt16() + { + EnsureTypeIs(VariantValueType.UInt16); + return UnsafeGetUInt16(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private ushort UnsafeGetUInt16() + { + return (ushort)(_l & StripTypeMask); + } + + public int GetInt32() + { + EnsureTypeIs(VariantValueType.Int32); + return UnsafeGetInt32(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int UnsafeGetInt32() + { + return (int)(_l & StripTypeMask); + } + + public uint GetUInt32() + { + EnsureTypeIs(VariantValueType.UInt32); + return UnsafeGetUInt32(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private uint UnsafeGetUInt32() + { + return (uint)(_l & StripTypeMask); + } + + public long GetInt64() + { + EnsureTypeIs(VariantValueType.Int64); + return UnsafeGetInt64(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private long UnsafeGetInt64() + { + return _l; + } + + public ulong GetUInt64() + { + EnsureTypeIs(VariantValueType.UInt64); + return UnsafeGetUInt64(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private ulong UnsafeGetUInt64() + { + return (ulong)(_l); + } + + public string GetString() + { + EnsureTypeIs(VariantValueType.String); + return UnsafeGetString(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private string UnsafeGetString() + { + return (_o as string)!; + } + + public string GetObjectPath() + { + EnsureTypeIs(VariantValueType.ObjectPath); + return UnsafeGetString(); + } + + public string GetSignature() + { + EnsureTypeIs(VariantValueType.Signature); + return UnsafeGetString(); + } + + public double GetDouble() + { + EnsureTypeIs(VariantValueType.Double); + return UnsafeGetDouble(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe double UnsafeGetDouble() + { + double value; + *(long*)&value = _l; + return value; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private T UnsafeGet<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + { + if (typeof(T) == typeof(byte)) + { + return (T)(object)UnsafeGetByte(); + } + else if (typeof(T) == typeof(bool)) + { + return (T)(object)UnsafeGetBool(); + } + else if (typeof(T) == typeof(short)) + { + return (T)(object)UnsafeGetInt16(); + } + else if (typeof(T) == typeof(ushort)) + { + return (T)(object)UnsafeGetUInt16(); + } + else if (typeof(T) == typeof(int)) + { + return (T)(object)UnsafeGetInt32(); + } + else if (typeof(T) == typeof(uint)) + { + return (T)(object)UnsafeGetUInt32(); + } + else if (typeof(T) == typeof(long)) + { + return (T)(object)UnsafeGetInt64(); + } + else if (typeof(T) == typeof(ulong)) + { + return (T)(object)UnsafeGetUInt64(); + } + else if (typeof(T) == typeof(double)) + { + return (T)(object)UnsafeGetDouble(); + } + else if (typeof(T) == typeof(string)) + { + return (T)(object)UnsafeGetString(); + } + else if (typeof(T) == typeof(VariantValue)) + { + return (T)(object)this; + } + else if (typeof(T).IsAssignableTo(typeof(SafeHandle))) + { + return (T)(object)UnsafeReadHandle()!; + } + + ThrowCannotRetrieveAs(Type, typeof(T)); + return default!; + } + + public Dictionary GetDictionary + < + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]TKey, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]TValue + > + () + where TKey : notnull + where TValue : notnull + { + EnsureTypeIs(VariantValueType.Dictionary); + EnsureCanUnsafeGet(KeyType); + EnsureCanUnsafeGet(ValueType); + + Dictionary dict = new(); + var pairs = (_o as KeyValuePair[])!.AsSpan(); + foreach (var pair in pairs) + { + dict[pair.Key.UnsafeGet()] = pair.Value.UnsafeGet(); + } + return dict; + } + + public T[] GetArray<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + where T : notnull + { + EnsureTypeIs(VariantValueType.Array); + EnsureCanUnsafeGet(ItemType); + + // Return the array by reference when we can. + // Don't bother to make a copy in case the caller mutates the data and + // calls GetArray again to retrieve the original data. It's an unlikely scenario. + if (typeof(T) == typeof(byte)) + { + return (T[])(object)(_o as byte[])!; + } + else if (typeof(T) == typeof(short)) + { + return (T[])(object)(_o as short[])!; + } + else if (typeof(T) == typeof(int)) + { + return (T[])(object)(_o as int[])!; + } + else if (typeof(T) == typeof(long)) + { + return (T[])(object)(_o as long[])!; + } + else if (typeof(T) == typeof(ushort)) + { + return (T[])(object)(_o as ushort[])!; + } + else if (typeof(T) == typeof(uint)) + { + return (T[])(object)(_o as uint[])!; + } + else if (typeof(T) == typeof(ulong)) + { + return (T[])(object)(_o as ulong[])!; + } + else if (typeof(T) == typeof(double)) + { + return (T[])(object)(_o as double[])!; + } + else if (typeof(T) == typeof(string)) + { + return (T[])(object)(_o as string[])!; + } + else + { + var items = (_o as VariantValue[])!.AsSpan(); + T[] array = new T[items.Length]; + int i = 0; + foreach (var item in items) + { + array[i++] = item.UnsafeGet(); + } + return array; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureCanUnsafeGet(VariantValueType type) + { + if (typeof(T) == typeof(byte)) + { + EnsureTypeIs(type, VariantValueType.Byte); + } + else if (typeof(T) == typeof(bool)) + { + EnsureTypeIs(type, VariantValueType.Bool); + } + else if (typeof(T) == typeof(short)) + { + EnsureTypeIs(type, VariantValueType.Int16); + } + else if (typeof(T) == typeof(ushort)) + { + EnsureTypeIs(type, VariantValueType.UInt16); + } + else if (typeof(T) == typeof(int)) + { + EnsureTypeIs(type, VariantValueType.Int32); + } + else if (typeof(T) == typeof(uint)) + { + EnsureTypeIs(type, VariantValueType.UInt32); + } + else if (typeof(T) == typeof(long)) + { + EnsureTypeIs(type, VariantValueType.Int64); + } + else if (typeof(T) == typeof(ulong)) + { + EnsureTypeIs(type, VariantValueType.UInt64); + } + else if (typeof(T) == typeof(double)) + { + EnsureTypeIs(type, VariantValueType.Double); + } + else if (typeof(T) == typeof(string)) + { + EnsureTypeIs(type, [ VariantValueType.String, VariantValueType.Signature, VariantValueType.ObjectPath ]); + } + else if (typeof(T) == typeof(VariantValue)) + { } + else if (typeof(T).IsAssignableTo(typeof(SafeHandle))) + { + EnsureTypeIs(type, VariantValueType.UnixFd); + } + else + { + ThrowCannotRetrieveAs(type, typeof(T)); + } + } + + public T? ReadHandle< +#if NET6_0_OR_GREATER + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] +#endif + T>() where T : SafeHandle + { + EnsureTypeIs(VariantValueType.UnixFd); + return UnsafeReadHandle(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private T? UnsafeReadHandle<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]T>() + { + var handles = (UnixFdCollection?)_o; + if (handles is null) + { + return default; + } + int index = (int)_l; + return handles.ReadHandleGeneric(index); + } + + // Use for Array, Struct and Dictionary. + public int Count + { + get + { + Array? array = _o as Array; + return array?.Length ?? -1; + } + } + + // Valid for Array, Struct. + public VariantValue GetItem(int i) + { + if (Type == VariantValueType.Array) + { + switch (_l) + { + case ArrayOfByte: + return new VariantValue((_o as byte[])![i]); + case ArrayOfInt16: + return new VariantValue((_o as short[])![i]); + case ArrayOfUInt16: + return new VariantValue((_o as ushort[])![i]); + case ArrayOfInt32: + return new VariantValue((_o as int[])![i]); + case ArrayOfUInt32: + return new VariantValue((_o as uint[])![i]); + case ArrayOfInt64: + return new VariantValue((_o as long[])![i]); + case ArrayOfUInt64: + return new VariantValue((_o as ulong[])![i]); + case ArrayOfDouble: + return new VariantValue((_o as double[])![i]); + case ArrayOfString: + case ArrayOfObjectPath: + return new VariantValue((_o as string[])![i]); + } + } + var values = _o as VariantValue[]; + if (_o is null) + { + ThrowCannotRetrieveAs(Type, [VariantValueType.Array, VariantValueType.Struct]); + } + return values![i]; + } + + // Valid for Dictionary. + public KeyValuePair GetDictionaryEntry(int i) + { + var values = _o as KeyValuePair[]; + if (_o is null) + { + ThrowCannotRetrieveAs(Type, VariantValueType.Dictionary); + } + return values![i]; + } + + // implicit conversion to VariantValue for basic D-Bus types (except Unix_FD). + public static implicit operator VariantValue(byte value) + => new VariantValue(value); + public static implicit operator VariantValue(bool value) + => new VariantValue(value); + public static implicit operator VariantValue(short value) + => new VariantValue(value); + public static implicit operator VariantValue(ushort value) + => new VariantValue(value); + public static implicit operator VariantValue(int value) + => new VariantValue(value); + public static implicit operator VariantValue(uint value) + => new VariantValue(value); + public static implicit operator VariantValue(long value) + => new VariantValue(value); + public static implicit operator VariantValue(ulong value) + => new VariantValue(value); + public static implicit operator VariantValue(double value) + => new VariantValue(value); + public static implicit operator VariantValue(string value) + => new VariantValue(value); + public static implicit operator VariantValue(ObjectPath value) + => new VariantValue(value); + public static implicit operator VariantValue(Signature value) + => new VariantValue(value); + + public VariantValueType ItemType + => DetermineInnerType(VariantValueType.Array, ArrayItemTypeShift); + + public VariantValueType KeyType + => DetermineInnerType(VariantValueType.Dictionary, DictionaryKeyTypeShift); + + public VariantValueType ValueType + => DetermineInnerType(VariantValueType.Dictionary, DictionaryValueTypeShift); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void EnsureTypeIs(VariantValueType expected) + => EnsureTypeIs(Type, expected); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureTypeIs(VariantValueType actual, VariantValueType expected) + { + if (actual != expected) + { + ThrowCannotRetrieveAs(actual, expected); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureTypeIs(VariantValueType actual, VariantValueType[] expected) + { + if (Array.IndexOf(expected, actual) == -1) + { + ThrowCannotRetrieveAs(actual, expected); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private VariantValueType DetermineInnerType(VariantValueType outer, int typeShift) + { + VariantValueType type = DetermineType(); + return type == outer ? (VariantValueType)((_l >> typeShift) & 0xff) : VariantValueType.Invalid; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private VariantValueType DetermineType() + { + // For most types, we store the VariantValueType in the highest byte of the long. + // Except for some types, like Int64, for which we store the value allocation free + // in the long, and use the object field to store the type. + VariantValueType type = (VariantValueType)(_l >> TypeShift); + if (_o is not null) + { + if (_o.GetType() == typeof(VariantValueType)) + { + type = (VariantValueType)_o; + } + } + return type; + } + + private static void ThrowCannotRetrieveAs(VariantValueType from, VariantValueType to) + => ThrowCannotRetrieveAs(from.ToString(), [ to.ToString() ]); + + private static void ThrowCannotRetrieveAs(VariantValueType from, VariantValueType[] to) + => ThrowCannotRetrieveAs(from.ToString(), to.Select(expected => expected.ToString())); + + private static void ThrowCannotRetrieveAs(string from, string to) + => ThrowCannotRetrieveAs(from, [ to ]); + + private static void ThrowCannotRetrieveAs(VariantValueType from, Type to) + => ThrowCannotRetrieveAs(from.ToString(), to.FullName ?? "??"); + + private static void ThrowCannotRetrieveAs(string from, IEnumerable to) + { + throw new InvalidOperationException($"Type {from} can not be retrieved as {string.Join("/", to)}."); + } + + public override string ToString() + => ToString(includeTypeSuffix: true); + + public string ToString(bool includeTypeSuffix) + { + // This is implemented so something user-friendly shows in the debugger. + // By overriding the ToString method, it will also affect generic types like KeyValueType that call ToString. + VariantValueType type = Type; + switch (type) + { + case VariantValueType.Byte: + return $"{GetByte()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.Bool: + return $"{GetBool()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.Int16: + return $"{GetInt16()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.UInt16: + return $"{GetUInt16()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.Int32: + return $"{GetInt32()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.UInt32: + return $"{GetUInt32()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.Int64: + return $"{GetInt64()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.UInt64: + return $"{GetUInt64()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.Double: + return $"{GetDouble()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.String: + return $"{GetString()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.ObjectPath: + return $"{GetObjectPath()}{TypeSuffix(includeTypeSuffix, type)}"; + case VariantValueType.Signature: + return $"{GetSignature()}{TypeSuffix(includeTypeSuffix, type)}"; + + case VariantValueType.Array: + return $"[{nameof(VariantValueType.Array)}<{ItemType}>, Count={Count}]"; + case VariantValueType.Struct: + var values = (_o as VariantValue[]) ?? Array.Empty(); + return $"({ + string.Join(", ", values.Select(v => v.ToString(includeTypeSuffix: false))) + }){( + !includeTypeSuffix ? "" + : $" [{nameof(VariantValueType.Struct)}]<{ + string.Join(", ", values.Select(v => v.Type)) + }>]")})"; + case VariantValueType.Dictionary: + return $"[{nameof(VariantValueType.Dictionary)}<{KeyType}, {ValueType}>, Count={Count}]"; + case VariantValueType.UnixFd: + return $"[{nameof(VariantValueType.UnixFd)}]"; + + case VariantValueType.Invalid: + return $"[{nameof(VariantValueType.Invalid)}]"; + case VariantValueType.VariantValue: // note: No VariantValue returns this as its Type. + default: + return $"[?{Type}?]"; + } + } + + static string TypeSuffix(bool includeTypeSuffix, VariantValueType type) + => includeTypeSuffix ? $" [{type}]" : ""; + + public static bool operator==(VariantValue lhs, VariantValue rhs) + => lhs.Equals(rhs); + + public static bool operator!=(VariantValue lhs, VariantValue rhs) + => !lhs.Equals(rhs); + + public override bool Equals(object? obj) + { + if (obj is not null && obj.GetType() == typeof(VariantValue)) + { + return ((VariantValue)obj).Equals(this); + } + return false; + } + + public override int GetHashCode() + { +#if NETSTANDARD2_0 + return _l.GetHashCode() + 17 * (_o?.GetHashCode() ?? 0); +#else + return HashCode.Combine(_l, _o); +#endif + } + + public bool Equals(VariantValue other) + { + if (_l == other._l && object.ReferenceEquals(_o, other._o)) + { + return true; + } + VariantValueType type = Type; + if (type != other.Type) + { + return false; + } + switch (type) + { + case VariantValueType.String: + case VariantValueType.ObjectPath: + case VariantValueType.Signature: + return (_o as string)!.Equals(other._o as string, StringComparison.Ordinal); + } + // Always return false for composite types and handles. + return false; + } +} \ No newline at end of file diff --git a/src/Linux/Tmds.DBus.Protocol/VariantValueType.cs b/src/Linux/Tmds.DBus.Protocol/VariantValueType.cs new file mode 100644 index 0000000000..544f7e366b --- /dev/null +++ b/src/Linux/Tmds.DBus.Protocol/VariantValueType.cs @@ -0,0 +1,30 @@ +namespace Tmds.DBus.Protocol; + +public enum VariantValueType +{ + Invalid = 0, + + // VariantValue is used for a variant for which we read the value + // and no longer track its signature. + VariantValue = 1, + + // Match the DBusType values for easy conversion. + Byte = DBusType.Byte, + Bool = DBusType.Bool, + Int16 = DBusType.Int16, + UInt16 = DBusType.UInt16, + Int32 = DBusType.Int32, + UInt32 = DBusType.UInt32, + Int64 = DBusType.Int64, + UInt64 = DBusType.UInt64, + Double = DBusType.Double, + String = DBusType.String, + ObjectPath = DBusType.ObjectPath, + Signature = DBusType.Signature, + Array = DBusType.Array, + Struct = DBusType.Struct, + Dictionary = DBusType.DictEntry, + UnixFd = DBusType.UnixFd, + // We don't need this : variants are resolved into the VariantValue. + // Variant = DBusType.Variant, +} \ No newline at end of file