From 700a4e907e3444132314ac1f6120bb0a97aba0dc Mon Sep 17 00:00:00 2001 From: Reuben Bond Date: Wed, 29 Apr 2026 17:49:05 -0700 Subject: [PATCH 1/2] Add message ownership tracking with ref-counted pooling Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../Messaging/ClientMessageCenter.cs | 2 + src/Orleans.Core/Messaging/Message.cs | 105 +++++++++- src/Orleans.Core/Messaging/MessageFactory.cs | 46 ++--- src/Orleans.Core/Messaging/MessagePool.cs | 120 ++++++++++++ .../Messaging/MessageSerializer.cs | 3 +- src/Orleans.Core/Networking/Connection.cs | 17 ++ src/Orleans.Core/Runtime/CallbackData.cs | 12 ++ .../Runtime/InvokableObjectManager.cs | 5 + .../Runtime/OutsideRuntimeClient.cs | 5 + src/Orleans.Runtime/Catalog/ActivationData.cs | 8 + .../Catalog/StatelessWorkerGrainContext.cs | 1 + src/Orleans.Runtime/Core/HostedClient.cs | 1 + .../Core/InsideRuntimeClient.cs | 11 ++ src/Orleans.Runtime/Messaging/Gateway.cs | 1 + .../Messaging/MessageCenter.cs | 2 + .../Networking/GatewayInboundConnection.cs | 5 + .../Networking/SiloConnection.cs | 15 ++ .../Messaging/MessagePoolTests.cs | 181 ++++++++++++++++++ 18 files changed, 512 insertions(+), 28 deletions(-) create mode 100644 src/Orleans.Core/Messaging/MessagePool.cs create mode 100644 test/Orleans.Core.Tests/Messaging/MessagePoolTests.cs diff --git a/src/Orleans.Core/Messaging/ClientMessageCenter.cs b/src/Orleans.Core/Messaging/ClientMessageCenter.cs index 589017ef06b..7f132d56c69 100644 --- a/src/Orleans.Core/Messaging/ClientMessageCenter.cs +++ b/src/Orleans.Core/Messaging/ClientMessageCenter.cs @@ -370,6 +370,7 @@ public void RejectMessage(Message msg, string reason, Exception exc = null) if (msg.Direction != Message.Directions.Request) { LogDroppingMessage(msg, reason); + msg.ReleaseDropped("DroppedNonRequest"); } else { @@ -377,6 +378,7 @@ public void RejectMessage(Message msg, string reason, Exception exc = null) MessagingInstruments.OnRejectedMessage(msg); var error = this.messageFactory.CreateRejectionResponse(msg, Message.RejectionTypes.Unrecoverable, reason, exc); DispatchLocalMessage(error); + msg.ReleaseDropped("RejectedRequest"); } } diff --git a/src/Orleans.Core/Messaging/Message.cs b/src/Orleans.Core/Messaging/Message.cs index 16e7f67bc41..2dc4e06c7b9 100644 --- a/src/Orleans.Core/Messaging/Message.cs +++ b/src/Orleans.Core/Messaging/Message.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Threading; @@ -15,6 +16,14 @@ internal sealed class Message : ISpanFormattable [NonSerialized] private short _retryCount; + [NonSerialized] + private int _refCount; + +#if DEBUG + [NonSerialized] + private string? _lastTransferTag; +#endif + public CoarseStopwatch _timeToExpiry; public object? BodyObject { get; set; } @@ -303,6 +312,73 @@ internal void AddToCacheInvalidationHeader(GrainAddress invalidAddress, GrainAdd } } + internal void InitializeRefCount() + { + // Messages are acquired once when checked out from the pool. + // Additional owners must call Acquire() and Release(). + _refCount = 1; +#if DEBUG + _lastTransferTag = null; +#endif + } + + internal void Acquire() + { + var newRefCount = Interlocked.Increment(ref _refCount); + Debug.Assert(newRefCount > 1); + } + + internal void Release() + { + var newRefCount = Interlocked.Decrement(ref _refCount); + if (newRefCount == 0) + { + MessagePool.ReturnCore(this); + } + else if (newRefCount < 0) + { + // Ref count should never go negative - indicates a double release. +#if DEBUG + Debug.Fail($"Message ref count went negative. Last transfer tag: '{_lastTransferTag}'"); +#else + Debug.Fail("Message ref count went negative."); +#endif + } + } + + [Conditional("DEBUG")] + internal void MarkTransferred(string tag) + { +#if DEBUG + _lastTransferTag = tag; +#endif + } + + /// + /// Releases this message after it has been dropped (expired, rejected, blocked, etc). + /// Marks the transfer for debugging and releases the reference. + /// + /// A short description of why the message was dropped. + internal void ReleaseDropped(string reason) + { + MarkTransferred($"Dropped:{reason}"); + Release(); + } + + /// + /// Asserts that this message has not been released (refcount > 0). + /// Only executes in DEBUG builds. + /// + [Conditional("DEBUG")] + internal void AssertNotReleased([System.Runtime.CompilerServices.CallerMemberName] string? caller = null) + { +#if DEBUG + var currentRefCount = Volatile.Read(ref _refCount); + Debug.Assert(currentRefCount > 0, + $"Message used after release. Caller: {caller}, RefCount: {currentRefCount}, LastTransfer: {_lastTransferTag}"); +#endif + } + public override string ToString() => $"{this}"; string IFormattable.ToString(string? format, IFormatProvider? formatProvider) => ToString(); @@ -371,6 +447,31 @@ static bool Append(ref Span dst, ReadOnlySpan value) internal bool IsPing() => _requestContextData?.TryGetValue(RequestContext.PING_APPLICATION_HEADER, out var value) == true && value is bool isPing && isPing; + /// + /// Resets the message to its default state for reuse. + /// + internal void Reset() + { + _retryCount = 0; + _timeToExpiry = default; + BodyObject = null; + _headers = default; + _id = default; + _refCount = 0; +#if DEBUG + _lastTransferTag = null; +#endif + + _requestContextData = null; + _targetSilo = null; + _targetGrain = default; + _sendingSilo = null; + _sendingGrain = default; + _interfaceVersion = 0; + _interfaceType = default; + _cacheInvalidationHeader = null; + } + [Flags] internal enum MessageFlags : ushort { @@ -386,10 +487,10 @@ internal enum MessageFlags : ushort HasTimeToLive = 1 << 8, // Message cannot be forwarded to another activation. - IsLocalOnly = 1 << 9, + IsLocalOnly = 1 << 9, // Message must not trigger grain activation or extend an activation's lifetime. - SuppressKeepAlive = 1 << 10, + SuppressKeepAlive = 1 << 10, // The most significant bit is reserved, possibly for use to indicate more data follows. Reserved = 1 << 15, diff --git a/src/Orleans.Core/Messaging/MessageFactory.cs b/src/Orleans.Core/Messaging/MessageFactory.cs index e41bc632815..ec4d8755849 100644 --- a/src/Orleans.Core/Messaging/MessageFactory.cs +++ b/src/Orleans.Core/Messaging/MessageFactory.cs @@ -32,16 +32,14 @@ public MessageFactory(DeepCopier deepCopier, ILogger logger, Mes public Message CreateMessage(object body, InvokeMethodOptions options) { - var message = new Message - { - Direction = (options & InvokeMethodOptions.OneWay) != 0 ? Message.Directions.OneWay : Message.Directions.Request, - Id = GetNextCorrelationId(), - IsReadOnly = (options & InvokeMethodOptions.ReadOnly) != 0, - IsUnordered = (options & InvokeMethodOptions.Unordered) != 0, - IsAlwaysInterleave = (options & InvokeMethodOptions.AlwaysInterleave) != 0, - BodyObject = body, - RequestContextData = RequestContextExtensions.Export(_deepCopier), - }; + var message = MessagePool.Get(); + message.Direction = (options & InvokeMethodOptions.OneWay) != 0 ? Message.Directions.OneWay : Message.Directions.Request; + message.Id = GetNextCorrelationId(); + message.IsReadOnly = (options & InvokeMethodOptions.ReadOnly) != 0; + message.IsUnordered = (options & InvokeMethodOptions.Unordered) != 0; + message.IsAlwaysInterleave = (options & InvokeMethodOptions.AlwaysInterleave) != 0; + message.BodyObject = body; + message.RequestContextData = RequestContextExtensions.Export(_deepCopier); _messagingTrace.OnCreateMessage(message); return message; @@ -55,21 +53,19 @@ private CorrelationId GetNextCorrelationId() public Message CreateResponseMessage(Message request) { - var response = new Message - { - IsSystemMessage = request.IsSystemMessage, - Direction = Message.Directions.Response, - Id = request.Id, - IsReadOnly = request.IsReadOnly, - IsAlwaysInterleave = request.IsAlwaysInterleave, - TargetSilo = request.SendingSilo, - TargetGrain = request.SendingGrain, - SendingSilo = request.TargetSilo, - SendingGrain = request.TargetGrain, - CacheInvalidationHeader = request.CacheInvalidationHeader, - TimeToLive = request.TimeToLive, - RequestContextData = RequestContextExtensions.Export(_deepCopier), - }; + var response = MessagePool.Get(); + response.IsSystemMessage = request.IsSystemMessage; + response.Direction = Message.Directions.Response; + response.Id = request.Id; + response.IsReadOnly = request.IsReadOnly; + response.IsAlwaysInterleave = request.IsAlwaysInterleave; + response.TargetSilo = request.SendingSilo; + response.TargetGrain = request.SendingGrain; + response.SendingSilo = request.TargetSilo; + response.SendingGrain = request.TargetGrain; + response.CacheInvalidationHeader = request.CacheInvalidationHeader; + response.TimeToLive = request.TimeToLive; + response.RequestContextData = RequestContextExtensions.Export(_deepCopier); _messagingTrace.OnCreateMessage(response); return response; diff --git a/src/Orleans.Core/Messaging/MessagePool.cs b/src/Orleans.Core/Messaging/MessagePool.cs new file mode 100644 index 00000000000..6a75141a1f5 --- /dev/null +++ b/src/Orleans.Core/Messaging/MessagePool.cs @@ -0,0 +1,120 @@ +#nullable enable +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; + +namespace Orleans.Runtime +{ + /// + /// A thread-local object pool for instances. + /// + internal static class MessagePool + { + private static readonly ThreadLocal> _messages = new(() => new()); + +#if DEBUG + /// + /// Tracks all messages that have been allocated but not returned to the pool. + /// Only available in DEBUG builds. Must be enabled via . + /// + private static readonly ConcurrentDictionary _outstandingMessages = new(); + + /// + /// When true, tracks all message allocations for leak detection. + /// Only available in DEBUG builds. + /// + public static bool EnableLeakTracking { get; set; } + + /// + /// Gets all messages that have been allocated but not returned to the pool. + /// Only available in DEBUG builds and when is true. + /// + public static IReadOnlyCollection GetOutstandingMessages() + { + return _outstandingMessages.Values.ToArray(); + } + + /// + /// Clears the outstanding messages tracking. Call this at the start of a test. + /// + public static void ClearLeakTracking() + { + _outstandingMessages.Clear(); + } + + /// + /// Information about a message allocation for leak tracking. + /// + public sealed class MessageAllocationInfo + { + public Message Message { get; } + public string AllocationStack { get; } + public DateTime AllocationTime { get; } + + public MessageAllocationInfo(Message message, string allocationStack) + { + Message = message; + AllocationStack = allocationStack; + AllocationTime = DateTime.UtcNow; + } + + public override string ToString() => + $"Message allocated at {AllocationTime:HH:mm:ss.fff}, Direction={Message.Direction}, Id={Message.Id}\nStack:\n{AllocationStack}"; + } +#endif + + /// + /// The maximum number of messages to keep per thread. + /// + public static int MaxPoolSizePerThread { get; set; } = 128; + + /// + /// Gets a message from the pool, or creates a new one if the pool is empty. + /// + public static Message Get() + { + var stack = _messages.Value!; + if (!stack.TryPop(out var message)) + { + message = new Message(); + } + + message.InitializeRefCount(); + +#if DEBUG + if (EnableLeakTracking) + { + var info = new MessageAllocationInfo(message, Environment.StackTrace); + _outstandingMessages[message] = info; + } +#endif + + return message; + } + + /// + /// Returns a message to the pool after resetting it. + /// + public static void Return(Message message) => message.Release(); + + internal static void ReturnCore(Message message) + { +#if DEBUG + if (EnableLeakTracking) + { + _outstandingMessages.TryRemove(message, out _); + } +#endif + + message.Reset(); + + var stack = _messages.Value!; + if (stack.Count < MaxPoolSizePerThread) + { + stack.Push(message); + } + } + } +} diff --git a/src/Orleans.Core/Messaging/MessageSerializer.cs b/src/Orleans.Core/Messaging/MessageSerializer.cs index 76a41537569..04f24a11df2 100644 --- a/src/Orleans.Core/Messaging/MessageSerializer.cs +++ b/src/Orleans.Core/Messaging/MessageSerializer.cs @@ -85,7 +85,8 @@ public MessageSerializer( var body = input.Slice(bodyOffset, bodyLength); // Build message - message = new(); + message = MessagePool.Get(); + if (header.IsSingleSegment) { var headersReader = Reader.Create(header.First.Span, _deserializationSession); diff --git a/src/Orleans.Core/Networking/Connection.cs b/src/Orleans.Core/Networking/Connection.cs index f2f0bc728de..e6a6e2b6603 100644 --- a/src/Orleans.Core/Networking/Connection.cs +++ b/src/Orleans.Core/Networking/Connection.cs @@ -374,6 +374,12 @@ private async Task ProcessOutgoing() { throw; } + + if (message is not null) + { + inflight.Remove(message); + message = null; + } } var flushResult = await output.FlushAsync(); @@ -382,6 +388,13 @@ private async Task ProcessOutgoing() break; } + // Release the send pipeline's reference after bytes have been flushed. + foreach (var msg in inflight) + { + msg.MarkTransferred("Connection.ProcessOutgoing:Sent"); + msg.Release(); + } + inflight.Clear(); } } @@ -490,6 +503,8 @@ private bool HandleSendMessageFailure(Message message, Exception exception) response.BodyObject = Response.FromException(exception); this.MessageCenter.DispatchLocalMessage(response); + message.MarkTransferred("Connection.HandleSendMessageFailure:RequestFailed"); + message.Release(); } else if (message.Direction == Message.Directions.Response && message.RetryCount < MessagingOptions.DEFAULT_MAX_MESSAGE_SEND_RETRIES) { @@ -509,6 +524,8 @@ private bool HandleSendMessageFailure(Message message, Exception exception) message); MessagingInstruments.OnDroppedSentMessage(message); + message.MarkTransferred("Connection.HandleSendMessageFailure:Dropped"); + message.Release(); } return true; diff --git a/src/Orleans.Core/Runtime/CallbackData.cs b/src/Orleans.Core/Runtime/CallbackData.cs index 98536d69ce7..338ed0011a9 100644 --- a/src/Orleans.Core/Runtime/CallbackData.cs +++ b/src/Orleans.Core/Runtime/CallbackData.cs @@ -24,6 +24,8 @@ public CallbackData( { this.shared = shared; this.context = ctx; + // CallbackData holds a reference to the request message while awaiting completion. + msg.Acquire(); this.Message = msg; _applicationRequestInstruments = applicationRequestInstruments; this.stopwatch = ValueStopwatch.StartNew(); @@ -108,6 +110,7 @@ private void OnCancellation() OrleansCallBackDataEvent.Instance.OnCanceled(Message); context.Complete(Response.FromException(new OperationCanceledException(_cancellationTokenRegistration.Token))); _cancellationTokenRegistration.Dispose(); + ReleaseRequest("CallbackData.OnCancellation"); } public void OnTimeout() @@ -138,6 +141,7 @@ public void OnTimeout() var exception = new TimeoutException($"Response did not arrive on time in {timeout} for message: {msg}. {statusMessage}"); context.Complete(Response.FromException(exception)); + ReleaseRequest("CallbackData.OnTimeout"); } public void OnTargetSiloFail() @@ -158,6 +162,7 @@ public void OnTargetSiloFail() LogTargetSiloFail(this.shared.Logger, msg, statusMessage, Constants.TroubleshootingHelpLink); var exception = new SiloUnavailableException($"The target silo became unavailable for message: {msg}. {statusMessage}See {Constants.TroubleshootingHelpLink} for troubleshooting help."); this.context.Complete(Response.FromException(exception)); + ReleaseRequest("CallbackData.OnTargetSiloFail"); } public void DoCallback(Message response) @@ -175,6 +180,13 @@ public void DoCallback(Message response) // do callback outside the CallbackData lock. Just not a good practice to hold a lock for this unrelated operation. ResponseCallback(response, this.context); + ReleaseRequest("CallbackData.DoCallback"); + } + + private void ReleaseRequest(string tag) + { + Message.MarkTransferred($"{tag}:ReleaseRequest"); + Message.Release(); } private static void ResponseCallback(Message message, IResponseCompletionSource context) diff --git a/src/Orleans.Core/Runtime/InvokableObjectManager.cs b/src/Orleans.Core/Runtime/InvokableObjectManager.cs index 3bed78f414b..06c42f261c9 100644 --- a/src/Orleans.Core/Runtime/InvokableObjectManager.cs +++ b/src/Orleans.Core/Runtime/InvokableObjectManager.cs @@ -63,6 +63,7 @@ public void Dispatch(Message message) if (!ObserverGrainId.TryParse(message.TargetGrain, out var observerId)) { LogNotAddressedToAnObserver(logger, message); + message.ReleaseDropped("NotAddressedToObserver"); return; } @@ -73,6 +74,7 @@ public void Dispatch(Message message) else { LogUnexpectedTargetInRequest(logger, message.TargetGrain, message); + message.ReleaseDropped("ObserverNotFound"); } } @@ -159,6 +161,7 @@ public void ReceiveMessage(object msg) LogObserverGarbageCollected(_manager.logger, this.ObserverId, message); // Try to remove. If it's not there, we don't care. _manager.TryDeregister(this.ObserverId); + message.ReleaseDropped("ObserverGarbageCollected"); return; } @@ -258,6 +261,7 @@ private async Task ProcessMessageAsync(Message message) if (message.IsExpired) { _manager.messagingTrace.OnDropExpiredMessage(message, MessagingInstruments.Phase.Invoke); + message.ReleaseDropped("ExpiredAtInvoke"); return; } @@ -333,6 +337,7 @@ private void SendResponseAsync(Message message, Response resultObject) if (message.IsExpired) { _manager.messagingTrace.OnDropExpiredMessage(message, MessagingInstruments.Phase.Respond); + message.ReleaseDropped("ExpiredAtRespond"); return; } diff --git a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs index c6d6a0cc4a1..888f330fba1 100644 --- a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs +++ b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs @@ -336,6 +336,8 @@ public void ReceiveResponse(Message response) } } + // Release the status response message - it's been fully processed + response.ReleaseDropped("StatusResponseHandled"); return; } @@ -347,10 +349,13 @@ public void ReceiveResponse(Message response) // Unfortunately, it is not enough, since CallContext.LogicalGetData will not flow "up" from task completion source into the resolved task. // RequestContextExtensions.Import(response.RequestContextData); callbackData.DoCallback(response); + response.MarkTransferred("OutsideRuntimeClient.ReceiveResponse:AfterDoCallback"); + response.Release(); } else { LogDebugNoCallbackForResponseMessage(logger, response); + response.ReleaseDropped("NoCallbackNotFound"); } } diff --git a/src/Orleans.Runtime/Catalog/ActivationData.cs b/src/Orleans.Runtime/Catalog/ActivationData.cs index 02e7fa5a40b..066b462bdcd 100644 --- a/src/Orleans.Runtime/Catalog/ActivationData.cs +++ b/src/Orleans.Runtime/Catalog/ActivationData.cs @@ -1088,6 +1088,7 @@ void ProcessPendingRequests() _shared.InternalRuntime.MessageCenter.RejectMessage(message, Message.RejectionTypes.Transient, exception); } + message.ReleaseDropped("SchedulingException"); _waitingRequests.RemoveAt(i); continue; } @@ -1496,6 +1497,11 @@ private void OnCompletedRequest(Message message) // Signal the message pump to see if there is another request which can be processed now that this one has completed _workSignal.Signal(); + + // Release the message - for local messages, CallbackData still holds a ref so it won't return to pool yet. + // For remote messages, this is the terminal owner so it returns to pool. + message.MarkTransferred("ActivationData.OnCompletedRequest"); + message.Release(); } public void ReceiveMessage(object message) => ReceiveMessage((Message)message); @@ -1508,6 +1514,7 @@ public void ReceiveMessage(Message message) { MessagingProcessingInstruments.OnDispatcherMessageProcessedError(message); _shared.InternalRuntime.MessagingTrace.OnDropExpiredMessage(message, MessagingInstruments.Phase.Dispatch); + message.ReleaseDropped("ExpiredAtDispatch"); return; } @@ -1546,6 +1553,7 @@ private void ReceiveRequest(Message message) { MessagingProcessingInstruments.OnDispatcherMessageProcessedError(message); _shared.InternalRuntime.MessageCenter.RejectMessage(message, Message.RejectionTypes.Overloaded, overloadException, "Target activation is overloaded " + this); + message.ReleaseDropped("RejectedOverload"); return; } diff --git a/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs b/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs index 197167b84d4..35606a6e3cd 100644 --- a/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs +++ b/src/Orleans.Runtime/Catalog/StatelessWorkerGrainContext.cs @@ -298,6 +298,7 @@ private void ReceiveMessageInternal(object message) Message.RejectionTypes.Transient, exception, "Exception while creating grain context"); + msg.ReleaseDropped("ExceptionCreatingContext"); } } diff --git a/src/Orleans.Runtime/Core/HostedClient.cs b/src/Orleans.Runtime/Core/HostedClient.cs index a9d0e6a90e4..a4cee84dd59 100644 --- a/src/Orleans.Runtime/Core/HostedClient.cs +++ b/src/Orleans.Runtime/Core/HostedClient.cs @@ -194,6 +194,7 @@ public bool TryDispatchToClient(Message message) if (message.IsExpired) { this.messagingTrace.OnDropExpiredMessage(message, MessagingInstruments.Phase.Receive); + message.ReleaseDropped("ExpiredAtDispatch"); return true; } diff --git a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs index 29e4abaac67..4bc60d5a07a 100644 --- a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs +++ b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs @@ -198,6 +198,8 @@ public void SendResponse(Message request, Response response) if (request.IsExpired) { this.messagingTrace.OnDropExpiredMessage(request, MessagingInstruments.Phase.Respond); + // Note: We don't release here because the request message is still owned by the activation. + // It will be released in ActivationData.OnCompletedRequest when invoke completes. return; } @@ -264,6 +266,8 @@ public async Task Invoke(IGrainContext target, Message message) if (message.IsExpired) { this.messagingTrace.OnDropExpiredMessage(message, MessagingInstruments.Phase.Invoke); + // Note: We don't release here because the message is still owned by the activation. + // It will be released in ActivationData.OnCompletedRequest when this method returns. return; } @@ -413,6 +417,7 @@ public void ReceiveResponse(Message message) break; case Message.RejectionTypes.CacheInvalidation when message.HasCacheInvalidationHeader: // The message targeted an invalid (eg, defunct) activation and this response serves only to invalidate this silo's activation cache. + message.ReleaseDropped("CacheInvalidationResponse"); return; default: LogErrorUnsupportedRejectionType(this.logger, rejection.RejectionType); @@ -452,6 +457,7 @@ public void ReceiveResponse(Message message) } } + message.ReleaseDropped("StatusResponseHandled"); return; } @@ -462,11 +468,16 @@ public void ReceiveResponse(Message message) // IMPORTANT: we do not schedule the response callback via the scheduler, since the only thing it does // is to resolve/break the resolver. The continuations/waits that are based on this resolution will be scheduled as work items. callbackData.DoCallback(message); + message.MarkTransferred("InsideRuntimeClient.ReceiveResponse:AfterDoCallback"); + message.Release(); } else { LogDebugNoCallbackForResponse(this.logger, message); + message.MarkTransferred("InsideRuntimeClient.ReceiveResponse:NoCallbackNotFound"); + message.Release(); } + } public string CurrentActivationIdentity => RuntimeContext.Current?.Address.ToString() ?? this.HostedClient.ToString(); diff --git a/src/Orleans.Runtime/Messaging/Gateway.cs b/src/Orleans.Runtime/Messaging/Gateway.cs index c0b0176b8f6..e3d34eebbcf 100644 --- a/src/Orleans.Runtime/Messaging/Gateway.cs +++ b/src/Orleans.Runtime/Messaging/Gateway.cs @@ -410,6 +410,7 @@ private void RejectDroppedClientMessages() { exception ??= new ClientNotAvailableException(Id.GrainId); _gateway.messageCenter.RejectMessage(message, Message.RejectionTypes.Transient, exc: exception, rejectInfo: "Client dropped"); + message.ReleaseDropped("ClientDropped"); } } diff --git a/src/Orleans.Runtime/Messaging/MessageCenter.cs b/src/Orleans.Runtime/Messaging/MessageCenter.cs index cc3a7a0b65b..105a90f305f 100644 --- a/src/Orleans.Runtime/Messaging/MessageCenter.cs +++ b/src/Orleans.Runtime/Messaging/MessageCenter.cs @@ -143,6 +143,7 @@ public void SendMessage(Message msg) { // Drop the message on the floor if it's an application message that isn't a rejection this.messagingTrace.OnDropBlockedApplicationMessage(msg); + msg.ReleaseDropped("BlockedApplicationMessage"); } else { @@ -159,6 +160,7 @@ public void SendMessage(Message msg) if (msg.IsExpired) { this.messagingTrace.OnDropExpiredMessage(msg, MessagingInstruments.Phase.Send); + msg.ReleaseDropped("ExpiredAtSend"); return; } diff --git a/src/Orleans.Runtime/Networking/GatewayInboundConnection.cs b/src/Orleans.Runtime/Networking/GatewayInboundConnection.cs index 5a4166aaac4..679a0f1793e 100644 --- a/src/Orleans.Runtime/Networking/GatewayInboundConnection.cs +++ b/src/Orleans.Runtime/Networking/GatewayInboundConnection.cs @@ -62,6 +62,7 @@ protected override void OnReceivedMessage(Message msg) if (msg.IsExpired) { this.MessagingTrace.OnDropExpiredMessage(msg, MessagingInstruments.Phase.Receive); + msg.ReleaseDropped("ExpiredAtReceive"); return; } @@ -73,6 +74,7 @@ protected override void OnReceivedMessage(Message msg) this.messageCenter.TryDeliverToProxy(rejection); LogRejectingRequestDueToOverloading(this.Log, msg); GatewayInstruments.GatewayLoadShedding.Add(1); + msg.ReleaseDropped("RejectedGatewayOverload"); return; } @@ -147,6 +149,7 @@ protected override bool PrepareMessageForSend(Message msg) if (msg.IsExpired) { this.MessagingTrace.OnDropExpiredMessage(msg, MessagingInstruments.Phase.Send); + msg.ReleaseDropped("ExpiredAtSend"); return false; } @@ -169,11 +172,13 @@ public void FailMessage(Message msg, string reason) Message.RejectionTypes.Transient, $"Silo {this.myAddress} is rejecting message: {msg}. Reason = {reason}", new SiloUnavailableException()); + msg.ReleaseDropped("FailedSendRequest"); } else { LogSiloDroppingMessage(this.Log, this.myAddress, msg, reason); MessagingInstruments.OnDroppedSentMessage(msg); + msg.ReleaseDropped("FailedSendNonRequest"); } } diff --git a/src/Orleans.Runtime/Networking/SiloConnection.cs b/src/Orleans.Runtime/Networking/SiloConnection.cs index 448d348d182..a5cf1ca6fe3 100644 --- a/src/Orleans.Runtime/Networking/SiloConnection.cs +++ b/src/Orleans.Runtime/Networking/SiloConnection.cs @@ -81,6 +81,7 @@ protected override void OnReceivedMessage(Message msg) if (msg.IsExpired) { this.MessagingTrace.OnDropExpiredMessage(msg, MessagingInstruments.Phase.Receive); + msg.ReleaseDropped("ExpiredAtReceive"); return; } @@ -92,12 +93,14 @@ protected override void OnReceivedMessage(Message msg) if (msg.Direction != Message.Directions.Request) { this.MessagingTrace.OnDropBlockedApplicationMessage(msg); + msg.ReleaseDropped("BlockedApplicationMessage"); return; } MessagingInstruments.OnRejectedMessage(msg); var rejection = this.MessageFactory.CreateRejectionResponse(msg, Message.RejectionTypes.Unrecoverable, "Silo stopping", new SiloUnavailableException()); this.Send(rejection); + msg.ReleaseDropped("RejectedSiloStopping"); return; } @@ -134,9 +137,15 @@ protected override void OnReceivedMessage(Message msg) } this.Send(rejection); + msg.ReleaseDropped("RejectedObsoleteEpoch"); LogDebugRejectingObsoleteRequest(this.Log, msg.TargetSilo?.ToString() ?? "null", this.LocalSiloAddress.ToString(), msg); } + else + { + // Response or OneWay to obsolete epoch - drop it + msg.ReleaseDropped("DroppedObsoleteEpoch"); + } } private void HandlePingMessage(Message msg) @@ -153,6 +162,7 @@ private void HandlePingMessage(Message msg) Message rejection = this.MessageFactory.CreateRejectionResponse(msg, Message.RejectionTypes.Unrecoverable, $"The target silo is no longer active: target was {msg.TargetSilo}, but this silo is {LocalSiloAddress}. The rejected ping message is {msg}."); this.Send(rejection); + msg.ReleaseDropped("RejectedPingObsoleteEpoch"); } else { @@ -160,6 +170,8 @@ private void HandlePingMessage(Message msg) var response = this.MessageFactory.CreateResponseMessage(msg); response.BodyObject = PingResponse; this.Send(response); + msg.MarkTransferred("SiloConnection.HandlePingMessage"); + msg.Release(); } } @@ -234,6 +246,7 @@ protected override bool PrepareMessageForSend(Message msg) if (msg.IsExpired) { this.MessagingTrace.OnDropExpiredMessage(msg, MessagingInstruments.Phase.Send); + msg.ReleaseDropped("ExpiredAtSend"); if (msg.IsPing()) { @@ -277,10 +290,12 @@ public void FailMessage(Message msg, string reason) Message.RejectionTypes.Transient, $"Silo {this.LocalSiloAddress} is rejecting message: {msg}. Reason = {reason}", new SiloUnavailableException()); + msg.ReleaseDropped("FailedSendRequest"); } else { this.MessagingTrace.OnSiloDropSendingMessage(this.LocalSiloAddress, msg, reason); + msg.ReleaseDropped("FailedSendNonRequest"); } } diff --git a/test/Orleans.Core.Tests/Messaging/MessagePoolTests.cs b/test/Orleans.Core.Tests/Messaging/MessagePoolTests.cs new file mode 100644 index 00000000000..9a3424f1ed7 --- /dev/null +++ b/test/Orleans.Core.Tests/Messaging/MessagePoolTests.cs @@ -0,0 +1,181 @@ +using Microsoft.Extensions.DependencyInjection; +using Orleans.CodeGeneration; +using Orleans.Runtime; +using Orleans.Runtime.Messaging; +using TestExtensions; +using Xunit; + +namespace UnitTests.Messaging +{ + /// + /// Tests for Message pooling and ownership tracking. + /// + [Collection(TestEnvironmentFixture.DefaultCollection)] + public class MessagePoolTests + { + private readonly MessageFactory _messageFactory; + + public MessagePoolTests(TestEnvironmentFixture fixture) + { + _messageFactory = fixture.Services.GetRequiredService(); + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void Message_RefCount_InitializedToOne() + { + var message = MessagePool.Get(); + + Assert.NotNull(message); + + message.Release(); + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void Message_Acquire_IncrementsRefCount() + { + var message = MessagePool.Get(); + + message.Acquire(); + + message.Release(); + message.Release(); + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void Message_ReleaseDropped_ReleasesMessage() + { + var message = MessagePool.Get(); + + message.ReleaseDropped("TestReason"); + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void Message_MultipleAcquireRelease_WorksCorrectly() + { + var message = MessagePool.Get(); + + message.Acquire(); + message.Acquire(); + + message.Release(); + message.Release(); + message.Release(); + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void MessageFactory_CreateMessage_ReturnsPooledMessage() + { + var message = _messageFactory.CreateMessage(null, InvokeMethodOptions.None); + + Assert.NotNull(message); + Assert.Equal(Message.Directions.Request, message.Direction); + + message.Release(); + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void Message_MarkTransferred_DoesNotThrow() + { + var message = MessagePool.Get(); + + message.MarkTransferred("TestTransfer"); + message.MarkTransferred("AnotherTransfer"); + + message.Release(); + } + +#if DEBUG + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void MessagePool_LeakTracking_TracksOutstandingMessages() + { + MessagePool.ClearLeakTracking(); + MessagePool.EnableLeakTracking = true; + + try + { + var message1 = MessagePool.Get(); + var message2 = MessagePool.Get(); + + var outstanding = MessagePool.GetOutstandingMessages(); + Assert.Equal(2, outstanding.Count); + + message1.Release(); + outstanding = MessagePool.GetOutstandingMessages(); + Assert.Single(outstanding); + + message2.Release(); + outstanding = MessagePool.GetOutstandingMessages(); + Assert.Empty(outstanding); + } + finally + { + MessagePool.EnableLeakTracking = false; + MessagePool.ClearLeakTracking(); + } + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void MessagePool_LeakTracking_CapturesAllocationInfo() + { + MessagePool.ClearLeakTracking(); + MessagePool.EnableLeakTracking = true; + + try + { + var message = MessagePool.Get(); + + var outstanding = MessagePool.GetOutstandingMessages(); + Assert.Single(outstanding); + + var info = outstanding.First(); + + Assert.Same(message, info.Message); + Assert.NotNull(info.AllocationStack); + Assert.True(info.AllocationTime <= DateTime.UtcNow); + + message.Release(); + } + finally + { + MessagePool.EnableLeakTracking = false; + MessagePool.ClearLeakTracking(); + } + } + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void MessagePool_LeakTracking_DisabledByDefault() + { + MessagePool.EnableLeakTracking = false; + MessagePool.ClearLeakTracking(); + + var message = MessagePool.Get(); + + var outstanding = MessagePool.GetOutstandingMessages(); + Assert.Empty(outstanding); + + message.Release(); + } +#endif + + [Fact, TestCategory("BVT"), TestCategory("Messaging")] + public void Message_Reset_ClearsAllFields() + { + var message = MessagePool.Get(); + message.Direction = Message.Directions.Request; + message.TargetGrain = GrainId.Create("test", "key"); + message.SendingGrain = GrainId.Create("sender", "key"); + message.BodyObject = "test body"; + + message.Release(); + + var newMessage = MessagePool.Get(); + + Assert.Equal(Message.Directions.None, newMessage.Direction); + Assert.True(newMessage.TargetGrain.IsDefault); + Assert.True(newMessage.SendingGrain.IsDefault); + Assert.Null(newMessage.BodyObject); + + newMessage.Release(); + } + } +} From 11233d302f900d90c3729c20332362d349c2a964 Mon Sep 17 00:00:00 2001 From: Reuben Bond Date: Wed, 29 Apr 2026 21:23:59 -0700 Subject: [PATCH 2/2] Fix repartitioner message sampling with pooled messages Snapshot sampled message addressing before asynchronous processing so activation repartitioning does not observe reset pooled messages. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/Orleans.Runtime/Messaging/MessageCenter.cs | 2 +- .../ActivationRepartitioner.MessageSink.cs | 16 ++++++++++++---- .../Repartitioning/ActivationRepartitioner.cs | 4 ++-- .../Repartitioning/RepartitionerMessageFilter.cs | 10 +++++----- .../TestMessageFilter.cs | 8 ++++---- 5 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/Orleans.Runtime/Messaging/MessageCenter.cs b/src/Orleans.Runtime/Messaging/MessageCenter.cs index 105a90f305f..e19ee4b9b36 100644 --- a/src/Orleans.Runtime/Messaging/MessageCenter.cs +++ b/src/Orleans.Runtime/Messaging/MessageCenter.cs @@ -534,8 +534,8 @@ public void ReceiveMessage(Message msg) return; } - targetActivation.ReceiveMessage(msg); _messageObserver?.Invoke(msg); + targetActivation.ReceiveMessage(msg); } } catch (Exception ex) diff --git a/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs b/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs index cbb4f6007eb..6f8b7237c86 100644 --- a/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs +++ b/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.MessageSink.cs @@ -41,7 +41,7 @@ private async Task ProcessPendingEdges(CancellationToken cancellationToken) { await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding | ConfigureAwaitOptions.ContinueOnCapturedContext); - var drainBuffer = new Message[128]; + var drainBuffer = new RecordedMessage[128]; var iteration = 0; const int MaxIterationsPerYield = 128; while (!cancellationToken.IsCancellationRequested) @@ -51,8 +51,7 @@ private async Task ProcessPendingEdges(CancellationToken cancellationToken) { foreach (var message in drainBuffer[..count]) { - if (!IsFullyAddressed(message) || // The silo addresses (likely the target) is set null some time later (after the message is recorded), this can lead to a NRE - !_messageFilter.IsAcceptable(message, out var isSenderMigratable, out var isTargetMigratable)) + if (!_messageFilter.IsAcceptable(message.SendingGrain, message.TargetGrain, out var isSenderMigratable, out var isTargetMigratable)) { continue; } @@ -122,7 +121,8 @@ private void RecordMessage(Message message) return; } - if (_pendingMessages.TryAdd(message) == Utilities.BufferStatus.Success) + var recordedMessage = new RecordedMessage(message.SendingGrain, message.SendingSilo!, message.TargetGrain, message.TargetSilo!); + if (_pendingMessages.TryAdd(recordedMessage) == Utilities.BufferStatus.Success) { _pendingMessageEvent.Signal(); } @@ -132,6 +132,14 @@ private void RecordMessage(Message message) private static bool IsFullyAddressed(Message message) => message.IsSenderFullyAddressed && message.IsTargetFullyAddressed; + private sealed class RecordedMessage(GrainId sendingGrain, SiloAddress sendingSilo, GrainId targetGrain, SiloAddress targetSilo) + { + public GrainId SendingGrain { get; } = sendingGrain; + public SiloAddress SendingSilo { get; } = sendingSilo; + public GrainId TargetGrain { get; } = targetGrain; + public SiloAddress TargetSilo { get; } = targetSilo; + } + async ValueTask IActivationRepartitionerSystemTarget.FlushBuffers() { while (_pendingMessages.Count > 0) diff --git a/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.cs b/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.cs index 19256b305b6..19a118347ea 100644 --- a/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.cs +++ b/src/Orleans.Runtime/Placement/Repartitioning/ActivationRepartitioner.cs @@ -30,7 +30,7 @@ internal sealed partial class ActivationRepartitioner : SystemTarget, IActivatio private readonly ActivationDirectory _activationDirectory; private readonly TimeProvider _timeProvider; private readonly ActivationRepartitionerOptions _options; - private readonly StripedMpscBuffer _pendingMessages; + private readonly StripedMpscBuffer _pendingMessages; private readonly SingleWaiterAutoResetEvent _pendingMessageEvent = new() { RunContinuationsAsynchronously = true }; private readonly FrequentEdgeCounter _edgeWeights; private readonly IGrainTimer _timer; @@ -63,7 +63,7 @@ public ActivationRepartitioner( _timeProvider = timeProvider; _edgeWeights = new(options.Value.MaxEdgeCount); _lastExchangedStopwatch = CoarseStopwatch.StartNew(); - _pendingMessages = new StripedMpscBuffer(Environment.ProcessorCount, options.Value.MaxUnprocessedEdges / Environment.ProcessorCount); + _pendingMessages = new StripedMpscBuffer(Environment.ProcessorCount, options.Value.MaxUnprocessedEdges / Environment.ProcessorCount); shared.ActivationDirectory.RecordNewTarget(this); _siloStatusOracle.SubscribeToSiloStatusEvents(this); _timer = RegisterTimer(_ => TriggerExchangeRequest().AsTask(), null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan); diff --git a/src/Orleans.Runtime/Placement/Repartitioning/RepartitionerMessageFilter.cs b/src/Orleans.Runtime/Placement/Repartitioning/RepartitionerMessageFilter.cs index 0e3d0aae1c3..0edb016aadf 100644 --- a/src/Orleans.Runtime/Placement/Repartitioning/RepartitionerMessageFilter.cs +++ b/src/Orleans.Runtime/Placement/Repartitioning/RepartitionerMessageFilter.cs @@ -5,12 +5,12 @@ namespace Orleans.Runtime.Placement.Repartitioning; internal interface IRepartitionerMessageFilter { - bool IsAcceptable(Message message, out bool isSenderMigratable, out bool isTargetMigratable); + bool IsAcceptable(GrainId sendingGrain, GrainId targetGrain, out bool isSenderMigratable, out bool isTargetMigratable); } internal sealed class RepartitionerMessageFilter(GrainMigratabilityChecker checker) : IRepartitionerMessageFilter { - public bool IsAcceptable(Message message, out bool isSenderMigratable, out bool isTargetMigratable) + public bool IsAcceptable(GrainId sendingGrain, GrainId targetGrain, out bool isSenderMigratable, out bool isTargetMigratable) { isSenderMigratable = false; isTargetMigratable = false; @@ -18,13 +18,13 @@ public bool IsAcceptable(Message message, out bool isSenderMigratable, out bool // There are some edge cases when this can happen i.e. a grain invoking another one of its methods via AsReference<>, but we still exclude it // as wherever this grain would be located in the cluster, it would always be a local call (since it targets itself), this would add negative transfer cost // which would skew a potential relocation of this grain, while it shouldn't, because whenever this grain is located, it would still make local calls to itself. - if (message.SendingGrain == message.TargetGrain) + if (sendingGrain == targetGrain) { return false; } - isSenderMigratable = checker.IsMigratable(message.SendingGrain.Type, ImmovableKind.Repartitioner); - isTargetMigratable = checker.IsMigratable(message.TargetGrain.Type, ImmovableKind.Repartitioner); + isSenderMigratable = checker.IsMigratable(sendingGrain.Type, ImmovableKind.Repartitioner); + isTargetMigratable = checker.IsMigratable(targetGrain.Type, ImmovableKind.Repartitioner); // If both are not migratable types we ignore this. But if one of them is not, then we allow passing, as we wish to move grains closer to them, as with any type of grain. return isSenderMigratable || isTargetMigratable; diff --git a/test/Orleans.Placement.Tests/ActivationRepartitioningTests/TestMessageFilter.cs b/test/Orleans.Placement.Tests/ActivationRepartitioningTests/TestMessageFilter.cs index fa3206d2fb3..a3d78ac157d 100644 --- a/test/Orleans.Placement.Tests/ActivationRepartitioningTests/TestMessageFilter.cs +++ b/test/Orleans.Placement.Tests/ActivationRepartitioningTests/TestMessageFilter.cs @@ -10,7 +10,7 @@ internal sealed class TestMessageFilter(GrainMigratabilityChecker checker) : IRe { private readonly RepartitionerMessageFilter _messageFilter = new(checker); - public bool IsAcceptable(Message message, out bool isSenderMigratable, out bool isTargetMigratable) => - _messageFilter.IsAcceptable(message, out isSenderMigratable, out isTargetMigratable) && - !message.SendingGrain.IsClient() && !message.TargetGrain.IsClient(); -} \ No newline at end of file + public bool IsAcceptable(GrainId sendingGrain, GrainId targetGrain, out bool isSenderMigratable, out bool isTargetMigratable) => + _messageFilter.IsAcceptable(sendingGrain, targetGrain, out isSenderMigratable, out isTargetMigratable) && + !sendingGrain.IsClient() && !targetGrain.IsClient(); +}