Browse Source

Capture ExecutionContext for Dispatcher.InvokeAsync (#19163)

* Capture ExecutionContext for Dispatcher.InvokeAsync

(cherry picked from commit e41c9272ac45afc31007a8fe26f9f56291b063ef)

* Implement CulturePreservingExecutionContext

* Add IsFlowSuppressed checking

* Add NET6_0_OR_GREATER because only the Restore need it.

* Use `ExecutionContext.Run` instead of `ExecutionContext.Restore`.

* Pass this to avoid lambda capture.

* Use ExecutionContext directly on NET6_0_OR_GREATER

* on NET6_0_OR_GREATER, use Restore so we can get a simple stack trace.

* Add unit tests.

* All test code must run inside Task.Run to avoid interfering with the test

* First, test Task.Run to ensure that the preceding validation always passes, serving as a baseline for the subsequent Invoke/InvokeAsync tests.
This way, if a later test fails, we have the .NET framework's baseline behavior for reference.
pull/19469/head
walterlv 6 months ago
committed by GitHub
parent
commit
46d5a693f1
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 156
      src/Avalonia.Base/Threading/CulturePreservingExecutionContext.cs
  2. 48
      src/Avalonia.Base/Threading/DispatcherOperation.cs
  3. 251
      tests/Avalonia.Base.UnitTests/DispatcherTests.cs

156
src/Avalonia.Base/Threading/CulturePreservingExecutionContext.cs

@ -0,0 +1,156 @@
#if NET6_0_OR_GREATER
// In .NET Core, the security context and call context are not supported, however,
// the impersonation context and culture would typically flow with the execution context.
// See: https://learn.microsoft.com/en-us/dotnet/api/system.threading.executioncontext
//
// So we can safely use ExecutionContext without worrying about culture flowing issues.
#else
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Threading;
namespace Avalonia.Threading;
/// <summary>
/// An ExecutionContext that preserves culture information across async operations.
/// This is a modernized version that removes legacy compatibility switches and
/// includes nullable reference type annotations.
/// </summary>
internal sealed class CulturePreservingExecutionContext
{
private readonly ExecutionContext _context;
private CultureAndContext? _cultureAndContext;
private CulturePreservingExecutionContext(ExecutionContext context)
{
_context = context;
}
/// <summary>
/// Captures the current ExecutionContext and culture information.
/// </summary>
/// <returns>A new CulturePreservingExecutionContext instance, or null if no context needs to be captured.</returns>
public static CulturePreservingExecutionContext? Capture()
{
// ExecutionContext.SuppressFlow had been called.
// We expect ExecutionContext.Capture() to return null, so match that behavior and return null.
if (ExecutionContext.IsFlowSuppressed())
{
return null;
}
var context = ExecutionContext.Capture();
if (context == null)
return null;
return new CulturePreservingExecutionContext(context);
}
/// <summary>
/// Runs the specified callback in the captured execution context while preserving culture information.
/// This method is used for .NET Framework and earlier .NET versions.
/// </summary>
/// <param name="executionContext">The execution context to run in.</param>
/// <param name="callback">The callback to execute.</param>
/// <param name="state">The state to pass to the callback.</param>
public static void Run(CulturePreservingExecutionContext executionContext, ContextCallback callback, object? state)
{
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (callback == null)
return;
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (executionContext == null)
ThrowNullContext();
// Save culture information - we will need this to restore just before
// the callback is actually invoked from CallbackWrapper.
executionContext._cultureAndContext = CultureAndContext.Initialize(callback, state);
try
{
ExecutionContext.Run(
executionContext._context,
s_callbackWrapperDelegate,
executionContext._cultureAndContext);
}
finally
{
// Restore culture information - it might have been modified during callback execution.
executionContext._cultureAndContext.RestoreCultureInfos();
}
}
[DoesNotReturn]
private static void ThrowNullContext()
{
throw new InvalidOperationException("ExecutionContext cannot be null.");
}
private static readonly ContextCallback s_callbackWrapperDelegate = CallbackWrapper;
/// <summary>
/// Executes the callback and saves culture values immediately afterwards.
/// </summary>
/// <param name="obj">Contains the actual callback and state.</param>
private static void CallbackWrapper(object? obj)
{
var cultureAndContext = (CultureAndContext)obj!;
// Restore culture information saved during Run()
cultureAndContext.RestoreCultureInfos();
try
{
// Execute the actual callback
cultureAndContext.Callback(cultureAndContext.State);
}
finally
{
// Save any culture changes that might have occurred during callback execution
cultureAndContext.CaptureCultureInfos();
}
}
/// <summary>
/// Helper class to manage culture information across execution contexts.
/// </summary>
private sealed class CultureAndContext
{
public ContextCallback Callback { get; }
public object? State { get; }
private CultureInfo? _culture;
private CultureInfo? _uiCulture;
private CultureAndContext(ContextCallback callback, object? state)
{
Callback = callback;
State = state;
CaptureCultureInfos();
}
public static CultureAndContext Initialize(ContextCallback callback, object? state)
{
return new CultureAndContext(callback, state);
}
public void CaptureCultureInfos()
{
_culture = Thread.CurrentThread.CurrentCulture;
_uiCulture = Thread.CurrentThread.CurrentUICulture;
}
public void RestoreCultureInfos()
{
if (_culture != null)
Thread.CurrentThread.CurrentCulture = _culture;
if (_uiCulture != null)
Thread.CurrentThread.CurrentUICulture = _uiCulture;
}
}
}
#endif

48
src/Avalonia.Base/Threading/DispatcherOperation.cs

@ -5,6 +5,12 @@ using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
#if NET6_0_OR_GREATER
using ExecutionContext = System.Threading.ExecutionContext;
#else
using ExecutionContext = Avalonia.Threading.CulturePreservingExecutionContext;
#endif
namespace Avalonia.Threading;
[DebuggerDisplay("{DebugDisplay}")]
@ -28,18 +34,19 @@ public class DispatcherOperation
protected internal object? Callback;
protected object? TaskSource;
internal DispatcherOperation? SequentialPrev { get; set; }
internal DispatcherOperation? SequentialNext { get; set; }
internal DispatcherOperation? PriorityPrev { get; set; }
internal DispatcherOperation? PriorityNext { get; set; }
internal PriorityChain? Chain { get; set; }
internal bool IsQueued => Chain != null;
private EventHandler? _aborted;
private EventHandler? _completed;
private DispatcherPriority _priority;
private readonly ExecutionContext? _executionContext;
internal DispatcherOperation(Dispatcher dispatcher, DispatcherPriority priority, Action callback, bool throwOnUiThread) :
this(dispatcher, priority, throwOnUiThread)
@ -52,6 +59,7 @@ public class DispatcherOperation
ThrowOnUiThread = throwOnUiThread;
Priority = priority;
Dispatcher = dispatcher;
_executionContext = ExecutionContext.Capture();
}
internal string DebugDisplay
@ -103,7 +111,7 @@ public class DispatcherOperation
_completed += value;
}
}
remove
{
lock(Dispatcher.InstanceLock)
@ -112,7 +120,7 @@ public class DispatcherOperation
}
}
}
public bool Abort()
{
if (Dispatcher.Abort(this))
@ -155,7 +163,7 @@ public class DispatcherOperation
// we throw an exception instead.
throw new InvalidOperationException("A thread cannot wait on operations already running on the same thread.");
}
var cts = new CancellationTokenSource();
EventHandler finishedHandler = delegate
{
@ -241,7 +249,7 @@ public class DispatcherOperation
}
public Task GetTask() => GetTaskCore();
/// <summary>
/// Returns an awaiter for awaiting the completion of the operation.
/// </summary>
@ -259,21 +267,35 @@ public class DispatcherOperation
AbortTask();
_aborted?.Invoke(this, EventArgs.Empty);
}
internal void Execute()
{
Debug.Assert(Status == DispatcherOperationStatus.Executing);
try
{
using (AvaloniaSynchronizationContext.Ensure(Dispatcher, Priority))
InvokeCore();
{
if (_executionContext is { } executionContext)
{
#if NET6_0_OR_GREATER
ExecutionContext.Restore(executionContext);
InvokeCore();
#else
ExecutionContext.Run(executionContext, static s => ((DispatcherOperation)s!).InvokeCore(), this);
#endif
}
else
{
InvokeCore();
}
}
}
finally
{
_completed?.Invoke(this, EventArgs.Empty);
}
}
protected virtual void InvokeCore()
{
try
@ -305,7 +327,7 @@ public class DispatcherOperation
}
internal virtual object? GetResult() => null;
protected virtual void AbortTask()
{
object? taskSource;
@ -401,14 +423,14 @@ internal sealed class SendOrPostCallbackDispatcherOperation : DispatcherOperatio
{
private readonly object? _arg;
internal SendOrPostCallbackDispatcherOperation(Dispatcher dispatcher, DispatcherPriority priority,
SendOrPostCallback callback, object? arg, bool throwOnUiThread)
internal SendOrPostCallbackDispatcherOperation(Dispatcher dispatcher, DispatcherPriority priority,
SendOrPostCallback callback, object? arg, bool throwOnUiThread)
: base(dispatcher, priority, throwOnUiThread)
{
Callback = callback;
_arg = arg;
}
protected override void InvokeCore()
{
try

251
tests/Avalonia.Base.UnitTests/DispatcherTests.cs

@ -2,6 +2,7 @@ using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@ -28,7 +29,7 @@ public partial class DispatcherTests
public event Action Timer;
public long? NextTimer { get; private set; }
public bool AskedForSignal { get; private set; }
public void UpdateTimer(long? dueTimeInTicks)
{
NextTimer = dueTimeInTicks;
@ -79,16 +80,16 @@ public partial class DispatcherTests
ReadyForBackgroundProcessing?.Invoke();
}
}
class SimpleControlledDispatcherImpl : SimpleDispatcherWithBackgroundProcessingImpl, IControlledDispatcherImpl
{
private readonly bool _useTestTimeout = true;
private readonly CancellationToken? _cancel;
public int RunLoopCount { get; private set; }
public SimpleControlledDispatcherImpl()
{
}
public SimpleControlledDispatcherImpl(CancellationToken cancel, bool useTestTimeout = false)
@ -96,7 +97,7 @@ public partial class DispatcherTests
_useTestTimeout = useTestTimeout;
_cancel = cancel;
}
public void RunLoop(CancellationToken token)
{
RunLoopCount++;
@ -114,8 +115,8 @@ public partial class DispatcherTests
}
[Fact]
public void DispatcherExecutesJobsAccordingToPriority()
{
@ -129,7 +130,7 @@ public partial class DispatcherTests
impl.ExecuteSignal();
Assert.Equal(new[] { "Render", "Input", "Background" }, actions);
}
[Fact]
public void DispatcherPreservesOrderWhenChangingPriority()
{
@ -139,13 +140,13 @@ public partial class DispatcherTests
var toPromote = disp.InvokeAsync(()=>actions.Add("PromotedRender"), DispatcherPriority.Background);
var toPromote2 = disp.InvokeAsync(()=>actions.Add("PromotedRender2"), DispatcherPriority.Input);
disp.Post(() => actions.Add("Render"), DispatcherPriority.Render);
toPromote.Priority = DispatcherPriority.Render;
toPromote2.Priority = DispatcherPriority.Render;
Assert.True(impl.AskedForSignal);
impl.ExecuteSignal();
Assert.Equal(new[] { "PromotedRender", "PromotedRender2", "Render" }, actions);
}
@ -178,7 +179,7 @@ public partial class DispatcherTests
var expectedCount = (c + 1) * 3;
if (c == 3)
expectedCount = 10;
Assert.Equal(Enumerable.Range(0, expectedCount), actions);
Assert.False(impl.AskedForSignal);
if (c < 3)
@ -189,8 +190,8 @@ public partial class DispatcherTests
Assert.Null(impl.NextTimer);
}
}
[Fact]
public void DispatcherStopsItemProcessingWhenInputIsPending()
{
@ -225,7 +226,7 @@ public partial class DispatcherTests
3 => 10,
_ => throw new InvalidOperationException($"Unexpected value {c}")
};
Assert.Equal(Enumerable.Range(0, expectedCount), actions);
Assert.False(impl.AskedForSignal);
if (c < 3)
@ -255,7 +256,7 @@ public partial class DispatcherTests
foreground ? DispatcherPriority.Default : DispatcherPriority.Background).Wait();
Assert.True(finished);
if (controlled)
if (controlled)
Assert.Equal(foreground ? 0 : 1, ((SimpleControlledDispatcherImpl)impl).RunLoopCount);
}
@ -271,7 +272,7 @@ public partial class DispatcherTests
Dispatcher.ResetForUnitTests();
SynchronizationContext.SetSynchronizationContext(null);
}
public void Dispose()
{
Dispatcher.ResetForUnitTests();
@ -279,7 +280,7 @@ public partial class DispatcherTests
SynchronizationContext.SetSynchronizationContext(null);
}
}
[Fact]
public void ExitAllFramesShouldExitAllFramesAndBeAbleToContinue()
{
@ -301,10 +302,10 @@ public partial class DispatcherTests
disp.MainLoop(CancellationToken.None);
Assert.Equal(new[] { "Nested frame", "ExitAllFrames", "Nested frame exited" }, actions);
actions.Clear();
var secondLoop = new CancellationTokenSource();
disp.Post(() =>
{
@ -315,8 +316,8 @@ public partial class DispatcherTests
Assert.Equal(new[] { "Callback after exit" }, actions);
}
}
[Fact]
public void ShutdownShouldExitAllFramesAndNotAllowNewFrames()
{
@ -335,7 +336,7 @@ public partial class DispatcherTests
actions.Add("Shutdown");
disp.BeginInvokeShutdown(DispatcherPriority.Normal);
});
disp.Post(() =>
{
actions.Add("Nested frame after shutdown");
@ -343,12 +344,12 @@ public partial class DispatcherTests
Dispatcher.UIThread.MainLoop(CancellationToken.None);
actions.Add("Nested frame after shutdown exited");
});
var criticalFrameAfterShutdown = new DispatcherFrame(false);
disp.Post(() =>
{
actions.Add("Critical frame after shutdown");
Dispatcher.UIThread.PushFrame(criticalFrameAfterShutdown);
actions.Add("Critical frame after shutdown exited");
});
@ -362,7 +363,7 @@ public partial class DispatcherTests
Assert.Equal(new[]
{
"Nested frame",
"Nested frame",
"Shutdown",
// Normal nested frames are supposed to exit immediately
"Nested frame after shutdown", "Nested frame after shutdown exited",
@ -372,7 +373,7 @@ public partial class DispatcherTests
"Nested frame exited"
}, actions);
actions.Clear();
disp.Post(()=>actions.Add("Frame after shutdown finished"));
Assert.Throws<InvalidOperationException>(() => disp.MainLoop(CancellationToken.None));
Assert.Empty(actions);
@ -388,7 +389,7 @@ public partial class DispatcherTests
return base.Wait(waitHandles, waitAll, millisecondsTimeout);
}
}
[Fact]
public void DisableProcessingShouldStopProcessing()
{
@ -407,7 +408,7 @@ public partial class DispatcherTests
SynchronizationContext.SetSynchronizationContext(avaloniaContext);
var waitHandle = new ManualResetEvent(true);
helper.WaitCount = 0;
waitHandle.WaitOne(100);
Assert.Equal(0, helper.WaitCount);
@ -431,8 +432,8 @@ public partial class DispatcherTests
void DumpCurrentPriority() =>
priorities.Add(((AvaloniaSynchronizationContext)SynchronizationContext.Current!).Priority);
disp.Post(DumpCurrentPriority, DispatcherPriority.Normal);
disp.Post(DumpCurrentPriority, DispatcherPriority.Loaded);
disp.Post(DumpCurrentPriority, DispatcherPriority.Input);
@ -467,34 +468,34 @@ public partial class DispatcherTests
public void DispatcherInvokeAsyncUnwrapsTasks()
{
int asyncMethodStage = 0;
async Task AsyncMethod()
{
asyncMethodStage = 1;
await Task.Delay(200);
asyncMethodStage = 2;
}
async Task<int> AsyncMethodWithResult()
{
await Task.Delay(100);
return 1;
}
async Task Test()
{
await Dispatcher.UIThread.InvokeAsync(AsyncMethod);
Assert.Equal(2, asyncMethodStage);
Assert.Equal(1, await Dispatcher.UIThread.InvokeAsync(AsyncMethodWithResult));
asyncMethodStage = 0;
await Dispatcher.UIThread.InvokeAsync(AsyncMethod, DispatcherPriority.Default);
Assert.Equal(2, asyncMethodStage);
Assert.Equal(1, await Dispatcher.UIThread.InvokeAsync(AsyncMethodWithResult, DispatcherPriority.Default));
Dispatcher.UIThread.ExitAllFrames();
}
using (new DispatcherServices(new ManagedDispatcherImpl(null)))
{
var t = Test();
@ -505,8 +506,8 @@ public partial class DispatcherTests
t.GetAwaiter().GetResult();
}
}
[Fact]
public async Task DispatcherResumeContinuesOnUIThread()
{
@ -605,4 +606,176 @@ public partial class DispatcherTests
Dispatcher.UIThread.MainLoop(tokenSource.Token);
}
#nullable enable
private class AsyncLocalTestClass
{
public AsyncLocal<string?> AsyncLocalField { get; set; } = new AsyncLocal<string?>();
}
[Fact]
public async Task ExecutionContextIsPreservedInDispatcherInvokeAsync()
{
using var services = new DispatcherServices(new SimpleControlledDispatcherImpl());
var tokenSource = new CancellationTokenSource();
string? test1 = null;
string? test2 = null;
string? test3 = null;
// All test code must run inside Task.Run to avoid interfering with the test:
// 1. Prevent the execution context from being captured by MainLoop.
// 2. Prevent the execution context from remaining effective when set on the same thread.
var task = Task.Run(() =>
{
var testObject = new AsyncLocalTestClass();
// Test 1: Verify Task.Run preserves the execution context.
// First, test Task.Run to ensure that the preceding validation always passes, serving as a baseline for the subsequent Invoke/InvokeAsync tests.
// This way, if a later test fails, we have the .NET framework's baseline behavior for reference.
testObject.AsyncLocalField.Value = "Initial Value";
var task1 = Task.Run(() =>
{
test1 = testObject.AsyncLocalField.Value;
});
// Test 2: Verify Invoke preserves the execution context.
testObject.AsyncLocalField.Value = "Initial Value";
Dispatcher.UIThread.Invoke(() =>
{
test2 = testObject.AsyncLocalField.Value;
});
// Test 3: Verify InvokeAsync preserves the execution context.
testObject.AsyncLocalField.Value = "Initial Value";
_ = Dispatcher.UIThread.InvokeAsync(() =>
{
test3 = testObject.AsyncLocalField.Value;
});
_ = Dispatcher.UIThread.InvokeAsync(async () =>
{
await Task.WhenAll(task1);
tokenSource.Cancel();
});
});
Dispatcher.UIThread.MainLoop(tokenSource.Token);
await Task.WhenAll(task);
// Assertions
// Invoke(): Always passes because the context is not changed.
Assert.Equal("Initial Value", test1);
// Task.Run: Always passes (guaranteed by the .NET runtime).
Assert.Equal("Initial Value", test2);
// InvokeAsync: See https://github.com/AvaloniaUI/Avalonia/pull/19163
Assert.Equal("Initial Value", test3);
}
[Fact]
public async Task ExecutionContextIsNotPreservedAmongDispatcherInvokeAsync()
{
using var services = new DispatcherServices(new SimpleControlledDispatcherImpl());
var tokenSource = new CancellationTokenSource();
string? test = null;
// All test code must run inside Task.Run to avoid interfering with the test:
// 1. Prevent the execution context from being captured by MainLoop.
// 2. Prevent the execution context from remaining effective when set on the same thread.
var task = Task.Run(() =>
{
var testObject = new AsyncLocalTestClass();
// Test: Verify that InvokeAsync calls do not share execution context between each other.
_ = Dispatcher.UIThread.InvokeAsync(() =>
{
testObject.AsyncLocalField.Value = "Initial Value";
});
_ = Dispatcher.UIThread.InvokeAsync(() =>
{
test = testObject.AsyncLocalField.Value;
});
_ = Dispatcher.UIThread.InvokeAsync(() =>
{
tokenSource.Cancel();
});
});
Dispatcher.UIThread.MainLoop(tokenSource.Token);
await Task.WhenAll(task);
// Assertions
// The value should NOT flow between different InvokeAsync execution contexts.
Assert.Null(test);
}
[Fact]
public async Task ExecutionContextCultureInfoIsPreservedInDispatcherInvokeAsync()
{
using var services = new DispatcherServices(new SimpleControlledDispatcherImpl());
var tokenSource = new CancellationTokenSource();
string? test1 = null;
string? test2 = null;
string? test3 = null;
var oldCulture = Thread.CurrentThread.CurrentCulture;
// All test code must run inside Task.Run to avoid interfering with the test:
// 1. Prevent the execution context from being captured by MainLoop.
// 2. Prevent the execution context from remaining effective when set on the same thread.
var task = Task.Run(() =>
{
// This culture tag is Sumerian and is extremely unlikely to be set as the default on any device,
// ensuring that this test will not be affected by the user's environment.
Thread.CurrentThread.CurrentCulture = CultureInfo.GetCultureInfo("sux-Shaw-UM");
// Test 1: Verify Task.Run preserves the culture in the execution context.
// First, test Task.Run to ensure that the preceding validation always passes, serving as a baseline for the subsequent Invoke/InvokeAsync tests.
// This way, if a later test fails, we have the .NET framework's baseline behavior for reference.
var task1 = Task.Run(() =>
{
test1 = Thread.CurrentThread.CurrentCulture.Name;
});
// Test 2: Verify Invoke preserves the execution context.
Dispatcher.UIThread.Invoke(() =>
{
test2 = Thread.CurrentThread.CurrentCulture.Name;
});
// Test 3: Verify InvokeAsync preserves the culture in the execution context.
_ = Dispatcher.UIThread.InvokeAsync(() =>
{
test3 = Thread.CurrentThread.CurrentCulture.Name;
});
_ = Dispatcher.UIThread.InvokeAsync(async () =>
{
await Task.WhenAll(task1);
tokenSource.Cancel();
});
});
try
{
Dispatcher.UIThread.MainLoop(tokenSource.Token);
await Task.WhenAll(task);
// Assertions
// Invoke(): Always passes because the context is not changed.
Assert.Equal("sux-Shaw-UM", test1);
// Task.Run: Always passes (guaranteed by the .NET runtime).
Assert.Equal("sux-Shaw-UM", test2);
// InvokeAsync: See https://github.com/AvaloniaUI/Avalonia/pull/19163
Assert.Equal("sux-Shaw-UM", test3);
}
finally
{
Thread.CurrentThread.CurrentCulture = oldCulture;
// Ensure that this test does not have a negative impact on other tests.
Assert.NotEqual("sux-Shaw-UM", oldCulture.Name);
}
}
#nullable restore
}

Loading…
Cancel
Save