Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Orleans.Core/Messaging/CorrelationId.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Orleans.Runtime

public static CorrelationId GetNext() => new(System.Threading.Interlocked.Increment(ref lastUsed));

public override int GetHashCode() => id.GetHashCode();
public override int GetHashCode() => HashCode.Combine(id);

public override bool Equals(object? obj) => obj is CorrelationId correlationId && Equals(correlationId);

Expand Down
3 changes: 2 additions & 1 deletion src/Orleans.Core/Messaging/MessageFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public Message CreateMessage(object body, InvokeMethodOptions options)
private CorrelationId GetNextCorrelationId()
{
var id = _seed ^ Interlocked.Increment(ref _nextId);
return new CorrelationId(unchecked((long)id));
var stripeIndex = StripedCallbackDictionary<object>.GetCurrentThreadStripeIndex();
return StripedCallbackDictionary<object>.CreateCorrelationId(unchecked((long)id), stripeIndex);
}
Comment on lines 50 to 55

public Message CreateResponseMessage(Message request)
Expand Down
255 changes: 255 additions & 0 deletions src/Orleans.Core/Messaging/StripedCallbackDictionary.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
#nullable enable
using System;
using System.Collections;
using System.Collections.Generic;
using System.Runtime.CompilerServices;

namespace Orleans.Runtime;

/// <summary>
/// A striped dictionary that distributes entries across multiple internal dictionaries
/// to reduce lock contention. The stripe is determined by bits embedded in the CorrelationId.
/// </summary>
/// <typeparam name="TValue">The type of values stored in the dictionary.</typeparam>
internal sealed class StripedCallbackDictionary<TValue> : IEnumerable<KeyValuePair<CorrelationId, TValue>>
{
/// <summary>
/// The number of bits used to identify the stripe (stored in the upper bits of the CorrelationId).
/// </summary>
public const int StripeBits = 7;

/// <summary>
/// The number of stripes (must be a power of 2).
/// </summary>
public const int StripeCount = 1 << StripeBits; // 128 stripes

/// <summary>
/// Mask to extract the stripe index from the upper bits.
/// </summary>
private const long StripeMask = (long)(StripeCount - 1) << (64 - StripeBits);

/// <summary>
/// The shift amount to move the stripe bits to the lowest position.
/// </summary>
private const int StripeShift = 64 - StripeBits;

private readonly Stripe[] _stripes;

public StripedCallbackDictionary()
{
_stripes = new Stripe[StripeCount];
for (int i = 0; i < StripeCount; i++)
{
_stripes[i] = new Stripe();
}
}

/// <summary>
/// Encodes a stripe index into the upper bits of a base value to create a CorrelationId.
/// </summary>
/// <param name="baseValue">The base value (e.g., from an incrementing counter XORed with a seed).</param>
/// <param name="stripeIndex">The stripe index (typically derived from thread id).</param>
/// <returns>A CorrelationId with the stripe encoded in the upper bits.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static CorrelationId CreateCorrelationId(long baseValue, int stripeIndex)
{
// Clear the upper StripeBits of the base value and set the stripe index there
long maskedBase = baseValue & ~StripeMask;
long stripeValue = (long)(stripeIndex & (StripeCount - 1)) << StripeShift;
return new CorrelationId(maskedBase | stripeValue);
}

/// <summary>
/// Extracts the stripe index from a CorrelationId.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetStripeIndex(CorrelationId correlationId)
{
return (int)((correlationId.ToInt64() & StripeMask) >>> StripeShift);
}

/// <summary>
/// Gets the stripe index for the current thread. Use this when creating new CorrelationIds.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetCurrentThreadStripeIndex()
{
return Environment.CurrentManagedThreadId & (StripeCount - 1);
}

/// <summary>
/// Gets the stripe for the given correlation id.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private Stripe GetStripe(CorrelationId correlationId)
{
return _stripes[GetStripeIndex(correlationId)];
}

/// <summary>
/// Attempts to add the specified key and value to the dictionary.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool TryAdd(CorrelationId key, TValue value)
{
var stripe = GetStripe(key);
lock (stripe.Lock)
{
return stripe.Dictionary.TryAdd(key, value);
}
}

/// <summary>
/// Attempts to get the value associated with the specified key.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool TryGetValue(CorrelationId key, out TValue? value)
{
var stripe = GetStripe(key);
lock (stripe.Lock)
{
return stripe.Dictionary.TryGetValue(key, out value);
}
}

/// <summary>
/// Attempts to remove the value with the specified key.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool TryRemove(CorrelationId key, out TValue? value)
{
var stripe = GetStripe(key);
lock (stripe.Lock)
{
return stripe.Dictionary.Remove(key, out value);
}
}

/// <summary>
/// Gets the approximate total count of items across all stripes.
/// </summary>
public int Count
{
get
{
int count = 0;
foreach (var stripe in _stripes)
{
lock (stripe.Lock)
{
count += stripe.Dictionary.Count;
}
}
return count;
}
}

/// <summary>
/// Counts items matching a predicate across all stripes.
/// </summary>
public int CountWhere(Func<KeyValuePair<CorrelationId, TValue>, bool> predicate)
{
int count = 0;
foreach (var stripe in _stripes)
{
lock (stripe.Lock)
{
foreach (var kvp in stripe.Dictionary)
{
if (predicate(kvp))
{
count++;
}
}
}
}
Comment on lines +155 to +165
return count;
}

/// <summary>
/// Returns an enumerator that iterates through all items in all stripes.
/// Note: This takes a snapshot of each stripe under its lock.
/// </summary>
public Enumerator GetEnumerator() => new(this);

IEnumerator<KeyValuePair<CorrelationId, TValue>> IEnumerable<KeyValuePair<CorrelationId, TValue>>.GetEnumerator() => GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

private sealed class Stripe
{
public readonly object Lock = new();
public readonly Dictionary<CorrelationId, TValue> Dictionary = new();
}

public struct Enumerator : IEnumerator<KeyValuePair<CorrelationId, TValue>>
{
private readonly StripedCallbackDictionary<TValue> _dictionary;
private int _stripeIndex;
private List<KeyValuePair<CorrelationId, TValue>>? _currentSnapshot;
private int _snapshotIndex;

internal Enumerator(StripedCallbackDictionary<TValue> dictionary)
{
_dictionary = dictionary;
_stripeIndex = -1;
_currentSnapshot = null;
_snapshotIndex = -1;
}

public KeyValuePair<CorrelationId, TValue> Current => _currentSnapshot![_snapshotIndex];

object IEnumerator.Current => Current;

public bool MoveNext()
{
while (true)
{
// Try to advance within current snapshot
if (_currentSnapshot != null)
{
_snapshotIndex++;
if (_snapshotIndex < _currentSnapshot.Count)
{
return true;
}
}

// Move to next stripe
_stripeIndex++;
if (_stripeIndex >= _dictionary._stripes.Length)
{
_currentSnapshot = null;
return false;
}

// Take a snapshot of the next stripe
var stripe = _dictionary._stripes[_stripeIndex];
lock (stripe.Lock)
{
if (stripe.Dictionary.Count > 0)
{
_currentSnapshot = new List<KeyValuePair<CorrelationId, TValue>>(stripe.Dictionary);
_snapshotIndex = -1;
}
else
{
_currentSnapshot = null;
}
}
}
}

public void Reset()
{
_stripeIndex = -1;
_currentSnapshot = null;
_snapshotIndex = -1;
}

public void Dispose()
{
_currentSnapshot = null;
}
}
}
6 changes: 3 additions & 3 deletions src/Orleans.Core/Runtime/OutsideRuntimeClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal partial class OutsideRuntimeClient : IRuntimeClient, IDisposable, IClus
private readonly ILogger logger;
private readonly ClientMessagingOptions clientMessagingOptions;

private readonly ConcurrentDictionary<CorrelationId, CallbackData> callbacks;
private readonly StripedCallbackDictionary<CallbackData> callbacks;
private InvokableObjectManager localObjects;
Comment on lines 26 to 30
private bool disposing;
private bool disposed;
Expand Down Expand Up @@ -84,7 +84,7 @@ public OutsideRuntimeClient(
this.loggerFactory = loggerFactory;
this.messagingTrace = messagingTrace;
this.logger = loggerFactory.CreateLogger<OutsideRuntimeClient>();
callbacks = new ConcurrentDictionary<CorrelationId, CallbackData>();
callbacks = new StripedCallbackDictionary<CallbackData>();
this.clientMessagingOptions = clientMessagingOptions.Value;
var period = Max(
TimeSpan.FromMilliseconds(1),
Expand Down Expand Up @@ -434,7 +434,7 @@ public void BreakOutstandingMessagesToSilo(SiloAddress deadSilo)
}

public int GetRunningRequestsCount(GrainInterfaceType grainInterfaceType)
=> this.callbacks.Count(c => c.Value.Message.InterfaceType == grainInterfaceType);
=> this.callbacks.CountWhere(c => c.Value.Message.InterfaceType == grainInterfaceType);

/// <inheritdoc />
public void NotifyClusterConnectionLost()
Expand Down
22 changes: 10 additions & 12 deletions src/Orleans.Runtime/Core/InsideRuntimeClient.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -32,7 +30,7 @@ internal sealed partial class InsideRuntimeClient : IRuntimeClient, ILifecyclePa
private readonly ILogger invokeExceptionLogger;
private readonly ILoggerFactory loggerFactory;
private readonly SiloMessagingOptions messagingOptions;
private readonly ConcurrentDictionary<(GrainId, CorrelationId), CallbackData> callbacks;
private readonly StripedCallbackDictionary<CallbackData> callbacks;
private readonly InterfaceToImplementationMappingCache interfaceToImplementationMapping;
private readonly SharedCallbackData sharedCallbackData;
private readonly SharedCallbackData systemSharedCallbackData;
Expand Down Expand Up @@ -74,7 +72,7 @@ public InsideRuntimeClient(
this._applicationRequestInstruments = new(orleansInstruments);
this.ServiceProvider = serviceProvider;
this.MySilo = siloDetails.SiloAddress;
this.callbacks = new ConcurrentDictionary<(GrainId, CorrelationId), CallbackData>();
this.callbacks = new StripedCallbackDictionary<CallbackData>();
this.messageFactory = messageFactory;
this.ConcreteGrainFactory = new GrainFactory(this, referenceActivator, interfaceIdResolver, interfaceToTypeResolver);
this.logger = loggerFactory.CreateLogger<InsideRuntimeClient>();
Expand All @@ -88,15 +86,15 @@ public InsideRuntimeClient(

var callbackDataLogger = loggerFactory.CreateLogger<CallbackData>();
this.sharedCallbackData = new SharedCallbackData(
msg => this.UnregisterCallback(msg.SendingGrain, msg.Id),
msg => this.UnregisterCallback(msg.Id),
callbackDataLogger,
this.messagingOptions.ResponseTimeout,
this.messagingOptions.CancelRequestOnTimeout,
this.messagingOptions.WaitForCancellationAcknowledgement,
cancellationManager: null);

this.systemSharedCallbackData = new SharedCallbackData(
msg => this.UnregisterCallback(msg.SendingGrain, msg.Id),
msg => this.UnregisterCallback(msg.Id),
callbackDataLogger,
this.messagingOptions.SystemResponseTimeout,
cancelOnTimeout: false,
Expand Down Expand Up @@ -178,7 +176,7 @@ public void SendRequest(

// Register a callback for the request.
var callbackData = new CallbackData(sharedData, context, message, _applicationRequestInstruments);
callbacks.TryAdd((message.SendingGrain, message.Id), callbackData);
callbacks.TryAdd(message.Id, callbackData);
callbackData.SubscribeForCancellation(cancellationToken);
}
else
Expand Down Expand Up @@ -207,9 +205,9 @@ public void SendResponse(Message request, Response response)
/// <summary>
/// UnRegister a callback.
/// </summary>
private void UnregisterCallback(GrainId grainId, CorrelationId correlationId)
private void UnregisterCallback(CorrelationId correlationId)
{
callbacks.TryRemove((grainId, correlationId), out _);
callbacks.TryRemove(correlationId, out _);
}

public void SniffIncomingMessage(Message message)
Expand Down Expand Up @@ -422,7 +420,7 @@ public void ReceiveResponse(Message message)
else if (message.Result == Message.ResponseTypes.Status)
{
var status = (StatusResponse)message.BodyObject;
callbacks.TryGetValue((message.TargetGrain, message.Id), out var callback);
callbacks.TryGetValue(message.Id, out var callback);
var request = callback?.Message;
if (request is not null)
{
Expand Down Expand Up @@ -456,7 +454,7 @@ public void ReceiveResponse(Message message)
}

CallbackData callbackData;
bool found = callbacks.TryRemove((message.TargetGrain, message.Id), out callbackData);
bool found = callbacks.TryRemove(message.Id, out callbackData);
if (found)
{
// IMPORTANT: we do not schedule the response callback via the scheduler, since the only thing it does
Expand Down Expand Up @@ -545,7 +543,7 @@ public void Participate(ISiloLifecycle lifecycle)
}

public int GetRunningRequestsCount(GrainInterfaceType grainInterfaceType)
=> this.callbacks.Count(c => c.Value.Message.InterfaceType == grainInterfaceType);
=> this.callbacks.CountWhere(c => c.Value.Message.InterfaceType == grainInterfaceType);

private async Task MonitorCallbackExpiry()
{
Expand Down
Loading