diff --git a/src/Orleans.Core.Abstractions/Core/IGrainContext.cs b/src/Orleans.Core.Abstractions/Core/IGrainContext.cs
index 3e78f7435f2..441ef98e7ee 100644
--- a/src/Orleans.Core.Abstractions/Core/IGrainContext.cs
+++ b/src/Orleans.Core.Abstractions/Core/IGrainContext.cs
@@ -204,7 +204,7 @@ public interface IWorkItemScheduler
///
/// The work item.
/// The state passed when invoking the item.
- void QueueAction(Action action, object state);
+ void QueueAction(Action action, object? state);
}
///
diff --git a/src/Orleans.Runtime/Activation/ActivationDataActivatorProvider.cs b/src/Orleans.Runtime/Activation/ActivationDataActivatorProvider.cs
index a25e1df70a1..b7c2dc794e4 100644
--- a/src/Orleans.Runtime/Activation/ActivationDataActivatorProvider.cs
+++ b/src/Orleans.Runtime/Activation/ActivationDataActivatorProvider.cs
@@ -1,4 +1,5 @@
using System.Diagnostics.CodeAnalysis;
+using System.Threading;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Orleans.Configuration;
@@ -61,7 +62,7 @@ private partial class ActivationDataActivator : IGrainContextActivator
private readonly IServiceProvider _serviceProvider;
private readonly GrainTypeSharedContext _sharedComponents;
private readonly Func _createWorkItemGroup;
- private readonly Action _startActivation;
+ private readonly SendOrPostCallback _startActivation;
public ActivationDataActivator(
IGrainActivator grainActivator,
@@ -82,7 +83,7 @@ public ActivationDataActivator(
_workItemGroupLogger,
_activationTaskSchedulerLogger,
_schedulingOptions);
- _startActivation = state => ((ActivationData)state!).Start(_grainActivator);
+ _startActivation = state => ((ActivationData)state!).Start(_grainActivator);
}
public IGrainContext CreateContext(GrainAddress activationAddress)
@@ -94,12 +95,7 @@ public IGrainContext CreateContext(GrainAddress activationAddress)
_sharedComponents);
using var ecSuppressor = ExecutionContext.SuppressFlow();
- _ = Task.Factory.StartNew(
- _startActivation,
- context,
- CancellationToken.None,
- TaskCreationOptions.DenyChildAttach,
- context.ActivationTaskScheduler);
+ context.WorkItemGroup.Post(_startActivation, context);
return context;
}
}
diff --git a/src/Orleans.Runtime/Catalog/ActivationData.cs b/src/Orleans.Runtime/Catalog/ActivationData.cs
index 02e7fa5a40b..f29223b0e00 100644
--- a/src/Orleans.Runtime/Catalog/ActivationData.cs
+++ b/src/Orleans.Runtime/Catalog/ActivationData.cs
@@ -45,7 +45,7 @@ internal sealed partial class ActivationData :
private readonly WorkItemGroup _workItemGroup;
private readonly List<(Message Message, CoarseStopwatch QueuedTime)> _waitingRequests = new();
private readonly Dictionary _runningRequests = new();
- private readonly SingleWaiterAutoResetEvent _workSignal = new() { RunContinuationsAsynchronously = true };
+ private readonly WorkItemGroupWaiter _workSignal;
private GrainLifecycle? _lifecycle;
private Queue? _pendingOperations;
private Message? _blockingRequest;
@@ -99,6 +99,7 @@ public ActivationData(
Debug.Assert(_serviceScope != null, "_serviceScope must not be null.");
_workItemGroup = createWorkItemGroup(this);
Debug.Assert(_workItemGroup != null, "_workItemGroup must not be null.");
+ _workSignal = new WorkItemGroupWaiter(_workItemGroup);
}
internal void SetActivationActivity(Activity activity)
@@ -140,6 +141,8 @@ public void Start(IGrainActivator grainActivator)
}
}
+ public WorkItemGroup WorkItemGroup => _workItemGroup;
+
public ActivationTaskScheduler ActivationTaskScheduler => _workItemGroup.TaskScheduler;
public IGrainRuntime GrainRuntime => _shared.Runtime;
public object? GrainInstance { get; private set; }
diff --git a/src/Orleans.Runtime/Scheduler/IWorkItem.cs b/src/Orleans.Runtime/Scheduler/IWorkItem.cs
index 4a3f7117c81..fae2b7a0f38 100644
--- a/src/Orleans.Runtime/Scheduler/IWorkItem.cs
+++ b/src/Orleans.Runtime/Scheduler/IWorkItem.cs
@@ -8,6 +8,6 @@ internal interface IWorkItem
IGrainContext GrainContext { get; }
void Execute();
- internal static readonly Action ExecuteWorkItem = state => ((IWorkItem)state).Execute();
+ internal static readonly Action ExecuteWorkItem = state => ((IWorkItem)state!).Execute();
}
}
diff --git a/src/Orleans.Runtime/Scheduler/TaskSchedulerUtils.cs b/src/Orleans.Runtime/Scheduler/TaskSchedulerUtils.cs
index eccab3a23f6..a6a6a58f831 100644
--- a/src/Orleans.Runtime/Scheduler/TaskSchedulerUtils.cs
+++ b/src/Orleans.Runtime/Scheduler/TaskSchedulerUtils.cs
@@ -1,9 +1,9 @@
+#nullable enable
using System;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Orleans.Runtime.Internal;
-#nullable disable
namespace Orleans.Runtime.Scheduler
{
internal static class TaskSchedulerUtils
@@ -18,7 +18,7 @@ public static void QueueAction(this ActivationTaskScheduler taskScheduler, Actio
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void QueueAction(this ActivationTaskScheduler taskScheduler, Action action, object state)
+ public static void QueueAction(this ActivationTaskScheduler taskScheduler, Action action, object? state)
{
using var suppressExecutionContext = new ExecutionContextSuppressor();
@@ -29,7 +29,7 @@ public static void QueueAction(this ActivationTaskScheduler taskScheduler, Actio
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void QueueWorkItem(this WorkItemGroup scheduler, IWorkItem workItem)
{
- QueueAction(scheduler.TaskScheduler, IWorkItem.ExecuteWorkItem, workItem);
+ scheduler.QueueAction(IWorkItem.ExecuteWorkItem, workItem);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
diff --git a/src/Orleans.Runtime/Scheduler/WorkItemGroup.cs b/src/Orleans.Runtime/Scheduler/WorkItemGroup.cs
index d8e9bda4da1..5de342d1a85 100644
--- a/src/Orleans.Runtime/Scheduler/WorkItemGroup.cs
+++ b/src/Orleans.Runtime/Scheduler/WorkItemGroup.cs
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
@@ -12,8 +13,46 @@
namespace Orleans.Runtime.Scheduler;
+internal readonly struct WorkItem
+{
+ public enum WorkItemType : byte
+ {
+ Task = 0,
+ SendOrPostCallback = 1,
+ ActionOfObject = 2
+ }
+
+ public readonly object Callback;
+ public readonly object? State;
+ public readonly WorkItemType Type;
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public WorkItem(Task task)
+ {
+ Callback = task;
+ State = null;
+ Type = WorkItemType.Task;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public WorkItem(SendOrPostCallback callback, object? state)
+ {
+ Callback = callback;
+ State = state;
+ Type = WorkItemType.SendOrPostCallback;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public WorkItem(Action callback, object? state)
+ {
+ Callback = callback;
+ State = state;
+ Type = WorkItemType.ActionOfObject;
+ }
+}
+
[DebuggerDisplay("WorkItemGroup Context={GrainContext} State={_state}")]
-internal sealed partial class WorkItemGroup : IThreadPoolWorkItem, IWorkItemScheduler
+internal sealed partial class WorkItemGroup : SynchronizationContext, IThreadPoolWorkItem, IWorkItemScheduler
{
private enum WorkGroupStatus : byte
{
@@ -28,7 +67,7 @@ private enum WorkGroupStatus : byte
#else
private readonly object _lockObj = new();
#endif
- private readonly Queue _workItems = new();
+ private readonly Queue _workItems = new();
private readonly SchedulingOptions _schedulingOptions;
private long _totalItemsEnqueued;
@@ -39,6 +78,9 @@ private enum WorkGroupStatus : byte
private Task? _currentTask;
private long _currentTaskStarted;
+ // Dummy task used to make TaskScheduler.Current return our scheduler
+ private readonly Task _schedulerTask;
+
internal ActivationTaskScheduler TaskScheduler { get; }
public IGrainContext GrainContext { get; set; }
@@ -60,6 +102,11 @@ public WorkItemGroup(
_state = WorkGroupStatus.Waiting;
_log = logger;
TaskScheduler = new ActivationTaskScheduler(this, activationTaskSchedulerLogger);
+
+ // Create a dummy task associated with our scheduler (never actually runs)
+ // We set m_taskScheduler directly so TaskScheduler.Current returns our scheduler
+ _schedulerTask = new Task(() => { }, TaskCreationOptions.None);
+ GetTaskSchedulerRef(_schedulerTask) = TaskScheduler;
}
///
@@ -76,12 +123,17 @@ public void EnqueueTask(Task task)
}
#endif
+ EnqueueWorkItem(new WorkItem(task));
+ }
+
+ private void EnqueueWorkItem(WorkItem workItem)
+ {
lock (_lockObj)
{
long thisSequenceNumber = _totalItemsEnqueued++;
int count = _workItems.Count;
- _workItems.Enqueue(task);
+ _workItems.Enqueue(workItem);
int maxPendingItemsLimit = _schedulingOptions.MaxPendingWorkItemsSoftLimit;
if (maxPendingItemsLimit > 0 && count > maxPendingItemsLimit)
{
@@ -103,7 +155,10 @@ public void EnqueueTask(Task task)
#if DEBUG
if (_log.IsEnabled(LogLevel.Trace))
{
- LogTraceAddToRunQueue(_log, task, thisSequenceNumber, GrainContext);
+ _log.LogTrace(
+ "Add to RunQueue #{SequenceNumber}, onto {GrainContext}",
+ thisSequenceNumber,
+ GrainContext);
}
#endif
ScheduleExecution(this);
@@ -121,9 +176,12 @@ private void LogTooManyTasksInQueue(int count, int maxPendingItemsLimit)
///
internal IEnumerable GetScheduledTasks()
{
- foreach (var task in _workItems)
+ lock (_lockObj)
{
- yield return task;
+ var tasks = _workItems
+ .Where(item => item.Type == WorkItem.WorkItemType.Task)
+ .Select(item => Unsafe.As(item.Callback));
+ return [.. tasks];
}
}
@@ -133,6 +191,11 @@ internal IEnumerable GetScheduledTasks()
public void Execute()
{
RuntimeContext.SetExecutionContext(GrainContext, out var originalContext);
+
+ // Set t_currentTask so TaskScheduler.Current returns our ActivationTaskScheduler
+ var previousTask = GetCurrentTask();
+ SetCurrentTask(_schedulerTask);
+
var turnWarningDurationMs = (long)Math.Ceiling(_schedulingOptions.TurnWarningLengthThreshold.TotalMilliseconds);
var activationSchedulingQuantumMs = (long)_schedulingOptions.ActivationSchedulingQuantum.TotalMilliseconds;
try
@@ -143,7 +206,7 @@ public void Execute()
loopStart = taskStart = taskEnd = Environment.TickCount64;
do
{
- Task task;
+ WorkItem workItem;
lock (_lockObj)
{
_state = WorkGroupStatus.Running;
@@ -151,7 +214,7 @@ public void Execute()
// Get the first Work Item on the list
if (_workItems.Count > 0)
{
- _currentTask = task = _workItems.Dequeue();
+ workItem = _workItems.Dequeue();
_currentTaskStarted = taskStart;
}
else
@@ -161,12 +224,27 @@ public void Execute()
}
}
-#if DEBUG
- LogTaskStart(task);
-#endif
try
{
- TaskScheduler.RunTaskFromWorkItemGroup(task);
+ switch (workItem.Type)
+ {
+ case WorkItem.WorkItemType.Task:
+ {
+ var task = Unsafe.As(workItem.Callback);
+ _currentTask = task;
+#if DEBUG
+ LogTaskStart(task);
+#endif
+ TaskScheduler.RunTaskFromWorkItemGroup(task);
+ }
+ break;
+ case WorkItem.WorkItemType.SendOrPostCallback:
+ Unsafe.As(workItem.Callback)(workItem.State);
+ break;
+ case WorkItem.WorkItemType.ActionOfObject:
+ Unsafe.As>(workItem.Callback)(workItem.State);
+ break;
+ }
}
finally
{
@@ -177,7 +255,10 @@ public void Execute()
if (taskDurationMs > turnWarningDurationMs)
{
SchedulerInstruments.LongRunningTurnsCounter.Add(1);
- LogLongRunningTurn(task, taskDurationMs);
+ if (workItem.Type == WorkItem.WorkItemType.Task)
+ {
+ LogLongRunningTurn(Unsafe.As(workItem.Callback), taskDurationMs);
+ }
}
_currentTask = null;
@@ -206,7 +287,7 @@ public void Execute()
_state = WorkGroupStatus.Waiting;
}
}
-
+ SetCurrentTask(previousTask);
RuntimeContext.ResetExecutionContext(originalContext);
}
}
@@ -281,8 +362,8 @@ public string DumpStatus()
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void ScheduleExecution(WorkItemGroup workItem) => ThreadPool.UnsafeQueueUserWorkItem(workItem, preferLocal: true);
- public void QueueAction(Action action) => TaskScheduler.QueueAction(action);
- public void QueueAction(Action action, object state) => TaskScheduler.QueueAction(action, state);
+ public void QueueAction(Action action) => EnqueueWorkItem(new WorkItem((Action)(static state => ((Action)state!)()), action));
+ public void QueueAction(Action action, object? state) => EnqueueWorkItem(new WorkItem(action, state));
public void QueueTask(Task task) => task.Start(TaskScheduler);
[LoggerMessage(
@@ -323,4 +404,46 @@ public string DumpStatus()
Message = "Task {Task} in WorkGroup {GrainContext} took elapsed time {Duration} for execution, which is longer than {TurnWarningLengthThreshold}. Running on thread {Thread}"
)]
private static partial void LogWarningLongRunningTurn(ILogger logger, object task, string grainContext, string duration, TimeSpan turnWarningLengthThreshold, string thread);
+
+ #region SynchronizationContext overrides
+
+ ///
+ /// Asynchronously posts a callback to be executed on this WorkItemGroup.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public override void Post(SendOrPostCallback d, object? state) => EnqueueWorkItem(new WorkItem(d, state));
+
+ ///
+ /// Synchronously sends a callback. Not supported - throws.
+ ///
+ public override void Send(SendOrPostCallback d, object? state) => throw new NotSupportedException();
+
+ ///
+ /// Creates a copy (returns same instance for single-threaded behavior).
+ ///
+ public override SynchronizationContext CreateCopy() => this;
+
+ #endregion
+
+ #region UnsafeAccessor methods for Task internals
+
+ ///
+ /// Gets a reference to the thread-static Task.t_currentTask field.
+ ///
+ [UnsafeAccessor(UnsafeAccessorKind.StaticField, Name = "t_currentTask")]
+ private static extern ref Task? GetCurrentTaskRef(Task? _);
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static Task? GetCurrentTask() => GetCurrentTaskRef(null);
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static void SetCurrentTask(Task? task) => GetCurrentTaskRef(null) = task;
+
+ ///
+ /// Sets the internal m_taskScheduler field on a Task.
+ ///
+ [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "m_taskScheduler")]
+ private static extern ref TaskScheduler? GetTaskSchedulerRef(Task task);
+
+ #endregion
}
diff --git a/src/Orleans.Runtime/Scheduler/WorkItemGroupWaiter.cs b/src/Orleans.Runtime/Scheduler/WorkItemGroupWaiter.cs
new file mode 100644
index 00000000000..d48d1b8aa81
--- /dev/null
+++ b/src/Orleans.Runtime/Scheduler/WorkItemGroupWaiter.cs
@@ -0,0 +1,181 @@
+#nullable enable
+
+using System;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Threading.Tasks.Sources;
+
+namespace Orleans.Runtime.Scheduler;
+
+///
+/// Represents a synchronization event that, when signaled, resets automatically after releasing a single waiter.
+/// This type supports concurrent signalers but only a single waiter.
+/// Continuations are always scheduled on the provided .
+///
+internal sealed class WorkItemGroupWaiter(WorkItemGroup workItemGroup) : IValueTaskSource
+{
+ // Signaled indicates that the event has been signaled and not yet reset.
+ private const uint SignaledFlag = 1;
+
+ // Waiting indicates that a waiter is present and waiting for the event to be signaled.
+ private const uint WaitingFlag = 1 << 1;
+
+ // ResetMask is used to clear both status flags.
+ private const uint ResetMask = ~SignaledFlag & ~WaitingFlag;
+
+ private static readonly Action Sentinel = static _ => Debug.Fail("The sentinel delegate should never be invoked.");
+
+ private readonly WorkItemGroup _workItemGroup = workItemGroup;
+
+ private Action? _continuation;
+ private object? _continuationState;
+ private volatile uint _status;
+
+ ValueTaskSourceStatus IValueTaskSource.GetStatus(short token)
+ {
+ // We only support success completion (no exception/cancellation paths)
+ return Volatile.Read(ref _continuation) is null ? ValueTaskSourceStatus.Pending : ValueTaskSourceStatus.Succeeded;
+ }
+
+ void IValueTaskSource.OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags)
+ {
+ if (continuation is null)
+ {
+ ThrowArgumentNullException();
+ }
+
+ // We ignore flags (FlowExecutionContext, UseSchedulingContext) because we always schedule on WorkItemGroup
+
+ // We need to set the continuation state before we swap in the delegate, so that
+ // if there's a race between this and Signal() and Signal() sees the _continuation
+ // as non-null, it'll be able to invoke it with the state stored here.
+ object? storedContinuation = _continuation;
+ if (storedContinuation is null)
+ {
+ _continuationState = state;
+ storedContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null);
+ if (storedContinuation is null)
+ {
+ // Operation hadn't already completed, so we're done. The continuation will be
+ // invoked when Signal is called at some later point.
+ return;
+ }
+ }
+
+ // Operation already completed, so we need to queue the supplied callback.
+ // At this point the storedContinuation should be the sentinel; if it's not, the instance was misused.
+ Debug.Assert(storedContinuation is not null);
+ Debug.Assert(ReferenceEquals(storedContinuation, Sentinel));
+
+ // Schedule the continuation on the WorkItemGroup
+ _workItemGroup.QueueAction(continuation, state);
+
+ [DoesNotReturn]
+ static void ThrowArgumentNullException() => throw new ArgumentNullException(nameof(continuation));
+ }
+
+ void IValueTaskSource.GetResult(short token)
+ {
+ // Reset the wait source.
+ Reset();
+
+ // Reset the status.
+ ResetStatus();
+ }
+
+ ///
+ /// Signal the waiter.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public void Signal()
+ {
+ if ((_status & SignaledFlag) == SignaledFlag)
+ {
+ // The event is already signaled.
+ return;
+ }
+
+ // Set the signaled flag.
+ var status = Interlocked.Or(ref _status, SignaledFlag);
+
+ // If there was a waiter and the signaled flag was unset, wake the waiter now.
+ if ((status & SignaledFlag) != SignaledFlag && (status & WaitingFlag) == WaitingFlag)
+ {
+ // Note that in this assert we are checking the volatile _status field.
+ // This is a sanity check to ensure that the signaling conditions are true:
+ // that "Signaled" and "Waiting" flags are both set.
+ Debug.Assert((_status & (SignaledFlag | WaitingFlag)) == (SignaledFlag | WaitingFlag));
+ SignalCompletion();
+ }
+ }
+
+ ///
+ /// Wait for the event to be signaled.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public ValueTask WaitAsync()
+ {
+ // Indicate that there is a waiter.
+ var status = Interlocked.Or(ref _status, WaitingFlag);
+
+ // If there was already a waiter, that is an error since this class is designed for use with a single waiter.
+ if ((status & WaitingFlag) == WaitingFlag)
+ {
+ ThrowConcurrentWaitersNotSupported();
+ }
+
+ // If the event was already signaled, immediately wake the waiter.
+ if ((status & SignaledFlag) == SignaledFlag)
+ {
+ // Reset just the status because the _continuation has not been set.
+ // We know that _continuation has not been set because it is only set when
+ // Signal() observes that the "Waiting" flag had been set but not the "Signaled" flag.
+ ResetStatus();
+ return default;
+ }
+
+ return new(this, 0);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private void Reset()
+ {
+ _continuation = null;
+ _continuationState = null;
+ }
+
+ private void SignalCompletion()
+ {
+ Action? continuation =
+ Volatile.Read(ref _continuation) ??
+ Interlocked.CompareExchange(ref _continuation, Sentinel, null);
+
+ if (continuation is not null)
+ {
+ Debug.Assert(continuation is not null);
+
+ // Always schedule on the WorkItemGroup
+ _workItemGroup.QueueAction(continuation, _continuationState);
+ }
+ }
+
+ ///
+ /// Called when a waiter handles the event signal.
+ ///
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private void ResetStatus()
+ {
+ // The event is being handled, so clear the "Signaled" flag now.
+ // The waiter is no longer waiting, so clear the "Waiting" flag, too.
+ var status = Interlocked.And(ref _status, ResetMask);
+
+ // If both the "Waiting" and "Signaled" flags were not already set, something has gone catastrophically wrong.
+ Debug.Assert((status & (WaitingFlag | SignaledFlag)) == (WaitingFlag | SignaledFlag));
+ }
+
+ [DoesNotReturn]
+ private static void ThrowConcurrentWaitersNotSupported() => throw new InvalidOperationException("Concurrent waiters are not supported");
+}