diff --git a/src/Orleans.Serialization/Invocation/Pools/ConcurrentObjectPool.cs b/src/Orleans.Serialization/Invocation/Pools/ConcurrentObjectPool.cs index 3ef1f621c28..032c47e3ea0 100644 --- a/src/Orleans.Serialization/Invocation/Pools/ConcurrentObjectPool.cs +++ b/src/Orleans.Serialization/Invocation/Pools/ConcurrentObjectPool.cs @@ -1,4 +1,6 @@ +using System; using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading; using Microsoft.Extensions.ObjectPool; @@ -14,17 +16,36 @@ public ConcurrentObjectPool() : base(new()) internal class ConcurrentObjectPool : ObjectPool where T : class where TPoolPolicy : IPooledObjectPolicy { - private readonly ThreadLocal> _objects = new(() => new()); + private static int NextPoolId = -1; + private readonly int _poolId = Interlocked.Increment(ref NextPoolId); private readonly TPoolPolicy _policy; public ConcurrentObjectPool(TPoolPolicy policy) => _policy = policy; public int MaxPoolSize { get; set; } = int.MaxValue; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private Stack GetStack() + { + var poolId = _poolId; + var stacks = PerThreadStack.Stacks; + if (stacks is null) + { + stacks = PerThreadStack.Stacks = new Stack[poolId + 1]; + } + else if ((uint)poolId >= (uint)stacks.Length) + { + Array.Resize(ref stacks, Math.Max(poolId + 1, stacks.Length * 2)); + PerThreadStack.Stacks = stacks; + } + + return stacks[poolId] ??= new(); + } + public override T Get() { - var stack = _objects.Value; + var stack = GetStack(); if (stack.TryPop(out var result)) { return result; @@ -37,12 +58,19 @@ public override void Return(T obj) { if (_policy.Return(obj)) { - var stack = _objects.Value; + var stack = GetStack(); if (stack.Count < MaxPoolSize) { stack.Push(obj); } } } + + // Thread-static stacks are indexed by pool instance to avoid sharing state between pools with different policies. + private static class PerThreadStack + { + [ThreadStatic] + internal static Stack[] Stacks; + } } }