diff --git a/src/Orleans.Core.Abstractions/Core/Grain.cs b/src/Orleans.Core.Abstractions/Core/Grain.cs index 7fcb8deb06d..71b37594b07 100644 --- a/src/Orleans.Core.Abstractions/Core/Grain.cs +++ b/src/Orleans.Core.Abstractions/Core/Grain.cs @@ -14,6 +14,8 @@ namespace Orleans; /// public abstract partial class Grain : IGrainBase, IAddressable { + private IGrainRuntime? _grainRuntime; + // Do not use this directly because we currently don't provide a way to inject it; // any interaction with it will result in non unit-testable code. Any behavior that can be accessed // from within client code (including subclasses of this class), should be exposed through IGrainRuntime. @@ -23,7 +25,7 @@ public abstract partial class Grain : IGrainBase, IAddressable public GrainReference GrainReference { get { return GrainContext.GrainReference; } } - internal IGrainRuntime Runtime { get; } + internal IGrainRuntime Runtime => _grainRuntime ??= GetGrainRuntime()!; /// /// Gets an object which can be used to access other grains. Null if this grain is not associated with a Runtime, such as when created directly for unit testing. @@ -54,9 +56,7 @@ protected Grain() : this(RuntimeContext.Current!, grainRuntime: null) protected Grain(IGrainContext grainContext, IGrainRuntime? grainRuntime = null) { GrainContext = grainContext; - - // ! The runtime ensures that this is not null and Unit testing frameworks must make sure that this is not null. - Runtime = grainRuntime ?? grainContext?.ActivationServices.GetService()!; + _grainRuntime = grainRuntime; } /// @@ -163,6 +163,22 @@ protected void DelayDeactivation(TimeSpan timeSpan) /// A cancellation token which signals when deactivation should complete promptly. public virtual Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken) => Task.CompletedTask; + private IGrainRuntime? GetGrainRuntime() + { + var grainContext = GrainContext; + if (grainContext is null) + { + return null; + } + + if (grainContext.GetComponent(typeof(IGrainRuntime)) is IGrainRuntime grainRuntime) + { + return grainRuntime; + } + + return grainContext.ActivationServices.GetService(); + } + internal void EnsureRuntime() { if (Runtime == null) diff --git a/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs b/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs index 75f5a021175..55cfce7e0fd 100644 --- a/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs +++ b/src/Orleans.Runtime/Catalog/GrainTypeSharedContext.cs @@ -109,6 +109,11 @@ private static TimeSpan GetCollectionAgeLimit(GrainType grainType, Type grainCla return (TComponent)Logger; } + if (typeof(TComponent) == typeof(IGrainRuntime) && Runtime is TComponent runtime) + { + return runtime; + } + if (_components is null) return default; _components.TryGetValue(typeof(TComponent), out var resultObj); return (TComponent?)resultObj; @@ -131,6 +136,11 @@ private static TimeSpan GetCollectionAgeLimit(GrainType grainType, Type grainCla return Logger; } + if (componentType == typeof(IGrainRuntime)) + { + return Runtime; + } + if (_components is null) return default; _components.TryGetValue(componentType, out var resultObj); return resultObj;