diff --git a/src/Orleans.Core/Messaging/CorrelationId.cs b/src/Orleans.Core/Messaging/CorrelationId.cs index b40167a4fca..c3a7b055dd3 100644 --- a/src/Orleans.Core/Messaging/CorrelationId.cs +++ b/src/Orleans.Core/Messaging/CorrelationId.cs @@ -16,7 +16,7 @@ namespace Orleans.Runtime public static CorrelationId GetNext() => new(System.Threading.Interlocked.Increment(ref lastUsed)); - public override int GetHashCode() => id.GetHashCode(); + public override int GetHashCode() => HashCode.Combine(id); public override bool Equals(object? obj) => obj is CorrelationId correlationId && Equals(correlationId); diff --git a/src/Orleans.Core/Messaging/MessageFactory.cs b/src/Orleans.Core/Messaging/MessageFactory.cs index e41bc632815..e1534ce394c 100644 --- a/src/Orleans.Core/Messaging/MessageFactory.cs +++ b/src/Orleans.Core/Messaging/MessageFactory.cs @@ -50,7 +50,8 @@ public Message CreateMessage(object body, InvokeMethodOptions options) private CorrelationId GetNextCorrelationId() { var id = _seed ^ Interlocked.Increment(ref _nextId); - return new CorrelationId(unchecked((long)id)); + var stripeIndex = StripedCallbackDictionary.GetCurrentThreadStripeIndex(); + return StripedCallbackDictionary.CreateCorrelationId(unchecked((long)id), stripeIndex); } public Message CreateResponseMessage(Message request) diff --git a/src/Orleans.Core/Messaging/StripedCallbackDictionary.cs b/src/Orleans.Core/Messaging/StripedCallbackDictionary.cs new file mode 100644 index 00000000000..3e131d8dda9 --- /dev/null +++ b/src/Orleans.Core/Messaging/StripedCallbackDictionary.cs @@ -0,0 +1,255 @@ +#nullable enable +using System; +using System.Collections; +using System.Collections.Generic; +using System.Runtime.CompilerServices; + +namespace Orleans.Runtime; + +/// +/// A striped dictionary that distributes entries across multiple internal dictionaries +/// to reduce lock contention. The stripe is determined by bits embedded in the CorrelationId. +/// +/// The type of values stored in the dictionary. +internal sealed class StripedCallbackDictionary : IEnumerable> +{ + /// + /// The number of bits used to identify the stripe (stored in the upper bits of the CorrelationId). + /// + public const int StripeBits = 7; + + /// + /// The number of stripes (must be a power of 2). + /// + public const int StripeCount = 1 << StripeBits; // 128 stripes + + /// + /// Mask to extract the stripe index from the upper bits. + /// + private const long StripeMask = (long)(StripeCount - 1) << (64 - StripeBits); + + /// + /// The shift amount to move the stripe bits to the lowest position. + /// + private const int StripeShift = 64 - StripeBits; + + private readonly Stripe[] _stripes; + + public StripedCallbackDictionary() + { + _stripes = new Stripe[StripeCount]; + for (int i = 0; i < StripeCount; i++) + { + _stripes[i] = new Stripe(); + } + } + + /// + /// Encodes a stripe index into the upper bits of a base value to create a CorrelationId. + /// + /// The base value (e.g., from an incrementing counter XORed with a seed). + /// The stripe index (typically derived from thread id). + /// A CorrelationId with the stripe encoded in the upper bits. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static CorrelationId CreateCorrelationId(long baseValue, int stripeIndex) + { + // Clear the upper StripeBits of the base value and set the stripe index there + long maskedBase = baseValue & ~StripeMask; + long stripeValue = (long)(stripeIndex & (StripeCount - 1)) << StripeShift; + return new CorrelationId(maskedBase | stripeValue); + } + + /// + /// Extracts the stripe index from a CorrelationId. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetStripeIndex(CorrelationId correlationId) + { + return (int)((correlationId.ToInt64() & StripeMask) >>> StripeShift); + } + + /// + /// Gets the stripe index for the current thread. Use this when creating new CorrelationIds. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int GetCurrentThreadStripeIndex() + { + return Environment.CurrentManagedThreadId & (StripeCount - 1); + } + + /// + /// Gets the stripe for the given correlation id. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private Stripe GetStripe(CorrelationId correlationId) + { + return _stripes[GetStripeIndex(correlationId)]; + } + + /// + /// Attempts to add the specified key and value to the dictionary. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryAdd(CorrelationId key, TValue value) + { + var stripe = GetStripe(key); + lock (stripe.Lock) + { + return stripe.Dictionary.TryAdd(key, value); + } + } + + /// + /// Attempts to get the value associated with the specified key. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryGetValue(CorrelationId key, out TValue? value) + { + var stripe = GetStripe(key); + lock (stripe.Lock) + { + return stripe.Dictionary.TryGetValue(key, out value); + } + } + + /// + /// Attempts to remove the value with the specified key. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryRemove(CorrelationId key, out TValue? value) + { + var stripe = GetStripe(key); + lock (stripe.Lock) + { + return stripe.Dictionary.Remove(key, out value); + } + } + + /// + /// Gets the approximate total count of items across all stripes. + /// + public int Count + { + get + { + int count = 0; + foreach (var stripe in _stripes) + { + lock (stripe.Lock) + { + count += stripe.Dictionary.Count; + } + } + return count; + } + } + + /// + /// Counts items matching a predicate across all stripes. + /// + public int CountWhere(Func, bool> predicate) + { + int count = 0; + foreach (var stripe in _stripes) + { + lock (stripe.Lock) + { + foreach (var kvp in stripe.Dictionary) + { + if (predicate(kvp)) + { + count++; + } + } + } + } + return count; + } + + /// + /// Returns an enumerator that iterates through all items in all stripes. + /// Note: This takes a snapshot of each stripe under its lock. + /// + public Enumerator GetEnumerator() => new(this); + + IEnumerator> IEnumerable>.GetEnumerator() => GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + private sealed class Stripe + { + public readonly object Lock = new(); + public readonly Dictionary Dictionary = new(); + } + + public struct Enumerator : IEnumerator> + { + private readonly StripedCallbackDictionary _dictionary; + private int _stripeIndex; + private List>? _currentSnapshot; + private int _snapshotIndex; + + internal Enumerator(StripedCallbackDictionary dictionary) + { + _dictionary = dictionary; + _stripeIndex = -1; + _currentSnapshot = null; + _snapshotIndex = -1; + } + + public KeyValuePair Current => _currentSnapshot![_snapshotIndex]; + + object IEnumerator.Current => Current; + + public bool MoveNext() + { + while (true) + { + // Try to advance within current snapshot + if (_currentSnapshot != null) + { + _snapshotIndex++; + if (_snapshotIndex < _currentSnapshot.Count) + { + return true; + } + } + + // Move to next stripe + _stripeIndex++; + if (_stripeIndex >= _dictionary._stripes.Length) + { + _currentSnapshot = null; + return false; + } + + // Take a snapshot of the next stripe + var stripe = _dictionary._stripes[_stripeIndex]; + lock (stripe.Lock) + { + if (stripe.Dictionary.Count > 0) + { + _currentSnapshot = new List>(stripe.Dictionary); + _snapshotIndex = -1; + } + else + { + _currentSnapshot = null; + } + } + } + } + + public void Reset() + { + _stripeIndex = -1; + _currentSnapshot = null; + _snapshotIndex = -1; + } + + public void Dispose() + { + _currentSnapshot = null; + } + } +} diff --git a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs index c6d6a0cc4a1..38f7a2a1c30 100644 --- a/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs +++ b/src/Orleans.Core/Runtime/OutsideRuntimeClient.cs @@ -26,7 +26,7 @@ internal partial class OutsideRuntimeClient : IRuntimeClient, IDisposable, IClus private readonly ILogger logger; private readonly ClientMessagingOptions clientMessagingOptions; - private readonly ConcurrentDictionary callbacks; + private readonly StripedCallbackDictionary callbacks; private InvokableObjectManager localObjects; private bool disposing; private bool disposed; @@ -84,7 +84,7 @@ public OutsideRuntimeClient( this.loggerFactory = loggerFactory; this.messagingTrace = messagingTrace; this.logger = loggerFactory.CreateLogger(); - callbacks = new ConcurrentDictionary(); + callbacks = new StripedCallbackDictionary(); this.clientMessagingOptions = clientMessagingOptions.Value; var period = Max( TimeSpan.FromMilliseconds(1), @@ -434,7 +434,7 @@ public void BreakOutstandingMessagesToSilo(SiloAddress deadSilo) } public int GetRunningRequestsCount(GrainInterfaceType grainInterfaceType) - => this.callbacks.Count(c => c.Value.Message.InterfaceType == grainInterfaceType); + => this.callbacks.CountWhere(c => c.Value.Message.InterfaceType == grainInterfaceType); /// public void NotifyClusterConnectionLost() diff --git a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs index 29e4abaac67..d163eda1082 100644 --- a/src/Orleans.Runtime/Core/InsideRuntimeClient.cs +++ b/src/Orleans.Runtime/Core/InsideRuntimeClient.cs @@ -1,8 +1,6 @@ using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -32,7 +30,7 @@ internal sealed partial class InsideRuntimeClient : IRuntimeClient, ILifecyclePa private readonly ILogger invokeExceptionLogger; private readonly ILoggerFactory loggerFactory; private readonly SiloMessagingOptions messagingOptions; - private readonly ConcurrentDictionary<(GrainId, CorrelationId), CallbackData> callbacks; + private readonly StripedCallbackDictionary callbacks; private readonly InterfaceToImplementationMappingCache interfaceToImplementationMapping; private readonly SharedCallbackData sharedCallbackData; private readonly SharedCallbackData systemSharedCallbackData; @@ -74,7 +72,7 @@ public InsideRuntimeClient( this._applicationRequestInstruments = new(orleansInstruments); this.ServiceProvider = serviceProvider; this.MySilo = siloDetails.SiloAddress; - this.callbacks = new ConcurrentDictionary<(GrainId, CorrelationId), CallbackData>(); + this.callbacks = new StripedCallbackDictionary(); this.messageFactory = messageFactory; this.ConcreteGrainFactory = new GrainFactory(this, referenceActivator, interfaceIdResolver, interfaceToTypeResolver); this.logger = loggerFactory.CreateLogger(); @@ -88,7 +86,7 @@ public InsideRuntimeClient( var callbackDataLogger = loggerFactory.CreateLogger(); this.sharedCallbackData = new SharedCallbackData( - msg => this.UnregisterCallback(msg.SendingGrain, msg.Id), + msg => this.UnregisterCallback(msg.Id), callbackDataLogger, this.messagingOptions.ResponseTimeout, this.messagingOptions.CancelRequestOnTimeout, @@ -96,7 +94,7 @@ public InsideRuntimeClient( cancellationManager: null); this.systemSharedCallbackData = new SharedCallbackData( - msg => this.UnregisterCallback(msg.SendingGrain, msg.Id), + msg => this.UnregisterCallback(msg.Id), callbackDataLogger, this.messagingOptions.SystemResponseTimeout, cancelOnTimeout: false, @@ -178,7 +176,7 @@ public void SendRequest( // Register a callback for the request. var callbackData = new CallbackData(sharedData, context, message, _applicationRequestInstruments); - callbacks.TryAdd((message.SendingGrain, message.Id), callbackData); + callbacks.TryAdd(message.Id, callbackData); callbackData.SubscribeForCancellation(cancellationToken); } else @@ -207,9 +205,9 @@ public void SendResponse(Message request, Response response) /// /// UnRegister a callback. /// - private void UnregisterCallback(GrainId grainId, CorrelationId correlationId) + private void UnregisterCallback(CorrelationId correlationId) { - callbacks.TryRemove((grainId, correlationId), out _); + callbacks.TryRemove(correlationId, out _); } public void SniffIncomingMessage(Message message) @@ -422,7 +420,7 @@ public void ReceiveResponse(Message message) else if (message.Result == Message.ResponseTypes.Status) { var status = (StatusResponse)message.BodyObject; - callbacks.TryGetValue((message.TargetGrain, message.Id), out var callback); + callbacks.TryGetValue(message.Id, out var callback); var request = callback?.Message; if (request is not null) { @@ -456,7 +454,7 @@ public void ReceiveResponse(Message message) } CallbackData callbackData; - bool found = callbacks.TryRemove((message.TargetGrain, message.Id), out callbackData); + bool found = callbacks.TryRemove(message.Id, out callbackData); if (found) { // IMPORTANT: we do not schedule the response callback via the scheduler, since the only thing it does @@ -545,7 +543,7 @@ public void Participate(ISiloLifecycle lifecycle) } public int GetRunningRequestsCount(GrainInterfaceType grainInterfaceType) - => this.callbacks.Count(c => c.Value.Message.InterfaceType == grainInterfaceType); + => this.callbacks.CountWhere(c => c.Value.Message.InterfaceType == grainInterfaceType); private async Task MonitorCallbackExpiry() {