diff --git a/src/Headless/Avalonia.Headless/HeadlessUnitTestSession.cs b/src/Headless/Avalonia.Headless/HeadlessUnitTestSession.cs index 8dc9629b4c..1610f6796c 100644 --- a/src/Headless/Avalonia.Headless/HeadlessUnitTestSession.cs +++ b/src/Headless/Avalonia.Headless/HeadlessUnitTestSession.cs @@ -83,13 +83,14 @@ public sealed class HeadlessUnitTestSession : IDisposable using var application = EnsureApplication(); var cts = new CancellationTokenSource(); - using var globalCts = token.Register(s => ((CancellationTokenSource)s!).Cancel(), cts); - using var localCts = cancellationToken.Register(s => ((CancellationTokenSource)s!).Cancel(), cts); + using var globalCts = token.Register(s => ((CancellationTokenSource)s!).Cancel(), cts, true); + using var localCts = cancellationToken.Register(s => ((CancellationTokenSource)s!).Cancel(), cts, true); try { var task = action(); - task.ContinueWith((_, s) => ((CancellationTokenSource)s!).Cancel(), cts); + task.ContinueWith((_, s) => ((CancellationTokenSource)s!).Cancel(), cts, + TaskScheduler.FromCurrentSynchronizationContext()); if (cts.IsCancellationRequested) { @@ -97,7 +98,7 @@ public sealed class HeadlessUnitTestSession : IDisposable } var frame = new DispatcherFrame(); - using var innerCts = cts.Token.Register(() => frame.Continue = false); + using var innerCts = cts.Token.Register(() => frame.Continue = false, true); Dispatcher.UIThread.PushFrame(frame); var result = task.GetAwaiter().GetResult();