csharpc-sharpdotnetxamlavaloniauicross-platformcross-platform-xamlavaloniaguimulti-platformuser-interfacedotnetcore
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1290 lines
42 KiB
1290 lines
42 KiB
using System.Net;
|
|
using System.Net.Sockets;
|
|
using System.Threading.Tasks.Sources;
|
|
|
|
#pragma warning disable VSTHRD100 // Avoid "async void" methods
|
|
|
|
namespace Tmds.DBus.Protocol;
|
|
|
|
class DBusConnection : IDisposable
|
|
{
|
|
private delegate void MessageReceivedHandler(Exception? exception, Message message, object? state);
|
|
|
|
sealed class MyValueTaskSource<T> : IValueTaskSource<T>, IValueTaskSource
|
|
{
|
|
private ManualResetValueTaskSourceCore<T> _core;
|
|
private volatile bool _continuationSet;
|
|
|
|
public void SetResult(T result)
|
|
{
|
|
// Ensure we complete the Task from the read loop.
|
|
SpinWait wait = new();
|
|
while (!_continuationSet)
|
|
{
|
|
wait.SpinOnce();
|
|
}
|
|
_core.SetResult(result);
|
|
}
|
|
|
|
public void SetException(Exception exception) => _core.SetException(exception);
|
|
|
|
public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token);
|
|
|
|
public void OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
|
|
{
|
|
_core.OnCompleted(continuation, state, token, flags);
|
|
_continuationSet = true;
|
|
}
|
|
|
|
T IValueTaskSource<T>.GetResult(short token) => _core.GetResult(token);
|
|
|
|
void IValueTaskSource.GetResult(short token) => _core.GetResult(token);
|
|
}
|
|
|
|
enum ConnectionState
|
|
{
|
|
Created,
|
|
Connecting,
|
|
Connected,
|
|
Disconnected
|
|
}
|
|
|
|
delegate void MessageHandlerDelegate(Exception? exception, Message message, object? state1, object? state2, object? state3);
|
|
|
|
readonly struct MessageHandler
|
|
{
|
|
public MessageHandler(MessageHandlerDelegate handler, object? state1 = null, object? state2 = null, object? state3 = null)
|
|
{
|
|
_delegate = handler;
|
|
_state1 = state1;
|
|
_state2 = state2;
|
|
_state3 = state3;
|
|
}
|
|
|
|
public void Invoke(Exception? exception, Message message)
|
|
{
|
|
_delegate(exception, message, _state1, _state2, _state3);
|
|
}
|
|
|
|
public bool HasValue => _delegate is not null;
|
|
|
|
private readonly MessageHandlerDelegate _delegate;
|
|
private readonly object? _state1;
|
|
private readonly object? _state2;
|
|
private readonly object? _state3;
|
|
}
|
|
|
|
delegate void MessageHandlerDelegate4(Exception? exception, Message message, object? state1, object? state2, object? state3, object? state4);
|
|
|
|
readonly struct MessageHandler4
|
|
{
|
|
public MessageHandler4(MessageHandlerDelegate4 handler, object? state1 = null, object? state2 = null, object? state3 = null, object? state4 = null)
|
|
{
|
|
_delegate = handler;
|
|
_state1 = state1;
|
|
_state2 = state2;
|
|
_state3 = state3;
|
|
_state4 = state4;
|
|
}
|
|
|
|
public void Invoke(Exception? exception, Message message)
|
|
{
|
|
_delegate(exception, message, _state1, _state2, _state3, _state4);
|
|
}
|
|
|
|
public bool HasValue => _delegate is not null;
|
|
|
|
private readonly MessageHandlerDelegate4 _delegate;
|
|
private readonly object? _state1;
|
|
private readonly object? _state2;
|
|
private readonly object? _state3;
|
|
private readonly object? _state4;
|
|
}
|
|
|
|
private readonly object _gate = new object();
|
|
private readonly Connection _parentConnection;
|
|
private readonly Dictionary<uint, MessageHandler> _pendingCalls;
|
|
private readonly CancellationTokenSource _connectCts;
|
|
private readonly Dictionary<string, MatchMaker> _matchMakers;
|
|
private readonly List<Observer> _matchedObservers;
|
|
private readonly PathNodeDictionary _pathNodes;
|
|
private readonly string _machineId;
|
|
|
|
private IMessageStream? _messageStream;
|
|
private ConnectionState _state;
|
|
private Exception? _disconnectReason;
|
|
private string? _localName;
|
|
private Message? _currentMessage;
|
|
private Observer? _currentObserver;
|
|
private SynchronizationContext? _currentSynchronizationContext;
|
|
private TaskCompletionSource<Exception?>? _disconnectedTcs;
|
|
private CancellationTokenSource _abortedCts;
|
|
private bool _isMonitor;
|
|
private Action<Exception?, DisposableMessage>? _monitorHandler;
|
|
|
|
public string? UniqueName => _localName;
|
|
|
|
public Exception DisconnectReason
|
|
{
|
|
get => _disconnectReason ?? new ObjectDisposedException(GetType().FullName);
|
|
set => Interlocked.CompareExchange(ref _disconnectReason, value, null);
|
|
}
|
|
|
|
public bool RemoteIsBus => _localName is not null;
|
|
|
|
public DBusConnection(Connection parent, string machineId)
|
|
{
|
|
_parentConnection = parent;
|
|
_connectCts = new();
|
|
_pendingCalls = new();
|
|
_matchMakers = new();
|
|
_matchedObservers = new();
|
|
_pathNodes = new();
|
|
_machineId = machineId;
|
|
_abortedCts = new();
|
|
}
|
|
|
|
// For tests.
|
|
internal void Connect(IMessageStream stream)
|
|
{
|
|
_messageStream = stream;
|
|
|
|
stream.ReceiveMessages(
|
|
static (Exception? exception, Message message, DBusConnection connection) =>
|
|
connection.HandleMessages(exception, message), this);
|
|
|
|
_state = ConnectionState.Connected;
|
|
}
|
|
|
|
public async ValueTask ConnectAsync(string address, string? userId, bool supportsFdPassing, CancellationToken cancellationToken)
|
|
{
|
|
_state = ConnectionState.Connecting;
|
|
Exception? firstException = null;
|
|
|
|
AddressParser.AddressEntry addr = default;
|
|
while (AddressParser.TryGetNextEntry(address, ref addr))
|
|
{
|
|
Socket? socket = null;
|
|
EndPoint? endpoint = null;
|
|
Guid guid = default;
|
|
|
|
if (AddressParser.IsType(addr, "unix"))
|
|
{
|
|
AddressParser.ParseUnixProperties(addr, out string path, out guid);
|
|
socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified);
|
|
endpoint = new UnixDomainSocketEndPoint(path);
|
|
}
|
|
else if (AddressParser.IsType(addr, "tcp"))
|
|
{
|
|
AddressParser.ParseTcpProperties(addr, out string host, out int? port, out guid);
|
|
if (!port.HasValue)
|
|
{
|
|
throw new ArgumentException("port");
|
|
}
|
|
socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
|
|
endpoint = new DnsEndPoint(host, port.Value);
|
|
}
|
|
|
|
if (socket is null)
|
|
{
|
|
continue;
|
|
}
|
|
|
|
try
|
|
{
|
|
await socket.ConnectAsync(endpoint!, cancellationToken).ConfigureAwait(false);
|
|
|
|
MessageStream stream;
|
|
lock (_gate)
|
|
{
|
|
if (_state != ConnectionState.Connecting)
|
|
{
|
|
throw new DisconnectedException(DisconnectReason);
|
|
}
|
|
_messageStream = stream = new MessageStream(socket);
|
|
}
|
|
|
|
await stream.DoClientAuthAsync(guid, userId, supportsFdPassing).ConfigureAwait(false);
|
|
|
|
stream.ReceiveMessages(
|
|
static (Exception? exception, Message message, DBusConnection connection) =>
|
|
connection.HandleMessages(exception, message), this);
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_state != ConnectionState.Connecting)
|
|
{
|
|
throw new DisconnectedException(DisconnectReason);
|
|
}
|
|
_state = ConnectionState.Connected;
|
|
}
|
|
|
|
_localName = await GetLocalNameAsync().ConfigureAwait(false);
|
|
|
|
return;
|
|
}
|
|
catch (Exception exception)
|
|
{
|
|
socket.Dispose();
|
|
firstException ??= exception;
|
|
}
|
|
}
|
|
|
|
if (firstException is not null)
|
|
{
|
|
throw firstException;
|
|
}
|
|
|
|
throw new ArgumentException("No addresses were found", nameof(address));
|
|
}
|
|
|
|
private async Task<string?> GetLocalNameAsync()
|
|
{
|
|
MyValueTaskSource<string?> vts = new();
|
|
|
|
await CallMethodAsync(
|
|
message: CreateHelloMessage(),
|
|
static (Exception? exception, Message message, object? state) =>
|
|
{
|
|
var vtsState = (MyValueTaskSource<string?>)state!;
|
|
|
|
if (exception is not null)
|
|
{
|
|
vtsState.SetException(exception);
|
|
}
|
|
else if (message.MessageType == MessageType.MethodReturn)
|
|
{
|
|
vtsState.SetResult(message.GetBodyReader().ReadString().ToString());
|
|
}
|
|
else
|
|
{
|
|
vtsState.SetResult(null);
|
|
}
|
|
}, vts).ConfigureAwait(false);
|
|
|
|
return await new ValueTask<string?>(vts, token: 0).ConfigureAwait(false);
|
|
|
|
MessageBuffer CreateHelloMessage()
|
|
{
|
|
using var writer = GetMessageWriter();
|
|
|
|
writer.WriteMethodCallHeader(
|
|
destination: "org.freedesktop.DBus",
|
|
path: "/org/freedesktop/DBus",
|
|
@interface: "org.freedesktop.DBus",
|
|
member: "Hello");
|
|
|
|
return writer.CreateMessage();
|
|
}
|
|
}
|
|
|
|
private async void HandleMessages(Exception? exception, Message message)
|
|
{
|
|
if (exception is not null)
|
|
{
|
|
_parentConnection.Disconnect(exception, this);
|
|
}
|
|
else
|
|
{
|
|
try
|
|
{
|
|
bool returnMessageToPool = true;
|
|
MessageHandler pendingCall = default;
|
|
IMethodHandler? methodHandler = null;
|
|
Action<Exception?, DisposableMessage>? monitor = null;
|
|
bool isMethodCall = message.MessageType == MessageType.MethodCall;
|
|
MethodContext? methodContext = null;
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_state == ConnectionState.Disconnected)
|
|
{
|
|
return;
|
|
}
|
|
|
|
monitor = _monitorHandler;
|
|
|
|
if (monitor is null)
|
|
{
|
|
if (message.ReplySerial.HasValue)
|
|
{
|
|
_pendingCalls.Remove(message.ReplySerial.Value, out pendingCall);
|
|
}
|
|
|
|
foreach (var matchMaker in _matchMakers.Values)
|
|
{
|
|
if (matchMaker.Matches(message))
|
|
{
|
|
_matchedObservers.AddRange(matchMaker.Observers);
|
|
}
|
|
}
|
|
|
|
if (isMethodCall)
|
|
{
|
|
methodContext = new MethodContext(_parentConnection, message, _abortedCts.Token); // TODO: pool.
|
|
|
|
if (message.PathIsSet)
|
|
{
|
|
if (_pathNodes.TryGetValue(message.PathAsString!, out PathNode? node))
|
|
{
|
|
methodHandler = node.MethodHandler;
|
|
|
|
bool isDBusIntrospect = message.Member.SequenceEqual("Introspect"u8) &&
|
|
message.Interface.SequenceEqual("org.freedesktop.DBus.Introspectable"u8);
|
|
methodContext.IsDBusIntrospectRequest = isDBusIntrospect;
|
|
if (isDBusIntrospect)
|
|
{
|
|
node.CopyChildNamesTo(methodContext);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (monitor is not null)
|
|
{
|
|
lock (monitor)
|
|
{
|
|
if (_monitorHandler is not null)
|
|
{
|
|
returnMessageToPool = false;
|
|
monitor(null, new DisposableMessage(message));
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (_matchedObservers.Count != 0)
|
|
{
|
|
foreach (var observer in _matchedObservers)
|
|
{
|
|
observer.Emit(message);
|
|
}
|
|
_matchedObservers.Clear();
|
|
}
|
|
|
|
if (pendingCall.HasValue)
|
|
{
|
|
pendingCall.Invoke(null, message);
|
|
}
|
|
|
|
if (isMethodCall)
|
|
{
|
|
Debug.Assert(methodContext is not null);
|
|
if (methodHandler is not null)
|
|
{
|
|
// Suppress methodContext nullability warnings.
|
|
#if NETSTANDARD2_0
|
|
#pragma warning disable CS8604
|
|
#endif
|
|
bool runHandlerSynchronously = methodHandler.RunMethodHandlerSynchronously(message);
|
|
if (runHandlerSynchronously)
|
|
{
|
|
await methodHandler.HandleMethodAsync(methodContext).ConfigureAwait(false);
|
|
HandleNoReplySent(methodContext);
|
|
}
|
|
else
|
|
{
|
|
returnMessageToPool = false;
|
|
RunMethodHandler(methodHandler, methodContext);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
HandleNoReplySent(methodContext);
|
|
}
|
|
#if NETSTANDARD2_0
|
|
#pragma warning restore CS8604
|
|
#endif
|
|
}
|
|
}
|
|
|
|
if (returnMessageToPool)
|
|
{
|
|
message.ReturnToPool();
|
|
}
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_parentConnection.Disconnect(ex, this);
|
|
}
|
|
}
|
|
}
|
|
|
|
private void HandleNoReplySent(MethodContext context)
|
|
{
|
|
if (context.ReplySent || context.NoReplyExpected)
|
|
{
|
|
return;
|
|
}
|
|
|
|
if (context.IsDBusIntrospectRequest)
|
|
{
|
|
context.ReplyIntrospectXml(interfaceXmls: []);
|
|
return;
|
|
}
|
|
|
|
var request = context.Request;
|
|
|
|
if (request.Interface.SequenceEqual("org.freedesktop.DBus.Peer"u8))
|
|
{
|
|
if (request.Member.SequenceEqual("Ping"u8))
|
|
{
|
|
using var writer = context.CreateReplyWriter(null);
|
|
context.Reply(writer.CreateMessage());
|
|
return;
|
|
}
|
|
else if (request.Member.SequenceEqual("GetMachineId"u8))
|
|
{
|
|
using var writer = context.CreateReplyWriter("s");
|
|
writer.WriteString(_machineId);
|
|
context.Reply(writer.CreateMessage());
|
|
return;
|
|
}
|
|
}
|
|
|
|
context.ReplyError("org.freedesktop.DBus.Error.UnknownMethod",
|
|
$"Method \"{request.MemberAsString}\" with signature \"{request.SignatureAsString}\" on interface \"{request.InterfaceAsString}\" doesn't exist");
|
|
}
|
|
|
|
private async void RunMethodHandler(IMethodHandler methodHandler, MethodContext context)
|
|
{
|
|
try
|
|
{
|
|
await methodHandler.HandleMethodAsync(context).ConfigureAwait(false);
|
|
HandleNoReplySent(context);
|
|
context.Request.ReturnToPool();
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_parentConnection.Disconnect(ex, this);
|
|
}
|
|
}
|
|
|
|
private void EmitOnSynchronizationContextHelper(Observer observer, SynchronizationContext synchronizationContext, Message message)
|
|
{
|
|
_currentMessage = message;
|
|
_currentObserver = observer;
|
|
_currentSynchronizationContext = synchronizationContext;
|
|
|
|
#pragma warning disable VSTHRD001 // Await JoinableTaskFactory.SwitchToMainThreadAsync() to switch to the UI thread instead of APIs that can deadlock or require specifying a priority.
|
|
// note: Send blocks the current thread until the SynchronizationContext ran the delegate.
|
|
synchronizationContext.Send(static o =>
|
|
{
|
|
SynchronizationContext? previousContext = SynchronizationContext.Current;
|
|
try
|
|
{
|
|
DBusConnection conn = (DBusConnection)o!;
|
|
SynchronizationContext.SetSynchronizationContext(conn._currentSynchronizationContext);
|
|
conn._currentObserver!.InvokeHandler(conn._currentMessage!);
|
|
}
|
|
finally
|
|
{
|
|
SynchronizationContext.SetSynchronizationContext(previousContext);
|
|
}
|
|
}, this);
|
|
|
|
_currentMessage = null;
|
|
_currentObserver = null;
|
|
_currentSynchronizationContext = null;
|
|
}
|
|
|
|
public void UpdateMethodHandlers<T>(Action<IMethodHandlerDictionary, T> update, T state)
|
|
{
|
|
lock (_gate)
|
|
{
|
|
update(_pathNodes, state);
|
|
}
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
Action<Exception?, DisposableMessage>? monitor = null;
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_state == ConnectionState.Disconnected)
|
|
{
|
|
return;
|
|
}
|
|
_state = ConnectionState.Disconnected;
|
|
monitor = _monitorHandler;
|
|
}
|
|
|
|
Exception disconnectReason = DisconnectReason;
|
|
|
|
_messageStream?.Close(disconnectReason);
|
|
|
|
_abortedCts.Cancel();
|
|
|
|
if (_pendingCalls is not null)
|
|
{
|
|
foreach (var pendingCall in _pendingCalls.Values)
|
|
{
|
|
pendingCall.Invoke(new DisconnectedException(disconnectReason), null!);
|
|
}
|
|
_pendingCalls.Clear();
|
|
}
|
|
|
|
foreach (var matchMaker in _matchMakers.Values)
|
|
{
|
|
foreach (var observer in matchMaker.Observers)
|
|
{
|
|
bool emitException = !object.ReferenceEquals(disconnectReason, Connection.DisposedException) ||
|
|
observer.EmitOnConnectionDispose;
|
|
Exception? exception = emitException ? new DisconnectedException(disconnectReason) : null;
|
|
observer.Dispose(exception, removeObserver: false);
|
|
}
|
|
}
|
|
_matchMakers.Clear();
|
|
|
|
if (monitor is not null)
|
|
{
|
|
lock (monitor)
|
|
{
|
|
_monitorHandler = null;
|
|
monitor(new DisconnectedException(disconnectReason), new DisposableMessage(null));
|
|
}
|
|
}
|
|
|
|
_disconnectedTcs?.SetResult(GetWaitForDisconnectException());
|
|
}
|
|
|
|
private ValueTask CallMethodAsync(MessageBuffer message, MessageReceivedHandler returnHandler, object? state)
|
|
{
|
|
MessageHandlerDelegate fn = static (Exception? exception, Message message, object? state1, object? state2, object? state3) =>
|
|
{
|
|
((MessageReceivedHandler)state1!)(exception, message, state2);
|
|
};
|
|
MessageHandler handler = new(fn, returnHandler, state);
|
|
|
|
return CallMethodAsync(message, handler);
|
|
}
|
|
|
|
private async ValueTask CallMethodAsync(MessageBuffer message, MessageHandler handler)
|
|
{
|
|
bool messageSent = false;
|
|
try
|
|
{
|
|
lock (_gate)
|
|
{
|
|
if (_state != ConnectionState.Connected)
|
|
{
|
|
throw new DisconnectedException(DisconnectReason!);
|
|
}
|
|
if (_isMonitor)
|
|
{
|
|
throw new InvalidOperationException("Cannot send messages on monitor connection.");
|
|
}
|
|
if ((message.MessageFlags & MessageFlags.NoReplyExpected) == 0)
|
|
{
|
|
_pendingCalls.Add(message.Serial, handler);
|
|
}
|
|
}
|
|
|
|
messageSent = await _messageStream!.TrySendMessageAsync(message).ConfigureAwait(false);
|
|
}
|
|
finally
|
|
{
|
|
if (!messageSent)
|
|
{
|
|
message.ReturnToPool();
|
|
}
|
|
}
|
|
}
|
|
|
|
public async Task<T> CallMethodAsync<T>(MessageBuffer message, MessageValueReader<T> valueReader, object? state = null)
|
|
{
|
|
MessageHandlerDelegate fn = static (Exception? exception, Message message, object? state1, object? state2, object? state3) =>
|
|
{
|
|
var valueReaderState = (MessageValueReader<T>)state1!;
|
|
var vtsState = (MyValueTaskSource<T>)state2!;
|
|
|
|
if (exception is not null)
|
|
{
|
|
vtsState.SetException(exception);
|
|
}
|
|
else if (message.MessageType == MessageType.MethodReturn)
|
|
{
|
|
try
|
|
{
|
|
vtsState.SetResult(valueReaderState(message, state3));
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
vtsState.SetException(ex);
|
|
}
|
|
}
|
|
else if (message.MessageType == MessageType.Error)
|
|
{
|
|
vtsState.SetException(CreateDBusExceptionForErrorMessage(message));
|
|
}
|
|
else
|
|
{
|
|
vtsState.SetException(new ProtocolException($"Unexpected reply type: {message.MessageType}."));
|
|
}
|
|
};
|
|
|
|
MyValueTaskSource<T> vts = new();
|
|
MessageHandler handler = new(fn, valueReader, vts, state);
|
|
|
|
await CallMethodAsync(message, handler).ConfigureAwait(false);
|
|
|
|
return await new ValueTask<T>(vts, 0).ConfigureAwait(false);
|
|
}
|
|
|
|
public async Task CallMethodAsync(MessageBuffer message)
|
|
{
|
|
MyValueTaskSource<object?> vts = new();
|
|
|
|
await CallMethodAsync(message,
|
|
static (Exception? exception, Message message, object? state) => CompleteCallValueTaskSource(exception, message, state), vts).ConfigureAwait(false);
|
|
|
|
await new ValueTask(vts, 0).ConfigureAwait(false);
|
|
}
|
|
|
|
private static void CompleteCallValueTaskSource(Exception? exception, Message message, object? vts)
|
|
{
|
|
var vtsState = (MyValueTaskSource<object?>)vts!;
|
|
|
|
if (exception is not null)
|
|
{
|
|
vtsState.SetException(exception);
|
|
}
|
|
else if (message.MessageType == MessageType.MethodReturn)
|
|
{
|
|
vtsState.SetResult(null);
|
|
}
|
|
else if (message.MessageType == MessageType.Error)
|
|
{
|
|
vtsState.SetException(CreateDBusExceptionForErrorMessage(message));
|
|
}
|
|
else
|
|
{
|
|
vtsState.SetException(new ProtocolException($"Unexpected reply type: {message.MessageType}."));
|
|
}
|
|
}
|
|
|
|
private static DBusException CreateDBusExceptionForErrorMessage(Message message)
|
|
{
|
|
string errorName = message.ErrorNameAsString ?? "<<No ErrorName>>.";
|
|
string errMessage = errorName;
|
|
if (message.SignatureIsSet && message.Signature.Length > 0 && (DBusType)message.Signature[0] == DBusType.String)
|
|
{
|
|
errMessage = message.GetBodyReader().ReadString();
|
|
}
|
|
return new DBusException(errorName, errMessage);
|
|
}
|
|
|
|
public async Task BecomeMonitorAsync(Action<Exception?, DisposableMessage> handler, IEnumerable<MatchRule>? rules)
|
|
{
|
|
Task reply;
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_state != ConnectionState.Connected)
|
|
{
|
|
throw new DisconnectedException(DisconnectReason!);
|
|
}
|
|
if (!RemoteIsBus)
|
|
{
|
|
throw new InvalidOperationException("The remote is not a bus.");
|
|
}
|
|
if (_matchMakers.Count != 0)
|
|
{
|
|
throw new InvalidOperationException("The connection has observers.");
|
|
}
|
|
if (_pendingCalls.Count != 0)
|
|
{
|
|
throw new InvalidOperationException("The connection has pending method calls.");
|
|
}
|
|
|
|
HashSet<string>? ruleStrings = null;
|
|
if (rules is not null)
|
|
{
|
|
ruleStrings = new();
|
|
foreach (var rule in rules)
|
|
{
|
|
ruleStrings.Add(rule.ToString());
|
|
}
|
|
}
|
|
|
|
reply = CallMethodAsync(CreateMessage(ruleStrings));
|
|
_isMonitor = true;
|
|
}
|
|
|
|
try
|
|
{
|
|
await reply.ConfigureAwait(false);
|
|
lock (_gate)
|
|
{
|
|
_messageStream!.BecomeMonitor();
|
|
_monitorHandler = handler;
|
|
}
|
|
}
|
|
catch
|
|
{
|
|
lock (_gate)
|
|
{
|
|
_isMonitor = false;
|
|
}
|
|
|
|
throw;
|
|
}
|
|
|
|
MessageBuffer CreateMessage(IEnumerable<string>? rules)
|
|
{
|
|
using var writer = GetMessageWriter();
|
|
writer.WriteMethodCallHeader(
|
|
destination: Connection.DBusServiceName,
|
|
path: Connection.DBusObjectPath,
|
|
@interface: "org.freedesktop.DBus.Monitoring",
|
|
signature: "asu",
|
|
member: "BecomeMonitor");
|
|
writer.WriteArray(rules ?? Array.Empty<string>());
|
|
writer.WriteUInt32(0);
|
|
return writer.CreateMessage();
|
|
}
|
|
}
|
|
|
|
public ValueTask<IDisposable> AddMatchAsync<T>(SynchronizationContext? synchronizationContext, MatchRule rule, MessageValueReader<T> valueReader, Action<Exception?, T, object?, object?> valueHandler, object? readerState, object? handlerState, ObserverFlags flags)
|
|
{
|
|
MessageHandlerDelegate4 fn = static (Exception? exception, Message message, object? reader, object? handler, object? rs, object? hs) =>
|
|
{
|
|
var valueHandlerState = (Action<Exception?, T, object?, object?>)handler!;
|
|
if (exception is not null)
|
|
{
|
|
valueHandlerState(exception, default(T)!, rs, hs);
|
|
}
|
|
else
|
|
{
|
|
var valueReaderState = (MessageValueReader<T>)reader!;
|
|
T value = valueReaderState(message, rs);
|
|
valueHandlerState(null, value, rs, hs);
|
|
}
|
|
};
|
|
|
|
return AddMatchAsync(synchronizationContext, rule, new(fn, valueReader, valueHandler, readerState, handlerState), flags);
|
|
}
|
|
|
|
private async ValueTask<IDisposable> AddMatchAsync(SynchronizationContext? synchronizationContext, MatchRule rule, MessageHandler4 handler, ObserverFlags flags)
|
|
{
|
|
MatchRuleData data = rule.Data;
|
|
MatchMaker? matchMaker;
|
|
string ruleString;
|
|
Observer observer;
|
|
MessageBuffer? addMatchMessage = null;
|
|
bool subscribe;
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_state != ConnectionState.Connected)
|
|
{
|
|
throw new DisconnectedException(DisconnectReason!);
|
|
}
|
|
if (!RemoteIsBus)
|
|
{
|
|
flags |= ObserverFlags.NoSubscribe;
|
|
}
|
|
if (_isMonitor)
|
|
{
|
|
throw new InvalidOperationException("Cannot add subscriptions on a monitor connection.");
|
|
}
|
|
|
|
ruleString = data.GetRuleString();
|
|
|
|
if (!_matchMakers.TryGetValue(ruleString, out matchMaker))
|
|
{
|
|
matchMaker = new MatchMaker(this, ruleString, data);
|
|
_matchMakers.Add(ruleString, matchMaker);
|
|
}
|
|
|
|
observer = new Observer(synchronizationContext, matchMaker, handler, flags);
|
|
matchMaker.Observers.Add(observer);
|
|
|
|
subscribe = observer.Subscribes;
|
|
bool sendMessage = subscribe && matchMaker.AddMatchTcs is null;
|
|
if (sendMessage)
|
|
{
|
|
addMatchMessage = CreateAddMatchMessage(matchMaker.RuleString);
|
|
matchMaker.AddMatchTcs = new();
|
|
|
|
MessageHandlerDelegate fn = static (Exception? exception, Message message, object? state1, object? state2, object? state3) =>
|
|
{
|
|
var mm = (MatchMaker)state1!;
|
|
if (message.MessageType == MessageType.MethodReturn)
|
|
{
|
|
mm.HasSubscribed = true;
|
|
}
|
|
CompleteCallValueTaskSource(exception, message, mm.AddMatchTcs!);
|
|
};
|
|
|
|
_pendingCalls.Add(addMatchMessage.Serial, new(fn, matchMaker));
|
|
}
|
|
}
|
|
|
|
if (subscribe)
|
|
{
|
|
if (addMatchMessage is not null)
|
|
{
|
|
if (!await _messageStream!.TrySendMessageAsync(addMatchMessage).ConfigureAwait(false))
|
|
{
|
|
addMatchMessage.ReturnToPool();
|
|
}
|
|
}
|
|
|
|
try
|
|
{
|
|
await matchMaker.AddMatchTask!.ConfigureAwait(false);
|
|
}
|
|
catch
|
|
{
|
|
observer.Dispose(exception: null);
|
|
|
|
throw;
|
|
}
|
|
}
|
|
|
|
return observer;
|
|
|
|
MessageBuffer CreateAddMatchMessage(string ruleString)
|
|
{
|
|
using var writer = GetMessageWriter();
|
|
|
|
writer.WriteMethodCallHeader(
|
|
destination: "org.freedesktop.DBus",
|
|
path: "/org/freedesktop/DBus",
|
|
@interface: "org.freedesktop.DBus",
|
|
member: "AddMatch",
|
|
signature: "s");
|
|
|
|
writer.WriteString(ruleString);
|
|
|
|
return writer.CreateMessage();
|
|
}
|
|
}
|
|
|
|
internal static readonly ObjectDisposedException ObserverDisposedException = new ObjectDisposedException(typeof(Observer).FullName);
|
|
|
|
sealed class Observer : IDisposable
|
|
{
|
|
private readonly object _gate = new object();
|
|
private readonly SynchronizationContext? _synchronizationContext;
|
|
private readonly MatchMaker _matchMaker;
|
|
private readonly MessageHandler4 _messageHandler;
|
|
private readonly ObserverFlags _flags;
|
|
private bool _disposed;
|
|
|
|
public bool Subscribes => (_flags & ObserverFlags.NoSubscribe) == 0;
|
|
public bool EmitOnConnectionDispose => (_flags & ObserverFlags.EmitOnConnectionDispose) != 0;
|
|
public bool EmitOnObserverDispose => (_flags & ObserverFlags.EmitOnObserverDispose) != 0;
|
|
|
|
public Observer(SynchronizationContext? synchronizationContext, MatchMaker matchMaker, in MessageHandler4 messageHandler, ObserverFlags flags)
|
|
{
|
|
_synchronizationContext = synchronizationContext;
|
|
_matchMaker = matchMaker;
|
|
_messageHandler = messageHandler;
|
|
_flags = flags;
|
|
}
|
|
|
|
public void Dispose() =>
|
|
Dispose(EmitOnObserverDispose ? ObserverDisposedException : null);
|
|
|
|
public void Dispose(Exception? exception, bool removeObserver = true)
|
|
{
|
|
lock (_gate)
|
|
{
|
|
if (_disposed)
|
|
{
|
|
return;
|
|
}
|
|
_disposed = true;
|
|
}
|
|
|
|
if (exception is not null)
|
|
{
|
|
Emit(exception);
|
|
}
|
|
|
|
if (removeObserver)
|
|
{
|
|
_matchMaker.Connection.RemoveObserver(_matchMaker, this);
|
|
}
|
|
}
|
|
|
|
public void Emit(Message message)
|
|
{
|
|
if (_synchronizationContext is null)
|
|
{
|
|
InvokeHandler(message);
|
|
}
|
|
else
|
|
{
|
|
_matchMaker.Connection.EmitOnSynchronizationContextHelper(this, _synchronizationContext, message);
|
|
}
|
|
}
|
|
|
|
private void Emit(Exception exception)
|
|
{
|
|
if (_synchronizationContext is null ||
|
|
SynchronizationContext.Current == _synchronizationContext)
|
|
{
|
|
_messageHandler.Invoke(exception, null!);
|
|
}
|
|
else
|
|
{
|
|
_synchronizationContext.Send(
|
|
delegate
|
|
{
|
|
SynchronizationContext? previousContext = SynchronizationContext.Current;
|
|
try
|
|
{
|
|
SynchronizationContext.SetSynchronizationContext(_synchronizationContext);
|
|
_messageHandler.Invoke(exception, null!);
|
|
}
|
|
finally
|
|
{
|
|
SynchronizationContext.SetSynchronizationContext(previousContext);
|
|
}
|
|
}, null);
|
|
}
|
|
}
|
|
|
|
internal void InvokeHandler(Message message)
|
|
{
|
|
if (Subscribes && !_matchMaker.HasSubscribed)
|
|
{
|
|
return;
|
|
}
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_disposed)
|
|
{
|
|
return;
|
|
}
|
|
|
|
_messageHandler.Invoke(null, message);
|
|
}
|
|
}
|
|
}
|
|
|
|
private async void RemoveObserver(MatchMaker matchMaker, Observer observer)
|
|
{
|
|
string ruleString = matchMaker.RuleString;
|
|
bool sendMessage = false;
|
|
|
|
lock (_gate)
|
|
{
|
|
if (_state == ConnectionState.Disconnected)
|
|
{
|
|
return;
|
|
}
|
|
|
|
if (_matchMakers.TryGetValue(ruleString, out _))
|
|
{
|
|
matchMaker.Observers.Remove(observer);
|
|
sendMessage = matchMaker.AddMatchTcs is not null && matchMaker.HasSubscribers;
|
|
if (sendMessage)
|
|
{
|
|
_matchMakers.Remove(ruleString);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (sendMessage)
|
|
{
|
|
var message = CreateRemoveMatchMessage();
|
|
if (!await _messageStream!.TrySendMessageAsync(message).ConfigureAwait(false))
|
|
{
|
|
message.ReturnToPool();
|
|
}
|
|
}
|
|
|
|
MessageBuffer CreateRemoveMatchMessage()
|
|
{
|
|
using var writer = GetMessageWriter();
|
|
|
|
writer.WriteMethodCallHeader(
|
|
destination: "org.freedesktop.DBus",
|
|
path: "/org/freedesktop/DBus",
|
|
@interface: "org.freedesktop.DBus",
|
|
member: "RemoveMatch",
|
|
signature: "s",
|
|
flags: MessageFlags.NoReplyExpected);
|
|
|
|
writer.WriteString(ruleString);
|
|
|
|
return writer.CreateMessage();
|
|
}
|
|
}
|
|
|
|
sealed class MatchMaker
|
|
{
|
|
private readonly MessageType? _type;
|
|
private readonly byte[]? _sender;
|
|
private readonly byte[]? _interface;
|
|
private readonly byte[]? _member;
|
|
private readonly byte[]? _path;
|
|
private readonly byte[]? _pathNamespace;
|
|
private readonly byte[]? _destination;
|
|
private readonly byte[]? _arg0;
|
|
private readonly byte[]? _arg0Path;
|
|
private readonly byte[]? _arg0Namespace;
|
|
private readonly string _rule;
|
|
|
|
private MyValueTaskSource<object?>? _vts;
|
|
|
|
public List<Observer> Observers { get; } = new();
|
|
|
|
public MyValueTaskSource<object?>? AddMatchTcs
|
|
{
|
|
get => _vts;
|
|
set
|
|
{
|
|
_vts = value;
|
|
if (value != null)
|
|
{
|
|
AddMatchTask = new ValueTask<object?>(value, token: 0).AsTask();
|
|
}
|
|
}
|
|
}
|
|
|
|
public Task<object?>? AddMatchTask { get; private set; }
|
|
|
|
public bool HasSubscribed { get; set; }
|
|
|
|
public DBusConnection Connection { get; }
|
|
|
|
public MatchMaker(DBusConnection connection, string rule, in MatchRuleData data)
|
|
{
|
|
Connection = connection;
|
|
_rule = rule;
|
|
|
|
_type = data.MessageType;
|
|
|
|
if (data.Sender is not null && data.Sender.StartsWith(":"))
|
|
{
|
|
_sender = Encoding.UTF8.GetBytes(data.Sender);
|
|
}
|
|
if (data.Interface is not null)
|
|
{
|
|
_interface = Encoding.UTF8.GetBytes(data.Interface);
|
|
}
|
|
if (data.Member is not null)
|
|
{
|
|
_member = Encoding.UTF8.GetBytes(data.Member);
|
|
}
|
|
if (data.Path is not null)
|
|
{
|
|
_path = Encoding.UTF8.GetBytes(data.Path);
|
|
}
|
|
if (data.PathNamespace is not null)
|
|
{
|
|
_pathNamespace = Encoding.UTF8.GetBytes(data.PathNamespace);
|
|
}
|
|
if (data.Destination is not null)
|
|
{
|
|
_destination = Encoding.UTF8.GetBytes(data.Destination);
|
|
}
|
|
if (data.Arg0 is not null)
|
|
{
|
|
_arg0 = Encoding.UTF8.GetBytes(data.Arg0);
|
|
}
|
|
if (data.Arg0Path is not null)
|
|
{
|
|
_arg0Path = Encoding.UTF8.GetBytes(data.Arg0Path);
|
|
}
|
|
if (data.Arg0Namespace is not null)
|
|
{
|
|
_arg0Namespace = Encoding.UTF8.GetBytes(data.Arg0Namespace);
|
|
}
|
|
}
|
|
|
|
public string RuleString => _rule;
|
|
|
|
public bool HasSubscribers
|
|
{
|
|
get
|
|
{
|
|
if (Observers.Count == 0)
|
|
{
|
|
return false;
|
|
}
|
|
foreach (var observer in Observers)
|
|
{
|
|
if (observer.Subscribes)
|
|
{
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public override string ToString() => _rule;
|
|
|
|
internal bool Matches(Message message)
|
|
{
|
|
if (_type.HasValue && _type != message.MessageType)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_sender is not null && !IsEqual(_sender, message.Sender))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_interface is not null && !IsEqual(_interface, message.Interface))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_member is not null && !IsEqual(_member, message.Member))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_path is not null && !IsEqual(_path, message.Path))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_destination is not null && !IsEqual(_destination, message.Destination))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_pathNamespace is not null && (!message.PathIsSet || !IsEqualOrChildOfPath(message.Path, _pathNamespace)))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_arg0Namespace is not null ||
|
|
_arg0 is not null ||
|
|
_arg0Path is not null)
|
|
{
|
|
if (message.Signature.Length == 0)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
DBusType arg0Type = (DBusType)message.Signature![0];
|
|
|
|
if (arg0Type != DBusType.String &&
|
|
arg0Type != DBusType.ObjectPath)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
ReadOnlySpan<byte> arg0 = message.GetBodyReader().ReadStringAsSpan();
|
|
|
|
if (_arg0Path is not null && !IsEqualParentOrChildOfPath(arg0, _arg0Path))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (arg0Type != DBusType.String)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_arg0 is not null && !IsEqual(_arg0, arg0))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (_arg0Namespace is not null && !IsEqualOrChildOfName(arg0, _arg0Namespace))
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
private static bool IsEqualOrChildOfName(ReadOnlySpan<byte> lhs, ReadOnlySpan<byte> rhs)
|
|
{
|
|
return lhs.StartsWith(rhs) && (lhs.Length == rhs.Length || lhs[rhs.Length] == '.');
|
|
}
|
|
|
|
private static bool IsEqualOrChildOfPath(ReadOnlySpan<byte> lhs, ReadOnlySpan<byte> rhs)
|
|
{
|
|
return lhs.StartsWith(rhs) && (lhs.Length == rhs.Length || lhs[rhs.Length] == '/');
|
|
}
|
|
|
|
private static bool IsEqualParentOrChildOfPath(ReadOnlySpan<byte> lhs, ReadOnlySpan<byte> rhs)
|
|
{
|
|
if (rhs.Length < lhs.Length)
|
|
{
|
|
return rhs[rhs.Length - 1] == '/' && lhs.StartsWith(rhs);
|
|
}
|
|
else if (lhs.Length < rhs.Length)
|
|
{
|
|
return lhs[lhs.Length - 1] == '/' && rhs.StartsWith(lhs);
|
|
}
|
|
else
|
|
{
|
|
return IsEqual(lhs, rhs);
|
|
}
|
|
}
|
|
|
|
private static bool IsEqual(ReadOnlySpan<byte> lhs, ReadOnlySpan<byte> rhs)
|
|
{
|
|
return lhs.SequenceEqual(rhs);
|
|
}
|
|
}
|
|
|
|
public MessageWriter GetMessageWriter() => _parentConnection.GetMessageWriter();
|
|
|
|
public async void SendMessage(MessageBuffer message)
|
|
{
|
|
bool messageSent = await _messageStream!.TrySendMessageAsync(message).ConfigureAwait(false);
|
|
if (!messageSent)
|
|
{
|
|
message.ReturnToPool();
|
|
}
|
|
}
|
|
|
|
public Task<Exception?> DisconnectedAsync()
|
|
{
|
|
lock (_gate)
|
|
{
|
|
if (_disconnectedTcs is null)
|
|
{
|
|
if (_state == ConnectionState.Disconnected)
|
|
{
|
|
return Task.FromResult(GetWaitForDisconnectException());
|
|
}
|
|
else
|
|
{
|
|
_disconnectedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
|
|
}
|
|
}
|
|
return _disconnectedTcs.Task;
|
|
}
|
|
}
|
|
|
|
private Exception? GetWaitForDisconnectException()
|
|
=> _disconnectReason is ObjectDisposedException ? null : _disconnectReason;
|
|
|
|
private void SendErrorReplyMessage(Message methodCall, string errorName, string errorMsg)
|
|
{
|
|
SendMessage(CreateErrorMessage(methodCall, errorName, errorMsg));
|
|
|
|
MessageBuffer CreateErrorMessage(Message methodCall, string errorName, string errorMsg)
|
|
{
|
|
using var writer = GetMessageWriter();
|
|
|
|
writer.WriteError(
|
|
replySerial: methodCall.Serial,
|
|
destination: methodCall.Sender,
|
|
errorName: errorName,
|
|
errorMsg: errorMsg);
|
|
|
|
return writer.CreateMessage();
|
|
}
|
|
}
|
|
}
|