Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions src/Orleans.CodeGenerator/ActivatorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using System.Collections.Generic;
using System.Linq;

namespace Orleans.CodeGenerator
{
Expand All @@ -15,6 +16,7 @@ private struct ConstructorArgument
public TypeSyntax Type { get; set; }
public string FieldName { get; set; }
public string ParameterName { get; set; }
public bool IsPool { get; set; }
}

public ActivatorGenerator(CodeGenerator codeGenerator)
Expand All @@ -34,7 +36,9 @@ public ClassDeclarationSyntax GenerateActivator(ISerializableTypeDescription typ
{
foreach (var arg in parameters)
{
orderedFields.Add(new ConstructorArgument { Type = arg, FieldName = $"_arg{index}", ParameterName = $"arg{index}" });
// Detect if this is an InvokablePool<T> parameter
var isPool = arg is GenericNameSyntax gns && gns.Identifier.Text == "InvokablePool";
orderedFields.Add(new ConstructorArgument { Type = arg, FieldName = $"_arg{index}", ParameterName = $"arg{index}", IsPool = isPool });
index++;
}
}
Expand Down Expand Up @@ -80,11 +84,23 @@ private ConstructorDeclarationSyntax GenerateConstructor(
{
parameters.Add(Parameter(field.ParameterName.ToIdentifier()).WithType(field.Type));

body.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
field.FieldName.ToIdentifierName(),
Unwrapped(field.ParameterName.ToIdentifierName()))));
// Pool fields are not wrapped services, assign directly
if (field.IsPool)
{
body.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
field.FieldName.ToIdentifierName(),
field.ParameterName.ToIdentifierName())));
}
else
{
body.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
field.FieldName.ToIdentifierName(),
Unwrapped(field.ParameterName.ToIdentifierName()))));
}
}

var constructorDeclaration = ConstructorDeclaration(simpleClassName)
Expand All @@ -104,6 +120,45 @@ static ExpressionSyntax Unwrapped(ExpressionSyntax expr)

private MemberDeclarationSyntax GenerateCreateMethod(ISerializableTypeDescription type, List<ConstructorArgument> orderedFields)
{
// Check if this is a poolable invokable (has InvokablePool<T> as first constructor argument)
var poolField = orderedFields.FirstOrDefault(f => f.IsPool);

if (poolField.IsPool)
{
// Generate: _pool.TryGet(out var item) ? item : new T(_pool, ...otherArgs)
var argList = new List<ArgumentSyntax>();
foreach (var field in orderedFields)
{
argList.Add(Argument(field.FieldName.ToIdentifierName()));
}

var newExpression = ObjectCreationExpression(type.TypeSyntax)
.WithArgumentList(ArgumentList(SeparatedList(argList)));

// _pool.TryGet(out var item)
var tryGetCall = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
poolField.FieldName.ToIdentifierName(),
IdentifierName("TryGet")),
ArgumentList(SingletonSeparatedList(
Argument(DeclarationExpression(
IdentifierName("var"),
SingleVariableDesignation(Identifier("item"))))
.WithRefKindKeyword(Token(SyntaxKind.OutKeyword)))));

// Conditional: tryGet ? item : new T(...)
var conditionalExpression = ConditionalExpression(
tryGetCall,
IdentifierName("item"),
newExpression);

return MethodDeclaration(type.TypeSyntax, "Create")
.WithExpressionBody(ArrowExpressionClause(conditionalExpression))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.AddModifiers(Token(SyntaxKind.PublicKeyword));
}

ExpressionSyntax createObject;
if (type.ActivatorConstructorParameters is { Count: > 0 })
{
Expand Down
2 changes: 1 addition & 1 deletion src/Orleans.CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ bool ShouldIncludePrimaryConstructorParameters(INamedTypeSymbol t)
}));
}

var usings = List(new[] { UsingDirective(ParseName("global::Orleans.Serialization.Codecs")), UsingDirective(ParseName("global::Orleans.Serialization.GeneratedCodeHelpers")) });
var usings = List(new[] { UsingDirective(ParseName("global::Orleans.Serialization.Codecs")), UsingDirective(ParseName("global::Orleans.Serialization.GeneratedCodeHelpers")), UsingDirective(ParseName("global::Orleans.Serialization.Invocation")) });
var namespaces = new List<MemberDeclarationSyntax>(_namespacedMembers.Count);
foreach (var pair in _namespacedMembers)
{
Expand Down
96 changes: 92 additions & 4 deletions src/Orleans.CodeGenerator/InvokableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokab

var baseClassType = GetBaseClassType(invokableMethodInfo);
var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo);
var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions);
var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType);

// Create the invokable type syntax for use in field declarations and constructor
var invokableTypeSyntax = CreateInvokableTypeSyntax(generatedClassName, invokableMethodInfo);

var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions, invokableTypeSyntax);
var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType, fieldDescriptions, invokableTypeSyntax);
var accessibility = GetAccessibility(method);
var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo, invokableMethodInfo.Key);

Expand Down Expand Up @@ -588,6 +592,8 @@ private MemberDeclarationSyntax GenerateDisposeMethod(
INamedTypeSymbol baseClassType)
{
var body = new List<StatementSyntax>();
PoolFieldDescription poolField = null;

foreach (var field in fields)
{
if (field is CancellationTokenSourceFieldDescription ctsField)
Expand All @@ -602,6 +608,11 @@ private MemberDeclarationSyntax GenerateDisposeMethod(
MemberBindingExpression(IdentifierName("Dispose"))))));
}

if (field is PoolFieldDescription pf)
{
poolField = pf;
}

if (field.IsInstanceField)
{
body.Add(
Expand All @@ -621,6 +632,19 @@ private MemberDeclarationSyntax GenerateDisposeMethod(
body.Add(ExpressionStatement(InvocationExpression(BaseExpression().Member("Dispose")).WithArgumentList(ArgumentList())));
}

// C# _pool.Return(this); - return to pool at the very end
if (poolField != null)
{
body.Add(
ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(poolField.FieldName),
IdentifierName("Return")),
ArgumentList(SingletonSeparatedList(Argument(ThisExpression()))))));
}

return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Dispose")
.WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)))
.WithBody(Block(body));
Expand Down Expand Up @@ -698,9 +722,28 @@ public static string GetSimpleClassName(InvokableMethodDescription method)
return $"Invokable_{method.ContainingInterface.Name}_{proxyKey}_{method.GeneratedMethodId}{typeArgs}";
}

private static TypeSyntax CreateInvokableTypeSyntax(string generatedClassName, InvokableMethodDescription method)
{
if (method.AllTypeParameters.Count > 0)
{
// Generic invokable: ClassName<TArg0, TArg1, ...>
var typeArguments = method.AllTypeParameters.Select(p =>
(TypeSyntax)IdentifierName(method.TypeParameterSubstitutions[p.Parameter]));
return GenericName(
Identifier(generatedClassName),
TypeArgumentList(SeparatedList(typeArguments)));
}
else
{
// Non-generic invokable: ClassName
return IdentifierName(generatedClassName);
}
}

private MemberDeclarationSyntax[] GetFieldDeclarations(
InvokableMethodDescription method,
List<InvokerFieldDescription> fieldDescriptions)
List<InvokerFieldDescription> fieldDescriptions,
TypeSyntax invokableTypeSyntax)
{
return fieldDescriptions.Select(GetFieldDeclaration).ToArray();

Expand Down Expand Up @@ -728,6 +771,18 @@ MemberDeclarationSyntax GetFieldDeclaration(InvokerFieldDescription description)
}))))))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
}
else if (description is PoolFieldDescription)
{
// Pool field: InvokablePool<ThisInvokableType>
var poolType = GenericName(
Identifier("InvokablePool"),
TypeArgumentList(SingletonSeparatedList(invokableTypeSyntax)));
field = FieldDeclaration(
VariableDeclaration(
poolType,
SingletonSeparatedList(VariableDeclarator(description.FieldName))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword));
}
else
{
field = FieldDeclaration(
Expand Down Expand Up @@ -758,14 +813,33 @@ private ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnume
private (ConstructorDeclarationSyntax Constructor, List<TypeSyntax> ConstructorArguments) GenerateConstructor(
string simpleClassName,
InvokableMethodDescription method,
INamedTypeSymbol baseClassType)
INamedTypeSymbol baseClassType,
List<InvokerFieldDescription> fieldDescriptions,
TypeSyntax invokableTypeSyntax)
{
var parameters = new List<ParameterSyntax>();

var body = new List<StatementSyntax>();

List<TypeSyntax> constructorArgumentTypes = new();
List<ArgumentSyntax> baseConstructorArguments = new();

// For non-generic methods, add pool parameter first
var poolField = fieldDescriptions.OfType<PoolFieldDescription>().FirstOrDefault();
if (poolField != null)
{
var poolType = GenericName(
Identifier("InvokablePool"),
TypeArgumentList(SingletonSeparatedList(invokableTypeSyntax)));
constructorArgumentTypes.Add(poolType);
parameters.Add(Parameter(Identifier("pool")).WithType(poolType));
body.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(poolField.FieldName),
IdentifierName("pool"))));
}

foreach (var constructor in baseClassType.GetAllMembers<IMethodSymbol>())
{
if (constructor.MethodKind != MethodKind.Constructor || constructor.DeclaredAccessibility == Accessibility.Private || constructor.IsImplicitlyDeclared)
Expand Down Expand Up @@ -831,6 +905,12 @@ private List<InvokerFieldDescription> GetFieldDescriptions(InvokableMethodDescri
fields.Add(new CancellationTokenSourceFieldDescription(LibraryTypes));
}

// Add pool field for non-generic methods (generic methods can't use pooling)
if (method.MethodTypeParameters.Count == 0)
{
fields.Add(new PoolFieldDescription(LibraryTypes));
}

return fields;
}

Expand Down Expand Up @@ -940,5 +1020,13 @@ public MethodInfoFieldDescription(ITypeSymbol fieldType, string fieldName) : bas
public override bool IsSerializable => false;
public override bool IsInstanceField => false;
}

internal sealed class PoolFieldDescription : InvokerFieldDescription
{
public PoolFieldDescription(LibraryTypes libraryTypes) : base(libraryTypes.InvokablePool_1, "_pool") { }

public override bool IsSerializable => false;
public override bool IsInstanceField => false; // Assigned via constructor, not reset on dispose
}
}
}
2 changes: 2 additions & 0 deletions src/Orleans.CodeGenerator/LibraryTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
GenerateSerializerAttribute = Type("Orleans.GenerateSerializerAttribute");
SerializationCallbacksAttribute = Type("Orleans.SerializationCallbacksAttribute");
IActivator_1 = Type("Orleans.Serialization.Activators.IActivator`1");
InvokablePool_1 = Type("Orleans.Serialization.Invocation.InvokablePool`1");
IBufferWriter = Type("System.Buffers.IBufferWriter`1");
IdAttributeType = Type(CodeGeneratorOptions.IdAttribute);
ConstructorAttributeTypes = CodeGeneratorOptions.ConstructorAttributes.Select(Type).ToArray();
Expand Down Expand Up @@ -215,6 +216,7 @@ INamedTypeSymbol Type(string metadataName)
public INamedTypeSymbol GenerateMethodSerializersAttribute { get; private set; }
public INamedTypeSymbol GenerateSerializerAttribute { get; private set; }
public INamedTypeSymbol IActivator_1 { get; private set; }
public INamedTypeSymbol InvokablePool_1 { get; private set; }
public INamedTypeSymbol IBufferWriter { get; private set; }
public INamedTypeSymbol IInvokable { get; private set; }
public INamedTypeSymbol ITargetHolder { get; private set; }
Expand Down
Loading
Loading