diff --git a/Directory.Build.targets b/Directory.Build.targets
index dbcbf3e886f..c0cbba7f573 100644
--- a/Directory.Build.targets
+++ b/Directory.Build.targets
@@ -14,6 +14,7 @@
diff --git a/src/Orleans.CodeGenerator/ActivatorGenerator.cs b/src/Orleans.CodeGenerator/ActivatorGenerator.cs
index 7acb8d8757d..22180717402 100644
--- a/src/Orleans.CodeGenerator/ActivatorGenerator.cs
+++ b/src/Orleans.CodeGenerator/ActivatorGenerator.cs
@@ -2,128 +2,137 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-using System.Collections.Generic;
-namespace Orleans.CodeGenerator
+namespace Orleans.CodeGenerator;
+
+internal class ActivatorGenerator(IGeneratorServices generatorServices)
{
- internal class ActivatorGenerator
+ private readonly IGeneratorServices _generatorServices = generatorServices;
+
+ private struct ConstructorArgument
+ {
+ public TypeSyntax Type { get; set; }
+ public string FieldName { get; set; }
+ public string ParameterName { get; set; }
+ }
+
+ public ClassDeclarationSyntax GenerateActivator(ISerializableTypeDescription type)
{
- private readonly CodeGenerator _codeGenerator;
+ var simpleClassName = GetSimpleClassName(type);
+
+ var baseInterface = _generatorServices.LibraryTypes.IActivator_1.ToTypeSyntax(type.TypeSyntax);
- private struct ConstructorArgument
+ var orderedFields = new List();
+ var index = 0;
+ if (type.ActivatorConstructorParameters is { Count: > 0 } parameters)
{
- public TypeSyntax Type { get; set; }
- public string FieldName { get; set; }
- public string ParameterName { get; set; }
+ foreach (var arg in parameters)
+ {
+ orderedFields.Add(new ConstructorArgument { Type = arg, FieldName = $"_arg{index}", ParameterName = $"arg{index}" });
+ index++;
+ }
}
- public ActivatorGenerator(CodeGenerator codeGenerator)
+ var members = new List();
+ foreach (var field in orderedFields)
{
- _codeGenerator = codeGenerator;
+ members.Add(
+ FieldDeclaration(VariableDeclaration(field.Type, SingletonSeparatedList(VariableDeclarator(field.FieldName))))
+ .AddModifiers(
+ Token(SyntaxKind.PrivateKeyword),
+ Token(SyntaxKind.ReadOnlyKeyword)));
}
- public ClassDeclarationSyntax GenerateActivator(ISerializableTypeDescription type)
- {
- var simpleClassName = GetSimpleClassName(type);
+ if (orderedFields.Count > 0)
+ members.Add(GenerateConstructor(simpleClassName, orderedFields));
- var baseInterface = _codeGenerator.LibraryTypes.IActivator_1.ToTypeSyntax(type.TypeSyntax);
+ members.Add(GenerateCreateMethod(type, orderedFields));
- var orderedFields = new List();
- var index = 0;
- if (type.ActivatorConstructorParameters is { Count: > 0 } parameters)
- {
- foreach (var arg in parameters)
- {
- orderedFields.Add(new ConstructorArgument { Type = arg, FieldName = $"_arg{index}", ParameterName = $"arg{index}" });
- index++;
- }
- }
+ var classDeclaration = ClassDeclaration(simpleClassName)
+ .AddBaseListTypes(SimpleBaseType(baseInterface))
+ .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword))
+ .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes())
+ .AddMembers([.. members]);
- var members = new List();
- foreach (var field in orderedFields)
- {
- members.Add(
- FieldDeclaration(VariableDeclaration(field.Type, SingletonSeparatedList(VariableDeclarator(field.FieldName))))
- .AddModifiers(
- Token(SyntaxKind.PrivateKeyword),
- Token(SyntaxKind.ReadOnlyKeyword)));
- }
-
- if (orderedFields.Count > 0)
- members.Add(GenerateConstructor(simpleClassName, orderedFields));
+ if (type.IsGenericType)
+ {
+ classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters);
+ }
- members.Add(GenerateCreateMethod(type, orderedFields));
+ return classDeclaration;
+ }
- var classDeclaration = ClassDeclaration(simpleClassName)
- .AddBaseListTypes(SimpleBaseType(baseInterface))
- .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword))
- .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes())
- .AddMembers(members.ToArray());
+ public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name);
- if (type.IsGenericType)
- {
- classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters);
- }
+ public static string GetSimpleClassName(string name) => $"Activator_{name}";
- return classDeclaration;
- }
-
- public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => $"Activator_{serializableType.Name}";
+ ///
+ /// Determines whether an activator should be generated for the specified type.
+ ///
+ internal static bool ShouldGenerateActivator(ISerializableTypeDescription type)
+ {
+ return !type.IsAbstractType
+ && !type.IsEnumType
+ && (!type.IsValueType
+ && type.IsEmptyConstructable
+ && !type.UseActivator
+ && type is not GeneratedInvokableDescription
+ || type.HasActivatorConstructor);
+ }
- private ConstructorDeclarationSyntax GenerateConstructor(
- string simpleClassName,
- List orderedFields)
+ private static ConstructorDeclarationSyntax GenerateConstructor(
+ string simpleClassName,
+ List orderedFields)
+ {
+ var parameters = new List();
+ var body = new List();
+ foreach (var field in orderedFields)
{
- var parameters = new List();
- var body = new List();
- foreach (var field in orderedFields)
- {
- parameters.Add(Parameter(field.ParameterName.ToIdentifier()).WithType(field.Type));
+ parameters.Add(Parameter(field.ParameterName.ToIdentifier()).WithType(field.Type));
- body.Add(ExpressionStatement(
- AssignmentExpression(
- SyntaxKind.SimpleAssignmentExpression,
- field.FieldName.ToIdentifierName(),
- Unwrapped(field.ParameterName.ToIdentifierName()))));
- }
+ body.Add(ExpressionStatement(
+ AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ field.FieldName.ToIdentifierName(),
+ Unwrapped(field.ParameterName.ToIdentifierName()))));
+ }
- var constructorDeclaration = ConstructorDeclaration(simpleClassName)
- .AddModifiers(Token(SyntaxKind.PublicKeyword))
- .AddParameterListParameters(parameters.ToArray())
- .AddBodyStatements(body.ToArray());
+ var constructorDeclaration = ConstructorDeclaration(simpleClassName)
+ .AddModifiers(Token(SyntaxKind.PublicKeyword))
+ .AddParameterListParameters([.. parameters])
+ .AddBodyStatements([.. body]);
- return constructorDeclaration;
+ return constructorDeclaration;
- static ExpressionSyntax Unwrapped(ExpressionSyntax expr)
- {
- return InvocationExpression(
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), IdentifierName("UnwrapService")),
- ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(expr) })));
- }
+ static ExpressionSyntax Unwrapped(ExpressionSyntax expr)
+ {
+ return InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), IdentifierName("UnwrapService")),
+ ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(expr)])));
}
+ }
- private MemberDeclarationSyntax GenerateCreateMethod(ISerializableTypeDescription type, List orderedFields)
+ private static MemberDeclarationSyntax GenerateCreateMethod(ISerializableTypeDescription type, List orderedFields)
+ {
+ ExpressionSyntax createObject;
+ if (type.ActivatorConstructorParameters is { Count: > 0 })
{
- ExpressionSyntax createObject;
- if (type.ActivatorConstructorParameters is { Count: > 0 })
- {
- var argList = new List();
- foreach (var field in orderedFields)
- {
- argList.Add(Argument(field.FieldName.ToIdentifierName()));
- }
-
- createObject = ObjectCreationExpression(type.TypeSyntax).WithArgumentList(ArgumentList(SeparatedList(argList)));
- }
- else
+ var argList = new List();
+ foreach (var field in orderedFields)
{
- createObject = type.GetObjectCreationExpression();
+ argList.Add(Argument(field.FieldName.ToIdentifierName()));
}
- return MethodDeclaration(type.TypeSyntax, "Create")
- .WithExpressionBody(ArrowExpressionClause(createObject))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
- .AddModifiers(Token(SyntaxKind.PublicKeyword));
+ createObject = ObjectCreationExpression(type.TypeSyntax).WithArgumentList(ArgumentList(SeparatedList(argList)));
}
+ else
+ {
+ createObject = type.GetObjectCreationExpression();
+ }
+
+ return MethodDeclaration(type.TypeSyntax, "Create")
+ .WithExpressionBody(ArrowExpressionClause(createObject))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .AddModifiers(Token(SyntaxKind.PublicKeyword));
}
-}
\ No newline at end of file
+}
diff --git a/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs b/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs
index 5b40dc120fd..1a8be26ef38 100644
--- a/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs
+++ b/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs
@@ -1,29 +1,27 @@
-using System.Collections.Generic;
-using Orleans.CodeGenerator.SyntaxGeneration;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Orleans.CodeGenerator.SyntaxGeneration;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-namespace Orleans.CodeGenerator
+namespace Orleans.CodeGenerator;
+
+internal static class ApplicationPartAttributeGenerator
{
- internal static class ApplicationPartAttributeGenerator
+ public static List GenerateSyntax(NameSyntax applicationPartAttribute, IEnumerable applicationParts)
{
- public static List GenerateSyntax(LibraryTypes wellKnownTypes, MetadataModel model)
- {
- var attributes = new List();
+ var attributes = new List();
- foreach (var assemblyName in model.ApplicationParts)
- {
- // Generate an assembly-level attribute with an instance of that class.
- var attribute = AttributeList(
- AttributeTargetSpecifier(Token(SyntaxKind.AssemblyKeyword)),
- SingletonSeparatedList(
- Attribute(wellKnownTypes.ApplicationPartAttribute.ToNameSyntax())
- .AddArgumentListArguments(AttributeArgument(assemblyName.GetLiteralExpression()))));
- attributes.Add(attribute);
- }
-
- return attributes;
+ foreach (var assemblyName in applicationParts)
+ {
+ // Generate an assembly-level attribute with an instance of that class.
+ var attribute = AttributeList(
+ AttributeTargetSpecifier(Token(SyntaxKind.AssemblyKeyword)),
+ SingletonSeparatedList(
+ Attribute(applicationPartAttribute)
+ .AddArgumentListArguments(AttributeArgument(assemblyName.GetLiteralExpression()))));
+ attributes.Add(attribute);
}
+
+ return attributes;
}
}
diff --git a/src/Orleans.CodeGenerator/CodeGenerator.cs b/src/Orleans.CodeGenerator/CodeGenerator.cs
deleted file mode 100644
index 93091f3839b..00000000000
--- a/src/Orleans.CodeGenerator/CodeGenerator.cs
+++ /dev/null
@@ -1,829 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Linq;
-using System.Text;
-using System.Threading;
-using Microsoft.CodeAnalysis;
-using Microsoft.CodeAnalysis.CSharp;
-using Microsoft.CodeAnalysis.CSharp.Syntax;
-using Orleans.CodeGenerator.Diagnostics;
-using Orleans.CodeGenerator.Hashing;
-using Orleans.CodeGenerator.Model;
-using Orleans.CodeGenerator.SyntaxGeneration;
-using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-using static Orleans.CodeGenerator.SyntaxGeneration.SymbolExtensions;
-
-#nullable disable
-namespace Orleans.CodeGenerator
-{
- public class CodeGeneratorOptions
- {
- public const string IdAttribute = "Orleans.IdAttribute";
- public const string AliasAttribute = "Orleans.AliasAttribute";
- public const string ImmutableAttribute = "Orleans.ImmutableAttribute";
- public static readonly IReadOnlyList ConstructorAttributes = ["Orleans.OrleansConstructorAttribute", "Microsoft.Extensions.DependencyInjection.ActivatorUtilitiesConstructorAttribute"];
- public GenerateFieldIds GenerateFieldIds { get; set; }
- public bool GenerateCompatibilityInvokers { get; set; }
- }
-
- public class CodeGenerator
- {
- internal const string CodeGeneratorName = "OrleansCodeGen";
- private readonly Dictionary> _namespacedMembers = new();
- private readonly Dictionary _invokableMethodDescriptions = new();
- private readonly HashSet _visitedInterfaces = new(SymbolEqualityComparer.Default);
- private readonly List DisabledWarnings = new() { "CS1591", "RS0016", "RS0041" };
-
- public CodeGenerator(Compilation compilation, CodeGeneratorOptions options)
- {
- Compilation = compilation;
- Options = options;
- LibraryTypes = LibraryTypes.FromCompilation(compilation, options);
- MetadataModel = new MetadataModel();
- CopierGenerator = new CopierGenerator(this);
- SerializerGenerator = new SerializerGenerator(this);
- ProxyGenerator = new ProxyGenerator(this);
- InvokableGenerator = new InvokableGenerator(this);
- MetadataGenerator = new MetadataGenerator(this);
- ActivatorGenerator = new ActivatorGenerator(this);
- }
-
- public Compilation Compilation { get; }
- public CodeGeneratorOptions Options { get; }
- internal LibraryTypes LibraryTypes { get; }
- internal MetadataModel MetadataModel { get; }
- internal CopierGenerator CopierGenerator { get; }
- internal SerializerGenerator SerializerGenerator { get; }
- internal ProxyGenerator ProxyGenerator { get; }
- internal InvokableGenerator InvokableGenerator { get; }
- internal MetadataGenerator MetadataGenerator { get; }
- internal ActivatorGenerator ActivatorGenerator { get; }
-
- public CompilationUnitSyntax GenerateCode(CancellationToken cancellationToken)
- {
- var assembliesToExamine = new HashSet(SymbolEqualityComparer.Default);
- var compilationAsm = LibraryTypes.Compilation.Assembly;
- ComputeAssembliesToExamine(compilationAsm, assembliesToExamine);
-
- // Expand the set of referenced assemblies
- MetadataModel.ApplicationParts.Add(compilationAsm.MetadataName);
- foreach (var reference in LibraryTypes.Compilation.References)
- {
- if (LibraryTypes.Compilation.GetAssemblyOrModuleSymbol(reference) is not IAssemblySymbol asm)
- {
- continue;
- }
-
- if (asm.GetAttributes(LibraryTypes.ApplicationPartAttribute, out var attrs))
- {
- MetadataModel.ApplicationParts.Add(asm.MetadataName);
- foreach (var attr in attrs)
- {
- MetadataModel.ApplicationParts.Add((string)attr.ConstructorArguments.First().Value);
- }
- }
- }
-
- // The mapping of proxy base types to a mapping of return types to invokable base types. Used to set default invokable base types for each proxy base type.
- var proxyBaseTypeInvokableBaseTypes = new Dictionary>(SymbolEqualityComparer.Default);
-
- foreach (var asm in assembliesToExamine)
- {
- var containingAssemblyAttributes = asm.GetAttributes();
-
- foreach (var symbol in asm.GetDeclaredTypes())
- {
- if (GetWellKnownTypeId(symbol) is uint wellKnownTypeId)
- {
- MetadataModel.WellKnownTypeIds.Add((symbol.ToOpenTypeSyntax(), wellKnownTypeId));
- }
-
- if (GetAlias(symbol) is string typeAlias)
- {
- MetadataModel.TypeAliases.Add((symbol.ToOpenTypeSyntax(), typeAlias));
- }
-
- if (GetCompoundTypeAlias(symbol) is CompoundTypeAliasComponent[] compoundTypeAlias)
- {
- MetadataModel.CompoundTypeAliases.Add(compoundTypeAlias, symbol.ToOpenTypeSyntax());
- }
-
- if (FSharpUtilities.IsUnionCase(LibraryTypes, symbol, out var sumType) && ShouldGenerateSerializer(sumType))
- {
- if (!Compilation.IsSymbolAccessibleWithin(sumType, Compilation.Assembly))
- {
- throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(sumType));
- }
-
- var typeDescription = new FSharpUtilities.FSharpUnionCaseTypeDescription(Compilation, symbol, LibraryTypes);
- MetadataModel.SerializableTypes.Add(typeDescription);
- }
- else if (ShouldGenerateSerializer(symbol))
- {
- // https://learn.microsoft.com/dotnet/api/system.runtime.compilerservices.referenceassemblyattribute
- if (containingAssemblyAttributes.Any(attributeData => attributeData.AttributeClass is
- {
- Name: "ReferenceAssemblyAttribute",
- ContainingNamespace:
- {
- Name: "CompilerServices",
- ContainingNamespace:
- {
- Name: "Runtime",
- ContainingNamespace:
- {
- Name: "System",
- ContainingNamespace.IsGlobalNamespace: true
- }
- }
- }
- }))
- {
- // not ALWAYS will be properly processed, therefore emit a warning
- throw new OrleansGeneratorDiagnosticAnalysisException(ReferenceAssemblyWithGenerateSerializerDiagnostic.CreateDiagnostic(symbol));
- }
-
- if (!Compilation.IsSymbolAccessibleWithin(symbol, Compilation.Assembly))
- {
- throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(symbol));
- }
-
- if (FSharpUtilities.IsRecord(LibraryTypes, symbol))
- {
- var typeDescription = new FSharpUtilities.FSharpRecordTypeDescription(Compilation, symbol, LibraryTypes);
- MetadataModel.SerializableTypes.Add(typeDescription);
- }
- else
- {
- // Regular type
- var includePrimaryConstructorParameters = ShouldIncludePrimaryConstructorParameters(symbol);
- var constructorParameters = ImmutableArray.Empty;
- if (includePrimaryConstructorParameters)
- {
- if (symbol.IsRecord)
- {
- // If there is a primary constructor then that will be declared before the copy constructor
- // A record always generates a copy constructor and marks it as compiler generated
- // todo: find an alternative to this magic
- var potentialPrimaryConstructor = symbol.Constructors[0];
- if (!potentialPrimaryConstructor.IsImplicitlyDeclared && !potentialPrimaryConstructor.IsCompilerGenerated())
- {
- constructorParameters = potentialPrimaryConstructor.Parameters;
- }
- }
- else
- {
- var annotatedConstructors = symbol.Constructors.Where(ctor => ctor.HasAnyAttribute(LibraryTypes.ConstructorAttributeTypes)).ToList();
- if (annotatedConstructors.Count == 1)
- {
- constructorParameters = annotatedConstructors[0].Parameters;
- }
- else
- {
- // record structs from referenced assemblies do not return IsRecord=true
- // above. See https://github.com/dotnet/roslyn/issues/69326
- // So we implement the same heuristics from ShouldIncludePrimaryConstructorParameters
- // to detect a primary constructor.
- var properties = symbol.GetMembers().OfType().ToImmutableArray();
- var primaryConstructor = symbol.GetMembers()
- .OfType()
- .Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0)
- // Check for a ctor where all parameters have a corresponding compiler-generated prop.
- .FirstOrDefault(ctor => ctor.Parameters.All(prm =>
- properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated())));
-
- if (primaryConstructor != null)
- constructorParameters = primaryConstructor.Parameters;
- }
- }
- }
-
- var implicitMemberSelectionStrategy = (Options.GenerateFieldIds, GetGenerateFieldIdsOptionFromType(symbol)) switch
- {
- (_, GenerateFieldIds.PublicProperties) => GenerateFieldIds.PublicProperties,
- (GenerateFieldIds.PublicProperties, _) => GenerateFieldIds.PublicProperties,
- _ => GenerateFieldIds.None
- };
- var fieldIdAssignmentHelper = new FieldIdAssignmentHelper(symbol, constructorParameters, implicitMemberSelectionStrategy, LibraryTypes);
- if (!fieldIdAssignmentHelper.IsValidForSerialization)
- {
- throw new OrleansGeneratorDiagnosticAnalysisException(CanNotGenerateImplicitFieldIdsDiagnostic.CreateDiagnostic(symbol, fieldIdAssignmentHelper.FailureReason));
- }
-
- var typeDescription = new SerializableTypeDescription(Compilation, symbol, includePrimaryConstructorParameters, GetDataMembers(fieldIdAssignmentHelper), LibraryTypes);
- MetadataModel.SerializableTypes.Add(typeDescription);
- }
- }
-
- if (symbol.TypeKind == TypeKind.Interface)
- {
- VisitInterface(symbol.OriginalDefinition);
- }
-
- if ((symbol.TypeKind == TypeKind.Class || symbol.TypeKind == TypeKind.Struct)
- && !symbol.IsAbstract
- && (symbol.DeclaredAccessibility == Accessibility.Public || symbol.DeclaredAccessibility == Accessibility.Internal))
- {
- if (symbol.HasAttribute(LibraryTypes.RegisterSerializerAttribute))
- {
- MetadataModel.DetectedSerializers.Add(symbol);
- }
-
- if (symbol.HasAttribute(LibraryTypes.RegisterActivatorAttribute))
- {
- MetadataModel.DetectedActivators.Add(symbol);
- }
-
- if (symbol.HasAttribute(LibraryTypes.RegisterCopierAttribute))
- {
- MetadataModel.DetectedCopiers.Add(symbol);
- }
-
- if (symbol.HasAttribute(LibraryTypes.RegisterConverterAttribute))
- {
- MetadataModel.DetectedConverters.Add(symbol);
- }
-
- // Find all implementations of invokable interfaces
- foreach (var iface in symbol.AllInterfaces)
- {
- var attribute = iface.GetAttribute(
- LibraryTypes.GenerateMethodSerializersAttribute,
- inherited: true);
- if (attribute != null)
- {
- MetadataModel.InvokableInterfaceImplementations.Add(symbol);
- break;
- }
- }
- }
-
- GenerateFieldIds GetGenerateFieldIdsOptionFromType(INamedTypeSymbol t)
- {
- var attribute = t.GetAttribute(LibraryTypes.GenerateSerializerAttribute);
- if (attribute is null)
- return GenerateFieldIds.None;
-
- foreach (var namedArgument in attribute.NamedArguments)
- {
- if (namedArgument.Key == "GenerateFieldIds")
- {
- var value = namedArgument.Value.Value;
- return value == null ? GenerateFieldIds.None : (GenerateFieldIds)(int)value;
- }
- }
- return GenerateFieldIds.None;
- }
-
- bool ShouldGenerateSerializer(INamedTypeSymbol t) => t.HasAttribute(LibraryTypes.GenerateSerializerAttribute);
-
- bool ShouldIncludePrimaryConstructorParameters(INamedTypeSymbol t)
- {
- static bool? TestGenerateSerializerAttribute(INamedTypeSymbol t, INamedTypeSymbol at)
- {
- var attribute = t.GetAttribute(at);
- if (attribute != null)
- {
- foreach (var namedArgument in attribute.NamedArguments)
- {
- if (namedArgument.Key == "IncludePrimaryConstructorParameters")
- {
- if (namedArgument.Value.Kind == TypedConstantKind.Primitive && namedArgument.Value.Value is bool b)
- {
- return b;
- }
- }
- }
- }
-
- // If there is no such named argument, return null so that other attributes have a chance to apply and defaults can be applied.
- return null;
- }
-
- if (TestGenerateSerializerAttribute(t, LibraryTypes.GenerateSerializerAttribute) is bool res)
- {
- return res;
- }
-
- // Default to true for records.
- if (t.IsRecord)
- return true;
-
- var properties = t.GetMembers().OfType().ToImmutableArray();
-
- return t.GetMembers()
- .OfType()
- .Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0)
- // Check for a ctor where all parameters have a corresponding compiler-generated prop.
- .Any(ctor => ctor.Parameters.All(prm =>
- properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated())));
- }
- }
- }
-
- // Generate serializers.
- foreach (var type in MetadataModel.SerializableTypes)
- {
- string ns = type.GeneratedNamespace;
-
- // Generate a partial serializer class for each serializable type.
- var serializer = SerializerGenerator.Generate(type);
- AddMember(ns, serializer);
-
- // Generate a copier for each serializable type.
- if (CopierGenerator.GenerateCopier(type, MetadataModel.DefaultCopiers) is { } copier)
- AddMember(ns, copier);
-
- if (!type.IsAbstractType && !type.IsEnumType && (!type.IsValueType && type.IsEmptyConstructable && !type.UseActivator && type is not GeneratedInvokableDescription || type.HasActivatorConstructor))
- {
- MetadataModel.ActivatableTypes.Add(type);
-
- // Generate an activator class for types with default constructor or activator constructor.
- var activator = ActivatorGenerator.GenerateActivator(type);
- AddMember(ns, activator);
- }
- }
-
- // Generate metadata.
- var metadataClassNamespace = CodeGeneratorName + "." + SyntaxGeneration.Identifier.SanitizeIdentifierName(Compilation.AssemblyName);
- var metadataClass = MetadataGenerator.GenerateMetadata();
- AddMember(ns: metadataClassNamespace, member: metadataClass);
- var metadataAttribute = AttributeList()
- .WithTarget(AttributeTargetSpecifier(Token(SyntaxKind.AssemblyKeyword)))
- .WithAttributes(
- SingletonSeparatedList(
- Attribute(LibraryTypes.TypeManifestProviderAttribute.ToNameSyntax())
- .AddArgumentListArguments(AttributeArgument(TypeOfExpression(QualifiedName(IdentifierName(metadataClassNamespace), IdentifierName(metadataClass.Identifier.Text)))))));
-
- var assemblyAttributes = ApplicationPartAttributeGenerator.GenerateSyntax(LibraryTypes, MetadataModel);
- assemblyAttributes.Add(metadataAttribute);
-
- if (assemblyAttributes.Count > 0)
- {
- assemblyAttributes[0] = assemblyAttributes[0]
- .WithLeadingTrivia(
- SyntaxFactory.TriviaList(
- new List
- {
- Trivia(
- PragmaWarningDirectiveTrivia(
- Token(SyntaxKind.DisableKeyword),
- SeparatedList(DisabledWarnings.Select(str =>
- {
- var syntaxToken = SyntaxFactory.Literal(
- SyntaxFactory.TriviaList(),
- str,
- str,
- SyntaxFactory.TriviaList());
-
- return (ExpressionSyntax)SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, syntaxToken);
- })),
- isActive: true)),
- }));
- }
-
- var usings = List(new[] { UsingDirective(ParseName("global::Orleans.Serialization.Codecs")), UsingDirective(ParseName("global::Orleans.Serialization.GeneratedCodeHelpers")) });
- var namespaces = new List(_namespacedMembers.Count);
- foreach (var pair in _namespacedMembers)
- {
- var ns = pair.Key;
- var member = pair.Value;
-
- namespaces.Add(NamespaceDeclaration(ParseName(ns)).WithMembers(List(member)).WithUsings(usings));
- }
-
- if (namespaces.Count > 0)
- {
- namespaces[namespaces.Count - 1] = namespaces[namespaces.Count - 1]
- .WithTrailingTrivia(
- SyntaxFactory.TriviaList(
- new List
- {
- Trivia(
- PragmaWarningDirectiveTrivia(
- Token(SyntaxKind.RestoreKeyword),
- SeparatedList(DisabledWarnings.Select(str =>
- {
- var syntaxToken = SyntaxFactory.Literal(
- SyntaxFactory.TriviaList(),
- str,
- str,
- SyntaxFactory.TriviaList());
-
- return (ExpressionSyntax)SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, syntaxToken);
- })),
- isActive: true)),
- }));
- }
-
- return CompilationUnit()
- .WithAttributeLists(List(assemblyAttributes))
- .WithMembers(List(namespaces));
- }
-
- public static string GetGeneratedNamespaceName(ITypeSymbol type) => type.GetNamespaceAndNesting() switch
- {
- { Length: > 0 } ns => $"{CodeGeneratorName}.{ns}",
- _ => CodeGeneratorName
- };
-
- public void AddMember(string ns, MemberDeclarationSyntax member)
- {
- if (!_namespacedMembers.TryGetValue(ns, out var existing))
- {
- existing = _namespacedMembers[ns] = new List();
- }
-
- existing.Add(member);
- }
-
- private void ComputeAssembliesToExamine(IAssemblySymbol asm, HashSet expandedAssemblies)
- {
- if (!expandedAssemblies.Add(asm))
- {
- return;
- }
-
- if (!asm.GetAttributes(LibraryTypes.GenerateCodeForDeclaringAssemblyAttribute, out var attrs)) return;
-
- foreach (var attr in attrs)
- {
- var param = attr.ConstructorArguments.First();
- if (param.Kind != TypedConstantKind.Type)
- {
- throw new ArgumentException($"Unrecognized argument type in attribute [{attr.AttributeClass.Name}({param.ToCSharpString()})]");
- }
-
- var type = (ITypeSymbol)param.Value;
-
- // Recurse on the assemblies which the type was declared in.
- var declaringAsm = type.OriginalDefinition.ContainingAssembly;
- if (declaringAsm is null)
- {
- var diagnostic = GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.CreateDiagnostic(attr, type);
- throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic);
- }
- else
- {
- ComputeAssembliesToExamine(declaringAsm, expandedAssemblies);
- }
- }
- }
-
- // Returns descriptions of all data members (fields and properties)
- private static IEnumerable GetDataMembers(FieldIdAssignmentHelper fieldIdAssignmentHelper)
- {
- var members = new Dictionary<(uint, bool), IMemberDescription>();
-
- foreach (var member in fieldIdAssignmentHelper.Members)
- {
- if (!fieldIdAssignmentHelper.TryGetSymbolKey(member, out var key))
- continue;
- var (id, isConstructorParameter) = key;
-
- // FieldDescription takes precedence over PropertyDescription (never replace)
- if (member is IPropertySymbol property && !members.TryGetValue((id, isConstructorParameter), out _))
- {
- members[(id, isConstructorParameter)] = new PropertyDescription(id, isConstructorParameter, property);
- }
-
- if (member is IFieldSymbol field)
- {
- // FieldDescription takes precedence over PropertyDescription (add or replace)
- if (!members.TryGetValue((id, isConstructorParameter), out var existing) || existing is PropertyDescription)
- {
- members[(id, isConstructorParameter)] = new FieldDescription(id, isConstructorParameter, field);
- }
- }
- }
- return members.Values;
- }
-
- public uint? GetId(ISymbol memberSymbol) => GetId(LibraryTypes, memberSymbol);
-
- internal static uint? GetId(LibraryTypes libraryTypes, ISymbol memberSymbol)
- {
- return memberSymbol.GetAttribute(libraryTypes.IdAttributeType) is { } attr
- ? (uint)attr.ConstructorArguments.First().Value
- : null;
- }
-
- internal static string CreateHashedMethodId(IMethodSymbol methodSymbol)
- {
- var methodSignature = Format(methodSymbol);
- var hash = XxHash32.Hash(Encoding.UTF8.GetBytes(methodSignature));
- return $"{HexConverter.ToString(hash)}";
-
- static string Format(IMethodSymbol methodInfo)
- {
- var result = new StringBuilder();
- result.Append(methodInfo.ContainingType.ToDisplayName());
- result.Append('.');
- result.Append(methodInfo.Name);
-
- if (methodInfo.IsGenericMethod)
- {
- result.Append('<');
- var first = true;
- foreach (var typeArgument in methodInfo.TypeArguments)
- {
- if (!first) result.Append(',');
- else first = false;
- result.Append(typeArgument.Name);
- }
-
- result.Append('>');
- }
-
- {
- result.Append('(');
- var parameters = methodInfo.Parameters;
- var first = true;
- foreach (var parameter in parameters)
- {
- if (!first)
- {
- result.Append(',');
- }
-
- var parameterType = parameter.Type;
- switch (parameterType)
- {
- case ITypeParameterSymbol _:
- result.Append(parameterType.Name);
- break;
- default:
- result.Append(parameterType.ToDisplayName());
- break;
- }
-
- first = false;
- }
- }
-
- result.Append(')');
- return result.ToString();
- }
- }
-
- private uint? GetWellKnownTypeId(ISymbol symbol) => GetId(symbol);
-
- public string GetAlias(ISymbol symbol) => (string)symbol.GetAttribute(LibraryTypes.AliasAttribute)?.ConstructorArguments.First().Value;
-
- private CompoundTypeAliasComponent[] GetCompoundTypeAlias(ISymbol symbol)
- {
- var attr = symbol.GetAttribute(LibraryTypes.CompoundTypeAliasAttribute);
- if (attr is null)
- {
- return null;
- }
-
- var allArgs = attr.ConstructorArguments;
- if (allArgs.Length != 1 || allArgs[0].Values.Length == 0)
- {
- throw new ArgumentException($"Unsupported arguments in attribute [{attr.AttributeClass.Name}({string.Join(", ", allArgs.Select(a => a.ToCSharpString()))})]");
- }
-
- var args = allArgs[0].Values;
- var result = new CompoundTypeAliasComponent[args.Length];
- for (var i = 0; i < args.Length; i++)
- {
- var arg = args[i];
- if (arg.IsNull)
- {
- throw new ArgumentNullException($"Unsupported null argument in attribute [{attr.AttributeClass.Name}({string.Join(", ", allArgs.Select(a => a.ToCSharpString()))})]");
- }
-
- result[i] = arg.Value switch
- {
- ITypeSymbol type => new CompoundTypeAliasComponent(type),
- string str => new CompoundTypeAliasComponent(str),
- _ => throw new ArgumentException($"Unrecognized argument type for argument {arg.ToCSharpString()} in attribute [{attr.AttributeClass.Name}({string.Join(", ", allArgs.Select(a => a.ToCSharpString()))})]"),
- };
- }
-
- return result;
- }
-
- internal static AttributeListSyntax GetGeneratedCodeAttributes() => GeneratedCodeAttributeSyntax;
- private static readonly AttributeListSyntax GeneratedCodeAttributeSyntax =
- AttributeList().AddAttributes(
- Attribute(ParseName("global::System.CodeDom.Compiler.GeneratedCodeAttribute"))
- .AddArgumentListArguments(
- AttributeArgument(CodeGeneratorName.GetLiteralExpression()),
- AttributeArgument(typeof(CodeGenerator).Assembly.GetName().Version.ToString().GetLiteralExpression())),
- Attribute(ParseName("global::System.ComponentModel.EditorBrowsableAttribute"))
- .AddArgumentListArguments(
- AttributeArgument(ParseName("global::System.ComponentModel.EditorBrowsableState").Member("Never"))),
- Attribute(ParseName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute"))
- );
-
- internal static AttributeSyntax GetMethodImplAttributeSyntax() => MethodImplAttributeSyntax;
- private static readonly AttributeSyntax MethodImplAttributeSyntax =
- Attribute(ParseName("global::System.Runtime.CompilerServices.MethodImplAttribute"))
- .AddArgumentListArguments(AttributeArgument(ParseName("global::System.Runtime.CompilerServices.MethodImplOptions").Member("AggressiveInlining")));
-
- internal void VisitInterface(INamedTypeSymbol interfaceType)
- {
- // Get or generate an invokable for the original method definition.
- if (!SymbolEqualityComparer.Default.Equals(interfaceType, interfaceType.OriginalDefinition))
- {
- interfaceType = interfaceType.OriginalDefinition;
- }
-
- if (!_visitedInterfaces.Add(interfaceType))
- {
- return;
- }
-
- foreach (var proxyBase in GetProxyBases(interfaceType))
- {
- _ = GetInvokableInterfaceDescription(proxyBase.ProxyBaseType, interfaceType);
- }
-
- /*
- foreach (var baseInterface in interfaceType.AllInterfaces)
- {
- VisitInterface(baseInterface);
- }
- */
- }
-
- internal bool TryGetInvokableInterfaceDescription(INamedTypeSymbol interfaceType, out ProxyInterfaceDescription result)
- {
- if (!TryGetProxyBaseDescription(interfaceType, out var description))
- {
- result = null;
- return false;
- }
-
- result = GetInvokableInterfaceDescription(description.ProxyBaseType, interfaceType);
- return true;
- }
-
- private readonly Dictionary> _interfaceProxyBases = new(SymbolEqualityComparer.Default);
- internal List GetProxyBases(INamedTypeSymbol interfaceType)
- {
- if (_interfaceProxyBases.TryGetValue(interfaceType, out var result))
- {
- return result;
- }
-
- result = new List();
- if (interfaceType.GetAttributes(LibraryTypes.GenerateMethodSerializersAttribute, out var attributes, inherited: true))
- {
- foreach (var attribute in attributes)
- {
- var proxyBase = GetProxyBaseDescription(attribute);
- if (!result.Contains(proxyBase))
- {
- result.Add(proxyBase);
- }
- }
- }
-
- return result;
- }
-
- internal bool TryGetProxyBaseDescription(INamedTypeSymbol interfaceType, out InvokableMethodProxyBase result)
- {
- var attribute = interfaceType.GetAttribute(LibraryTypes.GenerateMethodSerializersAttribute, inherited: true);
- if (attribute == null)
- {
- result = null;
- return false;
- }
-
- result = GetProxyBaseDescription(attribute);
- return true;
- }
-
- private InvokableMethodProxyBase GetProxyBaseDescription(AttributeData attribute)
- {
- var proxyBaseType = ((INamedTypeSymbol)attribute.ConstructorArguments[0].Value).OriginalDefinition;
- var isExtension = (bool)attribute.ConstructorArguments[1].Value;
- var invokableBaseTypes = GetInvokableBaseTypes(proxyBaseType);
- var descriptor = new InvokableMethodProxyBaseId(proxyBaseType, isExtension);
- var description = new InvokableMethodProxyBase(this, descriptor, invokableBaseTypes);
- return description;
-
- Dictionary GetInvokableBaseTypes(INamedTypeSymbol baseClass)
- {
- // Set the base invokable types which are used if attributes on individual methods do not override them.
- if (!MetadataModel.ProxyBaseTypeInvokableBaseTypes.TryGetValue(baseClass, out var invokableBaseTypes))
- {
- invokableBaseTypes = new Dictionary(SymbolEqualityComparer.Default);
- if (baseClass.GetAttributes(LibraryTypes.DefaultInvokableBaseTypeAttribute, out var invokableBaseTypeAttributes))
- {
- foreach (var attr in invokableBaseTypeAttributes)
- {
- var ctorArgs = attr.ConstructorArguments;
- var returnType = (INamedTypeSymbol)ctorArgs[0].Value;
- var invokableBaseType = (INamedTypeSymbol)ctorArgs[1].Value;
- invokableBaseTypes[returnType] = invokableBaseType;
- }
- }
-
- MetadataModel.ProxyBaseTypeInvokableBaseTypes[baseClass] = invokableBaseTypes;
- }
-
- return invokableBaseTypes;
- }
- }
-
- internal InvokableMethodProxyBase GetProxyBase(INamedTypeSymbol interfaceType)
- {
- if (!TryGetProxyBaseDescription(interfaceType, out var result))
- {
- throw new InvalidOperationException($"Cannot get proxy base description for a type which does not have or inherit [{nameof(LibraryTypes.GenerateMethodSerializersAttribute)}]");
- }
-
- return result;
- }
-
- private ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSymbol proxyBaseType, INamedTypeSymbol interfaceType)
- {
- var originalInterface = interfaceType.OriginalDefinition;
- if (MetadataModel.InvokableInterfaces.TryGetValue(originalInterface, out var description))
- {
- return description;
- }
-
- description = new ProxyInterfaceDescription(this, proxyBaseType, originalInterface);
- MetadataModel.InvokableInterfaces.Add(originalInterface, description);
-
- // Generate a proxy.
- var (generatedClass, proxyDescription) = ProxyGenerator.Generate(description);
-
- // Emit the generated proxy
- if (Compilation.GetTypeByMetadataName(proxyDescription.MetadataName) == null)
- {
- AddMember(proxyDescription.InterfaceDescription.GeneratedNamespace, generatedClass);
- }
-
- MetadataModel.GeneratedProxies.Add(proxyDescription);
-
- return description;
- }
-
- internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method)
- {
- var originalMethod = method.OriginalDefinition;
- var proxyBaseInfo = GetProxyBase(interfaceType);
-
- // For extensions, we want to ensure that the containing type is always the extension.
- // This ensures that we will always know which 'component' to get in our SetTarget method.
- // If the type is not an extension, use the original method definition's containing type.
- // This is the interface where the type was originally defined.
- var containingType = proxyBaseInfo.IsExtension ? interfaceType : originalMethod.ContainingType;
-
- var invokableId = new InvokableMethodId(proxyBaseInfo, containingType, originalMethod);
- var interfaceDescription = GetInvokableInterfaceDescription(invokableId.ProxyBase.ProxyBaseType, interfaceType);
-
- // Get or generate an invokable for the original method definition.
- if (!MetadataModel.GeneratedInvokables.TryGetValue(invokableId, out var generatedInvokable))
- {
- if (!_invokableMethodDescriptions.TryGetValue(invokableId, out var methodDescription))
- {
- methodDescription = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId, containingType);
- }
-
- generatedInvokable = MetadataModel.GeneratedInvokables[invokableId] = InvokableGenerator.Generate(methodDescription);
-
- if (Compilation.GetTypeByMetadataName(generatedInvokable.MetadataName) == null)
- {
- // Emit the generated code on-demand.
- AddMember(generatedInvokable.GeneratedNamespace, generatedInvokable.ClassDeclarationSyntax);
-
- // Ensure the type will have a serializer generated for it.
- MetadataModel.SerializableTypes.Add(generatedInvokable);
-
- foreach (var alias in generatedInvokable.CompoundTypeAliases)
- {
- MetadataModel.CompoundTypeAliases.Add(alias, generatedInvokable.OpenTypeSyntax);
- }
- }
- }
-
- var proxyMethodDescription = ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method);
-
- // For backwards compatibility, generate invokers for the specific implementation types as well, where they differ.
- if (Options.GenerateCompatibilityInvokers && !SymbolEqualityComparer.Default.Equals(method.OriginalDefinition.ContainingType, interfaceType))
- {
- var compatInvokableId = new InvokableMethodId(proxyBaseInfo, interfaceType, method);
- var compatMethodDescription = InvokableMethodDescription.Create(compatInvokableId, interfaceType);
- var compatInvokable = InvokableGenerator.Generate(compatMethodDescription);
- AddMember(compatInvokable.GeneratedNamespace, compatInvokable.ClassDeclarationSyntax);
- var alias =
- InvokableGenerator.GetCompoundTypeAliasComponents(
- compatInvokableId,
- interfaceType,
- compatMethodDescription.GeneratedMethodId);
- MetadataModel.CompoundTypeAliases.Add(alias, compatInvokable.OpenTypeSyntax);
- }
-
- return proxyMethodDescription;
- }
- }
-}
diff --git a/src/Orleans.CodeGenerator/CodeGeneratorOptions.cs b/src/Orleans.CodeGenerator/CodeGeneratorOptions.cs
new file mode 100644
index 00000000000..c7be35dbbe5
--- /dev/null
+++ b/src/Orleans.CodeGenerator/CodeGeneratorOptions.cs
@@ -0,0 +1,13 @@
+using Orleans.CodeGenerator.Model;
+
+namespace Orleans.CodeGenerator;
+
+public class CodeGeneratorOptions
+{
+ public const string IdAttribute = "Orleans.IdAttribute";
+ public const string AliasAttribute = "Orleans.AliasAttribute";
+ public const string ImmutableAttribute = "Orleans.ImmutableAttribute";
+ public static readonly IReadOnlyList ConstructorAttributes = ["Orleans.OrleansConstructorAttribute", "Microsoft.Extensions.DependencyInjection.ActivatorUtilitiesConstructorAttribute"];
+ public GenerateFieldIds GenerateFieldIds { get; set; }
+ public bool GenerateCompatibilityInvokers { get; set; }
+}
diff --git a/src/Orleans.CodeGenerator/CopierGenerator.cs b/src/Orleans.CodeGenerator/CopierGenerator.cs
index e06554bd8c4..5a18d9cb176 100644
--- a/src/Orleans.CodeGenerator/CopierGenerator.cs
+++ b/src/Orleans.CodeGenerator/CopierGenerator.cs
@@ -1,697 +1,688 @@
-using System.Collections.Generic;
-using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System.Diagnostics;
using Orleans.CodeGenerator.SyntaxGeneration;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Orleans.CodeGenerator.InvokableGenerator;
using static Orleans.CodeGenerator.SerializerGenerator;
-#nullable disable
-namespace Orleans.CodeGenerator
+namespace Orleans.CodeGenerator;
+
+internal class CopierGenerator(IGeneratorServices generatorServices)
{
- internal class CopierGenerator
- {
- private const string BaseTypeCopierFieldName = "_baseTypeCopier";
- private const string ActivatorFieldName = "_activator";
- private const string DeepCopyMethodName = "DeepCopy";
- private readonly CodeGenerator _codeGenerator;
+ private const string BaseTypeCopierFieldName = "_baseTypeCopier";
+ private const string ActivatorFieldName = "_activator";
+ private const string DeepCopyMethodName = "DeepCopy";
+ private readonly IGeneratorServices _generatorServices = generatorServices;
- public CopierGenerator(CodeGenerator codeGenerator)
+ private LibraryTypes LibraryTypes => _generatorServices.LibraryTypes;
+
+ public ClassDeclarationSyntax? GenerateCopier(
+ ISerializableTypeDescription type,
+ Dictionary defaultCopiers)
+ {
+ var isShallowCopyable = type.IsShallowCopyable;
+ if (isShallowCopyable && !type.IsGenericType)
{
- _codeGenerator = codeGenerator;
+ defaultCopiers.Add(type, LibraryTypes.ShallowCopier.ToTypeSyntax(type.TypeSyntax));
+ return null;
}
- private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes;
+ var simpleClassName = GetSimpleClassName(type);
- public ClassDeclarationSyntax GenerateCopier(
- ISerializableTypeDescription type,
- Dictionary defaultCopiers)
+ var members = new List();
+ foreach (var member in type.Members)
{
- var isShallowCopyable = type.IsShallowCopyable;
- if (isShallowCopyable && !type.IsGenericType)
+ if (!member.IsCopyable)
{
- defaultCopiers.Add(type, LibraryTypes.ShallowCopier.ToTypeSyntax(type.TypeSyntax));
- return null;
+ continue;
}
- var simpleClassName = GetSimpleClassName(type);
-
- var members = new List();
- foreach (var member in type.Members)
+ if (member is ISerializableMember serializable)
{
- if (!member.IsCopyable)
- {
- continue;
- }
-
- if (member is ISerializableMember serializable)
- {
- members.Add(serializable);
- }
- else if (member is IFieldDescription or IPropertyDescription)
- {
- members.Add(new SerializableMember(_codeGenerator, member, members.Count));
- }
- else if (member is MethodParameterFieldDescription methodParameter)
- {
- members.Add(new SerializableMethodMember(methodParameter));
- }
+ members.Add(serializable);
}
-
- var accessibility = type.Accessibility switch
+ else if (member is IFieldDescription or IPropertyDescription)
{
- Accessibility.Public => SyntaxKind.PublicKeyword,
- _ => SyntaxKind.InternalKeyword,
- };
-
- var isExceptionType = type.IsExceptionType && type.SerializationHooks.Count == 0;
-
- var baseType = isExceptionType ? QualifiedName(AliasQualifiedName("global", IdentifierName("Orleans.Serialization.GeneratedCodeHelpers.OrleansGeneratedCodeHelper")), GenericName(Identifier("ExceptionCopier"), TypeArgumentList(SeparatedList(new[] { type.TypeSyntax, type.BaseType.ToTypeSyntax() }))))
- : (isShallowCopyable ? LibraryTypes.ShallowCopier : LibraryTypes.DeepCopier_1).ToTypeSyntax(type.TypeSyntax);
+ members.Add(new SerializableMember(_generatorServices, member, members.Count));
+ }
+ else if (member is MethodParameterFieldDescription methodParameter)
+ {
+ members.Add(new SerializableMethodMember(methodParameter));
+ }
+ }
- var classDeclaration = ClassDeclaration(simpleClassName)
- .AddBaseListTypes(SimpleBaseType(baseType))
- .AddModifiers(Token(accessibility), Token(SyntaxKind.SealedKeyword))
- .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes());
+ var accessibility = type.Accessibility switch
+ {
+ Accessibility.Public => SyntaxKind.PublicKeyword,
+ _ => SyntaxKind.InternalKeyword,
+ };
- if (!isShallowCopyable)
- {
- var fieldDescriptions = GetFieldDescriptions(type, members, isExceptionType, out var onlyDeepFields);
- var fieldDeclarations = GetFieldDeclarations(fieldDescriptions);
- var ctor = GenerateConstructor(simpleClassName, fieldDescriptions, isExceptionType);
+ var isExceptionType = type.IsExceptionType && type.SerializationHooks.Count == 0;
- classDeclaration = classDeclaration.AddMembers(fieldDeclarations);
+ var baseType = isExceptionType ? QualifiedName(AliasQualifiedName("global", IdentifierName("Orleans.Serialization.GeneratedCodeHelpers.OrleansGeneratedCodeHelper")), GenericName(Identifier("ExceptionCopier"), TypeArgumentList(SeparatedList([type.TypeSyntax, type.BaseType.ToTypeSyntax()]))))
+ : (isShallowCopyable ? LibraryTypes.ShallowCopier : LibraryTypes.DeepCopier_1).ToTypeSyntax(type.TypeSyntax);
- if (!isExceptionType)
- {
- var copyMethod = GenerateMemberwiseDeepCopyMethod(type, fieldDescriptions, members, onlyDeepFields);
- classDeclaration = classDeclaration.AddMembers(copyMethod);
- }
+ var classDeclaration = ClassDeclaration(simpleClassName)
+ .AddBaseListTypes(SimpleBaseType(baseType))
+ .AddModifiers(Token(accessibility), Token(SyntaxKind.SealedKeyword))
+ .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes());
- if (ctor != null)
- classDeclaration = classDeclaration.AddMembers(ctor);
+ if (!isShallowCopyable)
+ {
+ var fieldDescriptions = GetFieldDescriptions(type, members, isExceptionType, out var onlyDeepFields);
+ var fieldDeclarations = GetFieldDeclarations(fieldDescriptions);
+ var ctor = GenerateConstructor(simpleClassName, fieldDescriptions, isExceptionType);
- if (isExceptionType || !type.IsSealedType)
- {
- if (GenerateBaseCopierDeepCopyMethod(type, fieldDescriptions, members, isExceptionType) is { } baseCopier)
- classDeclaration = classDeclaration.AddMembers(baseCopier);
+ classDeclaration = classDeclaration.AddMembers(fieldDeclarations);
- if (!isExceptionType)
- classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(LibraryTypes.BaseCopier_1.ToTypeSyntax(type.TypeSyntax)));
- }
+ if (!isExceptionType)
+ {
+ var copyMethod = GenerateMemberwiseDeepCopyMethod(type, fieldDescriptions, members, onlyDeepFields);
+ classDeclaration = classDeclaration.AddMembers(copyMethod);
}
- if (type.IsGenericType)
+ if (ctor != null)
+ classDeclaration = classDeclaration.AddMembers(ctor);
+
+ if (isExceptionType || !type.IsSealedType)
{
- classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters);
+ if (GenerateBaseCopierDeepCopyMethod(type, fieldDescriptions, members, isExceptionType) is { } baseCopier)
+ classDeclaration = classDeclaration.AddMembers(baseCopier);
+
+ if (!isExceptionType)
+ classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(LibraryTypes.BaseCopier_1.ToTypeSyntax(type.TypeSyntax)));
}
+ }
- return classDeclaration;
+ if (type.IsGenericType)
+ {
+ classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters);
}
- public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name);
+ return classDeclaration;
+ }
- public static string GetSimpleClassName(string name) => $"Copier_{name}";
+ public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name);
- private MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions)
- {
- return fieldDescriptions.Select(GetFieldDeclaration).ToArray();
+ public static string GetSimpleClassName(string name) => $"Copier_{name}";
- static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description)
+ private static MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions)
+ {
+ return [.. fieldDescriptions.Select(GetFieldDeclaration)];
+
+ static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description)
+ {
+ switch (description)
{
- switch (description)
- {
- case FieldAccessorDescription accessor when accessor.InitializationSyntax != null:
- return
- FieldDeclaration(VariableDeclaration(accessor.FieldType,
- SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax)))))
- .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
- case FieldAccessorDescription accessor when accessor.InitializationSyntax == null:
- //[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")]
- //extern static void SetAmount(External instance, int value);
- return
- MethodDeclaration(
- PredefinedType(Token(SyntaxKind.VoidKeyword)),
- accessor.AccessorName)
- .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword))
- .AddAttributeLists(AttributeList(SingletonSeparatedList(
- Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor"))
- .AddArgumentListArguments(
- AttributeArgument(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"),
- IdentifierName("Method"))),
- AttributeArgument(
- LiteralExpression(
- SyntaxKind.StringLiteralExpression,
- Literal($"set_{accessor.FieldName}")))
- .WithNameEquals(NameEquals("Name"))))))
- .WithParameterList(
- ParameterList(SeparatedList(new[]
- {
- Parameter(Identifier("instance")).WithType(accessor.ContainingType),
- Parameter(Identifier("value")).WithType(description.FieldType)
- })))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
- default:
- return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName))))
- .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword));
- }
+ case FieldAccessorDescription accessor when accessor.InitializationSyntax != null:
+ return
+ FieldDeclaration(VariableDeclaration(accessor.FieldType,
+ SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax)))))
+ .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
+ case FieldAccessorDescription accessor when accessor.InitializationSyntax == null:
+ //[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")]
+ //extern static void SetAmount(External instance, int value);
+ return
+ MethodDeclaration(
+ PredefinedType(Token(SyntaxKind.VoidKeyword)),
+ accessor.AccessorName)
+ .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword))
+ .AddAttributeLists(AttributeList(SingletonSeparatedList(
+ Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor"))
+ .AddArgumentListArguments(
+ AttributeArgument(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"),
+ IdentifierName("Method"))),
+ AttributeArgument(
+ LiteralExpression(
+ SyntaxKind.StringLiteralExpression,
+ Literal($"set_{accessor.FieldName}")))
+ .WithNameEquals(NameEquals("Name"))))))
+ .WithParameterList(
+ ParameterList(SeparatedList(
+ [
+ Parameter(Identifier("instance")).WithType(accessor.ContainingType),
+ Parameter(Identifier("value")).WithType(description.FieldType)
+ ])))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+ default:
+ return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName))))
+ .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword));
}
}
+ }
- private ConstructorDeclarationSyntax GenerateConstructor(string simpleClassName, List fieldDescriptions, bool isExceptionType)
- {
- var codecProviderAdded = false;
- var parameters = new List();
- var statements = new List();
+ private ConstructorDeclarationSyntax? GenerateConstructor(string simpleClassName, List fieldDescriptions, bool isExceptionType)
+ {
+ var codecProviderAdded = false;
+ var parameters = new List();
+ var statements = new List();
- if (isExceptionType)
- {
- parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax()));
- codecProviderAdded = true;
- }
+ if (isExceptionType)
+ {
+ parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax()));
+ codecProviderAdded = true;
+ }
- foreach (var field in fieldDescriptions)
+ foreach (var field in fieldDescriptions)
+ {
+ switch (field)
{
- switch (field)
- {
- case GeneratedFieldDescription _ when field.IsInjected:
- parameters.Add(Parameter(field.FieldName.ToIdentifier()).WithType(field.FieldType));
- statements.Add(ExpressionStatement(
- AssignmentExpression(
- SyntaxKind.SimpleAssignmentExpression,
- ThisExpression().Member(field.FieldName.ToIdentifierName()),
- Unwrapped(field.FieldName.ToIdentifierName()))));
- break;
- case CopierFieldDescription or BaseCopierFieldDescription when !field.IsInjected:
- if (!codecProviderAdded)
- {
- parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax()));
- codecProviderAdded = true;
- }
-
- var copier = InvocationExpression(
- IdentifierName("OrleansGeneratedCodeHelper").Member(GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(field.FieldType)))),
- ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(IdentifierName("codecProvider")) })));
-
- statements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, field.FieldName.ToIdentifierName(), copier)));
- break;
- }
- }
+ case GeneratedFieldDescription _ when field.IsInjected:
+ parameters.Add(Parameter(field.FieldName.ToIdentifier()).WithType(field.FieldType));
+ statements.Add(ExpressionStatement(
+ AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ ThisExpression().Member(field.FieldName.ToIdentifierName()),
+ Unwrapped(field.FieldName.ToIdentifierName()))));
+ break;
+ case CopierFieldDescription or BaseCopierFieldDescription when !field.IsInjected:
+ if (!codecProviderAdded)
+ {
+ parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax()));
+ codecProviderAdded = true;
+ }
- return statements.Count == 0 && !isExceptionType ? null : ConstructorDeclaration(simpleClassName)
- .AddModifiers(Token(SyntaxKind.PublicKeyword))
- .AddParameterListParameters(parameters.ToArray())
- .AddBodyStatements(statements.ToArray())
- .WithInitializer(isExceptionType ? ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList(SingletonSeparatedList(Argument(IdentifierName("codecProvider"))))) : null);
+ var copier = InvocationExpression(
+ IdentifierName("OrleansGeneratedCodeHelper").Member(GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(field.FieldType)))),
+ ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(IdentifierName("codecProvider"))])));
- static ExpressionSyntax Unwrapped(ExpressionSyntax expr)
- {
- return InvocationExpression(
- IdentifierName("OrleansGeneratedCodeHelper").Member("UnwrapService"),
- ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(expr) })));
+ statements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, field.FieldName.ToIdentifierName(), copier)));
+ break;
}
}
- private List GetFieldDescriptions(
- ISerializableTypeDescription serializableTypeDescription,
- List members,
- bool isExceptionType,
- out bool onlyDeepFields)
+ return statements.Count == 0 && !isExceptionType ? null : ConstructorDeclaration(simpleClassName)
+ .AddModifiers(Token(SyntaxKind.PublicKeyword))
+ .AddParameterListParameters([.. parameters])
+ .AddBodyStatements([.. statements])
+ .WithInitializer(isExceptionType ? ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList(SingletonSeparatedList(Argument(IdentifierName("codecProvider"))))) : null);
+
+ static ExpressionSyntax Unwrapped(ExpressionSyntax expr)
{
- var serializationHooks = serializableTypeDescription.SerializationHooks;
- onlyDeepFields = serializableTypeDescription.IsValueType && serializationHooks.Count == 0;
+ return InvocationExpression(
+ IdentifierName("OrleansGeneratedCodeHelper").Member("UnwrapService"),
+ ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(expr)])));
+ }
+ }
- var fields = new List();
+ private List GetFieldDescriptions(
+ ISerializableTypeDescription serializableTypeDescription,
+ List members,
+ bool isExceptionType,
+ out bool onlyDeepFields)
+ {
+ var serializationHooks = serializableTypeDescription.SerializationHooks;
+ onlyDeepFields = serializableTypeDescription.IsValueType && serializationHooks.Count == 0;
- if (!isExceptionType && serializableTypeDescription.HasComplexBaseType)
- {
- fields.Add(GetBaseTypeField(serializableTypeDescription));
- }
+ var fields = new List();
- if (!serializableTypeDescription.IsImmutable)
- {
- if (!isExceptionType && serializableTypeDescription.UseActivator && !serializableTypeDescription.IsAbstractType)
- {
- onlyDeepFields = false;
- fields.Add(new ActivatorFieldDescription(LibraryTypes.IActivator_1.ToTypeSyntax(serializableTypeDescription.TypeSyntax), ActivatorFieldName));
- }
+ if (!isExceptionType && serializableTypeDescription.HasComplexBaseType)
+ {
+ fields.Add(GetBaseTypeField(serializableTypeDescription));
+ }
- // Add a copier field for any field in the target which does not have a static copier.
- GetCopierFieldDescriptions(serializableTypeDescription.Members, fields);
+ if (!serializableTypeDescription.IsImmutable)
+ {
+ if (!isExceptionType && serializableTypeDescription.UseActivator && !serializableTypeDescription.IsAbstractType)
+ {
+ onlyDeepFields = false;
+ fields.Add(new ActivatorFieldDescription(LibraryTypes.IActivator_1.ToTypeSyntax(serializableTypeDescription.TypeSyntax), ActivatorFieldName));
}
- foreach (var member in members)
- {
- if (onlyDeepFields && member.IsShallowCopyable) continue;
+ // Add a copier field for any field in the target which does not have a static copier.
+ GetCopierFieldDescriptions(serializableTypeDescription.Members, fields);
+ }
- if (member.GetGetterFieldDescription() is { } getterFieldDescription)
- {
- fields.Add(getterFieldDescription);
- }
+ foreach (var member in members)
+ {
+ if (onlyDeepFields && member.IsShallowCopyable) continue;
- if (member.GetSetterFieldDescription() is { } setterFieldDescription)
- {
- fields.Add(setterFieldDescription);
- }
+ if (member.GetGetterFieldDescription() is { } getterFieldDescription)
+ {
+ fields.Add(getterFieldDescription);
}
- for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex)
+ if (member.GetSetterFieldDescription() is { } setterFieldDescription)
{
- var hookType = serializationHooks[hookIndex];
- fields.Add(new SerializationHookFieldDescription(hookType.ToTypeSyntax(), $"_hook{hookIndex}"));
+ fields.Add(setterFieldDescription);
}
-
- return fields;
}
- private BaseCopierFieldDescription GetBaseTypeField(ISerializableTypeDescription serializableTypeDescription)
+ for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex)
{
- var baseType = serializableTypeDescription.BaseType;
- if (baseType.HasAttribute(LibraryTypes.GenerateSerializerAttribute)
- && (SymbolEqualityComparer.Default.Equals(baseType.ContainingAssembly, LibraryTypes.Compilation.Assembly) || baseType.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute))
- && baseType is not INamedTypeSymbol { IsGenericType: true })
- {
- // Use the concrete generated type and avoid expensive interface dispatch (except for generic types that will fall back to IBaseCopier)
- return new(QualifiedName(ParseName(GetGeneratedNamespaceName(baseType)), IdentifierName(GetSimpleClassName(baseType.Name))), true);
- }
+ var hookType = serializationHooks[hookIndex];
+ fields.Add(new SerializationHookFieldDescription(hookType.ToTypeSyntax(), $"_hook{hookIndex}"));
+ }
+
+ return fields;
+ }
- return new(LibraryTypes.BaseCopier_1.ToTypeSyntax(serializableTypeDescription.BaseTypeSyntax));
+ private BaseCopierFieldDescription GetBaseTypeField(ISerializableTypeDescription serializableTypeDescription)
+ {
+ var baseType = serializableTypeDescription.BaseType;
+ if (baseType.HasAttribute(LibraryTypes.GenerateSerializerAttribute)
+ && (SymbolEqualityComparer.Default.Equals(baseType.ContainingAssembly, LibraryTypes.Compilation.Assembly) || baseType.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute))
+ && baseType is not INamedTypeSymbol { IsGenericType: true })
+ {
+ // Use the concrete generated type and avoid expensive interface dispatch (except for generic types that will fall back to IBaseCopier)
+ return new(QualifiedName(ParseName(GetGeneratedNamespaceName(baseType)), IdentifierName(GetSimpleClassName(baseType.Name))), true);
}
- public void GetCopierFieldDescriptions(IEnumerable members, List fields)
+ return new(LibraryTypes.BaseCopier_1.ToTypeSyntax(serializableTypeDescription.BaseTypeSyntax));
+ }
+
+ public void GetCopierFieldDescriptions(IEnumerable members, List fields)
+ {
+ var fieldIndex = 0;
+ var uniqueTypes = new HashSet(MemberDescriptionTypeComparer.Default);
+ foreach (var member in members)
{
- var fieldIndex = 0;
- var uniqueTypes = new HashSet(MemberDescriptionTypeComparer.Default);
- foreach (var member in members)
+ if (!member.IsCopyable)
{
- if (!member.IsCopyable)
- {
- continue;
- }
+ continue;
+ }
- var t = member.Type;
+ var t = member.Type;
- if (LibraryTypes.IsShallowCopyable(t))
- continue;
+ if (LibraryTypes.IsShallowCopyable(t))
+ continue;
- foreach (var c in LibraryTypes.StaticCopiers)
- if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, t))
- goto skip;
+ foreach (var c in LibraryTypes.StaticCopiers)
+ if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, t))
+ goto skip;
- if (member.Symbol.HasAttribute(LibraryTypes.ImmutableAttribute))
- continue;
+ if (member.Symbol.HasAttribute(LibraryTypes.ImmutableAttribute))
+ continue;
- if (!uniqueTypes.Add(member))
- continue;
+ if (!uniqueTypes.Add(member))
+ continue;
- TypeSyntax copierType;
- if (t.HasAttribute(LibraryTypes.GenerateSerializerAttribute)
- && (SymbolEqualityComparer.Default.Equals(t.ContainingAssembly, LibraryTypes.Compilation.Assembly) || t.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute))
- && t is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 0 })
- {
- // Use the concrete generated type and avoid expensive interface dispatch (except for complex nested cases that will fall back to IDeepCopier)
- SimpleNameSyntax name;
- if (t is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType)
- {
- // Construct the full generic type name
- name = GenericName(Identifier(GetSimpleClassName(t.Name)), TypeArgumentList(SeparatedList(namedTypeSymbol.TypeArguments.Select(arg => arg.ToTypeSyntax()))));
- }
- else
- {
- name = IdentifierName(GetSimpleClassName(t.Name));
- }
- copierType = QualifiedName(ParseName(GetGeneratedNamespaceName(t)), name);
- }
- else if (t is IArrayTypeSymbol { IsSZArray: true } array)
- {
- copierType = LibraryTypes.ArrayCopier.Construct(array.ElementType).ToTypeSyntax();
- }
- else if (LibraryTypes.WellKnownCopiers.FindByUnderlyingType(t) is { } copier)
- {
- // The copier is not a static copier and is also not a generic copiers.
- copierType = copier.CopierType.ToTypeSyntax();
- }
- else if (t is INamedTypeSymbol { ConstructedFrom: ISymbol unboundFieldType } named && LibraryTypes.WellKnownCopiers.FindByUnderlyingType(unboundFieldType) is { } genericCopier)
+ TypeSyntax copierType;
+ if (t.HasAttribute(LibraryTypes.GenerateSerializerAttribute)
+ && (SymbolEqualityComparer.Default.Equals(t.ContainingAssembly, LibraryTypes.Compilation.Assembly) || t.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute))
+ && t is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 0 })
+ {
+ // Use the concrete generated type and avoid expensive interface dispatch (except for complex nested cases that will fall back to IDeepCopier)
+ SimpleNameSyntax name;
+ if (t is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType)
{
- // Construct the generic copier type using the field's type arguments.
- copierType = genericCopier.CopierType.Construct(named.TypeArguments.ToArray()).ToTypeSyntax();
+ // Construct the full generic type name
+ name = GenericName(Identifier(GetSimpleClassName(t.Name)), TypeArgumentList(SeparatedList(namedTypeSymbol.TypeArguments.Select(arg => arg.ToTypeSyntax()))));
}
else
{
- // Use the IDeepCopier interface
- copierType = LibraryTypes.DeepCopier_1.ToTypeSyntax(member.TypeSyntax);
+ name = IdentifierName(GetSimpleClassName(t.Name));
}
-
- fields.Add(new CopierFieldDescription(copierType, $"_copier{fieldIndex++}", t));
-skip:;
+ copierType = QualifiedName(ParseName(GetGeneratedNamespaceName(t)), name);
}
- }
-
- private MemberDeclarationSyntax GenerateMemberwiseDeepCopyMethod(
- ISerializableTypeDescription type,
- List copierFields,
- List members,
- bool onlyDeepFields)
- {
- var returnType = type.TypeSyntax;
-
- var originalParam = "original".ToIdentifierName();
- var contextParam = "context".ToIdentifierName();
- var resultVar = "result".ToIdentifierName();
-
- var body = new List();
-
- var membersCopied = false;
- if (type.IsAbstractType)
+ else if (t is IArrayTypeSymbol { IsSZArray: true } array)
{
- // C#: return context.DeepCopy(original)
- body.Add(ReturnStatement(InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam))))));
- membersCopied = true;
+ copierType = LibraryTypes.ArrayCopier.Construct(array.ElementType).ToTypeSyntax();
}
- else if (type.IsUnsealedImmutable)
+ else if (LibraryTypes.WellKnownCopiers.FindByUnderlyingType(t) is { } copier)
{
- // C#: return original is null || original.GetType() == typeof(T) ? original : context.DeepCopy(original);
- var exactTypeMatch = BinaryExpression(SyntaxKind.EqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax));
- var nullOrTypeMatch = BinaryExpression(SyntaxKind.LogicalOrExpression, BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), exactTypeMatch);
- var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam))));
- body.Add(ReturnStatement(ConditionalExpression(nullOrTypeMatch, originalParam, contextCopy)));
- membersCopied = true;
+ // The copier is not a static copier and is also not a generic copiers.
+ copierType = copier.CopierType.ToTypeSyntax();
}
- else if (!type.IsValueType)
+ else if (t is INamedTypeSymbol { ConstructedFrom: ISymbol unboundFieldType } named && LibraryTypes.WellKnownCopiers.FindByUnderlyingType(unboundFieldType) is { } genericCopier)
{
- if (type.TrackReferences)
- {
- // C#: if (context.TryGetCopy(original, out T existing)) return existing;
- var tryGetCopy = InvocationExpression(
- contextParam.Member("TryGetCopy"),
- ArgumentList(SeparatedList(new[]
- {
- Argument(originalParam),
- Argument(DeclarationExpression(
- type.TypeSyntax,
- SingleVariableDesignation(Identifier("existing"))))
- .WithRefKindKeyword(Token(SyntaxKind.OutKeyword))
- })));
- body.Add(IfStatement(tryGetCopy, ReturnStatement("existing".ToIdentifierName())));
- }
- else
- {
- // C#: if (original is null) return null;
- body.Add(IfStatement(BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression))));
- }
+ // Construct the generic copier type using the field's type arguments.
+ copierType = genericCopier.CopierType.Construct([.. named.TypeArguments]).ToTypeSyntax();
+ }
+ else
+ {
+ // Use the IDeepCopier interface
+ copierType = LibraryTypes.DeepCopier_1.ToTypeSyntax(member.TypeSyntax);
+ }
- if (!type.IsSealedType)
- {
- // C#: if (original.GetType() != typeof(T)) { return context.DeepCopy(original); }
- var exactTypeMatch = BinaryExpression(SyntaxKind.NotEqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax));
- var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam))));
- body.Add(IfStatement(exactTypeMatch, ReturnStatement(contextCopy)));
- }
+ fields.Add(new CopierFieldDescription(copierType, $"_copier{fieldIndex++}", t));
+skip:;
+ }
+ }
- // C#: var result = _activator.Create();
- body.Add(LocalDeclarationStatement(
- VariableDeclaration(
- IdentifierName("var"),
- SingletonSeparatedList(VariableDeclarator(
- resultVar.Identifier,
- argumentList: null,
- initializer: EqualsValueClause(GetCreateValueExpression(type, copierFields)))))));
+ private MemberDeclarationSyntax GenerateMemberwiseDeepCopyMethod(
+ ISerializableTypeDescription type,
+ List copierFields,
+ List members,
+ bool onlyDeepFields)
+ {
+ var returnType = type.TypeSyntax;
- if (type.TrackReferences)
- {
- // C#: context.RecordCopy(original, result);
- body.Add(ExpressionStatement(InvocationExpression(contextParam.Member("RecordCopy"), ArgumentList(SeparatedList(new[]
- {
- Argument(originalParam),
- Argument(resultVar)
- })))));
- }
+ var originalParam = "original".ToIdentifierName();
+ var contextParam = "context".ToIdentifierName();
+ var resultVar = "result".ToIdentifierName();
- if (!type.IsSealedType)
- {
- // C#: DeepCopy(original, result, context);
- body.Add(ExpressionStatement(InvocationExpression(IdentifierName("DeepCopy"),
- ArgumentList(SeparatedList(new[] { Argument(originalParam), Argument(resultVar), Argument(contextParam) })))));
- body.Add(ReturnStatement(resultVar));
- membersCopied = true;
- }
- else if (type.HasComplexBaseType)
- {
- // C#: _baseTypeCopier.DeepCopy(original, result, context);
- body.Add(
- ExpressionStatement(
- InvocationExpression(
- BaseTypeCopierFieldName.ToIdentifierName().Member(DeepCopyMethodName),
- ArgumentList(SeparatedList(new[]
- {
- Argument(originalParam),
- Argument(resultVar),
- Argument(contextParam)
- })))));
- }
- }
- else if (!onlyDeepFields)
+ var body = new List();
+
+ var membersCopied = false;
+ if (type.IsAbstractType)
+ {
+ // C#: return context.DeepCopy(original)
+ body.Add(ReturnStatement(InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam))))));
+ membersCopied = true;
+ }
+ else if (type.IsUnsealedImmutable)
+ {
+ // C#: return original is null || original.GetType() == typeof(T) ? original : context.DeepCopy(original);
+ var exactTypeMatch = BinaryExpression(SyntaxKind.EqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax));
+ var nullOrTypeMatch = BinaryExpression(SyntaxKind.LogicalOrExpression, BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), exactTypeMatch);
+ var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam))));
+ body.Add(ReturnStatement(ConditionalExpression(nullOrTypeMatch, originalParam, contextCopy)));
+ membersCopied = true;
+ }
+ else if (!type.IsValueType)
+ {
+ if (type.TrackReferences)
{
- // C#: var result = _activator.Create();
- // or C#: var result = new TField();
- // or C#: var result = default(TField);
- body.Add(LocalDeclarationStatement(
- VariableDeclaration(
- IdentifierName("var"),
- SingletonSeparatedList(VariableDeclarator(resultVar.Identifier)
- .WithInitializer(EqualsValueClause(GetCreateValueExpression(type, copierFields)))))));
+ // C#: if (context.TryGetCopy(original, out T existing)) return existing;
+ var tryGetCopy = InvocationExpression(
+ contextParam.Member("TryGetCopy"),
+ ArgumentList(SeparatedList(
+ [
+ Argument(originalParam),
+ Argument(DeclarationExpression(
+ type.TypeSyntax,
+ SingleVariableDesignation(Identifier("existing"))))
+ .WithRefKindKeyword(Token(SyntaxKind.OutKeyword))
+ ])));
+ body.Add(IfStatement(tryGetCopy, ReturnStatement("existing".ToIdentifierName())));
}
else
{
- originalParam = resultVar;
+ // C#: if (original is null) return null;
+ body.Add(IfStatement(BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression))));
}
- if (!membersCopied)
+ if (!type.IsSealedType)
{
- GenerateMemberwiseCopy(type, copierFields, members, originalParam, contextParam, resultVar, body, onlyDeepFields);
- body.Add(ReturnStatement(resultVar));
+ // C#: if (original.GetType() != typeof(T)) { return context.DeepCopy(original); }
+ var exactTypeMatch = BinaryExpression(SyntaxKind.NotEqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax));
+ var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam))));
+ body.Add(IfStatement(exactTypeMatch, ReturnStatement(contextCopy)));
}
- var parameters = new[]
- {
- Parameter(originalParam.Identifier).WithType(type.TypeSyntax),
- Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax())
- };
-
- return MethodDeclaration(returnType, DeepCopyMethodName)
- .AddModifiers(Token(SyntaxKind.PublicKeyword))
- .AddParameterListParameters(parameters)
- .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax())))
- .AddBodyStatements(body.ToArray());
- }
+ // C#: var result = _activator.Create();
+ body.Add(LocalDeclarationStatement(
+ VariableDeclaration(
+ IdentifierName("var"),
+ SingletonSeparatedList(VariableDeclarator(
+ resultVar.Identifier,
+ argumentList: null,
+ initializer: EqualsValueClause(GetCreateValueExpression(type, copierFields)))))));
- private ExpressionSyntax GetCreateValueExpression(ISerializableTypeDescription type, List copierFields)
- {
- return type.UseActivator switch
+ if (type.TrackReferences)
{
- true => InvocationExpression(copierFields.Find(f => f is ActivatorFieldDescription).FieldName.ToIdentifierName().Member("Create")),
- false => type.GetObjectCreationExpression()
- };
- }
-
- private MemberDeclarationSyntax GenerateBaseCopierDeepCopyMethod(
- ISerializableTypeDescription type,
- List copierFields,
- List members,
- bool isExceptionType)
- {
- var inputParam = "input".ToIdentifierName();
- var resultParam = "output".ToIdentifierName();
- var contextParam = "context".ToIdentifierName();
-
- var body = new List();
+ // C#: context.RecordCopy(original, result);
+ body.Add(ExpressionStatement(InvocationExpression(contextParam.Member("RecordCopy"), ArgumentList(SeparatedList(
+ [
+ Argument(originalParam),
+ Argument(resultVar)
+ ])))));
+ }
- if (type.HasComplexBaseType)
+ if (!type.IsSealedType)
+ {
+ // C#: DeepCopy(original, result, context);
+ body.Add(ExpressionStatement(InvocationExpression(IdentifierName("DeepCopy"),
+ ArgumentList(SeparatedList([Argument(originalParam), Argument(resultVar), Argument(contextParam)])))));
+ body.Add(ReturnStatement(resultVar));
+ membersCopied = true;
+ }
+ else if (type.HasComplexBaseType)
{
// C#: _baseTypeCopier.DeepCopy(original, result, context);
body.Add(
ExpressionStatement(
InvocationExpression(
- (isExceptionType ? (ExpressionSyntax)BaseExpression() : IdentifierName(BaseTypeCopierFieldName)).Member(DeepCopyMethodName),
- ArgumentList(SeparatedList(new[]
- {
- Argument(inputParam),
- Argument(resultParam),
+ BaseTypeCopierFieldName.ToIdentifierName().Member(DeepCopyMethodName),
+ ArgumentList(SeparatedList(
+ [
+ Argument(originalParam),
+ Argument(resultVar),
Argument(contextParam)
- })))));
+ ])))));
}
+ }
+ else if (!onlyDeepFields)
+ {
+ // C#: var result = _activator.Create();
+ // or C#: var result = new TField();
+ // or C#: var result = default(TField);
+ body.Add(LocalDeclarationStatement(
+ VariableDeclaration(
+ IdentifierName("var"),
+ SingletonSeparatedList(VariableDeclarator(resultVar.Identifier)
+ .WithInitializer(EqualsValueClause(GetCreateValueExpression(type, copierFields)))))));
+ }
+ else
+ {
+ originalParam = resultVar;
+ }
- var emptyBodyCount = body.Count;
+ if (!membersCopied)
+ {
+ GenerateMemberwiseCopy(type, copierFields, members, originalParam, contextParam, resultVar, body, onlyDeepFields);
+ body.Add(ReturnStatement(resultVar));
+ }
- GenerateMemberwiseCopy(
- type,
- copierFields,
- members,
- inputParam,
- contextParam,
- resultParam,
- body);
+ var parameters = new[]
+ {
+ Parameter(originalParam.Identifier).WithType(type.TypeSyntax),
+ Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax())
+ };
+
+ return MethodDeclaration(returnType, DeepCopyMethodName)
+ .AddModifiers(Token(SyntaxKind.PublicKeyword))
+ .AddParameterListParameters(parameters)
+ .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax())))
+ .AddBodyStatements([.. body]);
+ }
- if (isExceptionType && body.Count == emptyBodyCount)
- return null;
+ private static ExpressionSyntax GetCreateValueExpression(ISerializableTypeDescription type, List copierFields)
+ {
+ return type.UseActivator switch
+ {
+ true => InvocationExpression(GetActivatorField(copierFields).FieldName.ToIdentifierName().Member("Create")),
+ false => type.GetObjectCreationExpression()
+ };
- var parameters = new[]
- {
- Parameter(inputParam.Identifier).WithType(type.TypeSyntax),
- Parameter(resultParam.Identifier).WithType(type.TypeSyntax),
- Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax())
- };
+ static GeneratedFieldDescription GetActivatorField(List copierFields)
+ {
+ var result = copierFields.Find(f => f is ActivatorFieldDescription);
+ Debug.Assert(result is not null);
+ return result!;
+ }
+ }
- var method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), DeepCopyMethodName)
- .AddModifiers(Token(SyntaxKind.PublicKeyword))
- .AddParameterListParameters(parameters)
- .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax())))
- .AddBodyStatements(body.ToArray());
+ private MemberDeclarationSyntax? GenerateBaseCopierDeepCopyMethod(
+ ISerializableTypeDescription type,
+ List copierFields,
+ List members,
+ bool isExceptionType)
+ {
+ var inputParam = "input".ToIdentifierName();
+ var resultParam = "output".ToIdentifierName();
+ var contextParam = "context".ToIdentifierName();
- if (isExceptionType)
- method = method.AddModifiers(Token(SyntaxKind.OverrideKeyword));
+ var body = new List();
- return method;
+ if (type.HasComplexBaseType)
+ {
+ // C#: _baseTypeCopier.DeepCopy(original, result, context);
+ body.Add(
+ ExpressionStatement(
+ InvocationExpression(
+ (isExceptionType ? (ExpressionSyntax)BaseExpression() : IdentifierName(BaseTypeCopierFieldName)).Member(DeepCopyMethodName),
+ ArgumentList(SeparatedList(
+ [
+ Argument(inputParam),
+ Argument(resultParam),
+ Argument(contextParam)
+ ])))));
}
- private void GenerateMemberwiseCopy(
- ISerializableTypeDescription type,
- List copierFields,
- List members,
- IdentifierNameSyntax sourceVar,
- IdentifierNameSyntax contextVar,
- IdentifierNameSyntax destinationVar,
- List body,
- bool onlyDeepFields = false)
+ var emptyBodyCount = body.Count;
+
+ GenerateMemberwiseCopy(
+ type,
+ copierFields,
+ members,
+ inputParam,
+ contextParam,
+ resultParam,
+ body);
+
+ if (isExceptionType && body.Count == emptyBodyCount)
+ return null;
+
+ var parameters = new[]
{
- AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopying", body);
+ Parameter(inputParam.Identifier).WithType(type.TypeSyntax),
+ Parameter(resultParam.Identifier).WithType(type.TypeSyntax),
+ Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax())
+ };
- var copiers = type.IsUnsealedImmutable ? null : copierFields.OfType()
- .Concat(LibraryTypes.StaticCopiers)
- .ToList();
+ var method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), DeepCopyMethodName)
+ .AddModifiers(Token(SyntaxKind.PublicKeyword))
+ .AddParameterListParameters(parameters)
+ .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax())))
+ .AddBodyStatements([.. body]);
- var orderedMembers = members.OrderBy(m => m.Member.FieldId);
- foreach (var member in orderedMembers)
- {
- if (onlyDeepFields && member.IsShallowCopyable) continue;
-
- var getValueExpression = GenerateMemberCopy(
- copierFields,
- inputValue: member.GetGetter(sourceVar),
- contextVar,
- copiers,
- member);
- var memberAssignment = ExpressionStatement(member.GetSetter(destinationVar, getValueExpression));
- body.Add(memberAssignment);
- }
+ if (isExceptionType)
+ method = method.AddModifiers(Token(SyntaxKind.OverrideKeyword));
- AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopied", body);
- }
+ return method;
+ }
+
+ private void GenerateMemberwiseCopy(
+ ISerializableTypeDescription type,
+ List copierFields,
+ List members,
+ IdentifierNameSyntax sourceVar,
+ IdentifierNameSyntax contextVar,
+ IdentifierNameSyntax destinationVar,
+ List body,
+ bool onlyDeepFields = false)
+ {
+ AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopying", body);
+
+ var copiers = type.IsUnsealedImmutable ? null : copierFields.OfType()
+ .Concat(LibraryTypes.StaticCopiers)
+ .ToList();
- public ExpressionSyntax GenerateMemberCopy(
- List copierFields,
- ExpressionSyntax inputValue,
- ExpressionSyntax copyContextVar,
- List copiers,
- ISerializableMember member)
+ var orderedMembers = members.OrderBy(m => m.Member.FieldId);
+ foreach (var member in orderedMembers)
{
- if (copiers is null || member.IsShallowCopyable)
- return inputValue;
+ if (onlyDeepFields && member.IsShallowCopyable) continue;
- var description = member.Member;
+ var getValueExpression = GenerateMemberCopy(
+ copierFields,
+ inputValue: member.GetGetter(sourceVar),
+ contextVar,
+ copiers,
+ member);
+ var memberAssignment = ExpressionStatement(member.GetSetter(destinationVar, getValueExpression));
+ body.Add(memberAssignment);
+ }
- // Copiers can either be static classes or injected into the constructor.
- // Either way, the member signatures are the same.
- var memberType = description.Type;
- var copier = copiers.Find(f => SymbolEqualityComparer.Default.Equals(f.UnderlyingType, memberType));
- ExpressionSyntax getValueExpression;
+ AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopied", body);
+ }
- if (copier is null)
+ public ExpressionSyntax GenerateMemberCopy(
+ List copierFields,
+ ExpressionSyntax inputValue,
+ ExpressionSyntax copyContextVar,
+ List? copiers,
+ ISerializableMember member)
+ {
+ if (copiers is null || member.IsShallowCopyable)
+ return inputValue;
+
+ var description = member.Member;
+
+ // Copiers can either be static classes or injected into the constructor.
+ // Either way, the member signatures are the same.
+ var memberType = description.Type;
+ var copier = copiers.Find(f => SymbolEqualityComparer.Default.Equals(f.UnderlyingType, memberType));
+ ExpressionSyntax getValueExpression;
+
+ if (copier is null)
+ {
+ getValueExpression = InvocationExpression(
+ copyContextVar.Member(DeepCopyMethodName),
+ ArgumentList(SeparatedList([Argument(inputValue)])));
+ }
+ else
+ {
+ ExpressionSyntax copierExpression;
+ var staticCopier = LibraryTypes.StaticCopiers.FindByUnderlyingType(memberType);
+ if (staticCopier != null)
{
- getValueExpression = InvocationExpression(
- copyContextVar.Member(DeepCopyMethodName),
- ArgumentList(SeparatedList(new[] { Argument(inputValue) })));
+ copierExpression = staticCopier.CopierType.ToNameSyntax();
}
else
{
- ExpressionSyntax copierExpression;
- var staticCopier = LibraryTypes.StaticCopiers.FindByUnderlyingType(memberType);
- if (staticCopier != null)
- {
- copierExpression = staticCopier.CopierType.ToNameSyntax();
- }
- else
- {
- var instanceCopier = copierFields.First(f => f is CopierFieldDescription cf && SymbolEqualityComparer.Default.Equals(cf.UnderlyingType, memberType));
- copierExpression = IdentifierName(instanceCopier.FieldName);
- }
-
- getValueExpression = InvocationExpression(
- copierExpression.Member(DeepCopyMethodName),
- ArgumentList(SeparatedList(new[] { Argument(inputValue), Argument(copyContextVar) })));
- if (!SymbolEqualityComparer.Default.Equals(copier.UnderlyingType, member.Member.Type))
- {
- // If the member type type differs from the copier type (eg because the member is an array), cast the result.
- getValueExpression = CastExpression(description.TypeSyntax, getValueExpression);
- }
+ var instanceCopier = copierFields.OfType().First(f => SymbolEqualityComparer.Default.Equals(f.UnderlyingType, memberType));
+ copierExpression = IdentifierName(instanceCopier.FieldName);
}
- return getValueExpression;
- }
-
- private void AddSerializationCallbacks(ISerializableTypeDescription type, IdentifierNameSyntax originalInstance, IdentifierNameSyntax resultInstance, string callbackMethodName, List body)
- {
- var serializationHooks = type.SerializationHooks;
- for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex)
+ getValueExpression = InvocationExpression(
+ copierExpression.Member(DeepCopyMethodName),
+ ArgumentList(SeparatedList([Argument(inputValue), Argument(copyContextVar)])));
+ if (!SymbolEqualityComparer.Default.Equals(copier.UnderlyingType, member.Member.Type))
{
- var hookType = serializationHooks[hookIndex];
- var member = hookType.GetAllMembers(callbackMethodName, Accessibility.Public).FirstOrDefault();
- if (member is null || member.Parameters.Length != 2)
- {
- continue;
- }
-
- var originalArgument = Argument(originalInstance);
- if (member.Parameters[0].RefKind == RefKind.Ref)
- {
- originalArgument = originalArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword));
- }
-
- var resultArgument = Argument(resultInstance);
- if (member.Parameters[1].RefKind == RefKind.Ref)
- {
- resultArgument = resultArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword));
- }
-
- body.Add(ExpressionStatement(InvocationExpression(
- IdentifierName($"_hook{hookIndex}").Member(callbackMethodName),
- ArgumentList(SeparatedList(new[] { originalArgument, resultArgument })))));
+ // If the member type type differs from the copier type (eg because the member is an array), cast the result.
+ getValueExpression = CastExpression(description.TypeSyntax, getValueExpression);
}
}
- internal sealed class BaseCopierFieldDescription : GeneratedFieldDescription
+ return getValueExpression;
+ }
+
+ private static void AddSerializationCallbacks(ISerializableTypeDescription type, IdentifierNameSyntax originalInstance, IdentifierNameSyntax resultInstance, string callbackMethodName, List body)
+ {
+ var serializationHooks = type.SerializationHooks;
+ for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex)
{
- public BaseCopierFieldDescription(TypeSyntax fieldType, bool concreteType = false) : base(fieldType, BaseTypeCopierFieldName)
- => IsInjected = !concreteType;
+ var hookType = serializationHooks[hookIndex];
+ var member = hookType.GetAllMembers(callbackMethodName, Accessibility.Public).FirstOrDefault();
+ if (member is null || member.Parameters.Length != 2)
+ {
+ continue;
+ }
- public override bool IsInjected { get; }
- }
+ var originalArgument = Argument(originalInstance);
+ if (member.Parameters[0].RefKind == RefKind.Ref)
+ {
+ originalArgument = originalArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword));
+ }
- internal sealed class CopierFieldDescription : GeneratedFieldDescription, ICopierDescription
- {
- public CopierFieldDescription(TypeSyntax fieldType, string fieldName, ITypeSymbol underlyingType) : base(fieldType, fieldName)
+ var resultArgument = Argument(resultInstance);
+ if (member.Parameters[1].RefKind == RefKind.Ref)
{
- UnderlyingType = underlyingType;
+ resultArgument = resultArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword));
}
- public ITypeSymbol UnderlyingType { get; }
- public override bool IsInjected => false;
+ body.Add(ExpressionStatement(InvocationExpression(
+ IdentifierName($"_hook{hookIndex}").Member(callbackMethodName),
+ ArgumentList(SeparatedList([originalArgument, resultArgument])))));
}
+ }
+ internal sealed class BaseCopierFieldDescription(TypeSyntax fieldType, bool concreteType = false) : GeneratedFieldDescription(fieldType, BaseTypeCopierFieldName)
+ {
+ public override bool IsInjected { get; } = !concreteType;
}
+
+ internal sealed class CopierFieldDescription(TypeSyntax fieldType, string fieldName, ITypeSymbol underlyingType) : GeneratedFieldDescription(fieldType, fieldName), ICopierDescription
+ {
+ public ITypeSymbol UnderlyingType { get; } = underlyingType;
+ public override bool IsInjected => false;
+ }
+
}
diff --git a/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs
index 377fdf8cd5a..621faa3b215 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs
@@ -1,4 +1,3 @@
-using System.Linq;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
@@ -7,10 +6,10 @@ public static class CanNotGenerateImplicitFieldIdsDiagnostic
{
public const string DiagnosticId = DiagnosticRuleId.CanNotGenerateImplicitFieldIds;
public const string Title = "Implicit field identifiers could not be generated";
- public const string MessageFormat = "Could not generate implicit field identifiers for the type {0}: {reason}";
+ public const string MessageFormat = "Could not generate implicit field identifiers for the type {0}: {1}";
public const string Category = "Usage";
private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true);
- internal static Diagnostic CreateDiagnostic(ISymbol symbol, string reason) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), reason);
+ internal static Diagnostic CreateDiagnostic(ISymbol symbol, string reason, Location? location = null) => Diagnostic.Create(Rule, location ?? symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), reason);
}
diff --git a/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs
index bb0007121e9..36aea5dba92 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs
@@ -2,7 +2,6 @@
namespace Orleans.CodeGenerator.Diagnostics;
-#nullable disable
public static class GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic
{
public const string DiagnosticId = DiagnosticRuleId.GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly;
@@ -12,5 +11,5 @@ public static class GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembl
private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true);
- internal static Diagnostic CreateDiagnostic(AttributeData attribute, ITypeSymbol type) => Diagnostic.Create(Rule, attribute.ApplicationSyntaxReference.SyntaxTree.GetLocation(attribute.ApplicationSyntaxReference.Span), type.ToDisplayString(), attribute.ToString());
+ internal static Diagnostic CreateDiagnostic(AttributeData attribute, ITypeSymbol type) => Diagnostic.Create(Rule, attribute.ApplicationSyntaxReference!.SyntaxTree.GetLocation(attribute.ApplicationSyntaxReference.Span), type.ToDisplayString(), attribute.ToString());
}
diff --git a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs
index 97c77d9e35d..902348bc2e8 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs
@@ -1,4 +1,3 @@
-using System.Linq;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
@@ -13,5 +12,5 @@ public static class InaccessibleSerializableTypeDiagnostic
private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Descsription);
- internal static Diagnostic CreateDiagnostic(ISymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
+ internal static Diagnostic CreateDiagnostic(ISymbol symbol, Location? location = null) => Diagnostic.Create(Rule, location ?? symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
}
diff --git a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs
index 214a40f2e04..d6d68eea209 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs
@@ -12,5 +12,5 @@ public static class InaccessibleSetterDiagnostic
internal static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description);
- public static Diagnostic CreateDiagnostic(Location location, string identifier) => Diagnostic.Create(Rule, location, identifier);
+ public static Diagnostic CreateDiagnostic(Location? location, string identifier) => Diagnostic.Create(Rule, location, identifier);
}
diff --git a/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs
index 1937b0ef2ab..3b260282105 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs
@@ -1,9 +1,7 @@
-using System.Linq;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
-#nullable disable
public static class InvalidRpcMethodReturnTypeDiagnostic
{
public const string RuleId = DiagnosticRuleId.InvalidRpcMethodReturnType;
@@ -14,7 +12,7 @@ public static class InvalidRpcMethodReturnTypeDiagnostic
internal static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description);
- public static Diagnostic CreateDiagnostic(Location location, string returnType, string methodIdentifier, string supportedReturnTypeList) => Diagnostic.Create(Rule, location, returnType, methodIdentifier, supportedReturnTypeList);
+ public static Diagnostic CreateDiagnostic(Location? location, string returnType, string methodIdentifier, string supportedReturnTypeList) => Diagnostic.Create(Rule, location, returnType, methodIdentifier, supportedReturnTypeList);
internal static Diagnostic CreateDiagnostic(InvokableMethodDescription methodDescription)
{
diff --git a/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs
index 0a18a7f18ce..f6acde06951 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs
@@ -1,4 +1,3 @@
-using System.Linq;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
diff --git a/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs
index cc8a5d878ef..0d18b2a3c23 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs
@@ -1,4 +1,3 @@
-using System.Linq;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
@@ -13,5 +12,5 @@ public static class ReferenceAssemblyWithGenerateSerializerDiagnostic
private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description);
- internal static Diagnostic CreateDiagnostic(ISymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
+ internal static Diagnostic CreateDiagnostic(ISymbol symbol, Location? location = null) => Diagnostic.Create(Rule, location ?? symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));
}
diff --git a/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs
index 0cd35584d0e..7d083e116a4 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs
@@ -1,4 +1,3 @@
-using System.Linq;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
diff --git a/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs
index c8ea4fa0228..bfb4da6f531 100644
--- a/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs
+++ b/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs
@@ -1,4 +1,3 @@
-using System;
using Microsoft.CodeAnalysis;
namespace Orleans.CodeGenerator.Diagnostics;
@@ -13,5 +12,5 @@ public static class UnhandledCodeGenerationExceptionDiagnostic
internal static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description);
- internal static Diagnostic CreateDiagnostic(Exception exception) => Diagnostic.Create(Rule, location: null, messageArgs: new[] { exception.ToString(), exception.StackTrace });
+ internal static Diagnostic CreateDiagnostic(Exception exception) => Diagnostic.Create(Rule, location: null, messageArgs: [exception.ToString(), exception.StackTrace]);
}
\ No newline at end of file
diff --git a/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs b/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs
index dc63ee718b8..75a1769378b 100644
--- a/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs
+++ b/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs
@@ -1,8 +1,5 @@
-using System;
using System.Buffers.Binary;
-using System.Collections.Generic;
using System.Collections.Immutable;
-using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.CodeAnalysis;
@@ -14,6 +11,7 @@ namespace Orleans.CodeGenerator;
internal class FieldIdAssignmentHelper
{
+ private const string NoAssignedFieldIdsFailureReason = "no field ids were assigned to any candidate serializable members";
private readonly GenerateFieldIds _implicitMemberSelectionStrategy;
private readonly ImmutableArray _constructorParameters;
private readonly LibraryTypes _libraryTypes;
@@ -30,15 +28,46 @@ public FieldIdAssignmentHelper(INamedTypeSymbol typeSymbol, ImmutableArray _symbols.TryGetValue(symbol, out key);
private bool HasMemberWithIdAnnotation() => Array.Exists(_memberSymbols, member => member.HasAttribute(_libraryTypes.IdAttributeType));
+ private bool HasCandidateSerializableMembers()
+ {
+ foreach (var member in _memberSymbols)
+ {
+ if (member is IFieldSymbol)
+ {
+ return true;
+ }
+
+ if (member is IPropertySymbol property
+ && (_implicitMemberSelectionStrategy != GenerateFieldIds.None
+ || property.HasAttribute(_libraryTypes.IdAttributeType)
+ || PropertyUtility.GetMatchingPrimaryConstructorParameter(property, _constructorParameters) is not null
+ || PropertyUtility.GetMatchingField(property, _memberSymbols) is not null))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
private IEnumerable GetMembers(INamedTypeSymbol symbol)
{
foreach (var member in symbol.GetMembers().OrderBy(m => m.MetadataName))
@@ -68,48 +97,54 @@ private bool ExtractFieldIdAnnotations()
{
if (member is IPropertySymbol prop)
{
- var id = CodeGenerator.GetId(_libraryTypes, prop);
+ var constructorParameter = PropertyUtility.GetMatchingPrimaryConstructorParameter(prop, _constructorParameters);
+ var id = GeneratedCodeUtilities.GetId(_libraryTypes, prop);
if (id.HasValue)
{
_symbols[member] = (id.Value, false);
}
- else if (PropertyUtility.GetMatchingPrimaryConstructorParameter(prop, _constructorParameters) is { } prm)
+ else if (constructorParameter is not null)
{
- id = CodeGenerator.GetId(_libraryTypes, prop);
+ var matchingField = PropertyUtility.GetMatchingField(prop, _memberSymbols);
+ if (matchingField is not null && GeneratedCodeUtilities.GetId(_libraryTypes, matchingField).HasValue)
+ {
+ continue;
+ }
+
+ id = GeneratedCodeUtilities.GetId(_libraryTypes, constructorParameter);
if (id.HasValue)
{
_symbols[member] = (id.Value, true);
}
else
{
- _symbols[member] = ((uint)_constructorParameters.IndexOf(prm), true);
+ _symbols[member] = ((uint)_constructorParameters.IndexOf(constructorParameter), true);
}
}
}
if (member is IFieldSymbol field)
{
- var id = CodeGenerator.GetId(_libraryTypes, field);
+ var id = GeneratedCodeUtilities.GetId(_libraryTypes, field);
var isConstructorParameter = false;
+ IPropertySymbol? property = null;
+ IParameterSymbol? constructorParameter = null;
- if (!id.HasValue)
+ property = PropertyUtility.GetMatchingProperty(field, _memberSymbols);
+ if (property is not null)
{
- var property = PropertyUtility.GetMatchingProperty(field, _memberSymbols);
- if (property is null)
- {
- continue;
- }
+ constructorParameter = PropertyUtility.GetMatchingPrimaryConstructorParameter(property, _constructorParameters);
+ }
- id = CodeGenerator.GetId(_libraryTypes, property);
- if (!id.HasValue)
+ if (!id.HasValue && property is not null)
+ {
+ id = GeneratedCodeUtilities.GetId(_libraryTypes, property);
+ if (!id.HasValue && constructorParameter is not null)
{
- var constructorParameter = _constructorParameters.FirstOrDefault(x => x.Name.Equals(property.Name, StringComparison.OrdinalIgnoreCase));
- if (constructorParameter is not null)
- {
- id = (uint)_constructorParameters.IndexOf(constructorParameter);
- isConstructorParameter = true;
- }
+ id = GeneratedCodeUtilities.GetId(_libraryTypes, constructorParameter)
+ ?? (uint)_constructorParameters.IndexOf(constructorParameter);
+ isConstructorParameter = true;
}
}
diff --git a/src/Orleans.CodeGenerator/GeneratedCodeUtilities.cs b/src/Orleans.CodeGenerator/GeneratedCodeUtilities.cs
new file mode 100644
index 00000000000..a7660c72e5b
--- /dev/null
+++ b/src/Orleans.CodeGenerator/GeneratedCodeUtilities.cs
@@ -0,0 +1,107 @@
+using System.Text;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Orleans.CodeGenerator.Hashing;
+using Orleans.CodeGenerator.SyntaxGeneration;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using static Orleans.CodeGenerator.SyntaxGeneration.SymbolExtensions;
+
+namespace Orleans.CodeGenerator;
+
+internal static class GeneratedCodeUtilities
+{
+ internal const string CodeGeneratorName = "OrleansCodeGen";
+
+ internal static string GetGeneratedNamespaceName(ITypeSymbol type) => type.GetNamespaceAndNesting() switch
+ {
+ { Length: > 0 } ns => $"{CodeGeneratorName}.{ns}",
+ _ => CodeGeneratorName
+ };
+
+ internal static uint? GetId(LibraryTypes libraryTypes, ISymbol memberSymbol)
+ {
+ return memberSymbol.GetAttribute(libraryTypes.IdAttributeType) is { } attr
+ ? (uint)attr.ConstructorArguments.First().Value!
+ : null;
+ }
+
+ internal static string CreateHashedMethodId(IMethodSymbol methodSymbol)
+ {
+ var methodSignature = Format(methodSymbol);
+ var hash = XxHash32.Hash(Encoding.UTF8.GetBytes(methodSignature));
+ return $"{HexConverter.ToString(hash)}";
+
+ static string Format(IMethodSymbol methodInfo)
+ {
+ var result = new StringBuilder();
+ result.Append(methodInfo.ContainingType.ToDisplayName());
+ result.Append('.');
+ result.Append(methodInfo.Name);
+
+ if (methodInfo.IsGenericMethod)
+ {
+ result.Append('<');
+ var first = true;
+ foreach (var typeArgument in methodInfo.TypeArguments)
+ {
+ if (!first) result.Append(',');
+ else first = false;
+ result.Append(typeArgument.Name);
+ }
+
+ result.Append('>');
+ }
+
+ {
+ result.Append('(');
+ var parameters = methodInfo.Parameters;
+ var first = true;
+ foreach (var parameter in parameters)
+ {
+ if (!first)
+ {
+ result.Append(',');
+ }
+
+ var parameterType = parameter.Type;
+ switch (parameterType)
+ {
+ case ITypeParameterSymbol _:
+ result.Append(parameterType.Name);
+ break;
+ default:
+ result.Append(parameterType.ToDisplayName());
+ break;
+ }
+
+ first = false;
+ }
+ }
+
+ result.Append(')');
+ return result.ToString();
+ }
+ }
+
+ internal static string? GetAlias(LibraryTypes libraryTypes, ISymbol symbol) => (string?)symbol.GetAttribute(libraryTypes.AliasAttribute)?.ConstructorArguments.First().Value;
+
+ internal static AttributeListSyntax GetGeneratedCodeAttributes() => GeneratedCodeAttributeSyntax;
+
+ private static readonly AttributeListSyntax GeneratedCodeAttributeSyntax =
+ AttributeList().AddAttributes(
+ Attribute(ParseName("global::System.CodeDom.Compiler.GeneratedCodeAttribute"))
+ .AddArgumentListArguments(
+ AttributeArgument(CodeGeneratorName.GetLiteralExpression()),
+ AttributeArgument(typeof(GeneratedCodeUtilities).Assembly.GetName().Version.ToString().GetLiteralExpression())),
+ Attribute(ParseName("global::System.ComponentModel.EditorBrowsableAttribute"))
+ .AddArgumentListArguments(
+ AttributeArgument(ParseName("global::System.ComponentModel.EditorBrowsableState").Member("Never"))),
+ Attribute(ParseName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute"))
+ );
+
+ internal static AttributeSyntax GetMethodImplAttributeSyntax() => MethodImplAttributeSyntax;
+
+ private static readonly AttributeSyntax MethodImplAttributeSyntax =
+ Attribute(ParseName("global::System.Runtime.CompilerServices.MethodImplAttribute"))
+ .AddArgumentListArguments(AttributeArgument(ParseName("global::System.Runtime.CompilerServices.MethodImplOptions").Member("AggressiveInlining")));
+}
diff --git a/src/Orleans.CodeGenerator/GeneratedSourceOutput.cs b/src/Orleans.CodeGenerator/GeneratedSourceOutput.cs
new file mode 100644
index 00000000000..4fa4c1d269f
--- /dev/null
+++ b/src/Orleans.CodeGenerator/GeneratedSourceOutput.cs
@@ -0,0 +1,383 @@
+using System.Collections.Immutable;
+using System.Globalization;
+using System.Text;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Orleans.CodeGenerator.Hashing;
+using Orleans.CodeGenerator.Model;
+
+namespace Orleans.CodeGenerator;
+
+internal static class GeneratedSourceOutput
+{
+ private const string GeneratedCodeWarningDisable = "#pragma warning disable CS1591, RS0016, RS0041";
+ private const string GeneratedCodeWarningRestore = "#pragma warning restore CS1591, RS0016, RS0041";
+
+ internal static void EmitSourceOutputResult(SourceProductionContext context, SourceOutputResult result)
+ {
+ if (result.Diagnostic is { } diagnostic)
+ {
+ context.ReportDiagnostic(diagnostic);
+ return;
+ }
+
+ if (result.SourceEntry is { } sourceEntry)
+ {
+ context.AddSource(sourceEntry.HintName, sourceEntry.SourceText);
+ }
+ }
+
+ internal static ImmutableArray DeduplicateSerializableTypeResults(
+ ImmutableArray results)
+ {
+ if (results.IsDefaultOrEmpty)
+ {
+ return [];
+ }
+
+ var models = new Dictionary(StringComparer.Ordinal);
+ var diagnostics = new Dictionary(StringComparer.Ordinal);
+ foreach (var result in OrderSerializableTypeResultsForCanonicalSelection(results))
+ {
+ if (result.Model is not null)
+ {
+ var key = CreateSerializableTypeDedupeKey(result);
+ if (!models.ContainsKey(key))
+ {
+ models.Add(key, result);
+ }
+ }
+ else if (result.Diagnostic is { } diagnostic)
+ {
+ var key = $"{CreateSerializableTypeDedupeKey(result)}|{diagnostic.Id}";
+ if (!diagnostics.ContainsKey(key))
+ {
+ diagnostics.Add(key, result);
+ }
+ }
+ }
+
+ return [.. OrderSerializableTypeResultsForEmission(models.Values.Concat(diagnostics.Values))];
+ }
+
+ internal static ImmutableArray GetSerializableTypeModels(
+ ImmutableArray results)
+ {
+ if (results.IsDefaultOrEmpty)
+ {
+ return [];
+ }
+
+ var builder = ImmutableArray.CreateBuilder();
+ foreach (var result in results)
+ {
+ if (result.Model is { } model)
+ {
+ builder.Add(model);
+ }
+ }
+
+ return ModelExtractor.DeduplicateSerializableTypes(builder.ToImmutable());
+ }
+
+ internal static IOrderedEnumerable OrderSerializableTypeResultsForCanonicalSelection(
+ IEnumerable results)
+ => results
+ .Where(static result => result.Model is not null || result.Diagnostic is not null)
+ .OrderBy(static result => result.SourceLocation.SourceOrderGroup)
+ .ThenBy(static result => result.SourceLocation.FilePath, StringComparer.Ordinal)
+ .ThenBy(static result => result.SourceLocation.Position)
+ .ThenBy(static result => result.MetadataIdentity.MetadataName, StringComparer.Ordinal)
+ .ThenBy(static result => result.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal)
+ .ThenBy(static result => result.MetadataIdentity.AssemblyName, StringComparer.Ordinal)
+ .ThenBy(static result => result.TypeSyntax, StringComparer.Ordinal)
+ .ThenBy(static result => result.Diagnostic?.Id ?? string.Empty, StringComparer.Ordinal);
+
+ internal static IOrderedEnumerable OrderSerializableTypeResultsForEmission(
+ IEnumerable results)
+ => results
+ .OrderBy(static result => result.Model is null ? 1 : 0)
+ .ThenBy(static result => result.MetadataIdentity.MetadataName, StringComparer.Ordinal)
+ .ThenBy(static result => result.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal)
+ .ThenBy(static result => result.MetadataIdentity.AssemblyName, StringComparer.Ordinal)
+ .ThenBy(static result => result.TypeSyntax, StringComparer.Ordinal)
+ .ThenBy(static result => result.Diagnostic?.Id ?? string.Empty, StringComparer.Ordinal)
+ .ThenBy(static result => result.SourceLocation.SourceOrderGroup)
+ .ThenBy(static result => result.SourceLocation.FilePath, StringComparer.Ordinal)
+ .ThenBy(static result => result.SourceLocation.Position);
+
+ internal static string CreateSerializableTypeDedupeKey(SerializableTypeResult result)
+ => CreateTypeDedupeKey(result.MetadataIdentity, result.TypeSyntax);
+
+ internal static string CreateTypeDedupeKey(TypeMetadataIdentity metadataIdentity, string typeSyntax)
+ {
+ if (!metadataIdentity.IsEmpty)
+ {
+ return string.Join(
+ "|",
+ "M",
+ metadataIdentity.AssemblyIdentity ?? string.Empty,
+ metadataIdentity.AssemblyName ?? string.Empty,
+ metadataIdentity.MetadataName ?? string.Empty);
+ }
+
+ return string.Join("|", "S", typeSyntax ?? string.Empty);
+ }
+
+ internal static ImmutableArray DeduplicateSourceOutputs(
+ ImmutableArray.Builder sourceEntries)
+ => DeduplicateSourceOutputs(sourceEntries.ToImmutable());
+
+ internal static ImmutableArray DeduplicateSourceOutputs(
+ ImmutableArray sourceEntries)
+ {
+ var emittedSourcesByOriginalHintName = new Dictionary>(StringComparer.Ordinal);
+ var emittedSourceByHintName = new Dictionary(StringComparer.Ordinal);
+ var result = ImmutableArray.CreateBuilder();
+ foreach (var sourceOutput in sourceEntries)
+ {
+ if (sourceOutput.SourceEntry is not { } entry)
+ {
+ result.Add(sourceOutput);
+ continue;
+ }
+
+ var source = entry.Source ?? string.Empty;
+ if (!emittedSourcesByOriginalHintName.TryGetValue(entry.HintName, out var emittedSources))
+ {
+ emittedSources = new Dictionary(StringComparer.Ordinal);
+ emittedSourcesByOriginalHintName.Add(entry.HintName, emittedSources);
+ }
+
+ if (emittedSources.ContainsKey(source))
+ {
+ continue;
+ }
+
+ if (!emittedSourceByHintName.TryGetValue(entry.HintName, out var emittedSource))
+ {
+ emittedSources.Add(source, entry.HintName);
+ emittedSourceByHintName.Add(entry.HintName, source);
+ result.Add(sourceOutput);
+ continue;
+ }
+
+ if (string.Equals(emittedSource, source, StringComparison.Ordinal))
+ {
+ emittedSources.Add(source, entry.HintName);
+ continue;
+ }
+
+ var uniqueHintName = CreateDistinctSourceHintName(entry.HintName, source, emittedSourceByHintName);
+ emittedSources.Add(source, uniqueHintName);
+ emittedSourceByHintName.Add(uniqueHintName, source);
+ result.Add(SourceOutputResult.FromSource(new GeneratedSourceEntry(uniqueHintName, source)));
+ }
+
+ return NormalizeSourceOutputs(result.ToImmutable());
+ }
+
+ internal static ImmutableArray NormalizeSourceOutputs(ImmutableArray sourceOutputs)
+ => StructuralEquality.Normalize(sourceOutputs);
+
+ internal static GeneratedSourceEntry CreateSerializableSourceEntry(
+ string assemblyName,
+ string typeName,
+ TypeMetadataIdentity metadataIdentity,
+ string hintGeneratedNamespace,
+ int genericArity,
+ ClassDeclarationSyntax serializer,
+ ClassDeclarationSyntax? copier,
+ ClassDeclarationSyntax? activator,
+ string generatedNamespace)
+ {
+ var namespacedMembers = new Dictionary>(StringComparer.Ordinal);
+ AddMember(namespacedMembers, generatedNamespace, serializer);
+ if (copier is not null)
+ {
+ AddMember(namespacedMembers, generatedNamespace, copier);
+ }
+
+ if (activator is not null)
+ {
+ AddMember(namespacedMembers, generatedNamespace, activator);
+ }
+
+ return new GeneratedSourceEntry(
+ CreateSerializableHintName(assemblyName, typeName, metadataIdentity, hintGeneratedNamespace, genericArity),
+ CreateSourceString(CreateCompilationUnit(namespacedMembers)));
+ }
+
+ internal static void AddMember(
+ Dictionary> namespacedMembers,
+ string ns,
+ MemberDeclarationSyntax member)
+ {
+ var namespaceName = ns ?? string.Empty;
+ if (!namespacedMembers.TryGetValue(namespaceName, out var members))
+ {
+ members = [];
+ namespacedMembers[namespaceName] = members;
+ }
+
+ members.Add(member);
+ }
+
+ internal static string CreateSourceString(CompilationUnitSyntax unit)
+ {
+ return $"{GeneratedCodeWarningDisable}\r\n{unit.NormalizeWhitespace().ToFullString()}\r\n{GeneratedCodeWarningRestore}";
+ }
+
+ internal static CompilationUnitSyntax CreateCompilationUnit(
+ Dictionary> namespacedMembers,
+ SyntaxList attributeLists = default)
+ {
+ var unit = SyntaxFactory.CompilationUnit().WithAttributeLists(attributeLists);
+ var usingDirectives = SyntaxFactory.List(
+ [
+ SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("global::Orleans.Serialization.Codecs")),
+ SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("global::Orleans.Serialization.GeneratedCodeHelpers")),
+ ]);
+ var members = new List(namespacedMembers.Count);
+
+ foreach (var pair in namespacedMembers.OrderBy(static pair => pair.Key, StringComparer.Ordinal))
+ {
+ if (string.IsNullOrWhiteSpace(pair.Key))
+ {
+ members.AddRange(pair.Value);
+ continue;
+ }
+
+ members.Add(
+ SyntaxFactory.NamespaceDeclaration(SyntaxFactory.ParseName(pair.Key))
+ .WithUsings(usingDirectives)
+ .WithMembers(SyntaxFactory.List(pair.Value)));
+ }
+
+ return unit.WithMembers(SyntaxFactory.List(members));
+ }
+
+ internal static string CreateSerializableHintName(
+ string assemblyName,
+ string typeName,
+ TypeMetadataIdentity metadataIdentity,
+ string generatedNamespace,
+ int genericArity)
+ {
+ var hash = CreateHintNameHash(metadataIdentity, generatedNamespace, typeName, genericArity);
+
+ return $"{assemblyName}.orleans.ser.{SanitizeHintComponent(typeName)}.{hash}.g.cs";
+ }
+
+ internal static string CreateProxyHintName(string assemblyName, ProxyInterfaceDescription interfaceDescription)
+ {
+ var interfaceName = interfaceDescription.InterfaceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ var hash = CreateHintNameHash(
+ TypeMetadataIdentity.Create(interfaceDescription.InterfaceType),
+ interfaceDescription.GeneratedNamespace,
+ interfaceName,
+ interfaceDescription.TypeParameters.Count);
+
+ return $"{assemblyName}.orleans.proxy.{SanitizeHintComponent(interfaceName)}.{hash}.g.cs";
+ }
+
+ internal static string CreateMetadataHintName(string assemblyName)
+ => $"{assemblyName}.orleans.metadata.g.cs";
+
+ internal static string CreateHintNameHash(
+ TypeMetadataIdentity metadataIdentity,
+ string generatedNamespace,
+ string syntaxString,
+ int genericArity)
+ {
+ var builder = new StringBuilder();
+ AppendHashComponent(builder, metadataIdentity.AssemblyIdentity);
+ AppendHashComponent(builder, metadataIdentity.AssemblyName);
+ AppendHashComponent(builder, metadataIdentity.MetadataName);
+ AppendHashComponent(builder, generatedNamespace);
+ AppendHashComponent(builder, syntaxString);
+ AppendHashComponent(builder, genericArity.ToString(CultureInfo.InvariantCulture));
+
+ return CreateStableHash(builder.ToString());
+ }
+
+ internal static string CreateStableHash(string value)
+ => HexConverter.ToString(XxHash32.Hash(Encoding.UTF8.GetBytes(value ?? string.Empty)));
+
+ internal static void AppendHashComponent(StringBuilder builder, string value)
+ {
+ builder.Append(value?.Length ?? 0);
+ builder.Append(':');
+ builder.Append(value ?? string.Empty);
+ builder.Append('|');
+ }
+
+ internal static string CreateDistinctSourceHintName(
+ string hintName,
+ string source,
+ Dictionary emittedSourceByHintName)
+ {
+ var sourceHash = CreateStableHash(source);
+ var candidate = InsertHintNameComponent(hintName, $"collision.{sourceHash}");
+ if (!emittedSourceByHintName.ContainsKey(candidate))
+ {
+ return candidate;
+ }
+
+ for (var index = 1; ; index++)
+ {
+ candidate = InsertHintNameComponent(hintName, $"collision.{sourceHash}.{index}");
+ if (!emittedSourceByHintName.ContainsKey(candidate))
+ {
+ return candidate;
+ }
+ }
+ }
+
+ internal static string InsertHintNameComponent(string hintName, string component)
+ {
+ const string GeneratedSourceSuffix = ".g.cs";
+ if (hintName.EndsWith(GeneratedSourceSuffix, StringComparison.Ordinal))
+ {
+ return $"{hintName.Substring(0, hintName.Length - GeneratedSourceSuffix.Length)}.{component}{GeneratedSourceSuffix}";
+ }
+
+ const string SourceSuffix = ".cs";
+ if (hintName.EndsWith(SourceSuffix, StringComparison.Ordinal))
+ {
+ return $"{hintName.Substring(0, hintName.Length - SourceSuffix.Length)}.{component}{SourceSuffix}";
+ }
+
+ return $"{hintName}.{component}";
+ }
+
+ internal static string SanitizeHintComponent(string value)
+ {
+ if (string.IsNullOrWhiteSpace(value))
+ {
+ return "generated";
+ }
+
+ var builder = new StringBuilder(value.Length);
+ var previousCharacterWasUnderscore = false;
+ foreach (var character in value)
+ {
+ if (char.IsLetterOrDigit(character) || character is '_' or '.')
+ {
+ builder.Append(character);
+ previousCharacterWasUnderscore = false;
+ }
+ else if (!previousCharacterWasUnderscore)
+ {
+ builder.Append('_');
+ previousCharacterWasUnderscore = true;
+ }
+ }
+
+ var result = builder.ToString().Trim('_', '.');
+ return result.Length > 0 ? result : "generated";
+ }
+}
+
+
diff --git a/src/Orleans.CodeGenerator/GeneratorServices.cs b/src/Orleans.CodeGenerator/GeneratorServices.cs
new file mode 100644
index 00000000000..7935ff28b9a
--- /dev/null
+++ b/src/Orleans.CodeGenerator/GeneratorServices.cs
@@ -0,0 +1,22 @@
+using Microsoft.CodeAnalysis;
+
+namespace Orleans.CodeGenerator;
+
+internal interface IGeneratorServices
+{
+ Compilation Compilation { get; }
+ CodeGeneratorOptions Options { get; }
+ LibraryTypes LibraryTypes { get; }
+}
+
+internal sealed class GeneratorServices(Compilation compilation, CodeGeneratorOptions options, LibraryTypes libraryTypes) : IGeneratorServices
+{
+ public GeneratorServices(Compilation compilation, CodeGeneratorOptions options)
+ : this(compilation, options, LibraryTypes.FromCompilation(compilation, options))
+ {
+ }
+
+ public Compilation Compilation { get; } = compilation ?? throw new ArgumentNullException(nameof(compilation));
+ public CodeGeneratorOptions Options { get; } = options ?? throw new ArgumentNullException(nameof(options));
+ public LibraryTypes LibraryTypes { get; } = libraryTypes ?? throw new ArgumentNullException(nameof(libraryTypes));
+}
diff --git a/src/Orleans.CodeGenerator/Hashing/BitOperations.cs b/src/Orleans.CodeGenerator/Hashing/BitOperations.cs
index e890109944f..bf3aa3e7181 100644
--- a/src/Orleans.CodeGenerator/Hashing/BitOperations.cs
+++ b/src/Orleans.CodeGenerator/Hashing/BitOperations.cs
@@ -3,19 +3,18 @@
using System.Runtime.CompilerServices;
-namespace Orleans.CodeGenerator.Hashing
+namespace Orleans.CodeGenerator.Hashing;
+
+internal static class BitOperations
{
- internal static class BitOperations
- {
- ///
- /// Rotates the specified value left by the specified number of bits.
- /// Similar in behavior to the x86 instruction ROL.
- ///
- /// The value to rotate.
- /// The number of bits to rotate by.
- /// Any value outside the range [0..31] is treated as congruent mod 32.
- /// The rotated value.
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static uint RotateLeft(uint value, int offset) => (value << offset) | (value >> (32 - offset));
- }
+ ///
+ /// Rotates the specified value left by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROL.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..31] is treated as congruent mod 32.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static uint RotateLeft(uint value, int offset) => (value << offset) | (value >> (32 - offset));
}
\ No newline at end of file
diff --git a/src/Orleans.CodeGenerator/Hashing/HexConverter.cs b/src/Orleans.CodeGenerator/Hashing/HexConverter.cs
index fbade005f0e..169f24dcd48 100644
--- a/src/Orleans.CodeGenerator/Hashing/HexConverter.cs
+++ b/src/Orleans.CodeGenerator/Hashing/HexConverter.cs
@@ -1,39 +1,37 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System;
using System.Runtime.CompilerServices;
-namespace Orleans.CodeGenerator.Hashing
+namespace Orleans.CodeGenerator.Hashing;
+
+internal static class HexConverter
{
- internal static class HexConverter
+ public static unsafe string ToString(ReadOnlySpan bytes)
{
- public static unsafe string ToString(ReadOnlySpan bytes)
- {
- // Adapted from: https://github.com/dotnet/runtime/blob/f156fb9dcf121e536b93ae90bcc5e8e6d5336062/src/libraries/Common/src/System/HexConverter.cs#L196
-
- Span result = bytes.Length > 16 ?
- new char[bytes.Length * 2].AsSpan() :
- stackalloc char[bytes.Length * 2];
+ // Adapted from: https://github.com/dotnet/runtime/blob/f156fb9dcf121e536b93ae90bcc5e8e6d5336062/src/libraries/Common/src/System/HexConverter.cs#L196
+
+ Span result = bytes.Length > 16 ?
+ new char[bytes.Length * 2].AsSpan() :
+ stackalloc char[bytes.Length * 2];
- int pos = 0;
- foreach (byte b in bytes)
- {
- ToCharsBuffer(b, result, pos);
- pos += 2;
- }
+ int pos = 0;
+ foreach (byte b in bytes)
+ {
+ ToCharsBuffer(b, result, pos);
+ pos += 2;
+ }
- return result.ToString();
+ return result.ToString();
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- static void ToCharsBuffer(byte value, Span buffer, int startingIndex = 0)
- {
- var difference = ((value & 0xF0U) << 4) + (value & 0x0FU) - 0x8989U;
- var packedResult = (((uint)-(int)difference & 0x7070U) >> 4) + difference + 0xB9B9U;
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ static void ToCharsBuffer(byte value, Span buffer, int startingIndex = 0)
+ {
+ var difference = ((value & 0xF0U) << 4) + (value & 0x0FU) - 0x8989U;
+ var packedResult = (((uint)-(int)difference & 0x7070U) >> 4) + difference + 0xB9B9U;
- buffer[startingIndex + 1] = (char)(packedResult & 0xFF);
- buffer[startingIndex] = (char)(packedResult >> 8);
- }
+ buffer[startingIndex + 1] = (char)(packedResult & 0xFF);
+ buffer[startingIndex] = (char)(packedResult >> 8);
}
}
}
\ No newline at end of file
diff --git a/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs b/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs
index 34e1f18a3a1..6802ab670e3 100644
--- a/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs
+++ b/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs
@@ -1,347 +1,342 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System;
using System.Buffers;
using System.ComponentModel;
using System.Diagnostics;
-using System.IO;
-using System.Threading;
-using System.Threading.Tasks;
-namespace Orleans.CodeGenerator.Hashing
+namespace Orleans.CodeGenerator.Hashing;
+
+///
+/// Represents a non-cryptographic hash algorithm.
+///
+internal abstract class NonCryptographicHashAlgorithm
{
///
- /// Represents a non-cryptographic hash algorithm.
+ /// Gets the number of bytes produced from this hash algorithm.
///
- internal abstract class NonCryptographicHashAlgorithm
- {
- ///
- /// Gets the number of bytes produced from this hash algorithm.
- ///
- /// The number of bytes produced from this hash algorithm.
- public int HashLengthInBytes { get; }
-
- ///
- /// Called from constructors in derived classes to initialize the
- /// class.
- ///
- ///
- /// The number of bytes produced from this hash algorithm.
- ///
- ///
- /// is less than 1.
- ///
- protected NonCryptographicHashAlgorithm(int hashLengthInBytes)
- {
- if (hashLengthInBytes < 1)
- throw new ArgumentOutOfRangeException(nameof(hashLengthInBytes));
+ /// The number of bytes produced from this hash algorithm.
+ public int HashLengthInBytes { get; }
- HashLengthInBytes = hashLengthInBytes;
- }
-
- ///
- /// When overridden in a derived class,
- /// appends the contents of to the data already
- /// processed for the current hash computation.
- ///
- /// The data to process.
- public abstract void Append(ReadOnlySpan source);
-
- ///
- /// When overridden in a derived class,
- /// resets the hash computation to the initial state.
- ///
- public abstract void Reset();
-
- ///
- /// When overridden in a derived class,
- /// writes the computed hash value to
- /// without modifying accumulated state.
- ///
- /// The buffer that receives the computed hash value.
- ///
- ///
- /// Implementations of this method must write exactly
- /// bytes to .
- /// Do not assume that the buffer was zero-initialized.
- ///
- ///
- /// The class validates the
- /// size of the buffer before calling this method, and slices the span
- /// down to be exactly in length.
- ///
- ///
- protected abstract void GetCurrentHashCore(Span destination);
-
- ///
- /// Appends the contents of to the data already
- /// processed for the current hash computation.
- ///
- /// The data to process.
- ///
- /// is .
- ///
- public void Append(byte[] source)
- {
- if (source is null)
- {
- throw new ArgumentNullException(nameof(source));
- }
+ ///
+ /// Called from constructors in derived classes to initialize the
+ /// class.
+ ///
+ ///
+ /// The number of bytes produced from this hash algorithm.
+ ///
+ ///
+ /// is less than 1.
+ ///
+ protected NonCryptographicHashAlgorithm(int hashLengthInBytes)
+ {
+ if (hashLengthInBytes < 1)
+ throw new ArgumentOutOfRangeException(nameof(hashLengthInBytes));
- Append(new ReadOnlySpan(source));
- }
+ HashLengthInBytes = hashLengthInBytes;
+ }
- ///
- /// Appends the contents of to the data already
- /// processed for the current hash computation.
- ///
- /// The data to process.
- ///
- /// is .
- ///
- ///
- public void Append(Stream stream)
- {
- if (stream is null)
- {
- throw new ArgumentNullException(nameof(stream));
- }
+ ///
+ /// When overridden in a derived class,
+ /// appends the contents of to the data already
+ /// processed for the current hash computation.
+ ///
+ /// The data to process.
+ public abstract void Append(ReadOnlySpan source);
- byte[] buffer = ArrayPool.Shared.Rent(4096);
+ ///
+ /// When overridden in a derived class,
+ /// resets the hash computation to the initial state.
+ ///
+ public abstract void Reset();
- while (true)
- {
- int read = stream.Read(buffer, 0, buffer.Length);
+ ///
+ /// When overridden in a derived class,
+ /// writes the computed hash value to
+ /// without modifying accumulated state.
+ ///
+ /// The buffer that receives the computed hash value.
+ ///
+ ///
+ /// Implementations of this method must write exactly
+ /// bytes to .
+ /// Do not assume that the buffer was zero-initialized.
+ ///
+ ///
+ /// The class validates the
+ /// size of the buffer before calling this method, and slices the span
+ /// down to be exactly in length.
+ ///
+ ///
+ protected abstract void GetCurrentHashCore(Span destination);
- if (read == 0)
- {
- break;
- }
+ ///
+ /// Appends the contents of to the data already
+ /// processed for the current hash computation.
+ ///
+ /// The data to process.
+ ///
+ /// is .
+ ///
+ public void Append(byte[] source)
+ {
+ if (source is null)
+ {
+ throw new ArgumentNullException(nameof(source));
+ }
- Append(new ReadOnlySpan(buffer, 0, read));
- }
+ Append(new ReadOnlySpan(source));
+ }
- ArrayPool.Shared.Return(buffer);
+ ///
+ /// Appends the contents of to the data already
+ /// processed for the current hash computation.
+ ///
+ /// The data to process.
+ ///
+ /// is .
+ ///
+ ///
+ public void Append(Stream stream)
+ {
+ if (stream is null)
+ {
+ throw new ArgumentNullException(nameof(stream));
}
- ///
- /// Asychronously reads the contents of
- /// and appends them to the data already
- /// processed for the current hash computation.
- ///
- /// The data to process.
- ///
- /// The token to monitor for cancellation requests.
- /// The default value is .
- ///
- ///
- /// A task that represents the asynchronous append operation.
- ///
- ///
- /// is .
- ///
- public Task AppendAsync(Stream stream, CancellationToken cancellationToken = default)
+ byte[] buffer = ArrayPool.Shared.Rent(4096);
+
+ while (true)
{
- if (stream is null)
+ int read = stream.Read(buffer, 0, buffer.Length);
+
+ if (read == 0)
{
- throw new ArgumentNullException(nameof(stream));
+ break;
}
- return AppendAsyncCore(stream, cancellationToken);
+ Append(new ReadOnlySpan(buffer, 0, read));
}
- private async Task AppendAsyncCore(Stream stream, CancellationToken cancellationToken)
+ ArrayPool.Shared.Return(buffer);
+ }
+
+ ///
+ /// Asychronously reads the contents of
+ /// and appends them to the data already
+ /// processed for the current hash computation.
+ ///
+ /// The data to process.
+ ///
+ /// The token to monitor for cancellation requests.
+ /// The default value is .
+ ///
+ ///
+ /// A task that represents the asynchronous append operation.
+ ///
+ ///
+ /// is .
+ ///
+ public Task AppendAsync(Stream stream, CancellationToken cancellationToken = default)
+ {
+ if (stream is null)
{
- byte[] buffer = ArrayPool.Shared.Rent(4096);
+ throw new ArgumentNullException(nameof(stream));
+ }
- while (true)
- {
+ return AppendAsyncCore(stream, cancellationToken);
+ }
+
+ private async Task AppendAsyncCore(Stream stream, CancellationToken cancellationToken)
+ {
+ byte[] buffer = ArrayPool.Shared.Rent(4096);
+
+ while (true)
+ {
#if NETCOREAPP
- int read = await stream.ReadAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false);
+ int read = await stream.ReadAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false);
#else
- int read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
+ int read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
#endif
- if (read == 0)
- {
- break;
- }
-
- Append(new ReadOnlySpan(buffer, 0, read));
+ if (read == 0)
+ {
+ break;
}
- ArrayPool.Shared.Return(buffer);
+ Append(new ReadOnlySpan(buffer, 0, read));
}
- ///
- /// Gets the current computed hash value without modifying accumulated state.
- ///
- ///
- /// The hash value for the data already provided.
- ///
- public byte[] GetCurrentHash()
+ ArrayPool.Shared.Return(buffer);
+ }
+
+ ///
+ /// Gets the current computed hash value without modifying accumulated state.
+ ///
+ ///
+ /// The hash value for the data already provided.
+ ///
+ public byte[] GetCurrentHash()
+ {
+ byte[] ret = new byte[HashLengthInBytes];
+ GetCurrentHashCore(ret);
+ return ret;
+ }
+
+ ///
+ /// Attempts to write the computed hash value to
+ /// without modifying accumulated state.
+ ///
+ /// The buffer that receives the computed hash value.
+ ///
+ /// On success, receives the number of bytes written to .
+ ///
+ ///
+ /// if is long enough to receive
+ /// the computed hash value; otherwise, .
+ ///
+ public bool TryGetCurrentHash(Span destination, out int bytesWritten)
+ {
+ if (destination.Length < HashLengthInBytes)
{
- byte[] ret = new byte[HashLengthInBytes];
- GetCurrentHashCore(ret);
- return ret;
+ bytesWritten = 0;
+ return false;
}
- ///
- /// Attempts to write the computed hash value to
- /// without modifying accumulated state.
- ///
- /// The buffer that receives the computed hash value.
- ///
- /// On success, receives the number of bytes written to .
- ///
- ///
- /// if is long enough to receive
- /// the computed hash value; otherwise, .
- ///
- public bool TryGetCurrentHash(Span destination, out int bytesWritten)
- {
- if (destination.Length < HashLengthInBytes)
- {
- bytesWritten = 0;
- return false;
- }
+ GetCurrentHashCore(destination.Slice(0, HashLengthInBytes));
+ bytesWritten = HashLengthInBytes;
+ return true;
+ }
- GetCurrentHashCore(destination.Slice(0, HashLengthInBytes));
- bytesWritten = HashLengthInBytes;
- return true;
+ ///
+ /// Writes the computed hash value to
+ /// without modifying accumulated state.
+ ///
+ /// The buffer that receives the computed hash value.
+ ///
+ /// The number of bytes written to ,
+ /// which is always .
+ ///
+ ///
+ /// is shorter than .
+ ///
+ public int GetCurrentHash(Span destination)
+ {
+ if (destination.Length < HashLengthInBytes)
+ {
+ throw new ArgumentException(@"destination too short", nameof(destination));
}
- ///
- /// Writes the computed hash value to
- /// without modifying accumulated state.
- ///
- /// The buffer that receives the computed hash value.
- ///
- /// The number of bytes written to ,
- /// which is always .
- ///
- ///
- /// is shorter than .
- ///
- public int GetCurrentHash(Span destination)
- {
- if (destination.Length < HashLengthInBytes)
- {
- throw new ArgumentException(@"destination too short", nameof(destination));
- }
+ GetCurrentHashCore(destination.Slice(0, HashLengthInBytes));
+ return HashLengthInBytes;
+ }
- GetCurrentHashCore(destination.Slice(0, HashLengthInBytes));
- return HashLengthInBytes;
- }
+ ///
+ /// Gets the current computed hash value and clears the accumulated state.
+ ///
+ ///
+ /// The hash value for the data already provided.
+ ///
+ public byte[] GetHashAndReset()
+ {
+ byte[] ret = new byte[HashLengthInBytes];
+ GetHashAndResetCore(ret);
+ return ret;
+ }
- ///
- /// Gets the current computed hash value and clears the accumulated state.
- ///
- ///
- /// The hash value for the data already provided.
- ///
- public byte[] GetHashAndReset()
+ ///
+ /// Attempts to write the computed hash value to .
+ /// If successful, clears the accumulated state.
+ ///
+ /// The buffer that receives the computed hash value.
+ ///
+ /// On success, receives the number of bytes written to .
+ ///
+ ///
+ /// and clears the accumulated state
+ /// if is long enough to receive
+ /// the computed hash value; otherwise, .
+ ///
+ public bool TryGetHashAndReset(Span destination, out int bytesWritten)
+ {
+ if (destination.Length < HashLengthInBytes)
{
- byte[] ret = new byte[HashLengthInBytes];
- GetHashAndResetCore(ret);
- return ret;
+ bytesWritten = 0;
+ return false;
}
- ///
- /// Attempts to write the computed hash value to .
- /// If successful, clears the accumulated state.
- ///
- /// The buffer that receives the computed hash value.
- ///
- /// On success, receives the number of bytes written to .
- ///
- ///
- /// and clears the accumulated state
- /// if is long enough to receive
- /// the computed hash value; otherwise, .
- ///
- public bool TryGetHashAndReset(Span destination, out int bytesWritten)
- {
- if (destination.Length < HashLengthInBytes)
- {
- bytesWritten = 0;
- return false;
- }
-
- GetHashAndResetCore(destination.Slice(0, HashLengthInBytes));
- bytesWritten = HashLengthInBytes;
- return true;
- }
+ GetHashAndResetCore(destination.Slice(0, HashLengthInBytes));
+ bytesWritten = HashLengthInBytes;
+ return true;
+ }
- ///
- /// Writes the computed hash value to
- /// then clears the accumulated state.
- ///
- /// The buffer that receives the computed hash value.
- ///
- /// The number of bytes written to ,
- /// which is always .
- ///
- ///
- /// is shorter than .
- ///
- public int GetHashAndReset(Span destination)
+ ///
+ /// Writes the computed hash value to
+ /// then clears the accumulated state.
+ ///
+ /// The buffer that receives the computed hash value.
+ ///
+ /// The number of bytes written to ,
+ /// which is always .
+ ///
+ ///
+ /// is shorter than .
+ ///
+ public int GetHashAndReset(Span destination)
+ {
+ if (destination.Length < HashLengthInBytes)
{
- if (destination.Length < HashLengthInBytes)
- {
- throw new ArgumentException(@"destination too short", nameof(destination));
- }
-
- GetHashAndResetCore(destination.Slice(0, HashLengthInBytes));
- return HashLengthInBytes;
+ throw new ArgumentException(@"destination too short", nameof(destination));
}
- ///
- /// Writes the computed hash value to
- /// then clears the accumulated state.
- ///
- /// The buffer that receives the computed hash value.
- ///
- ///
- /// Implementations of this method must write exactly
- /// bytes to .
- /// Do not assume that the buffer was zero-initialized.
- ///
- ///
- /// The class validates the
- /// size of the buffer before calling this method, and slices the span
- /// down to be exactly in length.
- ///
- ///
- /// The default implementation of this method calls
- /// followed by .
- /// Overrides of this method do not need to call either of those methods,
- /// but must ensure that the caller cannot observe a difference in behavior.
- ///
- ///
- protected virtual void GetHashAndResetCore(Span destination)
- {
- Debug.Assert(destination.Length == HashLengthInBytes);
+ GetHashAndResetCore(destination.Slice(0, HashLengthInBytes));
+ return HashLengthInBytes;
+ }
- GetCurrentHashCore(destination);
- Reset();
- }
+ ///
+ /// Writes the computed hash value to
+ /// then clears the accumulated state.
+ ///
+ /// The buffer that receives the computed hash value.
+ ///
+ ///
+ /// Implementations of this method must write exactly
+ /// bytes to .
+ /// Do not assume that the buffer was zero-initialized.
+ ///
+ ///
+ /// The class validates the
+ /// size of the buffer before calling this method, and slices the span
+ /// down to be exactly in length.
+ ///
+ ///
+ /// The default implementation of this method calls
+ /// followed by .
+ /// Overrides of this method do not need to call either of those methods,
+ /// but must ensure that the caller cannot observe a difference in behavior.
+ ///
+ ///
+ protected virtual void GetHashAndResetCore(Span destination)
+ {
+ Debug.Assert(destination.Length == HashLengthInBytes);
- ///
- /// This method is not supported and should not be called.
- /// Call or
- /// instead.
- ///
- /// This method will always throw a .
- /// In all cases.
- [EditorBrowsable(EditorBrowsableState.Never)]
- [Obsolete("Use GetCurrentHash() to retrieve the computed hash code.", true)]
+ GetCurrentHashCore(destination);
+ Reset();
+ }
+
+ ///
+ /// This method is not supported and should not be called.
+ /// Call or
+ /// instead.
+ ///
+ /// This method will always throw a .
+ /// In all cases.
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ [Obsolete("Use GetCurrentHash() to retrieve the computed hash code.", true)]
#pragma warning disable CS0809 // Obsolete member overrides non-obsolete member
- public override int GetHashCode()
+ public override int GetHashCode()
#pragma warning restore CS0809 // Obsolete member overrides non-obsolete member
- {
- throw new NotSupportedException("GetHashCode");
- }
+ {
+ throw new NotSupportedException("GetHashCode");
}
}
\ No newline at end of file
diff --git a/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs b/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs
index 19a84375ac7..69cbf4714d8 100644
--- a/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs
+++ b/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Runtime.CompilerServices;
@@ -9,100 +8,99 @@
// Implemented from the specification at
// https://github.com/Cyan4973/xxHash/blob/f9155bd4c57e2270a4ffbb176485e5d713de1c9b/doc/xxhash_spec.md
-namespace Orleans.CodeGenerator.Hashing
+namespace Orleans.CodeGenerator.Hashing;
+
+internal sealed partial class XxHash32
{
- internal sealed partial class XxHash32
+ private struct State
{
- private struct State
+ private const uint Prime32_1 = 0x9E3779B1;
+ private const uint Prime32_2 = 0x85EBCA77;
+ private const uint Prime32_3 = 0xC2B2AE3D;
+ private const uint Prime32_4 = 0x27D4EB2F;
+ private const uint Prime32_5 = 0x165667B1;
+
+ private uint _acc1;
+ private uint _acc2;
+ private uint _acc3;
+ private uint _acc4;
+ private readonly uint _smallAcc;
+ private bool _hadFullStripe;
+
+ internal State(uint seed)
{
- private const uint Prime32_1 = 0x9E3779B1;
- private const uint Prime32_2 = 0x85EBCA77;
- private const uint Prime32_3 = 0xC2B2AE3D;
- private const uint Prime32_4 = 0x27D4EB2F;
- private const uint Prime32_5 = 0x165667B1;
-
- private uint _acc1;
- private uint _acc2;
- private uint _acc3;
- private uint _acc4;
- private readonly uint _smallAcc;
- private bool _hadFullStripe;
-
- internal State(uint seed)
- {
- _acc1 = seed + unchecked(Prime32_1 + Prime32_2);
- _acc2 = seed + Prime32_2;
- _acc3 = seed;
- _acc4 = seed - Prime32_1;
+ _acc1 = seed + unchecked(Prime32_1 + Prime32_2);
+ _acc2 = seed + Prime32_2;
+ _acc3 = seed;
+ _acc4 = seed - Prime32_1;
- _smallAcc = seed + Prime32_5;
- _hadFullStripe = false;
- }
+ _smallAcc = seed + Prime32_5;
+ _hadFullStripe = false;
+ }
- internal void ProcessStripe(ReadOnlySpan source)
- {
- Debug.Assert(source.Length >= StripeSize);
- source = source.Slice(0, StripeSize);
+ internal void ProcessStripe(ReadOnlySpan source)
+ {
+ Debug.Assert(source.Length >= StripeSize);
+ source = source.Slice(0, StripeSize);
- _acc1 = ApplyRound(_acc1, source);
- _acc2 = ApplyRound(_acc2, source.Slice(sizeof(uint)));
- _acc3 = ApplyRound(_acc3, source.Slice(2 * sizeof(uint)));
- _acc4 = ApplyRound(_acc4, source.Slice(3 * sizeof(uint)));
+ _acc1 = ApplyRound(_acc1, source);
+ _acc2 = ApplyRound(_acc2, source.Slice(sizeof(uint)));
+ _acc3 = ApplyRound(_acc3, source.Slice(2 * sizeof(uint)));
+ _acc4 = ApplyRound(_acc4, source.Slice(3 * sizeof(uint)));
- _hadFullStripe = true;
- }
+ _hadFullStripe = true;
+ }
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private readonly uint Converge()
- {
- return
- BitOperations.RotateLeft(_acc1, 1) +
- BitOperations.RotateLeft(_acc2, 7) +
- BitOperations.RotateLeft(_acc3, 12) +
- BitOperations.RotateLeft(_acc4, 18);
- }
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private readonly uint Converge()
+ {
+ return
+ BitOperations.RotateLeft(_acc1, 1) +
+ BitOperations.RotateLeft(_acc2, 7) +
+ BitOperations.RotateLeft(_acc3, 12) +
+ BitOperations.RotateLeft(_acc4, 18);
+ }
- private static uint ApplyRound(uint acc, ReadOnlySpan lane)
+ private static uint ApplyRound(uint acc, ReadOnlySpan lane)
+ {
+ acc += BinaryPrimitives.ReadUInt32LittleEndian(lane) * Prime32_2;
+ acc = BitOperations.RotateLeft(acc, 13);
+ acc *= Prime32_1;
+
+ return acc;
+ }
+
+ internal readonly uint Complete(int length, ReadOnlySpan remaining)
+ {
+ uint acc = _hadFullStripe ? Converge() : _smallAcc;
+
+ acc += (uint)length;
+
+ while (remaining.Length >= sizeof(uint))
{
- acc += BinaryPrimitives.ReadUInt32LittleEndian(lane) * Prime32_2;
- acc = BitOperations.RotateLeft(acc, 13);
- acc *= Prime32_1;
+ uint lane = BinaryPrimitives.ReadUInt32LittleEndian(remaining);
+ acc += lane * Prime32_3;
+ acc = BitOperations.RotateLeft(acc, 17);
+ acc *= Prime32_4;
- return acc;
+ remaining = remaining.Slice(sizeof(uint));
}
- internal readonly uint Complete(int length, ReadOnlySpan remaining)
+ for (int i = 0; i < remaining.Length; i++)
{
- uint acc = _hadFullStripe ? Converge() : _smallAcc;
-
- acc += (uint)length;
-
- while (remaining.Length >= sizeof(uint))
- {
- uint lane = BinaryPrimitives.ReadUInt32LittleEndian(remaining);
- acc += lane * Prime32_3;
- acc = BitOperations.RotateLeft(acc, 17);
- acc *= Prime32_4;
-
- remaining = remaining.Slice(sizeof(uint));
- }
-
- for (int i = 0; i < remaining.Length; i++)
- {
- uint lane = remaining[i];
- acc += lane * Prime32_5;
- acc = BitOperations.RotateLeft(acc, 11);
- acc *= Prime32_1;
- }
-
- acc ^= (acc >> 15);
- acc *= Prime32_2;
- acc ^= (acc >> 13);
- acc *= Prime32_3;
- acc ^= (acc >> 16);
-
- return acc;
+ uint lane = remaining[i];
+ acc += lane * Prime32_5;
+ acc = BitOperations.RotateLeft(acc, 11);
+ acc *= Prime32_1;
}
+
+ acc ^= acc >> 15;
+ acc *= Prime32_2;
+ acc ^= acc >> 13;
+ acc *= Prime32_3;
+ acc ^= acc >> 16;
+
+ return acc;
}
}
}
\ No newline at end of file
diff --git a/src/Orleans.CodeGenerator/Hashing/XxHash32.cs b/src/Orleans.CodeGenerator/Hashing/XxHash32.cs
index 39b304681b8..4a7702e194d 100644
--- a/src/Orleans.CodeGenerator/Hashing/XxHash32.cs
+++ b/src/Orleans.CodeGenerator/Hashing/XxHash32.cs
@@ -1,234 +1,232 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System;
using System.Buffers.Binary;
// Implemented from the specification at
// https://github.com/Cyan4973/xxHash/blob/f9155bd4c57e2270a4ffbb176485e5d713de1c9b/doc/xxhash_spec.md
-namespace Orleans.CodeGenerator.Hashing
+namespace Orleans.CodeGenerator.Hashing;
+
+///
+/// Provides an implementation of the XxHash32 algorithm.
+///
+internal sealed partial class XxHash32 : NonCryptographicHashAlgorithm
{
+ private const int HashSize = sizeof(uint);
+ private const int StripeSize = 4 * sizeof(uint);
+
+ private readonly uint _seed;
+ private State _state;
+ private byte[]? _holdback;
+ private int _length;
+
///
- /// Provides an implementation of the XxHash32 algorithm.
+ /// Initializes a new instance of the class.
///
- internal sealed partial class XxHash32 : NonCryptographicHashAlgorithm
+ ///
+ /// The XxHash32 algorithm supports an optional seed value.
+ /// Instances created with this constructor use the default seed, zero.
+ ///
+ public XxHash32()
+ : this(0)
{
- private const int HashSize = sizeof(uint);
- private const int StripeSize = 4 * sizeof(uint);
-
- private readonly uint _seed;
- private State _state;
- private byte[]? _holdback;
- private int _length;
-
- ///
- /// Initializes a new instance of the class.
- ///
- ///
- /// The XxHash32 algorithm supports an optional seed value.
- /// Instances created with this constructor use the default seed, zero.
- ///
- public XxHash32()
- : this(0)
- {
- }
+ }
- ///
- /// Initializes a new instance of the class with
- /// a specified seed.
- ///
- ///
- /// The hash seed value for computations from this instance.
- ///
- public XxHash32(int seed)
- : base(HashSize)
- {
- _seed = (uint)seed;
- Reset();
- }
+ ///
+ /// Initializes a new instance of the class with
+ /// a specified seed.
+ ///
+ ///
+ /// The hash seed value for computations from this instance.
+ ///
+ public XxHash32(int seed)
+ : base(HashSize)
+ {
+ _seed = (uint)seed;
+ Reset();
+ }
- ///
- /// Resets the hash computation to the initial state.
- ///
- public override void Reset()
- {
- _state = new State(_seed);
- _length = 0;
- }
+ ///
+ /// Resets the hash computation to the initial state.
+ ///
+ public override void Reset()
+ {
+ _state = new State(_seed);
+ _length = 0;
+ }
- ///
- /// Appends the contents of to the data already
- /// processed for the current hash computation.
- ///
- /// The data to process.
- public override void Append(ReadOnlySpan source)
- {
- // Every time we've read 16 bytes, process the stripe.
- // Data that isn't perfectly mod-16 gets stored in a holdback
- // buffer.
+ ///
+ /// Appends the contents of to the data already
+ /// processed for the current hash computation.
+ ///
+ /// The data to process.
+ public override void Append(ReadOnlySpan source)
+ {
+ // Every time we've read 16 bytes, process the stripe.
+ // Data that isn't perfectly mod-16 gets stored in a holdback
+ // buffer.
- int held = _length & 0x0F;
+ int held = _length & 0x0F;
- if (held != 0)
- {
- int remain = StripeSize - held;
-
- if (source.Length >= remain)
- {
- source.Slice(0, remain).CopyTo(_holdback.AsSpan(held));
- _state.ProcessStripe(_holdback);
-
- source = source.Slice(remain);
- _length += remain;
- }
- else
- {
- source.CopyTo(_holdback.AsSpan(held));
- _length += source.Length;
- return;
- }
- }
+ if (held != 0)
+ {
+ int remain = StripeSize - held;
- while (source.Length >= StripeSize)
+ if (source.Length >= remain)
{
- _state.ProcessStripe(source);
- source = source.Slice(StripeSize);
- _length += StripeSize;
- }
+ source.Slice(0, remain).CopyTo(_holdback.AsSpan(held));
+ _state.ProcessStripe(_holdback);
- if (source.Length > 0)
+ source = source.Slice(remain);
+ _length += remain;
+ }
+ else
{
- _holdback ??= new byte[StripeSize];
- source.CopyTo(_holdback);
+ source.CopyTo(_holdback.AsSpan(held));
_length += source.Length;
+ return;
}
}
- ///
- /// Writes the computed hash value to
- /// without modifying accumulated state.
- ///
- protected override void GetCurrentHashCore(Span destination)
+ while (source.Length >= StripeSize)
{
- int remainingLength = _length & 0x0F;
- ReadOnlySpan remaining = ReadOnlySpan.Empty;
-
- if (remainingLength > 0)
- {
- remaining = new ReadOnlySpan(_holdback, 0, remainingLength);
- }
-
- uint acc = _state.Complete(_length, remaining);
- BinaryPrimitives.WriteUInt32BigEndian(destination, acc);
+ _state.ProcessStripe(source);
+ source = source.Slice(StripeSize);
+ _length += StripeSize;
}
- ///
- /// Computes the XxHash32 hash of the provided data.
- ///
- /// The data to hash.
- /// The XxHash32 hash of the provided data.
- ///
- /// is .
- ///
- public static byte[] Hash(byte[] source)
+ if (source.Length > 0)
{
- if (source is null)
- {
- throw new ArgumentNullException(nameof(source));
- }
-
- return Hash(new ReadOnlySpan(source));
+ _holdback ??= new byte[StripeSize];
+ source.CopyTo(_holdback);
+ _length += source.Length;
}
+ }
- ///
- /// Computes the XxHash32 hash of the provided data using the provided seed.
- ///
- /// The data to hash.
- /// The seed value for this hash computation.
- /// The XxHash32 hash of the provided data.
- ///
- /// is .
- ///
- public static byte[] Hash(byte[] source, int seed)
- {
- if (source is null)
- {
- throw new ArgumentNullException(nameof(source));
- }
+ ///
+ /// Writes the computed hash value to
+ /// without modifying accumulated state.
+ ///
+ protected override void GetCurrentHashCore(Span destination)
+ {
+ int remainingLength = _length & 0x0F;
+ ReadOnlySpan remaining = [];
- return Hash(new ReadOnlySpan(source), seed);
+ if (remainingLength > 0)
+ {
+ remaining = new ReadOnlySpan(_holdback, 0, remainingLength);
}
- ///
- /// Computes the XxHash32 hash of the provided data.
- ///
- /// The data to hash.
- /// The seed value for this hash computation. The default is zero.
- /// The XxHash32 hash of the provided data.
- public static byte[] Hash(ReadOnlySpan source, int seed = 0)
+ uint acc = _state.Complete(_length, remaining);
+ BinaryPrimitives.WriteUInt32BigEndian(destination, acc);
+ }
+
+ ///
+ /// Computes the XxHash32 hash of the provided data.
+ ///
+ /// The data to hash.
+ /// The XxHash32 hash of the provided data.
+ ///
+ /// is .
+ ///
+ public static byte[] Hash(byte[] source)
+ {
+ if (source is null)
{
- byte[] ret = new byte[HashSize];
- StaticHash(source, ret, seed);
- return ret;
+ throw new ArgumentNullException(nameof(source));
}
- ///
- /// Attempts to compute the XxHash32 hash of the provided data into the provided destination.
- ///
- /// The data to hash.
- /// The buffer that receives the computed hash value.
- ///
- /// On success, receives the number of bytes written to .
- ///
- /// The seed value for this hash computation. The default is zero.
- ///
- /// if is long enough to receive
- /// the computed hash value (4 bytes); otherwise, .
- ///
- public static bool TryHash(ReadOnlySpan source, Span destination, out int bytesWritten, int seed = 0)
- {
- if (destination.Length < HashSize)
- {
- bytesWritten = 0;
- return false;
- }
+ return Hash(new ReadOnlySpan(source));
+ }
- bytesWritten = StaticHash(source, destination, seed);
- return true;
+ ///
+ /// Computes the XxHash32 hash of the provided data using the provided seed.
+ ///
+ /// The data to hash.
+ /// The seed value for this hash computation.
+ /// The XxHash32 hash of the provided data.
+ ///
+ /// is .
+ ///
+ public static byte[] Hash(byte[] source, int seed)
+ {
+ if (source is null)
+ {
+ throw new ArgumentNullException(nameof(source));
}
- ///
- /// Computes the XxHash32 hash of the provided data into the provided destination.
- ///
- /// The data to hash.
- /// The buffer that receives the computed hash value.
- /// The seed value for this hash computation. The default is zero.
- ///
- /// The number of bytes written to .
- ///
- public static int Hash(ReadOnlySpan source, Span destination, int seed = 0)
- {
- if (destination.Length < HashSize)
- throw new ArgumentException(@"destination too short", nameof(destination));
+ return Hash(new ReadOnlySpan(source), seed);
+ }
- return StaticHash(source, destination, seed);
- }
+ ///
+ /// Computes the XxHash32 hash of the provided data.
+ ///
+ /// The data to hash.
+ /// The seed value for this hash computation. The default is zero.
+ /// The XxHash32 hash of the provided data.
+ public static byte[] Hash(ReadOnlySpan source, int seed = 0)
+ {
+ byte[] ret = new byte[HashSize];
+ StaticHash(source, ret, seed);
+ return ret;
+ }
- private static int StaticHash(ReadOnlySpan source, Span destination, int seed)
+ ///
+ /// Attempts to compute the XxHash32 hash of the provided data into the provided destination.
+ ///
+ /// The data to hash.
+ /// The buffer that receives the computed hash value.
+ ///
+ /// On success, receives the number of bytes written to .
+ ///
+ /// The seed value for this hash computation. The default is zero.
+ ///
+ /// if is long enough to receive
+ /// the computed hash value (4 bytes); otherwise, .
+ ///
+ public static bool TryHash(ReadOnlySpan source, Span destination, out int bytesWritten, int seed = 0)
+ {
+ if (destination.Length < HashSize)
{
- int totalLength = source.Length;
- State state = new State((uint)seed);
+ bytesWritten = 0;
+ return false;
+ }
- while (source.Length >= StripeSize)
- {
- state.ProcessStripe(source);
- source = source.Slice(StripeSize);
- }
+ bytesWritten = StaticHash(source, destination, seed);
+ return true;
+ }
- uint val = state.Complete(totalLength, source);
- BinaryPrimitives.WriteUInt32BigEndian(destination, val);
- return HashSize;
+ ///
+ /// Computes the XxHash32 hash of the provided data into the provided destination.
+ ///
+ /// The data to hash.
+ /// The buffer that receives the computed hash value.
+ /// The seed value for this hash computation. The default is zero.
+ ///
+ /// The number of bytes written to .
+ ///
+ public static int Hash(ReadOnlySpan source, Span destination, int seed = 0)
+ {
+ if (destination.Length < HashSize)
+ throw new ArgumentException(@"destination too short", nameof(destination));
+
+ return StaticHash(source, destination, seed);
+ }
+
+ private static int StaticHash(ReadOnlySpan source, Span destination, int seed)
+ {
+ int totalLength = source.Length;
+ State state = new State((uint)seed);
+
+ while (source.Length >= StripeSize)
+ {
+ state.ProcessStripe(source);
+ source = source.Slice(StripeSize);
}
+
+ uint val = state.Complete(totalLength, source);
+ BinaryPrimitives.WriteUInt32BigEndian(destination, val);
+ return HashSize;
}
}
\ No newline at end of file
diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs
index cd4db8407bf..e3994e96741 100644
--- a/src/Orleans.CodeGenerator/InvokableGenerator.cs
+++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs
@@ -2,943 +2,917 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
-using System;
-using System.Collections.Generic;
-using System.Linq;
using Orleans.CodeGenerator.Diagnostics;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
-using System.Linq.Expressions;
-#nullable disable
-namespace Orleans.CodeGenerator
+namespace Orleans.CodeGenerator;
+
+///
+/// Generates RPC stub objects called invokers.
+///
+internal class InvokableGenerator(ProxyGenerationContext generationContext)
{
- ///
- /// Generates RPC stub objects called invokers.
- ///
- internal class InvokableGenerator
- {
- private readonly CodeGenerator _codeGenerator;
+ private readonly ProxyGenerationContext _generationContext = generationContext;
+
+ private LibraryTypes LibraryTypes => _generationContext.LibraryTypes;
- public InvokableGenerator(CodeGenerator codeGenerator)
+ public GeneratedInvokableDescription Generate(InvokableMethodDescription invokableMethodInfo)
+ {
+ var method = invokableMethodInfo.Method;
+ var generatedClassName = GetSimpleClassName(invokableMethodInfo);
+
+ var baseClassType = GetBaseClassType(invokableMethodInfo);
+ var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo);
+ var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions);
+ var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType);
+ var accessibility = GetAccessibility(method);
+ var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo, invokableMethodInfo.Key);
+
+ List serializationHooks = new();
+ if (baseClassType.GetAttributes(LibraryTypes.SerializationCallbacksAttribute, out var hookAttributes))
{
- _codeGenerator = codeGenerator;
+ foreach (var hookAttribute in hookAttributes)
+ {
+ var hookType = (INamedTypeSymbol)hookAttribute.ConstructorArguments[0].Value!;
+ serializationHooks.Add(hookType);
+ }
}
- private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes;
+ var targetField = fieldDescriptions.OfType().Single();
- public GeneratedInvokableDescription Generate(InvokableMethodDescription invokableMethodInfo)
+ var accessibilityKind = accessibility switch
+ {
+ Accessibility.Public => SyntaxKind.PublicKeyword,
+ _ => SyntaxKind.InternalKeyword,
+ };
+
+ var classDeclaration = GetClassDeclarationSyntax(
+ invokableMethodInfo,
+ generatedClassName,
+ baseClassType,
+ fieldDescriptions,
+ fields,
+ ctor,
+ compoundTypeAliases,
+ targetField,
+ accessibilityKind);
+
+ string? returnValueInitializerMethod = null;
+ if (baseClassType.GetAttribute(LibraryTypes.ReturnValueProxyAttribute) is { ConstructorArguments: { Length: > 0 } attrArgs })
{
- var method = invokableMethodInfo.Method;
- var generatedClassName = GetSimpleClassName(invokableMethodInfo);
+ returnValueInitializerMethod = (string?)attrArgs[0].Value;
+ }
- var baseClassType = GetBaseClassType(invokableMethodInfo);
- var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo);
- var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions);
- var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType);
- var accessibility = GetAccessibility(method);
- var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo, invokableMethodInfo.Key);
+ while (baseClassType.HasAttribute(LibraryTypes.SerializerTransparentAttribute))
+ {
+ baseClassType = baseClassType.BaseType!;
+ }
- List serializationHooks = new();
- if (baseClassType.GetAttributes(LibraryTypes.SerializationCallbacksAttribute, out var hookAttributes))
+ var invokerDescription = new GeneratedInvokableDescription(
+ invokableMethodInfo,
+ accessibility,
+ generatedClassName,
+ GeneratedCodeUtilities.GetGeneratedNamespaceName(invokableMethodInfo.ContainingInterface),
+ [.. fieldDescriptions.OfType()],
+ serializationHooks,
+ baseClassType,
+ ctorArgs,
+ compoundTypeAliases,
+ returnValueInitializerMethod,
+ classDeclaration);
+ return invokerDescription;
+
+ static Accessibility GetAccessibility(IMethodSymbol methodSymbol)
+ {
+ Accessibility accessibility = methodSymbol.DeclaredAccessibility;
+ var t = methodSymbol.ContainingType;
+ while (t is not null)
{
- foreach (var hookAttribute in hookAttributes)
+ if ((int)t.DeclaredAccessibility < (int)accessibility)
{
- var hookType = (INamedTypeSymbol)hookAttribute.ConstructorArguments[0].Value;
- serializationHooks.Add(hookType);
+ accessibility = t.DeclaredAccessibility;
}
- }
- var targetField = fieldDescriptions.OfType().Single();
+ t = t.ContainingType;
+ }
- var accessibilityKind = accessibility switch
- {
- Accessibility.Public => SyntaxKind.PublicKeyword,
- _ => SyntaxKind.InternalKeyword,
- };
+ return accessibility;
+ }
+ }
- var classDeclaration = GetClassDeclarationSyntax(
- invokableMethodInfo,
- generatedClassName,
- baseClassType,
- fieldDescriptions,
- fields,
- ctor,
- compoundTypeAliases,
- targetField,
- accessibilityKind);
-
- string returnValueInitializerMethod = null;
- if (baseClassType.GetAttribute(LibraryTypes.ReturnValueProxyAttribute) is { ConstructorArguments: { Length: > 0 } attrArgs })
- {
- returnValueInitializerMethod = (string)attrArgs[0].Value;
- }
+ private ClassDeclarationSyntax GetClassDeclarationSyntax(
+ InvokableMethodDescription method,
+ string generatedClassName,
+ INamedTypeSymbol baseClassType,
+ List fieldDescriptions,
+ MemberDeclarationSyntax[] fields,
+ ConstructorDeclarationSyntax? ctor,
+ List compoundTypeAliases,
+ TargetFieldDescription targetField,
+ SyntaxKind accessibilityKind)
+ {
+ var classDeclaration = ClassDeclaration(generatedClassName)
+ .AddBaseListTypes(SimpleBaseType(baseClassType.ToTypeSyntax(method.TypeParameterSubstitutions)))
+ .AddModifiers(Token(accessibilityKind), Token(SyntaxKind.SealedKeyword))
+ .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes())
+ .AddMembers(fields);
- while (baseClassType.HasAttribute(LibraryTypes.SerializerTransparentAttribute))
- {
- baseClassType = baseClassType.BaseType;
- }
+ foreach (var alias in compoundTypeAliases)
+ {
+ classDeclaration = classDeclaration.AddAttributeLists(
+ AttributeList(SingletonSeparatedList(GetCompoundTypeAliasAttribute(alias))));
+ }
- var invokerDescription = new GeneratedInvokableDescription(
- invokableMethodInfo,
- accessibility,
- generatedClassName,
- CodeGenerator.GetGeneratedNamespaceName(invokableMethodInfo.ContainingInterface),
- fieldDescriptions.OfType().ToList(),
- serializationHooks,
- baseClassType,
- ctorArgs,
- compoundTypeAliases,
- returnValueInitializerMethod,
- classDeclaration);
- return invokerDescription;
-
- static Accessibility GetAccessibility(IMethodSymbol methodSymbol)
- {
- Accessibility accessibility = methodSymbol.DeclaredAccessibility;
- var t = methodSymbol.ContainingType;
- while (t is not null)
- {
- if ((int)t.DeclaredAccessibility < (int)accessibility)
- {
- accessibility = t.DeclaredAccessibility;
- }
+ if (ctor != null)
+ {
+ classDeclaration = classDeclaration.AddMembers(ctor);
+ }
- t = t.ContainingType;
- }
+ if (method.ResponseTimeoutTicks.HasValue)
+ {
+ classDeclaration = classDeclaration.AddMembers(GenerateResponseTimeoutPropertyMembers(method.ResponseTimeoutTicks.Value));
+ }
- return accessibility;
- }
+ classDeclaration = AddOptionalMembers(classDeclaration,
+ GenerateGetArgumentCount(method),
+ GenerateGetMethodName(method),
+ GenerateGetInterfaceName(method),
+ GenerateGetActivityName(method),
+ GenerateGetInterfaceType(method),
+ GenerateGetMethod(),
+ GenerateSetTargetMethod(method, targetField, fieldDescriptions),
+ GenerateGetTargetMethod(targetField),
+ GenerateDisposeMethod(fieldDescriptions, baseClassType),
+ GenerateGetArgumentMethod(method, fieldDescriptions),
+ GenerateSetArgumentMethod(method, fieldDescriptions),
+ GenerateInvokeInnerMethod(method, fieldDescriptions, targetField),
+ GenerateGetCancellationTokenMethod(method, fieldDescriptions),
+ GenerateTryCancelMethod(method, fieldDescriptions),
+ GenerateIsCancellableProperty(method));
+
+ if (method.AllTypeParameters.Count > 0)
+ {
+ classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, method.AllTypeParameters);
}
- private ClassDeclarationSyntax GetClassDeclarationSyntax(
- InvokableMethodDescription method,
- string generatedClassName,
- INamedTypeSymbol baseClassType,
- List fieldDescriptions,
- MemberDeclarationSyntax[] fields,
- ConstructorDeclarationSyntax ctor,
- List compoundTypeAliases,
- TargetFieldDescription targetField,
- SyntaxKind accessibilityKind)
- {
- var classDeclaration = ClassDeclaration(generatedClassName)
- .AddBaseListTypes(SimpleBaseType(baseClassType.ToTypeSyntax(method.TypeParameterSubstitutions)))
- .AddModifiers(Token(accessibilityKind), Token(SyntaxKind.SealedKeyword))
- .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes())
- .AddMembers(fields);
-
- foreach (var alias in compoundTypeAliases)
- {
- classDeclaration = classDeclaration.AddAttributeLists(
- AttributeList(SingletonSeparatedList(GetCompoundTypeAliasAttribute(alias))));
- }
+ return classDeclaration;
+ }
- if (ctor != null)
- {
- classDeclaration = classDeclaration.AddMembers(ctor);
- }
+ private MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(long value)
+ {
+ var timespanField = FieldDeclaration(
+ VariableDeclaration(
+ LibraryTypes.TimeSpan.ToTypeSyntax(),
+ SingletonSeparatedList(VariableDeclarator("_responseTimeoutValue")
+ .WithInitializer(EqualsValueClause(
+ InvocationExpression(
+ IdentifierName("global::System.TimeSpan").Member("FromTicks"),
+ ArgumentList(SeparatedList(
+ [
+ Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value)))
+ ]))))))))
+ .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
+
+ var responseTimeoutProperty = MethodDeclaration(NullableType(LibraryTypes.TimeSpan.ToTypeSyntax()), "GetDefaultResponseTimeout")
+ .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue")))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword));
+ return [timespanField, responseTimeoutProperty];
+ }
- if (method.ResponseTimeoutTicks.HasValue)
- {
- classDeclaration = classDeclaration.AddMembers(GenerateResponseTimeoutPropertyMembers(method.ResponseTimeoutTicks.Value));
- }
+ private static ClassDeclarationSyntax AddOptionalMembers(ClassDeclarationSyntax decl, params MemberDeclarationSyntax?[] items)
+ => decl.WithMembers(decl.Members.AddRange(items.OfType()));
- classDeclaration = AddOptionalMembers(classDeclaration,
- GenerateGetArgumentCount(method),
- GenerateGetMethodName(method),
- GenerateGetInterfaceName(method),
- GenerateGetActivityName(method),
- GenerateGetInterfaceType(method),
- GenerateGetMethod(),
- GenerateSetTargetMethod(method, targetField, fieldDescriptions),
- GenerateGetTargetMethod(targetField),
- GenerateDisposeMethod(fieldDescriptions, baseClassType),
- GenerateGetArgumentMethod(method, fieldDescriptions),
- GenerateSetArgumentMethod(method, fieldDescriptions),
- GenerateInvokeInnerMethod(method, fieldDescriptions, targetField),
- GenerateGetCancellationTokenMethod(method, fieldDescriptions),
- GenerateTryCancelMethod(method, fieldDescriptions),
- GenerateIsCancellableProperty(method));
-
- if (method.AllTypeParameters.Count > 0)
+ internal AttributeSyntax GetCompoundTypeAliasAttribute(CompoundTypeAliasComponent[] argValues)
+ {
+ var args = new AttributeArgumentSyntax[argValues.Length];
+ for (var i = 0; i < argValues.Length; i++)
+ {
+ ExpressionSyntax value;
+ value = argValues[i].Value switch
{
- classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, method.AllTypeParameters);
- }
+ string stringValue => LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(stringValue)),
+ ITypeSymbol typeValue => TypeOfExpression(typeValue.ToOpenTypeSyntax()),
+ _ => throw new InvalidOperationException($"Unsupported value")
+ };
- return classDeclaration;
+ args[i] = AttributeArgument(value);
}
- private MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(long value)
- {
- var timespanField = FieldDeclaration(
- VariableDeclaration(
- LibraryTypes.TimeSpan.ToTypeSyntax(),
- SingletonSeparatedList(VariableDeclarator("_responseTimeoutValue")
- .WithInitializer(EqualsValueClause(
- InvocationExpression(
- IdentifierName("global::System.TimeSpan").Member("FromTicks"),
- ArgumentList(SeparatedList(new[]
- {
- Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value)))
- }))))))))
- .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
+ return Attribute(LibraryTypes.CompoundTypeAliasAttribute.ToNameSyntax()).AddArgumentListArguments(args);
+ }
- var responseTimeoutProperty = MethodDeclaration(NullableType(LibraryTypes.TimeSpan.ToTypeSyntax()), "GetDefaultResponseTimeout")
- .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue")))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
- .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword));
- return new MemberDeclarationSyntax[] { timespanField, responseTimeoutProperty };
+ internal static List GetCompoundTypeAliasAttributeArguments(InvokableMethodDescription methodDescription, InvokableMethodId invokableId)
+ {
+ var result = new List(2);
+ var containingInterface = methodDescription.ContainingInterface;
+ if (methodDescription.HasAlias)
+ {
+ result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.MethodId));
}
- private ClassDeclarationSyntax AddOptionalMembers(ClassDeclarationSyntax decl, params MemberDeclarationSyntax[] items)
- => decl.WithMembers(decl.Members.AddRange(items.Where(i => i != null)));
+ result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.GeneratedMethodId));
+ return result;
+ }
- internal AttributeSyntax GetCompoundTypeAliasAttribute(CompoundTypeAliasComponent[] argValues)
+ public static CompoundTypeAliasComponent[] GetCompoundTypeAliasComponents(
+ InvokableMethodId invokableId,
+ INamedTypeSymbol containingInterface,
+ string methodId)
+ {
+ var proxyBase = invokableId.ProxyBase;
+ var proxyBaseComponents = proxyBase.CompositeAliasComponents;
+ var extensionArgCount = proxyBase.IsExtension ? 1 : 0;
+ var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + extensionArgCount + 2];
+ alias[0] = new("inv");
+ for (var i = 0; i < proxyBaseComponents.Length; i++)
{
- var args = new AttributeArgumentSyntax[argValues.Length];
- for (var i = 0; i < argValues.Length; i++)
- {
- ExpressionSyntax value;
- value = argValues[i].Value switch
- {
- string stringValue => LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(stringValue)),
- ITypeSymbol typeValue => TypeOfExpression(typeValue.ToOpenTypeSyntax()),
- _ => throw new InvalidOperationException($"Unsupported value")
- };
+ alias[i + 1] = proxyBaseComponents[i];
+ }
- args[i] = AttributeArgument(value);
- }
+ alias[1 + proxyBaseComponents.Length] = new(containingInterface);
- return Attribute(LibraryTypes.CompoundTypeAliasAttribute.ToNameSyntax()).AddArgumentListArguments(args);
+ // For grain extensions, also explicitly include the method's containing type.
+ // This is to distinguish between different extension methods with the same id (eg, alias) but different containing types.
+ if (proxyBase.IsExtension)
+ {
+ alias[1 + proxyBaseComponents.Length + 1] = new(invokableId.Method.ContainingType);
}
- internal static List GetCompoundTypeAliasAttributeArguments(InvokableMethodDescription methodDescription, InvokableMethodId invokableId)
+ alias[1 + proxyBaseComponents.Length + extensionArgCount + 1] = new(methodId);
+ return alias;
+ }
+
+ private static INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method)
+ {
+ var methodReturnType = method.Method.ReturnType;
+ if (methodReturnType is not INamedTypeSymbol namedMethodReturnType)
{
- var result = new List(2);
- var containingInterface = methodDescription.ContainingInterface;
- if (methodDescription.HasAlias)
- {
- result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.MethodId));
- }
+ throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method));
+ }
- result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.GeneratedMethodId));
- return result;
+ if (method.InvokableBaseTypes.TryGetValue(namedMethodReturnType, out var baseClassType))
+ {
+ return baseClassType;
}
- public static CompoundTypeAliasComponent[] GetCompoundTypeAliasComponents(
- InvokableMethodId invokableId,
- INamedTypeSymbol containingInterface,
- string methodId)
+ if (namedMethodReturnType.ConstructedFrom is { IsGenericType: true, IsUnboundGenericType: false } constructedFrom)
{
- var proxyBase = invokableId.ProxyBase;
- var proxyBaseComponents = proxyBase.CompositeAliasComponents;
- var extensionArgCount = proxyBase.IsExtension ? 1 : 0;
- var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + extensionArgCount + 2];
- alias[0] = new("inv");
- for (var i = 0; i < proxyBaseComponents.Length; i++)
+ var unbound = constructedFrom.ConstructUnboundGenericType();
+ if (method.InvokableBaseTypes.TryGetValue(unbound, out baseClassType))
{
- alias[i + 1] = proxyBaseComponents[i];
+ return baseClassType.ConstructedFrom.Construct([.. namedMethodReturnType.TypeArguments]);
}
+ }
- alias[1 + proxyBaseComponents.Length] = new(containingInterface);
-
- // For grain extensions, also explicitly include the method's containing type.
- // This is to distinguish between different extension methods with the same id (eg, alias) but different containing types.
- if (proxyBase.IsExtension)
- {
- alias[1 + proxyBaseComponents.Length + 1] = new(invokableId.Method.ContainingType);
- }
+ throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method));
+ }
- alias[1 + proxyBaseComponents.Length + extensionArgCount + 1] = new(methodId);
- return alias;
- }
+ private MemberDeclarationSyntax GenerateSetTargetMethod(
+ InvokableMethodDescription methodDescription,
+ TargetFieldDescription targetField,
+ List fieldDescriptions)
+ {
+ var holder = IdentifierName("holder");
+ var holderParameter = holder.Identifier;
- private INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method)
+ var containingInterface = methodDescription.ContainingInterface;
+ var targetType = containingInterface.ToTypeSyntax();
+ var isExtension = methodDescription.Key.ProxyBase.IsExtension;
+ var (name, args) = isExtension switch
{
- var methodReturnType = method.Method.ReturnType;
- if (methodReturnType is not INamedTypeSymbol namedMethodReturnType)
- {
- throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method));
- }
+ true => ("GetComponent", SingletonSeparatedList(Argument(TypeOfExpression(targetType)))),
+ _ => ("GetTarget", SeparatedList())
+ };
+ var getTarget = CastExpression(
+ targetType,
+ InvocationExpression(
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ holder,
+ IdentifierName(name)),
+ ArgumentList(args)));
- if (method.InvokableBaseTypes.TryGetValue(namedMethodReturnType, out var baseClassType))
- {
- return baseClassType;
- }
+ var member =
+ MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget")
+ .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax()))))
+ .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
- if (namedMethodReturnType.ConstructedFrom is { IsGenericType: true, IsUnboundGenericType: false } constructedFrom)
- {
- var unbound = constructedFrom.ConstructUnboundGenericType();
- if (method.InvokableBaseTypes.TryGetValue(unbound, out baseClassType))
+ var assignmentExpression = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(targetField.FieldName), getTarget);
+ if (!methodDescription.IsCancellable)
+ {
+ return member.WithExpressionBody(ArrowExpressionClause(assignmentExpression)).WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+ }
+ else
+ {
+ var ctsField = fieldDescriptions.First(f => f is CancellationTokenSourceFieldDescription);
+ var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax();
+ var ctField = fieldDescriptions.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType));
+ return member.WithBody(Block(
+ new List()
{
- return baseClassType.ConstructedFrom.Construct(namedMethodReturnType.TypeArguments.ToArray());
- }
- }
-
- throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method));
+ ExpressionStatement(assignmentExpression),
+ ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctsField.FieldName), ImplicitObjectCreationExpression())),
+ ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctField.FieldName), IdentifierName(ctsField.FieldName).Member("Token")))
+ }));
}
+ }
- private MemberDeclarationSyntax GenerateSetTargetMethod(
- InvokableMethodDescription methodDescription,
- TargetFieldDescription targetField,
- List fieldDescriptions)
- {
- var holder = IdentifierName("holder");
- var holderParameter = holder.Identifier;
+ private static MethodDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField)
+ {
+ return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget")
+ .WithParameterList(ParameterList())
+ .WithExpressionBody(ArrowExpressionClause(IdentifierName(targetField.FieldName)))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
+ }
- var containingInterface = methodDescription.ContainingInterface;
- var targetType = containingInterface.ToTypeSyntax();
- var isExtension = methodDescription.Key.ProxyBase.IsExtension;
- var (name, args) = isExtension switch
- {
- true => ("GetComponent", SingletonSeparatedList(Argument(TypeOfExpression(targetType)))),
- _ => ("GetTarget", SeparatedList())
- };
- var getTarget = CastExpression(
- targetType,
- InvocationExpression(
- MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- holder,
- IdentifierName(name)),
- ArgumentList(args)));
-
- var member =
- MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget")
- .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax()))))
- .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
-
- var assignmentExpression = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(targetField.FieldName), getTarget);
- if (!methodDescription.IsCancellable)
- {
- return member.WithExpressionBody(ArrowExpressionClause(assignmentExpression)).WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
- }
- else
- {
- var ctsField = fieldDescriptions.First(f => f is CancellationTokenSourceFieldDescription);
- var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax();
- var ctField = fieldDescriptions.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType));
- return member.WithBody(Block(
- new List()
- {
- ExpressionStatement(assignmentExpression),
- ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctsField.FieldName), ImplicitObjectCreationExpression())),
- ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctField.FieldName), IdentifierName(ctsField.FieldName).Member("Token")))
- }));
- }
+ private MemberDeclarationSyntax? GenerateGetCancellationTokenMethod(InvokableMethodDescription method, List fields)
+ {
+ if (!method.IsCancellable)
+ {
+ return null;
}
- private static MethodDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField)
+ // Method to get the cancellationToken argument
+ // C#: CancellationToken GetCancellationToken() =>
+ var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax();
+ var cancellationTokenField = fields.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType));
+ var member = MethodDeclaration(cancellationTokenType, "GetCancellationToken")
+ .WithExpressionBody(ArrowExpressionClause(cancellationTokenField.FieldName.ToIdentifierName()))
+ .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
+ return member;
+ }
+
+ private static MemberDeclarationSyntax? GenerateIsCancellableProperty(InvokableMethodDescription method)
+ {
+ if (!method.IsCancellable)
{
- return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget")
- .WithParameterList(ParameterList())
- .WithExpressionBody(ArrowExpressionClause(IdentifierName(targetField.FieldName)))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
- .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
+ return null;
}
- private MemberDeclarationSyntax GenerateGetCancellationTokenMethod(InvokableMethodDescription method, List fields)
- {
- if (!method.IsCancellable)
- {
- return null;
- }
+ // Property to indicate if the invokable is cancellable
+ // C#: public override bool IsCancellable => true;
+ var member = PropertyDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "IsCancellable")
+ .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.TrueLiteralExpression)))
+ .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
+ .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword));
+ return member;
+ }
- // Method to get the cancellationToken argument
- // C#: CancellationToken GetCancellationToken() =>
- var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax();
- var cancellationTokenField = fields.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType));
- var member = MethodDeclaration(cancellationTokenType, "GetCancellationToken")
- .WithExpressionBody(ArrowExpressionClause(cancellationTokenField.FieldName.ToIdentifierName()))
- .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
- return member;
+ private static MemberDeclarationSyntax? GenerateTryCancelMethod(InvokableMethodDescription method, List fields)
+ {
+ if (!method.IsCancellable)
+ {
+ return null;
}
- private MemberDeclarationSyntax GenerateIsCancellableProperty(InvokableMethodDescription method)
+ // Method to set the CancellationToken argument.
+ // C#:
+ // TryCancel()
+ // {
+ // if (_cts is { } cts)
+ // {
+ // cts.Cancel(false);
+ // return true;
+ // }
+ // return false;
+ // }
+ var cancellationTokenField = fields.First(f => f is CancellationTokenSourceFieldDescription);
+ var member = MethodDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "TryCancel")
+ .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))
+ .WithBody(Block(
+ IfStatement(
+ IsPatternExpression(
+ IdentifierName(cancellationTokenField.FieldName),
+ RecursivePattern()
+ .WithPropertyPatternClause(PropertyPatternClause())
+ .WithDesignation(SingleVariableDesignation(Identifier("cts")))),
+ Block(
+ ExpressionStatement(InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("cts"), IdentifierName("Cancel")))
+ .WithArgumentList(ArgumentList(SeparatedList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))])))),
+ ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)))),
+ ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression))));
+ return member;
+ }
+
+ private static MemberDeclarationSyntax? GenerateGetArgumentMethod(
+ InvokableMethodDescription methodDescription,
+ List fields)
+ {
+ if (methodDescription.Method.Parameters.Length == 0)
+ return null;
+
+ var index = IdentifierName("index");
+
+ var cases = new List();
+ foreach (var field in fields)
{
- if (!method.IsCancellable)
+ if (field is not MethodParameterFieldDescription parameter)
{
- return null;
+ continue;
}
- // Property to indicate if the invokable is cancellable
- // C#: public override bool IsCancellable => true;
- var member = PropertyDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "IsCancellable")
- .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.TrueLiteralExpression)))
- .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
- .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword));
- return member;
+ // C#: case {index}: return {field}
+ var label = CaseSwitchLabel(
+ LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal)));
+ cases.Add(
+ SwitchSection(
+ SingletonList(label),
+ new SyntaxList(
+ ReturnStatement(
+ IdentifierName(parameter.FieldName)))));
}
- private MemberDeclarationSyntax GenerateTryCancelMethod(InvokableMethodDescription method, List fields)
- {
- if (!method.IsCancellable)
- {
- return null;
- }
+ // C#: default: return OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs})
+ var throwHelperMethod = MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("OrleansGeneratedCodeHelper"),
+ IdentifierName("InvokableThrowArgumentOutOfRange"));
+ cases.Add(
+ SwitchSection(
+ SingletonList(DefaultSwitchLabel()),
+ new SyntaxList(
+ ReturnStatement(
+ InvocationExpression(
+ throwHelperMethod,
+ ArgumentList(
+ SeparatedList(
+ [
+ Argument(index),
+ Argument(
+ LiteralExpression(
+ SyntaxKind.NumericLiteralExpression,
+ Literal(
+ Math.Max(0, methodDescription.Method.Parameters.Length - 1))))
+ ])))))));
+ var body = SwitchStatement(index, new SyntaxList(cases));
+ return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetArgument")
+ .WithParameterList(
+ ParameterList(
+ SingletonSeparatedList(
+ Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword))))))
+ .WithBody(Block(body))
+ .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
+ }
- // Method to set the CancellationToken argument.
- // C#:
- // TryCancel()
- // {
- // if (_cts is { } cts)
- // {
- // cts.Cancel(false);
- // return true;
- // }
- // return false;
- // }
- var cancellationTokenField = fields.First(f => f is CancellationTokenSourceFieldDescription);
- var member = MethodDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "TryCancel")
- .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))
- .WithBody(Block(
- IfStatement(
- IsPatternExpression(
- IdentifierName(cancellationTokenField.FieldName),
- RecursivePattern()
- .WithPropertyPatternClause(PropertyPatternClause())
- .WithDesignation(SingleVariableDesignation(Identifier("cts")))),
- Block(
- ExpressionStatement(InvocationExpression(
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("cts"), IdentifierName("Cancel")))
- .WithArgumentList(ArgumentList(SeparatedList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))])))),
- ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)))),
- ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression))));
- return member;
- }
-
- private MemberDeclarationSyntax GenerateGetArgumentMethod(
- InvokableMethodDescription methodDescription,
- List fields)
- {
- if (methodDescription.Method.Parameters.Length == 0)
- return null;
-
- var index = IdentifierName("index");
-
- var cases = new List();
- foreach (var field in fields)
- {
- if (field is not MethodParameterFieldDescription parameter)
- {
- continue;
- }
+ private static MemberDeclarationSyntax? GenerateSetArgumentMethod(
+ InvokableMethodDescription methodDescription,
+ List fields)
+ {
+ if (methodDescription.Method.Parameters.Length == 0)
+ return null;
+
+ var index = IdentifierName("index");
+ var value = IdentifierName("value");
- // C#: case {index}: return {field}
- var label = CaseSwitchLabel(
- LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal)));
- cases.Add(
- SwitchSection(
- SingletonList(label),
- new SyntaxList(
- ReturnStatement(
- IdentifierName(parameter.FieldName)))));
+ var cases = new List();
+ foreach (var field in fields)
+ {
+ if (field is not MethodParameterFieldDescription parameter)
+ {
+ continue;
}
- // C#: default: return OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs})
- var throwHelperMethod = MemberAccessExpression(
- SyntaxKind.SimpleMemberAccessExpression,
- IdentifierName("OrleansGeneratedCodeHelper"),
- IdentifierName("InvokableThrowArgumentOutOfRange"));
+ // C#: case {index}: {field} = (TField)value; return;
+ var label = CaseSwitchLabel(
+ LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal)));
cases.Add(
SwitchSection(
- SingletonList(DefaultSwitchLabel()),
+ SingletonList(label),
new SyntaxList(
- ReturnStatement(
+ [
+ ExpressionStatement(
+ AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName(parameter.FieldName),
+ CastExpression(parameter.FieldType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions), value))),
+ ReturnStatement()
+ ])));
+ }
+
+ // C#: default: OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs})
+ var maxArgs = Math.Max(0, methodDescription.Method.Parameters.Length - 1);
+ var throwHelperMethod = MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName("OrleansGeneratedCodeHelper"),
+ IdentifierName("InvokableThrowArgumentOutOfRange"));
+ cases.Add(
+ SwitchSection(
+ SingletonList(DefaultSwitchLabel()),
+ new SyntaxList(
+ [
+ ExpressionStatement(
InvocationExpression(
throwHelperMethod,
ArgumentList(
SeparatedList(
- new[]
- {
+ [
Argument(index),
Argument(
LiteralExpression(
SyntaxKind.NumericLiteralExpression,
- Literal(
- Math.Max(0, methodDescription.Method.Parameters.Length - 1))))
- })))))));
- var body = SwitchStatement(index, new SyntaxList(cases));
- return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetArgument")
- .WithParameterList(
- ParameterList(
- SingletonSeparatedList(
- Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword))))))
- .WithBody(Block(body))
- .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
- }
-
- private MemberDeclarationSyntax GenerateSetArgumentMethod(
- InvokableMethodDescription methodDescription,
- List fields)
- {
- if (methodDescription.Method.Parameters.Length == 0)
- return null;
-
- var index = IdentifierName("index");
- var value = IdentifierName("value");
-
- var cases = new List();
- foreach (var field in fields)
- {
- if (field is not MethodParameterFieldDescription parameter)
- {
- continue;
- }
-
- // C#: case {index}: {field} = (TField)value; return;
- var label = CaseSwitchLabel(
- LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal)));
- cases.Add(
- SwitchSection(
- SingletonList(label),
- new SyntaxList(
- new StatementSyntax[]
- {
- ExpressionStatement(
- AssignmentExpression(
- SyntaxKind.SimpleAssignmentExpression,
- IdentifierName(parameter.FieldName),
- CastExpression(parameter.FieldType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions), value))),
- ReturnStatement()
- })));
- }
+ Literal(maxArgs)))
+ ])))),
+ ReturnStatement()
+ ])));
+ var body = SwitchStatement(index, new SyntaxList(cases));
+ return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetArgument")
+ .WithParameterList(
+ ParameterList(
+ SeparatedList(
+ [
+ Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword))),
+ Parameter(Identifier("value")).WithType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))
+ ]
+ )))
+ .WithBody(Block(body))
+ .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)));
+ }
- // C#: default: OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs})
- var maxArgs = Math.Max(0, methodDescription.Method.Parameters.Length - 1);
- var throwHelperMethod = MemberAccessExpression(
+ private static MemberDeclarationSyntax GenerateInvokeInnerMethod(
+ InvokableMethodDescription method,
+ List fields,
+ TargetFieldDescription target)
+ {
+ var resultTask = IdentifierName("resultTask");
+
+ // C# var resultTask = this.target.{Method}({params});
+ var args = SeparatedList(
+ fields.OfType