diff --git a/Directory.Build.targets b/Directory.Build.targets index dbcbf3e886f..c0cbba7f573 100644 --- a/Directory.Build.targets +++ b/Directory.Build.targets @@ -14,6 +14,7 @@ diff --git a/src/Orleans.CodeGenerator/ActivatorGenerator.cs b/src/Orleans.CodeGenerator/ActivatorGenerator.cs index 7acb8d8757d..22180717402 100644 --- a/src/Orleans.CodeGenerator/ActivatorGenerator.cs +++ b/src/Orleans.CodeGenerator/ActivatorGenerator.cs @@ -2,128 +2,137 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using System.Collections.Generic; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class ActivatorGenerator(IGeneratorServices generatorServices) { - internal class ActivatorGenerator + private readonly IGeneratorServices _generatorServices = generatorServices; + + private struct ConstructorArgument + { + public TypeSyntax Type { get; set; } + public string FieldName { get; set; } + public string ParameterName { get; set; } + } + + public ClassDeclarationSyntax GenerateActivator(ISerializableTypeDescription type) { - private readonly CodeGenerator _codeGenerator; + var simpleClassName = GetSimpleClassName(type); + + var baseInterface = _generatorServices.LibraryTypes.IActivator_1.ToTypeSyntax(type.TypeSyntax); - private struct ConstructorArgument + var orderedFields = new List(); + var index = 0; + if (type.ActivatorConstructorParameters is { Count: > 0 } parameters) { - public TypeSyntax Type { get; set; } - public string FieldName { get; set; } - public string ParameterName { get; set; } + foreach (var arg in parameters) + { + orderedFields.Add(new ConstructorArgument { Type = arg, FieldName = $"_arg{index}", ParameterName = $"arg{index}" }); + index++; + } } - public ActivatorGenerator(CodeGenerator codeGenerator) + var members = new List(); + foreach (var field in orderedFields) { - _codeGenerator = codeGenerator; + members.Add( + FieldDeclaration(VariableDeclaration(field.Type, SingletonSeparatedList(VariableDeclarator(field.FieldName)))) + .AddModifiers( + Token(SyntaxKind.PrivateKeyword), + Token(SyntaxKind.ReadOnlyKeyword))); } - public ClassDeclarationSyntax GenerateActivator(ISerializableTypeDescription type) - { - var simpleClassName = GetSimpleClassName(type); + if (orderedFields.Count > 0) + members.Add(GenerateConstructor(simpleClassName, orderedFields)); - var baseInterface = _codeGenerator.LibraryTypes.IActivator_1.ToTypeSyntax(type.TypeSyntax); + members.Add(GenerateCreateMethod(type, orderedFields)); - var orderedFields = new List(); - var index = 0; - if (type.ActivatorConstructorParameters is { Count: > 0 } parameters) - { - foreach (var arg in parameters) - { - orderedFields.Add(new ConstructorArgument { Type = arg, FieldName = $"_arg{index}", ParameterName = $"arg{index}" }); - index++; - } - } + var classDeclaration = ClassDeclaration(simpleClassName) + .AddBaseListTypes(SimpleBaseType(baseInterface)) + .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword)) + .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes()) + .AddMembers([.. members]); - var members = new List(); - foreach (var field in orderedFields) - { - members.Add( - FieldDeclaration(VariableDeclaration(field.Type, SingletonSeparatedList(VariableDeclarator(field.FieldName)))) - .AddModifiers( - Token(SyntaxKind.PrivateKeyword), - Token(SyntaxKind.ReadOnlyKeyword))); - } - - if (orderedFields.Count > 0) - members.Add(GenerateConstructor(simpleClassName, orderedFields)); + if (type.IsGenericType) + { + classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters); + } - members.Add(GenerateCreateMethod(type, orderedFields)); + return classDeclaration; + } - var classDeclaration = ClassDeclaration(simpleClassName) - .AddBaseListTypes(SimpleBaseType(baseInterface)) - .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword)) - .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) - .AddMembers(members.ToArray()); + public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name); - if (type.IsGenericType) - { - classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters); - } + public static string GetSimpleClassName(string name) => $"Activator_{name}"; - return classDeclaration; - } - - public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => $"Activator_{serializableType.Name}"; + /// + /// Determines whether an activator should be generated for the specified type. + /// + internal static bool ShouldGenerateActivator(ISerializableTypeDescription type) + { + return !type.IsAbstractType + && !type.IsEnumType + && (!type.IsValueType + && type.IsEmptyConstructable + && !type.UseActivator + && type is not GeneratedInvokableDescription + || type.HasActivatorConstructor); + } - private ConstructorDeclarationSyntax GenerateConstructor( - string simpleClassName, - List orderedFields) + private static ConstructorDeclarationSyntax GenerateConstructor( + string simpleClassName, + List orderedFields) + { + var parameters = new List(); + var body = new List(); + foreach (var field in orderedFields) { - var parameters = new List(); - var body = new List(); - foreach (var field in orderedFields) - { - parameters.Add(Parameter(field.ParameterName.ToIdentifier()).WithType(field.Type)); + parameters.Add(Parameter(field.ParameterName.ToIdentifier()).WithType(field.Type)); - body.Add(ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - field.FieldName.ToIdentifierName(), - Unwrapped(field.ParameterName.ToIdentifierName())))); - } + body.Add(ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + field.FieldName.ToIdentifierName(), + Unwrapped(field.ParameterName.ToIdentifierName())))); + } - var constructorDeclaration = ConstructorDeclaration(simpleClassName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters.ToArray()) - .AddBodyStatements(body.ToArray()); + var constructorDeclaration = ConstructorDeclaration(simpleClassName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters([.. parameters]) + .AddBodyStatements([.. body]); - return constructorDeclaration; + return constructorDeclaration; - static ExpressionSyntax Unwrapped(ExpressionSyntax expr) - { - return InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), IdentifierName("UnwrapService")), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(expr) }))); - } + static ExpressionSyntax Unwrapped(ExpressionSyntax expr) + { + return InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), IdentifierName("UnwrapService")), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(expr)]))); } + } - private MemberDeclarationSyntax GenerateCreateMethod(ISerializableTypeDescription type, List orderedFields) + private static MemberDeclarationSyntax GenerateCreateMethod(ISerializableTypeDescription type, List orderedFields) + { + ExpressionSyntax createObject; + if (type.ActivatorConstructorParameters is { Count: > 0 }) { - ExpressionSyntax createObject; - if (type.ActivatorConstructorParameters is { Count: > 0 }) - { - var argList = new List(); - foreach (var field in orderedFields) - { - argList.Add(Argument(field.FieldName.ToIdentifierName())); - } - - createObject = ObjectCreationExpression(type.TypeSyntax).WithArgumentList(ArgumentList(SeparatedList(argList))); - } - else + var argList = new List(); + foreach (var field in orderedFields) { - createObject = type.GetObjectCreationExpression(); + argList.Add(Argument(field.FieldName.ToIdentifierName())); } - return MethodDeclaration(type.TypeSyntax, "Create") - .WithExpressionBody(ArrowExpressionClause(createObject)) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) - .AddModifiers(Token(SyntaxKind.PublicKeyword)); + createObject = ObjectCreationExpression(type.TypeSyntax).WithArgumentList(ArgumentList(SeparatedList(argList))); } + else + { + createObject = type.GetObjectCreationExpression(); + } + + return MethodDeclaration(type.TypeSyntax, "Create") + .WithExpressionBody(ArrowExpressionClause(createObject)) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .AddModifiers(Token(SyntaxKind.PublicKeyword)); } -} \ No newline at end of file +} diff --git a/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs b/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs index 5b40dc120fd..1a8be26ef38 100644 --- a/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs +++ b/src/Orleans.CodeGenerator/ApplicationPartAttributeGenerator.cs @@ -1,29 +1,27 @@ -using System.Collections.Generic; -using Orleans.CodeGenerator.SyntaxGeneration; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.SyntaxGeneration; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal static class ApplicationPartAttributeGenerator { - internal static class ApplicationPartAttributeGenerator + public static List GenerateSyntax(NameSyntax applicationPartAttribute, IEnumerable applicationParts) { - public static List GenerateSyntax(LibraryTypes wellKnownTypes, MetadataModel model) - { - var attributes = new List(); + var attributes = new List(); - foreach (var assemblyName in model.ApplicationParts) - { - // Generate an assembly-level attribute with an instance of that class. - var attribute = AttributeList( - AttributeTargetSpecifier(Token(SyntaxKind.AssemblyKeyword)), - SingletonSeparatedList( - Attribute(wellKnownTypes.ApplicationPartAttribute.ToNameSyntax()) - .AddArgumentListArguments(AttributeArgument(assemblyName.GetLiteralExpression())))); - attributes.Add(attribute); - } - - return attributes; + foreach (var assemblyName in applicationParts) + { + // Generate an assembly-level attribute with an instance of that class. + var attribute = AttributeList( + AttributeTargetSpecifier(Token(SyntaxKind.AssemblyKeyword)), + SingletonSeparatedList( + Attribute(applicationPartAttribute) + .AddArgumentListArguments(AttributeArgument(assemblyName.GetLiteralExpression())))); + attributes.Add(attribute); } + + return attributes; } } diff --git a/src/Orleans.CodeGenerator/CodeGenerator.cs b/src/Orleans.CodeGenerator/CodeGenerator.cs deleted file mode 100644 index 93091f3839b..00000000000 --- a/src/Orleans.CodeGenerator/CodeGenerator.cs +++ /dev/null @@ -1,829 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; -using System.Text; -using System.Threading; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Orleans.CodeGenerator.Diagnostics; -using Orleans.CodeGenerator.Hashing; -using Orleans.CodeGenerator.Model; -using Orleans.CodeGenerator.SyntaxGeneration; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using static Orleans.CodeGenerator.SyntaxGeneration.SymbolExtensions; - -#nullable disable -namespace Orleans.CodeGenerator -{ - public class CodeGeneratorOptions - { - public const string IdAttribute = "Orleans.IdAttribute"; - public const string AliasAttribute = "Orleans.AliasAttribute"; - public const string ImmutableAttribute = "Orleans.ImmutableAttribute"; - public static readonly IReadOnlyList ConstructorAttributes = ["Orleans.OrleansConstructorAttribute", "Microsoft.Extensions.DependencyInjection.ActivatorUtilitiesConstructorAttribute"]; - public GenerateFieldIds GenerateFieldIds { get; set; } - public bool GenerateCompatibilityInvokers { get; set; } - } - - public class CodeGenerator - { - internal const string CodeGeneratorName = "OrleansCodeGen"; - private readonly Dictionary> _namespacedMembers = new(); - private readonly Dictionary _invokableMethodDescriptions = new(); - private readonly HashSet _visitedInterfaces = new(SymbolEqualityComparer.Default); - private readonly List DisabledWarnings = new() { "CS1591", "RS0016", "RS0041" }; - - public CodeGenerator(Compilation compilation, CodeGeneratorOptions options) - { - Compilation = compilation; - Options = options; - LibraryTypes = LibraryTypes.FromCompilation(compilation, options); - MetadataModel = new MetadataModel(); - CopierGenerator = new CopierGenerator(this); - SerializerGenerator = new SerializerGenerator(this); - ProxyGenerator = new ProxyGenerator(this); - InvokableGenerator = new InvokableGenerator(this); - MetadataGenerator = new MetadataGenerator(this); - ActivatorGenerator = new ActivatorGenerator(this); - } - - public Compilation Compilation { get; } - public CodeGeneratorOptions Options { get; } - internal LibraryTypes LibraryTypes { get; } - internal MetadataModel MetadataModel { get; } - internal CopierGenerator CopierGenerator { get; } - internal SerializerGenerator SerializerGenerator { get; } - internal ProxyGenerator ProxyGenerator { get; } - internal InvokableGenerator InvokableGenerator { get; } - internal MetadataGenerator MetadataGenerator { get; } - internal ActivatorGenerator ActivatorGenerator { get; } - - public CompilationUnitSyntax GenerateCode(CancellationToken cancellationToken) - { - var assembliesToExamine = new HashSet(SymbolEqualityComparer.Default); - var compilationAsm = LibraryTypes.Compilation.Assembly; - ComputeAssembliesToExamine(compilationAsm, assembliesToExamine); - - // Expand the set of referenced assemblies - MetadataModel.ApplicationParts.Add(compilationAsm.MetadataName); - foreach (var reference in LibraryTypes.Compilation.References) - { - if (LibraryTypes.Compilation.GetAssemblyOrModuleSymbol(reference) is not IAssemblySymbol asm) - { - continue; - } - - if (asm.GetAttributes(LibraryTypes.ApplicationPartAttribute, out var attrs)) - { - MetadataModel.ApplicationParts.Add(asm.MetadataName); - foreach (var attr in attrs) - { - MetadataModel.ApplicationParts.Add((string)attr.ConstructorArguments.First().Value); - } - } - } - - // The mapping of proxy base types to a mapping of return types to invokable base types. Used to set default invokable base types for each proxy base type. - var proxyBaseTypeInvokableBaseTypes = new Dictionary>(SymbolEqualityComparer.Default); - - foreach (var asm in assembliesToExamine) - { - var containingAssemblyAttributes = asm.GetAttributes(); - - foreach (var symbol in asm.GetDeclaredTypes()) - { - if (GetWellKnownTypeId(symbol) is uint wellKnownTypeId) - { - MetadataModel.WellKnownTypeIds.Add((symbol.ToOpenTypeSyntax(), wellKnownTypeId)); - } - - if (GetAlias(symbol) is string typeAlias) - { - MetadataModel.TypeAliases.Add((symbol.ToOpenTypeSyntax(), typeAlias)); - } - - if (GetCompoundTypeAlias(symbol) is CompoundTypeAliasComponent[] compoundTypeAlias) - { - MetadataModel.CompoundTypeAliases.Add(compoundTypeAlias, symbol.ToOpenTypeSyntax()); - } - - if (FSharpUtilities.IsUnionCase(LibraryTypes, symbol, out var sumType) && ShouldGenerateSerializer(sumType)) - { - if (!Compilation.IsSymbolAccessibleWithin(sumType, Compilation.Assembly)) - { - throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(sumType)); - } - - var typeDescription = new FSharpUtilities.FSharpUnionCaseTypeDescription(Compilation, symbol, LibraryTypes); - MetadataModel.SerializableTypes.Add(typeDescription); - } - else if (ShouldGenerateSerializer(symbol)) - { - // https://learn.microsoft.com/dotnet/api/system.runtime.compilerservices.referenceassemblyattribute - if (containingAssemblyAttributes.Any(attributeData => attributeData.AttributeClass is - { - Name: "ReferenceAssemblyAttribute", - ContainingNamespace: - { - Name: "CompilerServices", - ContainingNamespace: - { - Name: "Runtime", - ContainingNamespace: - { - Name: "System", - ContainingNamespace.IsGlobalNamespace: true - } - } - } - })) - { - // not ALWAYS will be properly processed, therefore emit a warning - throw new OrleansGeneratorDiagnosticAnalysisException(ReferenceAssemblyWithGenerateSerializerDiagnostic.CreateDiagnostic(symbol)); - } - - if (!Compilation.IsSymbolAccessibleWithin(symbol, Compilation.Assembly)) - { - throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(symbol)); - } - - if (FSharpUtilities.IsRecord(LibraryTypes, symbol)) - { - var typeDescription = new FSharpUtilities.FSharpRecordTypeDescription(Compilation, symbol, LibraryTypes); - MetadataModel.SerializableTypes.Add(typeDescription); - } - else - { - // Regular type - var includePrimaryConstructorParameters = ShouldIncludePrimaryConstructorParameters(symbol); - var constructorParameters = ImmutableArray.Empty; - if (includePrimaryConstructorParameters) - { - if (symbol.IsRecord) - { - // If there is a primary constructor then that will be declared before the copy constructor - // A record always generates a copy constructor and marks it as compiler generated - // todo: find an alternative to this magic - var potentialPrimaryConstructor = symbol.Constructors[0]; - if (!potentialPrimaryConstructor.IsImplicitlyDeclared && !potentialPrimaryConstructor.IsCompilerGenerated()) - { - constructorParameters = potentialPrimaryConstructor.Parameters; - } - } - else - { - var annotatedConstructors = symbol.Constructors.Where(ctor => ctor.HasAnyAttribute(LibraryTypes.ConstructorAttributeTypes)).ToList(); - if (annotatedConstructors.Count == 1) - { - constructorParameters = annotatedConstructors[0].Parameters; - } - else - { - // record structs from referenced assemblies do not return IsRecord=true - // above. See https://github.com/dotnet/roslyn/issues/69326 - // So we implement the same heuristics from ShouldIncludePrimaryConstructorParameters - // to detect a primary constructor. - var properties = symbol.GetMembers().OfType().ToImmutableArray(); - var primaryConstructor = symbol.GetMembers() - .OfType() - .Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0) - // Check for a ctor where all parameters have a corresponding compiler-generated prop. - .FirstOrDefault(ctor => ctor.Parameters.All(prm => - properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated()))); - - if (primaryConstructor != null) - constructorParameters = primaryConstructor.Parameters; - } - } - } - - var implicitMemberSelectionStrategy = (Options.GenerateFieldIds, GetGenerateFieldIdsOptionFromType(symbol)) switch - { - (_, GenerateFieldIds.PublicProperties) => GenerateFieldIds.PublicProperties, - (GenerateFieldIds.PublicProperties, _) => GenerateFieldIds.PublicProperties, - _ => GenerateFieldIds.None - }; - var fieldIdAssignmentHelper = new FieldIdAssignmentHelper(symbol, constructorParameters, implicitMemberSelectionStrategy, LibraryTypes); - if (!fieldIdAssignmentHelper.IsValidForSerialization) - { - throw new OrleansGeneratorDiagnosticAnalysisException(CanNotGenerateImplicitFieldIdsDiagnostic.CreateDiagnostic(symbol, fieldIdAssignmentHelper.FailureReason)); - } - - var typeDescription = new SerializableTypeDescription(Compilation, symbol, includePrimaryConstructorParameters, GetDataMembers(fieldIdAssignmentHelper), LibraryTypes); - MetadataModel.SerializableTypes.Add(typeDescription); - } - } - - if (symbol.TypeKind == TypeKind.Interface) - { - VisitInterface(symbol.OriginalDefinition); - } - - if ((symbol.TypeKind == TypeKind.Class || symbol.TypeKind == TypeKind.Struct) - && !symbol.IsAbstract - && (symbol.DeclaredAccessibility == Accessibility.Public || symbol.DeclaredAccessibility == Accessibility.Internal)) - { - if (symbol.HasAttribute(LibraryTypes.RegisterSerializerAttribute)) - { - MetadataModel.DetectedSerializers.Add(symbol); - } - - if (symbol.HasAttribute(LibraryTypes.RegisterActivatorAttribute)) - { - MetadataModel.DetectedActivators.Add(symbol); - } - - if (symbol.HasAttribute(LibraryTypes.RegisterCopierAttribute)) - { - MetadataModel.DetectedCopiers.Add(symbol); - } - - if (symbol.HasAttribute(LibraryTypes.RegisterConverterAttribute)) - { - MetadataModel.DetectedConverters.Add(symbol); - } - - // Find all implementations of invokable interfaces - foreach (var iface in symbol.AllInterfaces) - { - var attribute = iface.GetAttribute( - LibraryTypes.GenerateMethodSerializersAttribute, - inherited: true); - if (attribute != null) - { - MetadataModel.InvokableInterfaceImplementations.Add(symbol); - break; - } - } - } - - GenerateFieldIds GetGenerateFieldIdsOptionFromType(INamedTypeSymbol t) - { - var attribute = t.GetAttribute(LibraryTypes.GenerateSerializerAttribute); - if (attribute is null) - return GenerateFieldIds.None; - - foreach (var namedArgument in attribute.NamedArguments) - { - if (namedArgument.Key == "GenerateFieldIds") - { - var value = namedArgument.Value.Value; - return value == null ? GenerateFieldIds.None : (GenerateFieldIds)(int)value; - } - } - return GenerateFieldIds.None; - } - - bool ShouldGenerateSerializer(INamedTypeSymbol t) => t.HasAttribute(LibraryTypes.GenerateSerializerAttribute); - - bool ShouldIncludePrimaryConstructorParameters(INamedTypeSymbol t) - { - static bool? TestGenerateSerializerAttribute(INamedTypeSymbol t, INamedTypeSymbol at) - { - var attribute = t.GetAttribute(at); - if (attribute != null) - { - foreach (var namedArgument in attribute.NamedArguments) - { - if (namedArgument.Key == "IncludePrimaryConstructorParameters") - { - if (namedArgument.Value.Kind == TypedConstantKind.Primitive && namedArgument.Value.Value is bool b) - { - return b; - } - } - } - } - - // If there is no such named argument, return null so that other attributes have a chance to apply and defaults can be applied. - return null; - } - - if (TestGenerateSerializerAttribute(t, LibraryTypes.GenerateSerializerAttribute) is bool res) - { - return res; - } - - // Default to true for records. - if (t.IsRecord) - return true; - - var properties = t.GetMembers().OfType().ToImmutableArray(); - - return t.GetMembers() - .OfType() - .Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0) - // Check for a ctor where all parameters have a corresponding compiler-generated prop. - .Any(ctor => ctor.Parameters.All(prm => - properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated()))); - } - } - } - - // Generate serializers. - foreach (var type in MetadataModel.SerializableTypes) - { - string ns = type.GeneratedNamespace; - - // Generate a partial serializer class for each serializable type. - var serializer = SerializerGenerator.Generate(type); - AddMember(ns, serializer); - - // Generate a copier for each serializable type. - if (CopierGenerator.GenerateCopier(type, MetadataModel.DefaultCopiers) is { } copier) - AddMember(ns, copier); - - if (!type.IsAbstractType && !type.IsEnumType && (!type.IsValueType && type.IsEmptyConstructable && !type.UseActivator && type is not GeneratedInvokableDescription || type.HasActivatorConstructor)) - { - MetadataModel.ActivatableTypes.Add(type); - - // Generate an activator class for types with default constructor or activator constructor. - var activator = ActivatorGenerator.GenerateActivator(type); - AddMember(ns, activator); - } - } - - // Generate metadata. - var metadataClassNamespace = CodeGeneratorName + "." + SyntaxGeneration.Identifier.SanitizeIdentifierName(Compilation.AssemblyName); - var metadataClass = MetadataGenerator.GenerateMetadata(); - AddMember(ns: metadataClassNamespace, member: metadataClass); - var metadataAttribute = AttributeList() - .WithTarget(AttributeTargetSpecifier(Token(SyntaxKind.AssemblyKeyword))) - .WithAttributes( - SingletonSeparatedList( - Attribute(LibraryTypes.TypeManifestProviderAttribute.ToNameSyntax()) - .AddArgumentListArguments(AttributeArgument(TypeOfExpression(QualifiedName(IdentifierName(metadataClassNamespace), IdentifierName(metadataClass.Identifier.Text))))))); - - var assemblyAttributes = ApplicationPartAttributeGenerator.GenerateSyntax(LibraryTypes, MetadataModel); - assemblyAttributes.Add(metadataAttribute); - - if (assemblyAttributes.Count > 0) - { - assemblyAttributes[0] = assemblyAttributes[0] - .WithLeadingTrivia( - SyntaxFactory.TriviaList( - new List - { - Trivia( - PragmaWarningDirectiveTrivia( - Token(SyntaxKind.DisableKeyword), - SeparatedList(DisabledWarnings.Select(str => - { - var syntaxToken = SyntaxFactory.Literal( - SyntaxFactory.TriviaList(), - str, - str, - SyntaxFactory.TriviaList()); - - return (ExpressionSyntax)SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, syntaxToken); - })), - isActive: true)), - })); - } - - var usings = List(new[] { UsingDirective(ParseName("global::Orleans.Serialization.Codecs")), UsingDirective(ParseName("global::Orleans.Serialization.GeneratedCodeHelpers")) }); - var namespaces = new List(_namespacedMembers.Count); - foreach (var pair in _namespacedMembers) - { - var ns = pair.Key; - var member = pair.Value; - - namespaces.Add(NamespaceDeclaration(ParseName(ns)).WithMembers(List(member)).WithUsings(usings)); - } - - if (namespaces.Count > 0) - { - namespaces[namespaces.Count - 1] = namespaces[namespaces.Count - 1] - .WithTrailingTrivia( - SyntaxFactory.TriviaList( - new List - { - Trivia( - PragmaWarningDirectiveTrivia( - Token(SyntaxKind.RestoreKeyword), - SeparatedList(DisabledWarnings.Select(str => - { - var syntaxToken = SyntaxFactory.Literal( - SyntaxFactory.TriviaList(), - str, - str, - SyntaxFactory.TriviaList()); - - return (ExpressionSyntax)SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, syntaxToken); - })), - isActive: true)), - })); - } - - return CompilationUnit() - .WithAttributeLists(List(assemblyAttributes)) - .WithMembers(List(namespaces)); - } - - public static string GetGeneratedNamespaceName(ITypeSymbol type) => type.GetNamespaceAndNesting() switch - { - { Length: > 0 } ns => $"{CodeGeneratorName}.{ns}", - _ => CodeGeneratorName - }; - - public void AddMember(string ns, MemberDeclarationSyntax member) - { - if (!_namespacedMembers.TryGetValue(ns, out var existing)) - { - existing = _namespacedMembers[ns] = new List(); - } - - existing.Add(member); - } - - private void ComputeAssembliesToExamine(IAssemblySymbol asm, HashSet expandedAssemblies) - { - if (!expandedAssemblies.Add(asm)) - { - return; - } - - if (!asm.GetAttributes(LibraryTypes.GenerateCodeForDeclaringAssemblyAttribute, out var attrs)) return; - - foreach (var attr in attrs) - { - var param = attr.ConstructorArguments.First(); - if (param.Kind != TypedConstantKind.Type) - { - throw new ArgumentException($"Unrecognized argument type in attribute [{attr.AttributeClass.Name}({param.ToCSharpString()})]"); - } - - var type = (ITypeSymbol)param.Value; - - // Recurse on the assemblies which the type was declared in. - var declaringAsm = type.OriginalDefinition.ContainingAssembly; - if (declaringAsm is null) - { - var diagnostic = GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.CreateDiagnostic(attr, type); - throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); - } - else - { - ComputeAssembliesToExamine(declaringAsm, expandedAssemblies); - } - } - } - - // Returns descriptions of all data members (fields and properties) - private static IEnumerable GetDataMembers(FieldIdAssignmentHelper fieldIdAssignmentHelper) - { - var members = new Dictionary<(uint, bool), IMemberDescription>(); - - foreach (var member in fieldIdAssignmentHelper.Members) - { - if (!fieldIdAssignmentHelper.TryGetSymbolKey(member, out var key)) - continue; - var (id, isConstructorParameter) = key; - - // FieldDescription takes precedence over PropertyDescription (never replace) - if (member is IPropertySymbol property && !members.TryGetValue((id, isConstructorParameter), out _)) - { - members[(id, isConstructorParameter)] = new PropertyDescription(id, isConstructorParameter, property); - } - - if (member is IFieldSymbol field) - { - // FieldDescription takes precedence over PropertyDescription (add or replace) - if (!members.TryGetValue((id, isConstructorParameter), out var existing) || existing is PropertyDescription) - { - members[(id, isConstructorParameter)] = new FieldDescription(id, isConstructorParameter, field); - } - } - } - return members.Values; - } - - public uint? GetId(ISymbol memberSymbol) => GetId(LibraryTypes, memberSymbol); - - internal static uint? GetId(LibraryTypes libraryTypes, ISymbol memberSymbol) - { - return memberSymbol.GetAttribute(libraryTypes.IdAttributeType) is { } attr - ? (uint)attr.ConstructorArguments.First().Value - : null; - } - - internal static string CreateHashedMethodId(IMethodSymbol methodSymbol) - { - var methodSignature = Format(methodSymbol); - var hash = XxHash32.Hash(Encoding.UTF8.GetBytes(methodSignature)); - return $"{HexConverter.ToString(hash)}"; - - static string Format(IMethodSymbol methodInfo) - { - var result = new StringBuilder(); - result.Append(methodInfo.ContainingType.ToDisplayName()); - result.Append('.'); - result.Append(methodInfo.Name); - - if (methodInfo.IsGenericMethod) - { - result.Append('<'); - var first = true; - foreach (var typeArgument in methodInfo.TypeArguments) - { - if (!first) result.Append(','); - else first = false; - result.Append(typeArgument.Name); - } - - result.Append('>'); - } - - { - result.Append('('); - var parameters = methodInfo.Parameters; - var first = true; - foreach (var parameter in parameters) - { - if (!first) - { - result.Append(','); - } - - var parameterType = parameter.Type; - switch (parameterType) - { - case ITypeParameterSymbol _: - result.Append(parameterType.Name); - break; - default: - result.Append(parameterType.ToDisplayName()); - break; - } - - first = false; - } - } - - result.Append(')'); - return result.ToString(); - } - } - - private uint? GetWellKnownTypeId(ISymbol symbol) => GetId(symbol); - - public string GetAlias(ISymbol symbol) => (string)symbol.GetAttribute(LibraryTypes.AliasAttribute)?.ConstructorArguments.First().Value; - - private CompoundTypeAliasComponent[] GetCompoundTypeAlias(ISymbol symbol) - { - var attr = symbol.GetAttribute(LibraryTypes.CompoundTypeAliasAttribute); - if (attr is null) - { - return null; - } - - var allArgs = attr.ConstructorArguments; - if (allArgs.Length != 1 || allArgs[0].Values.Length == 0) - { - throw new ArgumentException($"Unsupported arguments in attribute [{attr.AttributeClass.Name}({string.Join(", ", allArgs.Select(a => a.ToCSharpString()))})]"); - } - - var args = allArgs[0].Values; - var result = new CompoundTypeAliasComponent[args.Length]; - for (var i = 0; i < args.Length; i++) - { - var arg = args[i]; - if (arg.IsNull) - { - throw new ArgumentNullException($"Unsupported null argument in attribute [{attr.AttributeClass.Name}({string.Join(", ", allArgs.Select(a => a.ToCSharpString()))})]"); - } - - result[i] = arg.Value switch - { - ITypeSymbol type => new CompoundTypeAliasComponent(type), - string str => new CompoundTypeAliasComponent(str), - _ => throw new ArgumentException($"Unrecognized argument type for argument {arg.ToCSharpString()} in attribute [{attr.AttributeClass.Name}({string.Join(", ", allArgs.Select(a => a.ToCSharpString()))})]"), - }; - } - - return result; - } - - internal static AttributeListSyntax GetGeneratedCodeAttributes() => GeneratedCodeAttributeSyntax; - private static readonly AttributeListSyntax GeneratedCodeAttributeSyntax = - AttributeList().AddAttributes( - Attribute(ParseName("global::System.CodeDom.Compiler.GeneratedCodeAttribute")) - .AddArgumentListArguments( - AttributeArgument(CodeGeneratorName.GetLiteralExpression()), - AttributeArgument(typeof(CodeGenerator).Assembly.GetName().Version.ToString().GetLiteralExpression())), - Attribute(ParseName("global::System.ComponentModel.EditorBrowsableAttribute")) - .AddArgumentListArguments( - AttributeArgument(ParseName("global::System.ComponentModel.EditorBrowsableState").Member("Never"))), - Attribute(ParseName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute")) - ); - - internal static AttributeSyntax GetMethodImplAttributeSyntax() => MethodImplAttributeSyntax; - private static readonly AttributeSyntax MethodImplAttributeSyntax = - Attribute(ParseName("global::System.Runtime.CompilerServices.MethodImplAttribute")) - .AddArgumentListArguments(AttributeArgument(ParseName("global::System.Runtime.CompilerServices.MethodImplOptions").Member("AggressiveInlining"))); - - internal void VisitInterface(INamedTypeSymbol interfaceType) - { - // Get or generate an invokable for the original method definition. - if (!SymbolEqualityComparer.Default.Equals(interfaceType, interfaceType.OriginalDefinition)) - { - interfaceType = interfaceType.OriginalDefinition; - } - - if (!_visitedInterfaces.Add(interfaceType)) - { - return; - } - - foreach (var proxyBase in GetProxyBases(interfaceType)) - { - _ = GetInvokableInterfaceDescription(proxyBase.ProxyBaseType, interfaceType); - } - - /* - foreach (var baseInterface in interfaceType.AllInterfaces) - { - VisitInterface(baseInterface); - } - */ - } - - internal bool TryGetInvokableInterfaceDescription(INamedTypeSymbol interfaceType, out ProxyInterfaceDescription result) - { - if (!TryGetProxyBaseDescription(interfaceType, out var description)) - { - result = null; - return false; - } - - result = GetInvokableInterfaceDescription(description.ProxyBaseType, interfaceType); - return true; - } - - private readonly Dictionary> _interfaceProxyBases = new(SymbolEqualityComparer.Default); - internal List GetProxyBases(INamedTypeSymbol interfaceType) - { - if (_interfaceProxyBases.TryGetValue(interfaceType, out var result)) - { - return result; - } - - result = new List(); - if (interfaceType.GetAttributes(LibraryTypes.GenerateMethodSerializersAttribute, out var attributes, inherited: true)) - { - foreach (var attribute in attributes) - { - var proxyBase = GetProxyBaseDescription(attribute); - if (!result.Contains(proxyBase)) - { - result.Add(proxyBase); - } - } - } - - return result; - } - - internal bool TryGetProxyBaseDescription(INamedTypeSymbol interfaceType, out InvokableMethodProxyBase result) - { - var attribute = interfaceType.GetAttribute(LibraryTypes.GenerateMethodSerializersAttribute, inherited: true); - if (attribute == null) - { - result = null; - return false; - } - - result = GetProxyBaseDescription(attribute); - return true; - } - - private InvokableMethodProxyBase GetProxyBaseDescription(AttributeData attribute) - { - var proxyBaseType = ((INamedTypeSymbol)attribute.ConstructorArguments[0].Value).OriginalDefinition; - var isExtension = (bool)attribute.ConstructorArguments[1].Value; - var invokableBaseTypes = GetInvokableBaseTypes(proxyBaseType); - var descriptor = new InvokableMethodProxyBaseId(proxyBaseType, isExtension); - var description = new InvokableMethodProxyBase(this, descriptor, invokableBaseTypes); - return description; - - Dictionary GetInvokableBaseTypes(INamedTypeSymbol baseClass) - { - // Set the base invokable types which are used if attributes on individual methods do not override them. - if (!MetadataModel.ProxyBaseTypeInvokableBaseTypes.TryGetValue(baseClass, out var invokableBaseTypes)) - { - invokableBaseTypes = new Dictionary(SymbolEqualityComparer.Default); - if (baseClass.GetAttributes(LibraryTypes.DefaultInvokableBaseTypeAttribute, out var invokableBaseTypeAttributes)) - { - foreach (var attr in invokableBaseTypeAttributes) - { - var ctorArgs = attr.ConstructorArguments; - var returnType = (INamedTypeSymbol)ctorArgs[0].Value; - var invokableBaseType = (INamedTypeSymbol)ctorArgs[1].Value; - invokableBaseTypes[returnType] = invokableBaseType; - } - } - - MetadataModel.ProxyBaseTypeInvokableBaseTypes[baseClass] = invokableBaseTypes; - } - - return invokableBaseTypes; - } - } - - internal InvokableMethodProxyBase GetProxyBase(INamedTypeSymbol interfaceType) - { - if (!TryGetProxyBaseDescription(interfaceType, out var result)) - { - throw new InvalidOperationException($"Cannot get proxy base description for a type which does not have or inherit [{nameof(LibraryTypes.GenerateMethodSerializersAttribute)}]"); - } - - return result; - } - - private ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSymbol proxyBaseType, INamedTypeSymbol interfaceType) - { - var originalInterface = interfaceType.OriginalDefinition; - if (MetadataModel.InvokableInterfaces.TryGetValue(originalInterface, out var description)) - { - return description; - } - - description = new ProxyInterfaceDescription(this, proxyBaseType, originalInterface); - MetadataModel.InvokableInterfaces.Add(originalInterface, description); - - // Generate a proxy. - var (generatedClass, proxyDescription) = ProxyGenerator.Generate(description); - - // Emit the generated proxy - if (Compilation.GetTypeByMetadataName(proxyDescription.MetadataName) == null) - { - AddMember(proxyDescription.InterfaceDescription.GeneratedNamespace, generatedClass); - } - - MetadataModel.GeneratedProxies.Add(proxyDescription); - - return description; - } - - internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method) - { - var originalMethod = method.OriginalDefinition; - var proxyBaseInfo = GetProxyBase(interfaceType); - - // For extensions, we want to ensure that the containing type is always the extension. - // This ensures that we will always know which 'component' to get in our SetTarget method. - // If the type is not an extension, use the original method definition's containing type. - // This is the interface where the type was originally defined. - var containingType = proxyBaseInfo.IsExtension ? interfaceType : originalMethod.ContainingType; - - var invokableId = new InvokableMethodId(proxyBaseInfo, containingType, originalMethod); - var interfaceDescription = GetInvokableInterfaceDescription(invokableId.ProxyBase.ProxyBaseType, interfaceType); - - // Get or generate an invokable for the original method definition. - if (!MetadataModel.GeneratedInvokables.TryGetValue(invokableId, out var generatedInvokable)) - { - if (!_invokableMethodDescriptions.TryGetValue(invokableId, out var methodDescription)) - { - methodDescription = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId, containingType); - } - - generatedInvokable = MetadataModel.GeneratedInvokables[invokableId] = InvokableGenerator.Generate(methodDescription); - - if (Compilation.GetTypeByMetadataName(generatedInvokable.MetadataName) == null) - { - // Emit the generated code on-demand. - AddMember(generatedInvokable.GeneratedNamespace, generatedInvokable.ClassDeclarationSyntax); - - // Ensure the type will have a serializer generated for it. - MetadataModel.SerializableTypes.Add(generatedInvokable); - - foreach (var alias in generatedInvokable.CompoundTypeAliases) - { - MetadataModel.CompoundTypeAliases.Add(alias, generatedInvokable.OpenTypeSyntax); - } - } - } - - var proxyMethodDescription = ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method); - - // For backwards compatibility, generate invokers for the specific implementation types as well, where they differ. - if (Options.GenerateCompatibilityInvokers && !SymbolEqualityComparer.Default.Equals(method.OriginalDefinition.ContainingType, interfaceType)) - { - var compatInvokableId = new InvokableMethodId(proxyBaseInfo, interfaceType, method); - var compatMethodDescription = InvokableMethodDescription.Create(compatInvokableId, interfaceType); - var compatInvokable = InvokableGenerator.Generate(compatMethodDescription); - AddMember(compatInvokable.GeneratedNamespace, compatInvokable.ClassDeclarationSyntax); - var alias = - InvokableGenerator.GetCompoundTypeAliasComponents( - compatInvokableId, - interfaceType, - compatMethodDescription.GeneratedMethodId); - MetadataModel.CompoundTypeAliases.Add(alias, compatInvokable.OpenTypeSyntax); - } - - return proxyMethodDescription; - } - } -} diff --git a/src/Orleans.CodeGenerator/CodeGeneratorOptions.cs b/src/Orleans.CodeGenerator/CodeGeneratorOptions.cs new file mode 100644 index 00000000000..c7be35dbbe5 --- /dev/null +++ b/src/Orleans.CodeGenerator/CodeGeneratorOptions.cs @@ -0,0 +1,13 @@ +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +public class CodeGeneratorOptions +{ + public const string IdAttribute = "Orleans.IdAttribute"; + public const string AliasAttribute = "Orleans.AliasAttribute"; + public const string ImmutableAttribute = "Orleans.ImmutableAttribute"; + public static readonly IReadOnlyList ConstructorAttributes = ["Orleans.OrleansConstructorAttribute", "Microsoft.Extensions.DependencyInjection.ActivatorUtilitiesConstructorAttribute"]; + public GenerateFieldIds GenerateFieldIds { get; set; } + public bool GenerateCompatibilityInvokers { get; set; } +} diff --git a/src/Orleans.CodeGenerator/CopierGenerator.cs b/src/Orleans.CodeGenerator/CopierGenerator.cs index e06554bd8c4..5a18d9cb176 100644 --- a/src/Orleans.CodeGenerator/CopierGenerator.cs +++ b/src/Orleans.CodeGenerator/CopierGenerator.cs @@ -1,697 +1,688 @@ -using System.Collections.Generic; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Diagnostics; using Orleans.CodeGenerator.SyntaxGeneration; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using static Orleans.CodeGenerator.InvokableGenerator; using static Orleans.CodeGenerator.SerializerGenerator; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class CopierGenerator(IGeneratorServices generatorServices) { - internal class CopierGenerator - { - private const string BaseTypeCopierFieldName = "_baseTypeCopier"; - private const string ActivatorFieldName = "_activator"; - private const string DeepCopyMethodName = "DeepCopy"; - private readonly CodeGenerator _codeGenerator; + private const string BaseTypeCopierFieldName = "_baseTypeCopier"; + private const string ActivatorFieldName = "_activator"; + private const string DeepCopyMethodName = "DeepCopy"; + private readonly IGeneratorServices _generatorServices = generatorServices; - public CopierGenerator(CodeGenerator codeGenerator) + private LibraryTypes LibraryTypes => _generatorServices.LibraryTypes; + + public ClassDeclarationSyntax? GenerateCopier( + ISerializableTypeDescription type, + Dictionary defaultCopiers) + { + var isShallowCopyable = type.IsShallowCopyable; + if (isShallowCopyable && !type.IsGenericType) { - _codeGenerator = codeGenerator; + defaultCopiers.Add(type, LibraryTypes.ShallowCopier.ToTypeSyntax(type.TypeSyntax)); + return null; } - private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; + var simpleClassName = GetSimpleClassName(type); - public ClassDeclarationSyntax GenerateCopier( - ISerializableTypeDescription type, - Dictionary defaultCopiers) + var members = new List(); + foreach (var member in type.Members) { - var isShallowCopyable = type.IsShallowCopyable; - if (isShallowCopyable && !type.IsGenericType) + if (!member.IsCopyable) { - defaultCopiers.Add(type, LibraryTypes.ShallowCopier.ToTypeSyntax(type.TypeSyntax)); - return null; + continue; } - var simpleClassName = GetSimpleClassName(type); - - var members = new List(); - foreach (var member in type.Members) + if (member is ISerializableMember serializable) { - if (!member.IsCopyable) - { - continue; - } - - if (member is ISerializableMember serializable) - { - members.Add(serializable); - } - else if (member is IFieldDescription or IPropertyDescription) - { - members.Add(new SerializableMember(_codeGenerator, member, members.Count)); - } - else if (member is MethodParameterFieldDescription methodParameter) - { - members.Add(new SerializableMethodMember(methodParameter)); - } + members.Add(serializable); } - - var accessibility = type.Accessibility switch + else if (member is IFieldDescription or IPropertyDescription) { - Accessibility.Public => SyntaxKind.PublicKeyword, - _ => SyntaxKind.InternalKeyword, - }; - - var isExceptionType = type.IsExceptionType && type.SerializationHooks.Count == 0; - - var baseType = isExceptionType ? QualifiedName(AliasQualifiedName("global", IdentifierName("Orleans.Serialization.GeneratedCodeHelpers.OrleansGeneratedCodeHelper")), GenericName(Identifier("ExceptionCopier"), TypeArgumentList(SeparatedList(new[] { type.TypeSyntax, type.BaseType.ToTypeSyntax() })))) - : (isShallowCopyable ? LibraryTypes.ShallowCopier : LibraryTypes.DeepCopier_1).ToTypeSyntax(type.TypeSyntax); + members.Add(new SerializableMember(_generatorServices, member, members.Count)); + } + else if (member is MethodParameterFieldDescription methodParameter) + { + members.Add(new SerializableMethodMember(methodParameter)); + } + } - var classDeclaration = ClassDeclaration(simpleClassName) - .AddBaseListTypes(SimpleBaseType(baseType)) - .AddModifiers(Token(accessibility), Token(SyntaxKind.SealedKeyword)) - .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()); + var accessibility = type.Accessibility switch + { + Accessibility.Public => SyntaxKind.PublicKeyword, + _ => SyntaxKind.InternalKeyword, + }; - if (!isShallowCopyable) - { - var fieldDescriptions = GetFieldDescriptions(type, members, isExceptionType, out var onlyDeepFields); - var fieldDeclarations = GetFieldDeclarations(fieldDescriptions); - var ctor = GenerateConstructor(simpleClassName, fieldDescriptions, isExceptionType); + var isExceptionType = type.IsExceptionType && type.SerializationHooks.Count == 0; - classDeclaration = classDeclaration.AddMembers(fieldDeclarations); + var baseType = isExceptionType ? QualifiedName(AliasQualifiedName("global", IdentifierName("Orleans.Serialization.GeneratedCodeHelpers.OrleansGeneratedCodeHelper")), GenericName(Identifier("ExceptionCopier"), TypeArgumentList(SeparatedList([type.TypeSyntax, type.BaseType.ToTypeSyntax()])))) + : (isShallowCopyable ? LibraryTypes.ShallowCopier : LibraryTypes.DeepCopier_1).ToTypeSyntax(type.TypeSyntax); - if (!isExceptionType) - { - var copyMethod = GenerateMemberwiseDeepCopyMethod(type, fieldDescriptions, members, onlyDeepFields); - classDeclaration = classDeclaration.AddMembers(copyMethod); - } + var classDeclaration = ClassDeclaration(simpleClassName) + .AddBaseListTypes(SimpleBaseType(baseType)) + .AddModifiers(Token(accessibility), Token(SyntaxKind.SealedKeyword)) + .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes()); - if (ctor != null) - classDeclaration = classDeclaration.AddMembers(ctor); + if (!isShallowCopyable) + { + var fieldDescriptions = GetFieldDescriptions(type, members, isExceptionType, out var onlyDeepFields); + var fieldDeclarations = GetFieldDeclarations(fieldDescriptions); + var ctor = GenerateConstructor(simpleClassName, fieldDescriptions, isExceptionType); - if (isExceptionType || !type.IsSealedType) - { - if (GenerateBaseCopierDeepCopyMethod(type, fieldDescriptions, members, isExceptionType) is { } baseCopier) - classDeclaration = classDeclaration.AddMembers(baseCopier); + classDeclaration = classDeclaration.AddMembers(fieldDeclarations); - if (!isExceptionType) - classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(LibraryTypes.BaseCopier_1.ToTypeSyntax(type.TypeSyntax))); - } + if (!isExceptionType) + { + var copyMethod = GenerateMemberwiseDeepCopyMethod(type, fieldDescriptions, members, onlyDeepFields); + classDeclaration = classDeclaration.AddMembers(copyMethod); } - if (type.IsGenericType) + if (ctor != null) + classDeclaration = classDeclaration.AddMembers(ctor); + + if (isExceptionType || !type.IsSealedType) { - classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters); + if (GenerateBaseCopierDeepCopyMethod(type, fieldDescriptions, members, isExceptionType) is { } baseCopier) + classDeclaration = classDeclaration.AddMembers(baseCopier); + + if (!isExceptionType) + classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(LibraryTypes.BaseCopier_1.ToTypeSyntax(type.TypeSyntax))); } + } - return classDeclaration; + if (type.IsGenericType) + { + classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters); } - public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name); + return classDeclaration; + } - public static string GetSimpleClassName(string name) => $"Copier_{name}"; + public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name); - private MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions) - { - return fieldDescriptions.Select(GetFieldDeclaration).ToArray(); + public static string GetSimpleClassName(string name) => $"Copier_{name}"; - static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description) + private static MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions) + { + return [.. fieldDescriptions.Select(GetFieldDeclaration)]; + + static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description) + { + switch (description) { - switch (description) - { - case FieldAccessorDescription accessor when accessor.InitializationSyntax != null: - return - FieldDeclaration(VariableDeclaration(accessor.FieldType, - SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax))))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - case FieldAccessorDescription accessor when accessor.InitializationSyntax == null: - //[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")] - //extern static void SetAmount(External instance, int value); - return - MethodDeclaration( - PredefinedType(Token(SyntaxKind.VoidKeyword)), - accessor.AccessorName) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword)) - .AddAttributeLists(AttributeList(SingletonSeparatedList( - Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor")) - .AddArgumentListArguments( - AttributeArgument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"), - IdentifierName("Method"))), - AttributeArgument( - LiteralExpression( - SyntaxKind.StringLiteralExpression, - Literal($"set_{accessor.FieldName}"))) - .WithNameEquals(NameEquals("Name")))))) - .WithParameterList( - ParameterList(SeparatedList(new[] - { - Parameter(Identifier("instance")).WithType(accessor.ContainingType), - Parameter(Identifier("value")).WithType(description.FieldType) - }))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - default: - return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName)))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - } + case FieldAccessorDescription accessor when accessor.InitializationSyntax != null: + return + FieldDeclaration(VariableDeclaration(accessor.FieldType, + SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + case FieldAccessorDescription accessor when accessor.InitializationSyntax == null: + //[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")] + //extern static void SetAmount(External instance, int value); + return + MethodDeclaration( + PredefinedType(Token(SyntaxKind.VoidKeyword)), + accessor.AccessorName) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword)) + .AddAttributeLists(AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor")) + .AddArgumentListArguments( + AttributeArgument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"), + IdentifierName("Method"))), + AttributeArgument( + LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal($"set_{accessor.FieldName}"))) + .WithNameEquals(NameEquals("Name")))))) + .WithParameterList( + ParameterList(SeparatedList( + [ + Parameter(Identifier("instance")).WithType(accessor.ContainingType), + Parameter(Identifier("value")).WithType(description.FieldType) + ]))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + default: + return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName)))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); } } + } - private ConstructorDeclarationSyntax GenerateConstructor(string simpleClassName, List fieldDescriptions, bool isExceptionType) - { - var codecProviderAdded = false; - var parameters = new List(); - var statements = new List(); + private ConstructorDeclarationSyntax? GenerateConstructor(string simpleClassName, List fieldDescriptions, bool isExceptionType) + { + var codecProviderAdded = false; + var parameters = new List(); + var statements = new List(); - if (isExceptionType) - { - parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax())); - codecProviderAdded = true; - } + if (isExceptionType) + { + parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax())); + codecProviderAdded = true; + } - foreach (var field in fieldDescriptions) + foreach (var field in fieldDescriptions) + { + switch (field) { - switch (field) - { - case GeneratedFieldDescription _ when field.IsInjected: - parameters.Add(Parameter(field.FieldName.ToIdentifier()).WithType(field.FieldType)); - statements.Add(ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - ThisExpression().Member(field.FieldName.ToIdentifierName()), - Unwrapped(field.FieldName.ToIdentifierName())))); - break; - case CopierFieldDescription or BaseCopierFieldDescription when !field.IsInjected: - if (!codecProviderAdded) - { - parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax())); - codecProviderAdded = true; - } - - var copier = InvocationExpression( - IdentifierName("OrleansGeneratedCodeHelper").Member(GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(field.FieldType)))), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(IdentifierName("codecProvider")) }))); - - statements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, field.FieldName.ToIdentifierName(), copier))); - break; - } - } + case GeneratedFieldDescription _ when field.IsInjected: + parameters.Add(Parameter(field.FieldName.ToIdentifier()).WithType(field.FieldType)); + statements.Add(ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + ThisExpression().Member(field.FieldName.ToIdentifierName()), + Unwrapped(field.FieldName.ToIdentifierName())))); + break; + case CopierFieldDescription or BaseCopierFieldDescription when !field.IsInjected: + if (!codecProviderAdded) + { + parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax())); + codecProviderAdded = true; + } - return statements.Count == 0 && !isExceptionType ? null : ConstructorDeclaration(simpleClassName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters.ToArray()) - .AddBodyStatements(statements.ToArray()) - .WithInitializer(isExceptionType ? ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList(SingletonSeparatedList(Argument(IdentifierName("codecProvider"))))) : null); + var copier = InvocationExpression( + IdentifierName("OrleansGeneratedCodeHelper").Member(GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(field.FieldType)))), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(IdentifierName("codecProvider"))]))); - static ExpressionSyntax Unwrapped(ExpressionSyntax expr) - { - return InvocationExpression( - IdentifierName("OrleansGeneratedCodeHelper").Member("UnwrapService"), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(expr) }))); + statements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, field.FieldName.ToIdentifierName(), copier))); + break; } } - private List GetFieldDescriptions( - ISerializableTypeDescription serializableTypeDescription, - List members, - bool isExceptionType, - out bool onlyDeepFields) + return statements.Count == 0 && !isExceptionType ? null : ConstructorDeclaration(simpleClassName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters([.. parameters]) + .AddBodyStatements([.. statements]) + .WithInitializer(isExceptionType ? ConstructorInitializer(SyntaxKind.BaseConstructorInitializer, ArgumentList(SingletonSeparatedList(Argument(IdentifierName("codecProvider"))))) : null); + + static ExpressionSyntax Unwrapped(ExpressionSyntax expr) { - var serializationHooks = serializableTypeDescription.SerializationHooks; - onlyDeepFields = serializableTypeDescription.IsValueType && serializationHooks.Count == 0; + return InvocationExpression( + IdentifierName("OrleansGeneratedCodeHelper").Member("UnwrapService"), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(expr)]))); + } + } - var fields = new List(); + private List GetFieldDescriptions( + ISerializableTypeDescription serializableTypeDescription, + List members, + bool isExceptionType, + out bool onlyDeepFields) + { + var serializationHooks = serializableTypeDescription.SerializationHooks; + onlyDeepFields = serializableTypeDescription.IsValueType && serializationHooks.Count == 0; - if (!isExceptionType && serializableTypeDescription.HasComplexBaseType) - { - fields.Add(GetBaseTypeField(serializableTypeDescription)); - } + var fields = new List(); - if (!serializableTypeDescription.IsImmutable) - { - if (!isExceptionType && serializableTypeDescription.UseActivator && !serializableTypeDescription.IsAbstractType) - { - onlyDeepFields = false; - fields.Add(new ActivatorFieldDescription(LibraryTypes.IActivator_1.ToTypeSyntax(serializableTypeDescription.TypeSyntax), ActivatorFieldName)); - } + if (!isExceptionType && serializableTypeDescription.HasComplexBaseType) + { + fields.Add(GetBaseTypeField(serializableTypeDescription)); + } - // Add a copier field for any field in the target which does not have a static copier. - GetCopierFieldDescriptions(serializableTypeDescription.Members, fields); + if (!serializableTypeDescription.IsImmutable) + { + if (!isExceptionType && serializableTypeDescription.UseActivator && !serializableTypeDescription.IsAbstractType) + { + onlyDeepFields = false; + fields.Add(new ActivatorFieldDescription(LibraryTypes.IActivator_1.ToTypeSyntax(serializableTypeDescription.TypeSyntax), ActivatorFieldName)); } - foreach (var member in members) - { - if (onlyDeepFields && member.IsShallowCopyable) continue; + // Add a copier field for any field in the target which does not have a static copier. + GetCopierFieldDescriptions(serializableTypeDescription.Members, fields); + } - if (member.GetGetterFieldDescription() is { } getterFieldDescription) - { - fields.Add(getterFieldDescription); - } + foreach (var member in members) + { + if (onlyDeepFields && member.IsShallowCopyable) continue; - if (member.GetSetterFieldDescription() is { } setterFieldDescription) - { - fields.Add(setterFieldDescription); - } + if (member.GetGetterFieldDescription() is { } getterFieldDescription) + { + fields.Add(getterFieldDescription); } - for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex) + if (member.GetSetterFieldDescription() is { } setterFieldDescription) { - var hookType = serializationHooks[hookIndex]; - fields.Add(new SerializationHookFieldDescription(hookType.ToTypeSyntax(), $"_hook{hookIndex}")); + fields.Add(setterFieldDescription); } - - return fields; } - private BaseCopierFieldDescription GetBaseTypeField(ISerializableTypeDescription serializableTypeDescription) + for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex) { - var baseType = serializableTypeDescription.BaseType; - if (baseType.HasAttribute(LibraryTypes.GenerateSerializerAttribute) - && (SymbolEqualityComparer.Default.Equals(baseType.ContainingAssembly, LibraryTypes.Compilation.Assembly) || baseType.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) - && baseType is not INamedTypeSymbol { IsGenericType: true }) - { - // Use the concrete generated type and avoid expensive interface dispatch (except for generic types that will fall back to IBaseCopier) - return new(QualifiedName(ParseName(GetGeneratedNamespaceName(baseType)), IdentifierName(GetSimpleClassName(baseType.Name))), true); - } + var hookType = serializationHooks[hookIndex]; + fields.Add(new SerializationHookFieldDescription(hookType.ToTypeSyntax(), $"_hook{hookIndex}")); + } + + return fields; + } - return new(LibraryTypes.BaseCopier_1.ToTypeSyntax(serializableTypeDescription.BaseTypeSyntax)); + private BaseCopierFieldDescription GetBaseTypeField(ISerializableTypeDescription serializableTypeDescription) + { + var baseType = serializableTypeDescription.BaseType; + if (baseType.HasAttribute(LibraryTypes.GenerateSerializerAttribute) + && (SymbolEqualityComparer.Default.Equals(baseType.ContainingAssembly, LibraryTypes.Compilation.Assembly) || baseType.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) + && baseType is not INamedTypeSymbol { IsGenericType: true }) + { + // Use the concrete generated type and avoid expensive interface dispatch (except for generic types that will fall back to IBaseCopier) + return new(QualifiedName(ParseName(GetGeneratedNamespaceName(baseType)), IdentifierName(GetSimpleClassName(baseType.Name))), true); } - public void GetCopierFieldDescriptions(IEnumerable members, List fields) + return new(LibraryTypes.BaseCopier_1.ToTypeSyntax(serializableTypeDescription.BaseTypeSyntax)); + } + + public void GetCopierFieldDescriptions(IEnumerable members, List fields) + { + var fieldIndex = 0; + var uniqueTypes = new HashSet(MemberDescriptionTypeComparer.Default); + foreach (var member in members) { - var fieldIndex = 0; - var uniqueTypes = new HashSet(MemberDescriptionTypeComparer.Default); - foreach (var member in members) + if (!member.IsCopyable) { - if (!member.IsCopyable) - { - continue; - } + continue; + } - var t = member.Type; + var t = member.Type; - if (LibraryTypes.IsShallowCopyable(t)) - continue; + if (LibraryTypes.IsShallowCopyable(t)) + continue; - foreach (var c in LibraryTypes.StaticCopiers) - if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, t)) - goto skip; + foreach (var c in LibraryTypes.StaticCopiers) + if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, t)) + goto skip; - if (member.Symbol.HasAttribute(LibraryTypes.ImmutableAttribute)) - continue; + if (member.Symbol.HasAttribute(LibraryTypes.ImmutableAttribute)) + continue; - if (!uniqueTypes.Add(member)) - continue; + if (!uniqueTypes.Add(member)) + continue; - TypeSyntax copierType; - if (t.HasAttribute(LibraryTypes.GenerateSerializerAttribute) - && (SymbolEqualityComparer.Default.Equals(t.ContainingAssembly, LibraryTypes.Compilation.Assembly) || t.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) - && t is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 0 }) - { - // Use the concrete generated type and avoid expensive interface dispatch (except for complex nested cases that will fall back to IDeepCopier) - SimpleNameSyntax name; - if (t is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) - { - // Construct the full generic type name - name = GenericName(Identifier(GetSimpleClassName(t.Name)), TypeArgumentList(SeparatedList(namedTypeSymbol.TypeArguments.Select(arg => arg.ToTypeSyntax())))); - } - else - { - name = IdentifierName(GetSimpleClassName(t.Name)); - } - copierType = QualifiedName(ParseName(GetGeneratedNamespaceName(t)), name); - } - else if (t is IArrayTypeSymbol { IsSZArray: true } array) - { - copierType = LibraryTypes.ArrayCopier.Construct(array.ElementType).ToTypeSyntax(); - } - else if (LibraryTypes.WellKnownCopiers.FindByUnderlyingType(t) is { } copier) - { - // The copier is not a static copier and is also not a generic copiers. - copierType = copier.CopierType.ToTypeSyntax(); - } - else if (t is INamedTypeSymbol { ConstructedFrom: ISymbol unboundFieldType } named && LibraryTypes.WellKnownCopiers.FindByUnderlyingType(unboundFieldType) is { } genericCopier) + TypeSyntax copierType; + if (t.HasAttribute(LibraryTypes.GenerateSerializerAttribute) + && (SymbolEqualityComparer.Default.Equals(t.ContainingAssembly, LibraryTypes.Compilation.Assembly) || t.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) + && t is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 0 }) + { + // Use the concrete generated type and avoid expensive interface dispatch (except for complex nested cases that will fall back to IDeepCopier) + SimpleNameSyntax name; + if (t is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) { - // Construct the generic copier type using the field's type arguments. - copierType = genericCopier.CopierType.Construct(named.TypeArguments.ToArray()).ToTypeSyntax(); + // Construct the full generic type name + name = GenericName(Identifier(GetSimpleClassName(t.Name)), TypeArgumentList(SeparatedList(namedTypeSymbol.TypeArguments.Select(arg => arg.ToTypeSyntax())))); } else { - // Use the IDeepCopier interface - copierType = LibraryTypes.DeepCopier_1.ToTypeSyntax(member.TypeSyntax); + name = IdentifierName(GetSimpleClassName(t.Name)); } - - fields.Add(new CopierFieldDescription(copierType, $"_copier{fieldIndex++}", t)); -skip:; + copierType = QualifiedName(ParseName(GetGeneratedNamespaceName(t)), name); } - } - - private MemberDeclarationSyntax GenerateMemberwiseDeepCopyMethod( - ISerializableTypeDescription type, - List copierFields, - List members, - bool onlyDeepFields) - { - var returnType = type.TypeSyntax; - - var originalParam = "original".ToIdentifierName(); - var contextParam = "context".ToIdentifierName(); - var resultVar = "result".ToIdentifierName(); - - var body = new List(); - - var membersCopied = false; - if (type.IsAbstractType) + else if (t is IArrayTypeSymbol { IsSZArray: true } array) { - // C#: return context.DeepCopy(original) - body.Add(ReturnStatement(InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam)))))); - membersCopied = true; + copierType = LibraryTypes.ArrayCopier.Construct(array.ElementType).ToTypeSyntax(); } - else if (type.IsUnsealedImmutable) + else if (LibraryTypes.WellKnownCopiers.FindByUnderlyingType(t) is { } copier) { - // C#: return original is null || original.GetType() == typeof(T) ? original : context.DeepCopy(original); - var exactTypeMatch = BinaryExpression(SyntaxKind.EqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax)); - var nullOrTypeMatch = BinaryExpression(SyntaxKind.LogicalOrExpression, BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), exactTypeMatch); - var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam)))); - body.Add(ReturnStatement(ConditionalExpression(nullOrTypeMatch, originalParam, contextCopy))); - membersCopied = true; + // The copier is not a static copier and is also not a generic copiers. + copierType = copier.CopierType.ToTypeSyntax(); } - else if (!type.IsValueType) + else if (t is INamedTypeSymbol { ConstructedFrom: ISymbol unboundFieldType } named && LibraryTypes.WellKnownCopiers.FindByUnderlyingType(unboundFieldType) is { } genericCopier) { - if (type.TrackReferences) - { - // C#: if (context.TryGetCopy(original, out T existing)) return existing; - var tryGetCopy = InvocationExpression( - contextParam.Member("TryGetCopy"), - ArgumentList(SeparatedList(new[] - { - Argument(originalParam), - Argument(DeclarationExpression( - type.TypeSyntax, - SingleVariableDesignation(Identifier("existing")))) - .WithRefKindKeyword(Token(SyntaxKind.OutKeyword)) - }))); - body.Add(IfStatement(tryGetCopy, ReturnStatement("existing".ToIdentifierName()))); - } - else - { - // C#: if (original is null) return null; - body.Add(IfStatement(BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression)))); - } + // Construct the generic copier type using the field's type arguments. + copierType = genericCopier.CopierType.Construct([.. named.TypeArguments]).ToTypeSyntax(); + } + else + { + // Use the IDeepCopier interface + copierType = LibraryTypes.DeepCopier_1.ToTypeSyntax(member.TypeSyntax); + } - if (!type.IsSealedType) - { - // C#: if (original.GetType() != typeof(T)) { return context.DeepCopy(original); } - var exactTypeMatch = BinaryExpression(SyntaxKind.NotEqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax)); - var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam)))); - body.Add(IfStatement(exactTypeMatch, ReturnStatement(contextCopy))); - } + fields.Add(new CopierFieldDescription(copierType, $"_copier{fieldIndex++}", t)); +skip:; + } + } - // C#: var result = _activator.Create(); - body.Add(LocalDeclarationStatement( - VariableDeclaration( - IdentifierName("var"), - SingletonSeparatedList(VariableDeclarator( - resultVar.Identifier, - argumentList: null, - initializer: EqualsValueClause(GetCreateValueExpression(type, copierFields))))))); + private MemberDeclarationSyntax GenerateMemberwiseDeepCopyMethod( + ISerializableTypeDescription type, + List copierFields, + List members, + bool onlyDeepFields) + { + var returnType = type.TypeSyntax; - if (type.TrackReferences) - { - // C#: context.RecordCopy(original, result); - body.Add(ExpressionStatement(InvocationExpression(contextParam.Member("RecordCopy"), ArgumentList(SeparatedList(new[] - { - Argument(originalParam), - Argument(resultVar) - }))))); - } + var originalParam = "original".ToIdentifierName(); + var contextParam = "context".ToIdentifierName(); + var resultVar = "result".ToIdentifierName(); - if (!type.IsSealedType) - { - // C#: DeepCopy(original, result, context); - body.Add(ExpressionStatement(InvocationExpression(IdentifierName("DeepCopy"), - ArgumentList(SeparatedList(new[] { Argument(originalParam), Argument(resultVar), Argument(contextParam) }))))); - body.Add(ReturnStatement(resultVar)); - membersCopied = true; - } - else if (type.HasComplexBaseType) - { - // C#: _baseTypeCopier.DeepCopy(original, result, context); - body.Add( - ExpressionStatement( - InvocationExpression( - BaseTypeCopierFieldName.ToIdentifierName().Member(DeepCopyMethodName), - ArgumentList(SeparatedList(new[] - { - Argument(originalParam), - Argument(resultVar), - Argument(contextParam) - }))))); - } - } - else if (!onlyDeepFields) + var body = new List(); + + var membersCopied = false; + if (type.IsAbstractType) + { + // C#: return context.DeepCopy(original) + body.Add(ReturnStatement(InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam)))))); + membersCopied = true; + } + else if (type.IsUnsealedImmutable) + { + // C#: return original is null || original.GetType() == typeof(T) ? original : context.DeepCopy(original); + var exactTypeMatch = BinaryExpression(SyntaxKind.EqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax)); + var nullOrTypeMatch = BinaryExpression(SyntaxKind.LogicalOrExpression, BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), exactTypeMatch); + var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam)))); + body.Add(ReturnStatement(ConditionalExpression(nullOrTypeMatch, originalParam, contextCopy))); + membersCopied = true; + } + else if (!type.IsValueType) + { + if (type.TrackReferences) { - // C#: var result = _activator.Create(); - // or C#: var result = new TField(); - // or C#: var result = default(TField); - body.Add(LocalDeclarationStatement( - VariableDeclaration( - IdentifierName("var"), - SingletonSeparatedList(VariableDeclarator(resultVar.Identifier) - .WithInitializer(EqualsValueClause(GetCreateValueExpression(type, copierFields))))))); + // C#: if (context.TryGetCopy(original, out T existing)) return existing; + var tryGetCopy = InvocationExpression( + contextParam.Member("TryGetCopy"), + ArgumentList(SeparatedList( + [ + Argument(originalParam), + Argument(DeclarationExpression( + type.TypeSyntax, + SingleVariableDesignation(Identifier("existing")))) + .WithRefKindKeyword(Token(SyntaxKind.OutKeyword)) + ]))); + body.Add(IfStatement(tryGetCopy, ReturnStatement("existing".ToIdentifierName()))); } else { - originalParam = resultVar; + // C#: if (original is null) return null; + body.Add(IfStatement(BinaryExpression(SyntaxKind.IsExpression, originalParam, LiteralExpression(SyntaxKind.NullLiteralExpression)), ReturnStatement(LiteralExpression(SyntaxKind.NullLiteralExpression)))); } - if (!membersCopied) + if (!type.IsSealedType) { - GenerateMemberwiseCopy(type, copierFields, members, originalParam, contextParam, resultVar, body, onlyDeepFields); - body.Add(ReturnStatement(resultVar)); + // C#: if (original.GetType() != typeof(T)) { return context.DeepCopy(original); } + var exactTypeMatch = BinaryExpression(SyntaxKind.NotEqualsExpression, InvocationExpression(originalParam.Member("GetType")), TypeOfExpression(type.TypeSyntax)); + var contextCopy = InvocationExpression(contextParam.Member("DeepCopy"), ArgumentList(SingletonSeparatedList(Argument(originalParam)))); + body.Add(IfStatement(exactTypeMatch, ReturnStatement(contextCopy))); } - var parameters = new[] - { - Parameter(originalParam.Identifier).WithType(type.TypeSyntax), - Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax()) - }; - - return MethodDeclaration(returnType, DeepCopyMethodName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); - } + // C#: var result = _activator.Create(); + body.Add(LocalDeclarationStatement( + VariableDeclaration( + IdentifierName("var"), + SingletonSeparatedList(VariableDeclarator( + resultVar.Identifier, + argumentList: null, + initializer: EqualsValueClause(GetCreateValueExpression(type, copierFields))))))); - private ExpressionSyntax GetCreateValueExpression(ISerializableTypeDescription type, List copierFields) - { - return type.UseActivator switch + if (type.TrackReferences) { - true => InvocationExpression(copierFields.Find(f => f is ActivatorFieldDescription).FieldName.ToIdentifierName().Member("Create")), - false => type.GetObjectCreationExpression() - }; - } - - private MemberDeclarationSyntax GenerateBaseCopierDeepCopyMethod( - ISerializableTypeDescription type, - List copierFields, - List members, - bool isExceptionType) - { - var inputParam = "input".ToIdentifierName(); - var resultParam = "output".ToIdentifierName(); - var contextParam = "context".ToIdentifierName(); - - var body = new List(); + // C#: context.RecordCopy(original, result); + body.Add(ExpressionStatement(InvocationExpression(contextParam.Member("RecordCopy"), ArgumentList(SeparatedList( + [ + Argument(originalParam), + Argument(resultVar) + ]))))); + } - if (type.HasComplexBaseType) + if (!type.IsSealedType) + { + // C#: DeepCopy(original, result, context); + body.Add(ExpressionStatement(InvocationExpression(IdentifierName("DeepCopy"), + ArgumentList(SeparatedList([Argument(originalParam), Argument(resultVar), Argument(contextParam)]))))); + body.Add(ReturnStatement(resultVar)); + membersCopied = true; + } + else if (type.HasComplexBaseType) { // C#: _baseTypeCopier.DeepCopy(original, result, context); body.Add( ExpressionStatement( InvocationExpression( - (isExceptionType ? (ExpressionSyntax)BaseExpression() : IdentifierName(BaseTypeCopierFieldName)).Member(DeepCopyMethodName), - ArgumentList(SeparatedList(new[] - { - Argument(inputParam), - Argument(resultParam), + BaseTypeCopierFieldName.ToIdentifierName().Member(DeepCopyMethodName), + ArgumentList(SeparatedList( + [ + Argument(originalParam), + Argument(resultVar), Argument(contextParam) - }))))); + ]))))); } + } + else if (!onlyDeepFields) + { + // C#: var result = _activator.Create(); + // or C#: var result = new TField(); + // or C#: var result = default(TField); + body.Add(LocalDeclarationStatement( + VariableDeclaration( + IdentifierName("var"), + SingletonSeparatedList(VariableDeclarator(resultVar.Identifier) + .WithInitializer(EqualsValueClause(GetCreateValueExpression(type, copierFields))))))); + } + else + { + originalParam = resultVar; + } - var emptyBodyCount = body.Count; + if (!membersCopied) + { + GenerateMemberwiseCopy(type, copierFields, members, originalParam, contextParam, resultVar, body, onlyDeepFields); + body.Add(ReturnStatement(resultVar)); + } - GenerateMemberwiseCopy( - type, - copierFields, - members, - inputParam, - contextParam, - resultParam, - body); + var parameters = new[] + { + Parameter(originalParam.Identifier).WithType(type.TypeSyntax), + Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax()) + }; + + return MethodDeclaration(returnType, DeepCopyMethodName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + } - if (isExceptionType && body.Count == emptyBodyCount) - return null; + private static ExpressionSyntax GetCreateValueExpression(ISerializableTypeDescription type, List copierFields) + { + return type.UseActivator switch + { + true => InvocationExpression(GetActivatorField(copierFields).FieldName.ToIdentifierName().Member("Create")), + false => type.GetObjectCreationExpression() + }; - var parameters = new[] - { - Parameter(inputParam.Identifier).WithType(type.TypeSyntax), - Parameter(resultParam.Identifier).WithType(type.TypeSyntax), - Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax()) - }; + static GeneratedFieldDescription GetActivatorField(List copierFields) + { + var result = copierFields.Find(f => f is ActivatorFieldDescription); + Debug.Assert(result is not null); + return result!; + } + } - var method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), DeepCopyMethodName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); + private MemberDeclarationSyntax? GenerateBaseCopierDeepCopyMethod( + ISerializableTypeDescription type, + List copierFields, + List members, + bool isExceptionType) + { + var inputParam = "input".ToIdentifierName(); + var resultParam = "output".ToIdentifierName(); + var contextParam = "context".ToIdentifierName(); - if (isExceptionType) - method = method.AddModifiers(Token(SyntaxKind.OverrideKeyword)); + var body = new List(); - return method; + if (type.HasComplexBaseType) + { + // C#: _baseTypeCopier.DeepCopy(original, result, context); + body.Add( + ExpressionStatement( + InvocationExpression( + (isExceptionType ? (ExpressionSyntax)BaseExpression() : IdentifierName(BaseTypeCopierFieldName)).Member(DeepCopyMethodName), + ArgumentList(SeparatedList( + [ + Argument(inputParam), + Argument(resultParam), + Argument(contextParam) + ]))))); } - private void GenerateMemberwiseCopy( - ISerializableTypeDescription type, - List copierFields, - List members, - IdentifierNameSyntax sourceVar, - IdentifierNameSyntax contextVar, - IdentifierNameSyntax destinationVar, - List body, - bool onlyDeepFields = false) + var emptyBodyCount = body.Count; + + GenerateMemberwiseCopy( + type, + copierFields, + members, + inputParam, + contextParam, + resultParam, + body); + + if (isExceptionType && body.Count == emptyBodyCount) + return null; + + var parameters = new[] { - AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopying", body); + Parameter(inputParam.Identifier).WithType(type.TypeSyntax), + Parameter(resultParam.Identifier).WithType(type.TypeSyntax), + Parameter(contextParam.Identifier).WithType(LibraryTypes.CopyContext.ToTypeSyntax()) + }; - var copiers = type.IsUnsealedImmutable ? null : copierFields.OfType() - .Concat(LibraryTypes.StaticCopiers) - .ToList(); + var method = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), DeepCopyMethodName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); - var orderedMembers = members.OrderBy(m => m.Member.FieldId); - foreach (var member in orderedMembers) - { - if (onlyDeepFields && member.IsShallowCopyable) continue; - - var getValueExpression = GenerateMemberCopy( - copierFields, - inputValue: member.GetGetter(sourceVar), - contextVar, - copiers, - member); - var memberAssignment = ExpressionStatement(member.GetSetter(destinationVar, getValueExpression)); - body.Add(memberAssignment); - } + if (isExceptionType) + method = method.AddModifiers(Token(SyntaxKind.OverrideKeyword)); - AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopied", body); - } + return method; + } + + private void GenerateMemberwiseCopy( + ISerializableTypeDescription type, + List copierFields, + List members, + IdentifierNameSyntax sourceVar, + IdentifierNameSyntax contextVar, + IdentifierNameSyntax destinationVar, + List body, + bool onlyDeepFields = false) + { + AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopying", body); + + var copiers = type.IsUnsealedImmutable ? null : copierFields.OfType() + .Concat(LibraryTypes.StaticCopiers) + .ToList(); - public ExpressionSyntax GenerateMemberCopy( - List copierFields, - ExpressionSyntax inputValue, - ExpressionSyntax copyContextVar, - List copiers, - ISerializableMember member) + var orderedMembers = members.OrderBy(m => m.Member.FieldId); + foreach (var member in orderedMembers) { - if (copiers is null || member.IsShallowCopyable) - return inputValue; + if (onlyDeepFields && member.IsShallowCopyable) continue; - var description = member.Member; + var getValueExpression = GenerateMemberCopy( + copierFields, + inputValue: member.GetGetter(sourceVar), + contextVar, + copiers, + member); + var memberAssignment = ExpressionStatement(member.GetSetter(destinationVar, getValueExpression)); + body.Add(memberAssignment); + } - // Copiers can either be static classes or injected into the constructor. - // Either way, the member signatures are the same. - var memberType = description.Type; - var copier = copiers.Find(f => SymbolEqualityComparer.Default.Equals(f.UnderlyingType, memberType)); - ExpressionSyntax getValueExpression; + AddSerializationCallbacks(type, sourceVar, destinationVar, "OnCopied", body); + } - if (copier is null) + public ExpressionSyntax GenerateMemberCopy( + List copierFields, + ExpressionSyntax inputValue, + ExpressionSyntax copyContextVar, + List? copiers, + ISerializableMember member) + { + if (copiers is null || member.IsShallowCopyable) + return inputValue; + + var description = member.Member; + + // Copiers can either be static classes or injected into the constructor. + // Either way, the member signatures are the same. + var memberType = description.Type; + var copier = copiers.Find(f => SymbolEqualityComparer.Default.Equals(f.UnderlyingType, memberType)); + ExpressionSyntax getValueExpression; + + if (copier is null) + { + getValueExpression = InvocationExpression( + copyContextVar.Member(DeepCopyMethodName), + ArgumentList(SeparatedList([Argument(inputValue)]))); + } + else + { + ExpressionSyntax copierExpression; + var staticCopier = LibraryTypes.StaticCopiers.FindByUnderlyingType(memberType); + if (staticCopier != null) { - getValueExpression = InvocationExpression( - copyContextVar.Member(DeepCopyMethodName), - ArgumentList(SeparatedList(new[] { Argument(inputValue) }))); + copierExpression = staticCopier.CopierType.ToNameSyntax(); } else { - ExpressionSyntax copierExpression; - var staticCopier = LibraryTypes.StaticCopiers.FindByUnderlyingType(memberType); - if (staticCopier != null) - { - copierExpression = staticCopier.CopierType.ToNameSyntax(); - } - else - { - var instanceCopier = copierFields.First(f => f is CopierFieldDescription cf && SymbolEqualityComparer.Default.Equals(cf.UnderlyingType, memberType)); - copierExpression = IdentifierName(instanceCopier.FieldName); - } - - getValueExpression = InvocationExpression( - copierExpression.Member(DeepCopyMethodName), - ArgumentList(SeparatedList(new[] { Argument(inputValue), Argument(copyContextVar) }))); - if (!SymbolEqualityComparer.Default.Equals(copier.UnderlyingType, member.Member.Type)) - { - // If the member type type differs from the copier type (eg because the member is an array), cast the result. - getValueExpression = CastExpression(description.TypeSyntax, getValueExpression); - } + var instanceCopier = copierFields.OfType().First(f => SymbolEqualityComparer.Default.Equals(f.UnderlyingType, memberType)); + copierExpression = IdentifierName(instanceCopier.FieldName); } - return getValueExpression; - } - - private void AddSerializationCallbacks(ISerializableTypeDescription type, IdentifierNameSyntax originalInstance, IdentifierNameSyntax resultInstance, string callbackMethodName, List body) - { - var serializationHooks = type.SerializationHooks; - for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex) + getValueExpression = InvocationExpression( + copierExpression.Member(DeepCopyMethodName), + ArgumentList(SeparatedList([Argument(inputValue), Argument(copyContextVar)]))); + if (!SymbolEqualityComparer.Default.Equals(copier.UnderlyingType, member.Member.Type)) { - var hookType = serializationHooks[hookIndex]; - var member = hookType.GetAllMembers(callbackMethodName, Accessibility.Public).FirstOrDefault(); - if (member is null || member.Parameters.Length != 2) - { - continue; - } - - var originalArgument = Argument(originalInstance); - if (member.Parameters[0].RefKind == RefKind.Ref) - { - originalArgument = originalArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } - - var resultArgument = Argument(resultInstance); - if (member.Parameters[1].RefKind == RefKind.Ref) - { - resultArgument = resultArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } - - body.Add(ExpressionStatement(InvocationExpression( - IdentifierName($"_hook{hookIndex}").Member(callbackMethodName), - ArgumentList(SeparatedList(new[] { originalArgument, resultArgument }))))); + // If the member type type differs from the copier type (eg because the member is an array), cast the result. + getValueExpression = CastExpression(description.TypeSyntax, getValueExpression); } } - internal sealed class BaseCopierFieldDescription : GeneratedFieldDescription + return getValueExpression; + } + + private static void AddSerializationCallbacks(ISerializableTypeDescription type, IdentifierNameSyntax originalInstance, IdentifierNameSyntax resultInstance, string callbackMethodName, List body) + { + var serializationHooks = type.SerializationHooks; + for (var hookIndex = 0; hookIndex < serializationHooks.Count; ++hookIndex) { - public BaseCopierFieldDescription(TypeSyntax fieldType, bool concreteType = false) : base(fieldType, BaseTypeCopierFieldName) - => IsInjected = !concreteType; + var hookType = serializationHooks[hookIndex]; + var member = hookType.GetAllMembers(callbackMethodName, Accessibility.Public).FirstOrDefault(); + if (member is null || member.Parameters.Length != 2) + { + continue; + } - public override bool IsInjected { get; } - } + var originalArgument = Argument(originalInstance); + if (member.Parameters[0].RefKind == RefKind.Ref) + { + originalArgument = originalArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); + } - internal sealed class CopierFieldDescription : GeneratedFieldDescription, ICopierDescription - { - public CopierFieldDescription(TypeSyntax fieldType, string fieldName, ITypeSymbol underlyingType) : base(fieldType, fieldName) + var resultArgument = Argument(resultInstance); + if (member.Parameters[1].RefKind == RefKind.Ref) { - UnderlyingType = underlyingType; + resultArgument = resultArgument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } - public ITypeSymbol UnderlyingType { get; } - public override bool IsInjected => false; + body.Add(ExpressionStatement(InvocationExpression( + IdentifierName($"_hook{hookIndex}").Member(callbackMethodName), + ArgumentList(SeparatedList([originalArgument, resultArgument]))))); } + } + internal sealed class BaseCopierFieldDescription(TypeSyntax fieldType, bool concreteType = false) : GeneratedFieldDescription(fieldType, BaseTypeCopierFieldName) + { + public override bool IsInjected { get; } = !concreteType; } + + internal sealed class CopierFieldDescription(TypeSyntax fieldType, string fieldName, ITypeSymbol underlyingType) : GeneratedFieldDescription(fieldType, fieldName), ICopierDescription + { + public ITypeSymbol UnderlyingType { get; } = underlyingType; + public override bool IsInjected => false; + } + } diff --git a/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs index 377fdf8cd5a..621faa3b215 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/CanNotGenerateImplicitFieldIdsDiagnostic.cs @@ -1,4 +1,3 @@ -using System.Linq; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; @@ -7,10 +6,10 @@ public static class CanNotGenerateImplicitFieldIdsDiagnostic { public const string DiagnosticId = DiagnosticRuleId.CanNotGenerateImplicitFieldIds; public const string Title = "Implicit field identifiers could not be generated"; - public const string MessageFormat = "Could not generate implicit field identifiers for the type {0}: {reason}"; + public const string MessageFormat = "Could not generate implicit field identifiers for the type {0}: {1}"; public const string Category = "Usage"; private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true); - internal static Diagnostic CreateDiagnostic(ISymbol symbol, string reason) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), reason); + internal static Diagnostic CreateDiagnostic(ISymbol symbol, string reason, Location? location = null) => Diagnostic.Create(Rule, location ?? symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), reason); } diff --git a/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs index bb0007121e9..36aea5dba92 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic.cs @@ -2,7 +2,6 @@ namespace Orleans.CodeGenerator.Diagnostics; -#nullable disable public static class GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly_Diagnostic { public const string DiagnosticId = DiagnosticRuleId.GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembly; @@ -12,5 +11,5 @@ public static class GenerateCodeForDeclaringAssemblyAttribute_NoDeclaringAssembl private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true); - internal static Diagnostic CreateDiagnostic(AttributeData attribute, ITypeSymbol type) => Diagnostic.Create(Rule, attribute.ApplicationSyntaxReference.SyntaxTree.GetLocation(attribute.ApplicationSyntaxReference.Span), type.ToDisplayString(), attribute.ToString()); + internal static Diagnostic CreateDiagnostic(AttributeData attribute, ITypeSymbol type) => Diagnostic.Create(Rule, attribute.ApplicationSyntaxReference!.SyntaxTree.GetLocation(attribute.ApplicationSyntaxReference.Span), type.ToDisplayString(), attribute.ToString()); } diff --git a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs index 97c77d9e35d..902348bc2e8 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSerializableTypeDiagnostic.cs @@ -1,4 +1,3 @@ -using System.Linq; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; @@ -13,5 +12,5 @@ public static class InaccessibleSerializableTypeDiagnostic private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Descsription); - internal static Diagnostic CreateDiagnostic(ISymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + internal static Diagnostic CreateDiagnostic(ISymbol symbol, Location? location = null) => Diagnostic.Create(Rule, location ?? symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); } diff --git a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs index 214a40f2e04..d6d68eea209 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/InaccessibleSetterDiagnostic.cs @@ -12,5 +12,5 @@ public static class InaccessibleSetterDiagnostic internal static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description); - public static Diagnostic CreateDiagnostic(Location location, string identifier) => Diagnostic.Create(Rule, location, identifier); + public static Diagnostic CreateDiagnostic(Location? location, string identifier) => Diagnostic.Create(Rule, location, identifier); } diff --git a/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs index 1937b0ef2ab..3b260282105 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/InvalidRpcMethodReturnTypeDiagnostic.cs @@ -1,9 +1,7 @@ -using System.Linq; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; -#nullable disable public static class InvalidRpcMethodReturnTypeDiagnostic { public const string RuleId = DiagnosticRuleId.InvalidRpcMethodReturnType; @@ -14,7 +12,7 @@ public static class InvalidRpcMethodReturnTypeDiagnostic internal static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description); - public static Diagnostic CreateDiagnostic(Location location, string returnType, string methodIdentifier, string supportedReturnTypeList) => Diagnostic.Create(Rule, location, returnType, methodIdentifier, supportedReturnTypeList); + public static Diagnostic CreateDiagnostic(Location? location, string returnType, string methodIdentifier, string supportedReturnTypeList) => Diagnostic.Create(Rule, location, returnType, methodIdentifier, supportedReturnTypeList); internal static Diagnostic CreateDiagnostic(InvokableMethodDescription methodDescription) { diff --git a/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs index 0a18a7f18ce..f6acde06951 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/MultipleCancellationTokenParametersDiagnostic.cs @@ -1,4 +1,3 @@ -using System.Linq; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; diff --git a/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs index cc8a5d878ef..0d18b2a3c23 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/ReferenceAssemblyWithGenerateSerializerDiagnostic.cs @@ -1,4 +1,3 @@ -using System.Linq; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; @@ -13,5 +12,5 @@ public static class ReferenceAssemblyWithGenerateSerializerDiagnostic private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description); - internal static Diagnostic CreateDiagnostic(ISymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + internal static Diagnostic CreateDiagnostic(ISymbol symbol, Location? location = null) => Diagnostic.Create(Rule, location ?? symbol.Locations.First(), symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); } diff --git a/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs index 0cd35584d0e..7d083e116a4 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/RpcInterfacePropertyDiagnostic.cs @@ -1,4 +1,3 @@ -using System.Linq; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; diff --git a/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs b/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs index c8ea4fa0228..bfb4da6f531 100644 --- a/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs +++ b/src/Orleans.CodeGenerator/Diagnostics/UnhandledCodeGenerationExceptionDiagnostic.cs @@ -1,4 +1,3 @@ -using System; using Microsoft.CodeAnalysis; namespace Orleans.CodeGenerator.Diagnostics; @@ -13,5 +12,5 @@ public static class UnhandledCodeGenerationExceptionDiagnostic internal static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(RuleId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description); - internal static Diagnostic CreateDiagnostic(Exception exception) => Diagnostic.Create(Rule, location: null, messageArgs: new[] { exception.ToString(), exception.StackTrace }); + internal static Diagnostic CreateDiagnostic(Exception exception) => Diagnostic.Create(Rule, location: null, messageArgs: [exception.ToString(), exception.StackTrace]); } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs b/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs index dc63ee718b8..75a1769378b 100644 --- a/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs +++ b/src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs @@ -1,8 +1,5 @@ -using System; using System.Buffers.Binary; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; using System.Runtime.InteropServices; using System.Text; using Microsoft.CodeAnalysis; @@ -14,6 +11,7 @@ namespace Orleans.CodeGenerator; internal class FieldIdAssignmentHelper { + private const string NoAssignedFieldIdsFailureReason = "no field ids were assigned to any candidate serializable members"; private readonly GenerateFieldIds _implicitMemberSelectionStrategy; private readonly ImmutableArray _constructorParameters; private readonly LibraryTypes _libraryTypes; @@ -30,15 +28,46 @@ public FieldIdAssignmentHelper(INamedTypeSymbol typeSymbol, ImmutableArray _symbols.TryGetValue(symbol, out key); private bool HasMemberWithIdAnnotation() => Array.Exists(_memberSymbols, member => member.HasAttribute(_libraryTypes.IdAttributeType)); + private bool HasCandidateSerializableMembers() + { + foreach (var member in _memberSymbols) + { + if (member is IFieldSymbol) + { + return true; + } + + if (member is IPropertySymbol property + && (_implicitMemberSelectionStrategy != GenerateFieldIds.None + || property.HasAttribute(_libraryTypes.IdAttributeType) + || PropertyUtility.GetMatchingPrimaryConstructorParameter(property, _constructorParameters) is not null + || PropertyUtility.GetMatchingField(property, _memberSymbols) is not null)) + { + return true; + } + } + + return false; + } + private IEnumerable GetMembers(INamedTypeSymbol symbol) { foreach (var member in symbol.GetMembers().OrderBy(m => m.MetadataName)) @@ -68,48 +97,54 @@ private bool ExtractFieldIdAnnotations() { if (member is IPropertySymbol prop) { - var id = CodeGenerator.GetId(_libraryTypes, prop); + var constructorParameter = PropertyUtility.GetMatchingPrimaryConstructorParameter(prop, _constructorParameters); + var id = GeneratedCodeUtilities.GetId(_libraryTypes, prop); if (id.HasValue) { _symbols[member] = (id.Value, false); } - else if (PropertyUtility.GetMatchingPrimaryConstructorParameter(prop, _constructorParameters) is { } prm) + else if (constructorParameter is not null) { - id = CodeGenerator.GetId(_libraryTypes, prop); + var matchingField = PropertyUtility.GetMatchingField(prop, _memberSymbols); + if (matchingField is not null && GeneratedCodeUtilities.GetId(_libraryTypes, matchingField).HasValue) + { + continue; + } + + id = GeneratedCodeUtilities.GetId(_libraryTypes, constructorParameter); if (id.HasValue) { _symbols[member] = (id.Value, true); } else { - _symbols[member] = ((uint)_constructorParameters.IndexOf(prm), true); + _symbols[member] = ((uint)_constructorParameters.IndexOf(constructorParameter), true); } } } if (member is IFieldSymbol field) { - var id = CodeGenerator.GetId(_libraryTypes, field); + var id = GeneratedCodeUtilities.GetId(_libraryTypes, field); var isConstructorParameter = false; + IPropertySymbol? property = null; + IParameterSymbol? constructorParameter = null; - if (!id.HasValue) + property = PropertyUtility.GetMatchingProperty(field, _memberSymbols); + if (property is not null) { - var property = PropertyUtility.GetMatchingProperty(field, _memberSymbols); - if (property is null) - { - continue; - } + constructorParameter = PropertyUtility.GetMatchingPrimaryConstructorParameter(property, _constructorParameters); + } - id = CodeGenerator.GetId(_libraryTypes, property); - if (!id.HasValue) + if (!id.HasValue && property is not null) + { + id = GeneratedCodeUtilities.GetId(_libraryTypes, property); + if (!id.HasValue && constructorParameter is not null) { - var constructorParameter = _constructorParameters.FirstOrDefault(x => x.Name.Equals(property.Name, StringComparison.OrdinalIgnoreCase)); - if (constructorParameter is not null) - { - id = (uint)_constructorParameters.IndexOf(constructorParameter); - isConstructorParameter = true; - } + id = GeneratedCodeUtilities.GetId(_libraryTypes, constructorParameter) + ?? (uint)_constructorParameters.IndexOf(constructorParameter); + isConstructorParameter = true; } } diff --git a/src/Orleans.CodeGenerator/GeneratedCodeUtilities.cs b/src/Orleans.CodeGenerator/GeneratedCodeUtilities.cs new file mode 100644 index 00000000000..a7660c72e5b --- /dev/null +++ b/src/Orleans.CodeGenerator/GeneratedCodeUtilities.cs @@ -0,0 +1,107 @@ +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Hashing; +using Orleans.CodeGenerator.SyntaxGeneration; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Orleans.CodeGenerator.SyntaxGeneration.SymbolExtensions; + +namespace Orleans.CodeGenerator; + +internal static class GeneratedCodeUtilities +{ + internal const string CodeGeneratorName = "OrleansCodeGen"; + + internal static string GetGeneratedNamespaceName(ITypeSymbol type) => type.GetNamespaceAndNesting() switch + { + { Length: > 0 } ns => $"{CodeGeneratorName}.{ns}", + _ => CodeGeneratorName + }; + + internal static uint? GetId(LibraryTypes libraryTypes, ISymbol memberSymbol) + { + return memberSymbol.GetAttribute(libraryTypes.IdAttributeType) is { } attr + ? (uint)attr.ConstructorArguments.First().Value! + : null; + } + + internal static string CreateHashedMethodId(IMethodSymbol methodSymbol) + { + var methodSignature = Format(methodSymbol); + var hash = XxHash32.Hash(Encoding.UTF8.GetBytes(methodSignature)); + return $"{HexConverter.ToString(hash)}"; + + static string Format(IMethodSymbol methodInfo) + { + var result = new StringBuilder(); + result.Append(methodInfo.ContainingType.ToDisplayName()); + result.Append('.'); + result.Append(methodInfo.Name); + + if (methodInfo.IsGenericMethod) + { + result.Append('<'); + var first = true; + foreach (var typeArgument in methodInfo.TypeArguments) + { + if (!first) result.Append(','); + else first = false; + result.Append(typeArgument.Name); + } + + result.Append('>'); + } + + { + result.Append('('); + var parameters = methodInfo.Parameters; + var first = true; + foreach (var parameter in parameters) + { + if (!first) + { + result.Append(','); + } + + var parameterType = parameter.Type; + switch (parameterType) + { + case ITypeParameterSymbol _: + result.Append(parameterType.Name); + break; + default: + result.Append(parameterType.ToDisplayName()); + break; + } + + first = false; + } + } + + result.Append(')'); + return result.ToString(); + } + } + + internal static string? GetAlias(LibraryTypes libraryTypes, ISymbol symbol) => (string?)symbol.GetAttribute(libraryTypes.AliasAttribute)?.ConstructorArguments.First().Value; + + internal static AttributeListSyntax GetGeneratedCodeAttributes() => GeneratedCodeAttributeSyntax; + + private static readonly AttributeListSyntax GeneratedCodeAttributeSyntax = + AttributeList().AddAttributes( + Attribute(ParseName("global::System.CodeDom.Compiler.GeneratedCodeAttribute")) + .AddArgumentListArguments( + AttributeArgument(CodeGeneratorName.GetLiteralExpression()), + AttributeArgument(typeof(GeneratedCodeUtilities).Assembly.GetName().Version.ToString().GetLiteralExpression())), + Attribute(ParseName("global::System.ComponentModel.EditorBrowsableAttribute")) + .AddArgumentListArguments( + AttributeArgument(ParseName("global::System.ComponentModel.EditorBrowsableState").Member("Never"))), + Attribute(ParseName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverageAttribute")) + ); + + internal static AttributeSyntax GetMethodImplAttributeSyntax() => MethodImplAttributeSyntax; + + private static readonly AttributeSyntax MethodImplAttributeSyntax = + Attribute(ParseName("global::System.Runtime.CompilerServices.MethodImplAttribute")) + .AddArgumentListArguments(AttributeArgument(ParseName("global::System.Runtime.CompilerServices.MethodImplOptions").Member("AggressiveInlining"))); +} diff --git a/src/Orleans.CodeGenerator/GeneratedSourceOutput.cs b/src/Orleans.CodeGenerator/GeneratedSourceOutput.cs new file mode 100644 index 00000000000..4fa4c1d269f --- /dev/null +++ b/src/Orleans.CodeGenerator/GeneratedSourceOutput.cs @@ -0,0 +1,383 @@ +using System.Collections.Immutable; +using System.Globalization; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Hashing; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal static class GeneratedSourceOutput +{ + private const string GeneratedCodeWarningDisable = "#pragma warning disable CS1591, RS0016, RS0041"; + private const string GeneratedCodeWarningRestore = "#pragma warning restore CS1591, RS0016, RS0041"; + + internal static void EmitSourceOutputResult(SourceProductionContext context, SourceOutputResult result) + { + if (result.Diagnostic is { } diagnostic) + { + context.ReportDiagnostic(diagnostic); + return; + } + + if (result.SourceEntry is { } sourceEntry) + { + context.AddSource(sourceEntry.HintName, sourceEntry.SourceText); + } + } + + internal static ImmutableArray DeduplicateSerializableTypeResults( + ImmutableArray results) + { + if (results.IsDefaultOrEmpty) + { + return []; + } + + var models = new Dictionary(StringComparer.Ordinal); + var diagnostics = new Dictionary(StringComparer.Ordinal); + foreach (var result in OrderSerializableTypeResultsForCanonicalSelection(results)) + { + if (result.Model is not null) + { + var key = CreateSerializableTypeDedupeKey(result); + if (!models.ContainsKey(key)) + { + models.Add(key, result); + } + } + else if (result.Diagnostic is { } diagnostic) + { + var key = $"{CreateSerializableTypeDedupeKey(result)}|{diagnostic.Id}"; + if (!diagnostics.ContainsKey(key)) + { + diagnostics.Add(key, result); + } + } + } + + return [.. OrderSerializableTypeResultsForEmission(models.Values.Concat(diagnostics.Values))]; + } + + internal static ImmutableArray GetSerializableTypeModels( + ImmutableArray results) + { + if (results.IsDefaultOrEmpty) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(); + foreach (var result in results) + { + if (result.Model is { } model) + { + builder.Add(model); + } + } + + return ModelExtractor.DeduplicateSerializableTypes(builder.ToImmutable()); + } + + internal static IOrderedEnumerable OrderSerializableTypeResultsForCanonicalSelection( + IEnumerable results) + => results + .Where(static result => result.Model is not null || result.Diagnostic is not null) + .OrderBy(static result => result.SourceLocation.SourceOrderGroup) + .ThenBy(static result => result.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static result => result.SourceLocation.Position) + .ThenBy(static result => result.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static result => result.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static result => result.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static result => result.TypeSyntax, StringComparer.Ordinal) + .ThenBy(static result => result.Diagnostic?.Id ?? string.Empty, StringComparer.Ordinal); + + internal static IOrderedEnumerable OrderSerializableTypeResultsForEmission( + IEnumerable results) + => results + .OrderBy(static result => result.Model is null ? 1 : 0) + .ThenBy(static result => result.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static result => result.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static result => result.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static result => result.TypeSyntax, StringComparer.Ordinal) + .ThenBy(static result => result.Diagnostic?.Id ?? string.Empty, StringComparer.Ordinal) + .ThenBy(static result => result.SourceLocation.SourceOrderGroup) + .ThenBy(static result => result.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static result => result.SourceLocation.Position); + + internal static string CreateSerializableTypeDedupeKey(SerializableTypeResult result) + => CreateTypeDedupeKey(result.MetadataIdentity, result.TypeSyntax); + + internal static string CreateTypeDedupeKey(TypeMetadataIdentity metadataIdentity, string typeSyntax) + { + if (!metadataIdentity.IsEmpty) + { + return string.Join( + "|", + "M", + metadataIdentity.AssemblyIdentity ?? string.Empty, + metadataIdentity.AssemblyName ?? string.Empty, + metadataIdentity.MetadataName ?? string.Empty); + } + + return string.Join("|", "S", typeSyntax ?? string.Empty); + } + + internal static ImmutableArray DeduplicateSourceOutputs( + ImmutableArray.Builder sourceEntries) + => DeduplicateSourceOutputs(sourceEntries.ToImmutable()); + + internal static ImmutableArray DeduplicateSourceOutputs( + ImmutableArray sourceEntries) + { + var emittedSourcesByOriginalHintName = new Dictionary>(StringComparer.Ordinal); + var emittedSourceByHintName = new Dictionary(StringComparer.Ordinal); + var result = ImmutableArray.CreateBuilder(); + foreach (var sourceOutput in sourceEntries) + { + if (sourceOutput.SourceEntry is not { } entry) + { + result.Add(sourceOutput); + continue; + } + + var source = entry.Source ?? string.Empty; + if (!emittedSourcesByOriginalHintName.TryGetValue(entry.HintName, out var emittedSources)) + { + emittedSources = new Dictionary(StringComparer.Ordinal); + emittedSourcesByOriginalHintName.Add(entry.HintName, emittedSources); + } + + if (emittedSources.ContainsKey(source)) + { + continue; + } + + if (!emittedSourceByHintName.TryGetValue(entry.HintName, out var emittedSource)) + { + emittedSources.Add(source, entry.HintName); + emittedSourceByHintName.Add(entry.HintName, source); + result.Add(sourceOutput); + continue; + } + + if (string.Equals(emittedSource, source, StringComparison.Ordinal)) + { + emittedSources.Add(source, entry.HintName); + continue; + } + + var uniqueHintName = CreateDistinctSourceHintName(entry.HintName, source, emittedSourceByHintName); + emittedSources.Add(source, uniqueHintName); + emittedSourceByHintName.Add(uniqueHintName, source); + result.Add(SourceOutputResult.FromSource(new GeneratedSourceEntry(uniqueHintName, source))); + } + + return NormalizeSourceOutputs(result.ToImmutable()); + } + + internal static ImmutableArray NormalizeSourceOutputs(ImmutableArray sourceOutputs) + => StructuralEquality.Normalize(sourceOutputs); + + internal static GeneratedSourceEntry CreateSerializableSourceEntry( + string assemblyName, + string typeName, + TypeMetadataIdentity metadataIdentity, + string hintGeneratedNamespace, + int genericArity, + ClassDeclarationSyntax serializer, + ClassDeclarationSyntax? copier, + ClassDeclarationSyntax? activator, + string generatedNamespace) + { + var namespacedMembers = new Dictionary>(StringComparer.Ordinal); + AddMember(namespacedMembers, generatedNamespace, serializer); + if (copier is not null) + { + AddMember(namespacedMembers, generatedNamespace, copier); + } + + if (activator is not null) + { + AddMember(namespacedMembers, generatedNamespace, activator); + } + + return new GeneratedSourceEntry( + CreateSerializableHintName(assemblyName, typeName, metadataIdentity, hintGeneratedNamespace, genericArity), + CreateSourceString(CreateCompilationUnit(namespacedMembers))); + } + + internal static void AddMember( + Dictionary> namespacedMembers, + string ns, + MemberDeclarationSyntax member) + { + var namespaceName = ns ?? string.Empty; + if (!namespacedMembers.TryGetValue(namespaceName, out var members)) + { + members = []; + namespacedMembers[namespaceName] = members; + } + + members.Add(member); + } + + internal static string CreateSourceString(CompilationUnitSyntax unit) + { + return $"{GeneratedCodeWarningDisable}\r\n{unit.NormalizeWhitespace().ToFullString()}\r\n{GeneratedCodeWarningRestore}"; + } + + internal static CompilationUnitSyntax CreateCompilationUnit( + Dictionary> namespacedMembers, + SyntaxList attributeLists = default) + { + var unit = SyntaxFactory.CompilationUnit().WithAttributeLists(attributeLists); + var usingDirectives = SyntaxFactory.List( + [ + SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("global::Orleans.Serialization.Codecs")), + SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("global::Orleans.Serialization.GeneratedCodeHelpers")), + ]); + var members = new List(namespacedMembers.Count); + + foreach (var pair in namespacedMembers.OrderBy(static pair => pair.Key, StringComparer.Ordinal)) + { + if (string.IsNullOrWhiteSpace(pair.Key)) + { + members.AddRange(pair.Value); + continue; + } + + members.Add( + SyntaxFactory.NamespaceDeclaration(SyntaxFactory.ParseName(pair.Key)) + .WithUsings(usingDirectives) + .WithMembers(SyntaxFactory.List(pair.Value))); + } + + return unit.WithMembers(SyntaxFactory.List(members)); + } + + internal static string CreateSerializableHintName( + string assemblyName, + string typeName, + TypeMetadataIdentity metadataIdentity, + string generatedNamespace, + int genericArity) + { + var hash = CreateHintNameHash(metadataIdentity, generatedNamespace, typeName, genericArity); + + return $"{assemblyName}.orleans.ser.{SanitizeHintComponent(typeName)}.{hash}.g.cs"; + } + + internal static string CreateProxyHintName(string assemblyName, ProxyInterfaceDescription interfaceDescription) + { + var interfaceName = interfaceDescription.InterfaceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var hash = CreateHintNameHash( + TypeMetadataIdentity.Create(interfaceDescription.InterfaceType), + interfaceDescription.GeneratedNamespace, + interfaceName, + interfaceDescription.TypeParameters.Count); + + return $"{assemblyName}.orleans.proxy.{SanitizeHintComponent(interfaceName)}.{hash}.g.cs"; + } + + internal static string CreateMetadataHintName(string assemblyName) + => $"{assemblyName}.orleans.metadata.g.cs"; + + internal static string CreateHintNameHash( + TypeMetadataIdentity metadataIdentity, + string generatedNamespace, + string syntaxString, + int genericArity) + { + var builder = new StringBuilder(); + AppendHashComponent(builder, metadataIdentity.AssemblyIdentity); + AppendHashComponent(builder, metadataIdentity.AssemblyName); + AppendHashComponent(builder, metadataIdentity.MetadataName); + AppendHashComponent(builder, generatedNamespace); + AppendHashComponent(builder, syntaxString); + AppendHashComponent(builder, genericArity.ToString(CultureInfo.InvariantCulture)); + + return CreateStableHash(builder.ToString()); + } + + internal static string CreateStableHash(string value) + => HexConverter.ToString(XxHash32.Hash(Encoding.UTF8.GetBytes(value ?? string.Empty))); + + internal static void AppendHashComponent(StringBuilder builder, string value) + { + builder.Append(value?.Length ?? 0); + builder.Append(':'); + builder.Append(value ?? string.Empty); + builder.Append('|'); + } + + internal static string CreateDistinctSourceHintName( + string hintName, + string source, + Dictionary emittedSourceByHintName) + { + var sourceHash = CreateStableHash(source); + var candidate = InsertHintNameComponent(hintName, $"collision.{sourceHash}"); + if (!emittedSourceByHintName.ContainsKey(candidate)) + { + return candidate; + } + + for (var index = 1; ; index++) + { + candidate = InsertHintNameComponent(hintName, $"collision.{sourceHash}.{index}"); + if (!emittedSourceByHintName.ContainsKey(candidate)) + { + return candidate; + } + } + } + + internal static string InsertHintNameComponent(string hintName, string component) + { + const string GeneratedSourceSuffix = ".g.cs"; + if (hintName.EndsWith(GeneratedSourceSuffix, StringComparison.Ordinal)) + { + return $"{hintName.Substring(0, hintName.Length - GeneratedSourceSuffix.Length)}.{component}{GeneratedSourceSuffix}"; + } + + const string SourceSuffix = ".cs"; + if (hintName.EndsWith(SourceSuffix, StringComparison.Ordinal)) + { + return $"{hintName.Substring(0, hintName.Length - SourceSuffix.Length)}.{component}{SourceSuffix}"; + } + + return $"{hintName}.{component}"; + } + + internal static string SanitizeHintComponent(string value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return "generated"; + } + + var builder = new StringBuilder(value.Length); + var previousCharacterWasUnderscore = false; + foreach (var character in value) + { + if (char.IsLetterOrDigit(character) || character is '_' or '.') + { + builder.Append(character); + previousCharacterWasUnderscore = false; + } + else if (!previousCharacterWasUnderscore) + { + builder.Append('_'); + previousCharacterWasUnderscore = true; + } + } + + var result = builder.ToString().Trim('_', '.'); + return result.Length > 0 ? result : "generated"; + } +} + + diff --git a/src/Orleans.CodeGenerator/GeneratorServices.cs b/src/Orleans.CodeGenerator/GeneratorServices.cs new file mode 100644 index 00000000000..7935ff28b9a --- /dev/null +++ b/src/Orleans.CodeGenerator/GeneratorServices.cs @@ -0,0 +1,22 @@ +using Microsoft.CodeAnalysis; + +namespace Orleans.CodeGenerator; + +internal interface IGeneratorServices +{ + Compilation Compilation { get; } + CodeGeneratorOptions Options { get; } + LibraryTypes LibraryTypes { get; } +} + +internal sealed class GeneratorServices(Compilation compilation, CodeGeneratorOptions options, LibraryTypes libraryTypes) : IGeneratorServices +{ + public GeneratorServices(Compilation compilation, CodeGeneratorOptions options) + : this(compilation, options, LibraryTypes.FromCompilation(compilation, options)) + { + } + + public Compilation Compilation { get; } = compilation ?? throw new ArgumentNullException(nameof(compilation)); + public CodeGeneratorOptions Options { get; } = options ?? throw new ArgumentNullException(nameof(options)); + public LibraryTypes LibraryTypes { get; } = libraryTypes ?? throw new ArgumentNullException(nameof(libraryTypes)); +} diff --git a/src/Orleans.CodeGenerator/Hashing/BitOperations.cs b/src/Orleans.CodeGenerator/Hashing/BitOperations.cs index e890109944f..bf3aa3e7181 100644 --- a/src/Orleans.CodeGenerator/Hashing/BitOperations.cs +++ b/src/Orleans.CodeGenerator/Hashing/BitOperations.cs @@ -3,19 +3,18 @@ using System.Runtime.CompilerServices; -namespace Orleans.CodeGenerator.Hashing +namespace Orleans.CodeGenerator.Hashing; + +internal static class BitOperations { - internal static class BitOperations - { - /// - /// Rotates the specified value left by the specified number of bits. - /// Similar in behavior to the x86 instruction ROL. - /// - /// The value to rotate. - /// The number of bits to rotate by. - /// Any value outside the range [0..31] is treated as congruent mod 32. - /// The rotated value. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static uint RotateLeft(uint value, int offset) => (value << offset) | (value >> (32 - offset)); - } + /// + /// Rotates the specified value left by the specified number of bits. + /// Similar in behavior to the x86 instruction ROL. + /// + /// The value to rotate. + /// The number of bits to rotate by. + /// Any value outside the range [0..31] is treated as congruent mod 32. + /// The rotated value. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint RotateLeft(uint value, int offset) => (value << offset) | (value >> (32 - offset)); } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Hashing/HexConverter.cs b/src/Orleans.CodeGenerator/Hashing/HexConverter.cs index fbade005f0e..169f24dcd48 100644 --- a/src/Orleans.CodeGenerator/Hashing/HexConverter.cs +++ b/src/Orleans.CodeGenerator/Hashing/HexConverter.cs @@ -1,39 +1,37 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Runtime.CompilerServices; -namespace Orleans.CodeGenerator.Hashing +namespace Orleans.CodeGenerator.Hashing; + +internal static class HexConverter { - internal static class HexConverter + public static unsafe string ToString(ReadOnlySpan bytes) { - public static unsafe string ToString(ReadOnlySpan bytes) - { - // Adapted from: https://github.com/dotnet/runtime/blob/f156fb9dcf121e536b93ae90bcc5e8e6d5336062/src/libraries/Common/src/System/HexConverter.cs#L196 - - Span result = bytes.Length > 16 ? - new char[bytes.Length * 2].AsSpan() : - stackalloc char[bytes.Length * 2]; + // Adapted from: https://github.com/dotnet/runtime/blob/f156fb9dcf121e536b93ae90bcc5e8e6d5336062/src/libraries/Common/src/System/HexConverter.cs#L196 + + Span result = bytes.Length > 16 ? + new char[bytes.Length * 2].AsSpan() : + stackalloc char[bytes.Length * 2]; - int pos = 0; - foreach (byte b in bytes) - { - ToCharsBuffer(b, result, pos); - pos += 2; - } + int pos = 0; + foreach (byte b in bytes) + { + ToCharsBuffer(b, result, pos); + pos += 2; + } - return result.ToString(); + return result.ToString(); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void ToCharsBuffer(byte value, Span buffer, int startingIndex = 0) - { - var difference = ((value & 0xF0U) << 4) + (value & 0x0FU) - 0x8989U; - var packedResult = (((uint)-(int)difference & 0x7070U) >> 4) + difference + 0xB9B9U; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void ToCharsBuffer(byte value, Span buffer, int startingIndex = 0) + { + var difference = ((value & 0xF0U) << 4) + (value & 0x0FU) - 0x8989U; + var packedResult = (((uint)-(int)difference & 0x7070U) >> 4) + difference + 0xB9B9U; - buffer[startingIndex + 1] = (char)(packedResult & 0xFF); - buffer[startingIndex] = (char)(packedResult >> 8); - } + buffer[startingIndex + 1] = (char)(packedResult & 0xFF); + buffer[startingIndex] = (char)(packedResult >> 8); } } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs b/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs index 34e1f18a3a1..6802ab670e3 100644 --- a/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs +++ b/src/Orleans.CodeGenerator/Hashing/NonCryptographicHashAlgorithm.cs @@ -1,347 +1,342 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; using System.ComponentModel; using System.Diagnostics; -using System.IO; -using System.Threading; -using System.Threading.Tasks; -namespace Orleans.CodeGenerator.Hashing +namespace Orleans.CodeGenerator.Hashing; + +/// +/// Represents a non-cryptographic hash algorithm. +/// +internal abstract class NonCryptographicHashAlgorithm { /// - /// Represents a non-cryptographic hash algorithm. + /// Gets the number of bytes produced from this hash algorithm. /// - internal abstract class NonCryptographicHashAlgorithm - { - /// - /// Gets the number of bytes produced from this hash algorithm. - /// - /// The number of bytes produced from this hash algorithm. - public int HashLengthInBytes { get; } - - /// - /// Called from constructors in derived classes to initialize the - /// class. - /// - /// - /// The number of bytes produced from this hash algorithm. - /// - /// - /// is less than 1. - /// - protected NonCryptographicHashAlgorithm(int hashLengthInBytes) - { - if (hashLengthInBytes < 1) - throw new ArgumentOutOfRangeException(nameof(hashLengthInBytes)); + /// The number of bytes produced from this hash algorithm. + public int HashLengthInBytes { get; } - HashLengthInBytes = hashLengthInBytes; - } - - /// - /// When overridden in a derived class, - /// appends the contents of to the data already - /// processed for the current hash computation. - /// - /// The data to process. - public abstract void Append(ReadOnlySpan source); - - /// - /// When overridden in a derived class, - /// resets the hash computation to the initial state. - /// - public abstract void Reset(); - - /// - /// When overridden in a derived class, - /// writes the computed hash value to - /// without modifying accumulated state. - /// - /// The buffer that receives the computed hash value. - /// - /// - /// Implementations of this method must write exactly - /// bytes to . - /// Do not assume that the buffer was zero-initialized. - /// - /// - /// The class validates the - /// size of the buffer before calling this method, and slices the span - /// down to be exactly in length. - /// - /// - protected abstract void GetCurrentHashCore(Span destination); - - /// - /// Appends the contents of to the data already - /// processed for the current hash computation. - /// - /// The data to process. - /// - /// is . - /// - public void Append(byte[] source) - { - if (source is null) - { - throw new ArgumentNullException(nameof(source)); - } + /// + /// Called from constructors in derived classes to initialize the + /// class. + /// + /// + /// The number of bytes produced from this hash algorithm. + /// + /// + /// is less than 1. + /// + protected NonCryptographicHashAlgorithm(int hashLengthInBytes) + { + if (hashLengthInBytes < 1) + throw new ArgumentOutOfRangeException(nameof(hashLengthInBytes)); - Append(new ReadOnlySpan(source)); - } + HashLengthInBytes = hashLengthInBytes; + } - /// - /// Appends the contents of to the data already - /// processed for the current hash computation. - /// - /// The data to process. - /// - /// is . - /// - /// - public void Append(Stream stream) - { - if (stream is null) - { - throw new ArgumentNullException(nameof(stream)); - } + /// + /// When overridden in a derived class, + /// appends the contents of to the data already + /// processed for the current hash computation. + /// + /// The data to process. + public abstract void Append(ReadOnlySpan source); - byte[] buffer = ArrayPool.Shared.Rent(4096); + /// + /// When overridden in a derived class, + /// resets the hash computation to the initial state. + /// + public abstract void Reset(); - while (true) - { - int read = stream.Read(buffer, 0, buffer.Length); + /// + /// When overridden in a derived class, + /// writes the computed hash value to + /// without modifying accumulated state. + /// + /// The buffer that receives the computed hash value. + /// + /// + /// Implementations of this method must write exactly + /// bytes to . + /// Do not assume that the buffer was zero-initialized. + /// + /// + /// The class validates the + /// size of the buffer before calling this method, and slices the span + /// down to be exactly in length. + /// + /// + protected abstract void GetCurrentHashCore(Span destination); - if (read == 0) - { - break; - } + /// + /// Appends the contents of to the data already + /// processed for the current hash computation. + /// + /// The data to process. + /// + /// is . + /// + public void Append(byte[] source) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); + } - Append(new ReadOnlySpan(buffer, 0, read)); - } + Append(new ReadOnlySpan(source)); + } - ArrayPool.Shared.Return(buffer); + /// + /// Appends the contents of to the data already + /// processed for the current hash computation. + /// + /// The data to process. + /// + /// is . + /// + /// + public void Append(Stream stream) + { + if (stream is null) + { + throw new ArgumentNullException(nameof(stream)); } - /// - /// Asychronously reads the contents of - /// and appends them to the data already - /// processed for the current hash computation. - /// - /// The data to process. - /// - /// The token to monitor for cancellation requests. - /// The default value is . - /// - /// - /// A task that represents the asynchronous append operation. - /// - /// - /// is . - /// - public Task AppendAsync(Stream stream, CancellationToken cancellationToken = default) + byte[] buffer = ArrayPool.Shared.Rent(4096); + + while (true) { - if (stream is null) + int read = stream.Read(buffer, 0, buffer.Length); + + if (read == 0) { - throw new ArgumentNullException(nameof(stream)); + break; } - return AppendAsyncCore(stream, cancellationToken); + Append(new ReadOnlySpan(buffer, 0, read)); } - private async Task AppendAsyncCore(Stream stream, CancellationToken cancellationToken) + ArrayPool.Shared.Return(buffer); + } + + /// + /// Asychronously reads the contents of + /// and appends them to the data already + /// processed for the current hash computation. + /// + /// The data to process. + /// + /// The token to monitor for cancellation requests. + /// The default value is . + /// + /// + /// A task that represents the asynchronous append operation. + /// + /// + /// is . + /// + public Task AppendAsync(Stream stream, CancellationToken cancellationToken = default) + { + if (stream is null) { - byte[] buffer = ArrayPool.Shared.Rent(4096); + throw new ArgumentNullException(nameof(stream)); + } - while (true) - { + return AppendAsyncCore(stream, cancellationToken); + } + + private async Task AppendAsyncCore(Stream stream, CancellationToken cancellationToken) + { + byte[] buffer = ArrayPool.Shared.Rent(4096); + + while (true) + { #if NETCOREAPP - int read = await stream.ReadAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); + int read = await stream.ReadAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false); #else - int read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + int read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); #endif - if (read == 0) - { - break; - } - - Append(new ReadOnlySpan(buffer, 0, read)); + if (read == 0) + { + break; } - ArrayPool.Shared.Return(buffer); + Append(new ReadOnlySpan(buffer, 0, read)); } - /// - /// Gets the current computed hash value without modifying accumulated state. - /// - /// - /// The hash value for the data already provided. - /// - public byte[] GetCurrentHash() + ArrayPool.Shared.Return(buffer); + } + + /// + /// Gets the current computed hash value without modifying accumulated state. + /// + /// + /// The hash value for the data already provided. + /// + public byte[] GetCurrentHash() + { + byte[] ret = new byte[HashLengthInBytes]; + GetCurrentHashCore(ret); + return ret; + } + + /// + /// Attempts to write the computed hash value to + /// without modifying accumulated state. + /// + /// The buffer that receives the computed hash value. + /// + /// On success, receives the number of bytes written to . + /// + /// + /// if is long enough to receive + /// the computed hash value; otherwise, . + /// + public bool TryGetCurrentHash(Span destination, out int bytesWritten) + { + if (destination.Length < HashLengthInBytes) { - byte[] ret = new byte[HashLengthInBytes]; - GetCurrentHashCore(ret); - return ret; + bytesWritten = 0; + return false; } - /// - /// Attempts to write the computed hash value to - /// without modifying accumulated state. - /// - /// The buffer that receives the computed hash value. - /// - /// On success, receives the number of bytes written to . - /// - /// - /// if is long enough to receive - /// the computed hash value; otherwise, . - /// - public bool TryGetCurrentHash(Span destination, out int bytesWritten) - { - if (destination.Length < HashLengthInBytes) - { - bytesWritten = 0; - return false; - } + GetCurrentHashCore(destination.Slice(0, HashLengthInBytes)); + bytesWritten = HashLengthInBytes; + return true; + } - GetCurrentHashCore(destination.Slice(0, HashLengthInBytes)); - bytesWritten = HashLengthInBytes; - return true; + /// + /// Writes the computed hash value to + /// without modifying accumulated state. + /// + /// The buffer that receives the computed hash value. + /// + /// The number of bytes written to , + /// which is always . + /// + /// + /// is shorter than . + /// + public int GetCurrentHash(Span destination) + { + if (destination.Length < HashLengthInBytes) + { + throw new ArgumentException(@"destination too short", nameof(destination)); } - /// - /// Writes the computed hash value to - /// without modifying accumulated state. - /// - /// The buffer that receives the computed hash value. - /// - /// The number of bytes written to , - /// which is always . - /// - /// - /// is shorter than . - /// - public int GetCurrentHash(Span destination) - { - if (destination.Length < HashLengthInBytes) - { - throw new ArgumentException(@"destination too short", nameof(destination)); - } + GetCurrentHashCore(destination.Slice(0, HashLengthInBytes)); + return HashLengthInBytes; + } - GetCurrentHashCore(destination.Slice(0, HashLengthInBytes)); - return HashLengthInBytes; - } + /// + /// Gets the current computed hash value and clears the accumulated state. + /// + /// + /// The hash value for the data already provided. + /// + public byte[] GetHashAndReset() + { + byte[] ret = new byte[HashLengthInBytes]; + GetHashAndResetCore(ret); + return ret; + } - /// - /// Gets the current computed hash value and clears the accumulated state. - /// - /// - /// The hash value for the data already provided. - /// - public byte[] GetHashAndReset() + /// + /// Attempts to write the computed hash value to . + /// If successful, clears the accumulated state. + /// + /// The buffer that receives the computed hash value. + /// + /// On success, receives the number of bytes written to . + /// + /// + /// and clears the accumulated state + /// if is long enough to receive + /// the computed hash value; otherwise, . + /// + public bool TryGetHashAndReset(Span destination, out int bytesWritten) + { + if (destination.Length < HashLengthInBytes) { - byte[] ret = new byte[HashLengthInBytes]; - GetHashAndResetCore(ret); - return ret; + bytesWritten = 0; + return false; } - /// - /// Attempts to write the computed hash value to . - /// If successful, clears the accumulated state. - /// - /// The buffer that receives the computed hash value. - /// - /// On success, receives the number of bytes written to . - /// - /// - /// and clears the accumulated state - /// if is long enough to receive - /// the computed hash value; otherwise, . - /// - public bool TryGetHashAndReset(Span destination, out int bytesWritten) - { - if (destination.Length < HashLengthInBytes) - { - bytesWritten = 0; - return false; - } - - GetHashAndResetCore(destination.Slice(0, HashLengthInBytes)); - bytesWritten = HashLengthInBytes; - return true; - } + GetHashAndResetCore(destination.Slice(0, HashLengthInBytes)); + bytesWritten = HashLengthInBytes; + return true; + } - /// - /// Writes the computed hash value to - /// then clears the accumulated state. - /// - /// The buffer that receives the computed hash value. - /// - /// The number of bytes written to , - /// which is always . - /// - /// - /// is shorter than . - /// - public int GetHashAndReset(Span destination) + /// + /// Writes the computed hash value to + /// then clears the accumulated state. + /// + /// The buffer that receives the computed hash value. + /// + /// The number of bytes written to , + /// which is always . + /// + /// + /// is shorter than . + /// + public int GetHashAndReset(Span destination) + { + if (destination.Length < HashLengthInBytes) { - if (destination.Length < HashLengthInBytes) - { - throw new ArgumentException(@"destination too short", nameof(destination)); - } - - GetHashAndResetCore(destination.Slice(0, HashLengthInBytes)); - return HashLengthInBytes; + throw new ArgumentException(@"destination too short", nameof(destination)); } - /// - /// Writes the computed hash value to - /// then clears the accumulated state. - /// - /// The buffer that receives the computed hash value. - /// - /// - /// Implementations of this method must write exactly - /// bytes to . - /// Do not assume that the buffer was zero-initialized. - /// - /// - /// The class validates the - /// size of the buffer before calling this method, and slices the span - /// down to be exactly in length. - /// - /// - /// The default implementation of this method calls - /// followed by . - /// Overrides of this method do not need to call either of those methods, - /// but must ensure that the caller cannot observe a difference in behavior. - /// - /// - protected virtual void GetHashAndResetCore(Span destination) - { - Debug.Assert(destination.Length == HashLengthInBytes); + GetHashAndResetCore(destination.Slice(0, HashLengthInBytes)); + return HashLengthInBytes; + } - GetCurrentHashCore(destination); - Reset(); - } + /// + /// Writes the computed hash value to + /// then clears the accumulated state. + /// + /// The buffer that receives the computed hash value. + /// + /// + /// Implementations of this method must write exactly + /// bytes to . + /// Do not assume that the buffer was zero-initialized. + /// + /// + /// The class validates the + /// size of the buffer before calling this method, and slices the span + /// down to be exactly in length. + /// + /// + /// The default implementation of this method calls + /// followed by . + /// Overrides of this method do not need to call either of those methods, + /// but must ensure that the caller cannot observe a difference in behavior. + /// + /// + protected virtual void GetHashAndResetCore(Span destination) + { + Debug.Assert(destination.Length == HashLengthInBytes); - /// - /// This method is not supported and should not be called. - /// Call or - /// instead. - /// - /// This method will always throw a . - /// In all cases. - [EditorBrowsable(EditorBrowsableState.Never)] - [Obsolete("Use GetCurrentHash() to retrieve the computed hash code.", true)] + GetCurrentHashCore(destination); + Reset(); + } + + /// + /// This method is not supported and should not be called. + /// Call or + /// instead. + /// + /// This method will always throw a . + /// In all cases. + [EditorBrowsable(EditorBrowsableState.Never)] + [Obsolete("Use GetCurrentHash() to retrieve the computed hash code.", true)] #pragma warning disable CS0809 // Obsolete member overrides non-obsolete member - public override int GetHashCode() + public override int GetHashCode() #pragma warning restore CS0809 // Obsolete member overrides non-obsolete member - { - throw new NotSupportedException("GetHashCode"); - } + { + throw new NotSupportedException("GetHashCode"); } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs b/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs index 19a84375ac7..69cbf4714d8 100644 --- a/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs +++ b/src/Orleans.CodeGenerator/Hashing/XxHash32.State.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers.Binary; using System.Diagnostics; using System.Runtime.CompilerServices; @@ -9,100 +8,99 @@ // Implemented from the specification at // https://github.com/Cyan4973/xxHash/blob/f9155bd4c57e2270a4ffbb176485e5d713de1c9b/doc/xxhash_spec.md -namespace Orleans.CodeGenerator.Hashing +namespace Orleans.CodeGenerator.Hashing; + +internal sealed partial class XxHash32 { - internal sealed partial class XxHash32 + private struct State { - private struct State + private const uint Prime32_1 = 0x9E3779B1; + private const uint Prime32_2 = 0x85EBCA77; + private const uint Prime32_3 = 0xC2B2AE3D; + private const uint Prime32_4 = 0x27D4EB2F; + private const uint Prime32_5 = 0x165667B1; + + private uint _acc1; + private uint _acc2; + private uint _acc3; + private uint _acc4; + private readonly uint _smallAcc; + private bool _hadFullStripe; + + internal State(uint seed) { - private const uint Prime32_1 = 0x9E3779B1; - private const uint Prime32_2 = 0x85EBCA77; - private const uint Prime32_3 = 0xC2B2AE3D; - private const uint Prime32_4 = 0x27D4EB2F; - private const uint Prime32_5 = 0x165667B1; - - private uint _acc1; - private uint _acc2; - private uint _acc3; - private uint _acc4; - private readonly uint _smallAcc; - private bool _hadFullStripe; - - internal State(uint seed) - { - _acc1 = seed + unchecked(Prime32_1 + Prime32_2); - _acc2 = seed + Prime32_2; - _acc3 = seed; - _acc4 = seed - Prime32_1; + _acc1 = seed + unchecked(Prime32_1 + Prime32_2); + _acc2 = seed + Prime32_2; + _acc3 = seed; + _acc4 = seed - Prime32_1; - _smallAcc = seed + Prime32_5; - _hadFullStripe = false; - } + _smallAcc = seed + Prime32_5; + _hadFullStripe = false; + } - internal void ProcessStripe(ReadOnlySpan source) - { - Debug.Assert(source.Length >= StripeSize); - source = source.Slice(0, StripeSize); + internal void ProcessStripe(ReadOnlySpan source) + { + Debug.Assert(source.Length >= StripeSize); + source = source.Slice(0, StripeSize); - _acc1 = ApplyRound(_acc1, source); - _acc2 = ApplyRound(_acc2, source.Slice(sizeof(uint))); - _acc3 = ApplyRound(_acc3, source.Slice(2 * sizeof(uint))); - _acc4 = ApplyRound(_acc4, source.Slice(3 * sizeof(uint))); + _acc1 = ApplyRound(_acc1, source); + _acc2 = ApplyRound(_acc2, source.Slice(sizeof(uint))); + _acc3 = ApplyRound(_acc3, source.Slice(2 * sizeof(uint))); + _acc4 = ApplyRound(_acc4, source.Slice(3 * sizeof(uint))); - _hadFullStripe = true; - } + _hadFullStripe = true; + } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private readonly uint Converge() - { - return - BitOperations.RotateLeft(_acc1, 1) + - BitOperations.RotateLeft(_acc2, 7) + - BitOperations.RotateLeft(_acc3, 12) + - BitOperations.RotateLeft(_acc4, 18); - } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private readonly uint Converge() + { + return + BitOperations.RotateLeft(_acc1, 1) + + BitOperations.RotateLeft(_acc2, 7) + + BitOperations.RotateLeft(_acc3, 12) + + BitOperations.RotateLeft(_acc4, 18); + } - private static uint ApplyRound(uint acc, ReadOnlySpan lane) + private static uint ApplyRound(uint acc, ReadOnlySpan lane) + { + acc += BinaryPrimitives.ReadUInt32LittleEndian(lane) * Prime32_2; + acc = BitOperations.RotateLeft(acc, 13); + acc *= Prime32_1; + + return acc; + } + + internal readonly uint Complete(int length, ReadOnlySpan remaining) + { + uint acc = _hadFullStripe ? Converge() : _smallAcc; + + acc += (uint)length; + + while (remaining.Length >= sizeof(uint)) { - acc += BinaryPrimitives.ReadUInt32LittleEndian(lane) * Prime32_2; - acc = BitOperations.RotateLeft(acc, 13); - acc *= Prime32_1; + uint lane = BinaryPrimitives.ReadUInt32LittleEndian(remaining); + acc += lane * Prime32_3; + acc = BitOperations.RotateLeft(acc, 17); + acc *= Prime32_4; - return acc; + remaining = remaining.Slice(sizeof(uint)); } - internal readonly uint Complete(int length, ReadOnlySpan remaining) + for (int i = 0; i < remaining.Length; i++) { - uint acc = _hadFullStripe ? Converge() : _smallAcc; - - acc += (uint)length; - - while (remaining.Length >= sizeof(uint)) - { - uint lane = BinaryPrimitives.ReadUInt32LittleEndian(remaining); - acc += lane * Prime32_3; - acc = BitOperations.RotateLeft(acc, 17); - acc *= Prime32_4; - - remaining = remaining.Slice(sizeof(uint)); - } - - for (int i = 0; i < remaining.Length; i++) - { - uint lane = remaining[i]; - acc += lane * Prime32_5; - acc = BitOperations.RotateLeft(acc, 11); - acc *= Prime32_1; - } - - acc ^= (acc >> 15); - acc *= Prime32_2; - acc ^= (acc >> 13); - acc *= Prime32_3; - acc ^= (acc >> 16); - - return acc; + uint lane = remaining[i]; + acc += lane * Prime32_5; + acc = BitOperations.RotateLeft(acc, 11); + acc *= Prime32_1; } + + acc ^= acc >> 15; + acc *= Prime32_2; + acc ^= acc >> 13; + acc *= Prime32_3; + acc ^= acc >> 16; + + return acc; } } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Hashing/XxHash32.cs b/src/Orleans.CodeGenerator/Hashing/XxHash32.cs index 39b304681b8..4a7702e194d 100644 --- a/src/Orleans.CodeGenerator/Hashing/XxHash32.cs +++ b/src/Orleans.CodeGenerator/Hashing/XxHash32.cs @@ -1,234 +1,232 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers.Binary; // Implemented from the specification at // https://github.com/Cyan4973/xxHash/blob/f9155bd4c57e2270a4ffbb176485e5d713de1c9b/doc/xxhash_spec.md -namespace Orleans.CodeGenerator.Hashing +namespace Orleans.CodeGenerator.Hashing; + +/// +/// Provides an implementation of the XxHash32 algorithm. +/// +internal sealed partial class XxHash32 : NonCryptographicHashAlgorithm { + private const int HashSize = sizeof(uint); + private const int StripeSize = 4 * sizeof(uint); + + private readonly uint _seed; + private State _state; + private byte[]? _holdback; + private int _length; + /// - /// Provides an implementation of the XxHash32 algorithm. + /// Initializes a new instance of the class. /// - internal sealed partial class XxHash32 : NonCryptographicHashAlgorithm + /// + /// The XxHash32 algorithm supports an optional seed value. + /// Instances created with this constructor use the default seed, zero. + /// + public XxHash32() + : this(0) { - private const int HashSize = sizeof(uint); - private const int StripeSize = 4 * sizeof(uint); - - private readonly uint _seed; - private State _state; - private byte[]? _holdback; - private int _length; - - /// - /// Initializes a new instance of the class. - /// - /// - /// The XxHash32 algorithm supports an optional seed value. - /// Instances created with this constructor use the default seed, zero. - /// - public XxHash32() - : this(0) - { - } + } - /// - /// Initializes a new instance of the class with - /// a specified seed. - /// - /// - /// The hash seed value for computations from this instance. - /// - public XxHash32(int seed) - : base(HashSize) - { - _seed = (uint)seed; - Reset(); - } + /// + /// Initializes a new instance of the class with + /// a specified seed. + /// + /// + /// The hash seed value for computations from this instance. + /// + public XxHash32(int seed) + : base(HashSize) + { + _seed = (uint)seed; + Reset(); + } - /// - /// Resets the hash computation to the initial state. - /// - public override void Reset() - { - _state = new State(_seed); - _length = 0; - } + /// + /// Resets the hash computation to the initial state. + /// + public override void Reset() + { + _state = new State(_seed); + _length = 0; + } - /// - /// Appends the contents of to the data already - /// processed for the current hash computation. - /// - /// The data to process. - public override void Append(ReadOnlySpan source) - { - // Every time we've read 16 bytes, process the stripe. - // Data that isn't perfectly mod-16 gets stored in a holdback - // buffer. + /// + /// Appends the contents of to the data already + /// processed for the current hash computation. + /// + /// The data to process. + public override void Append(ReadOnlySpan source) + { + // Every time we've read 16 bytes, process the stripe. + // Data that isn't perfectly mod-16 gets stored in a holdback + // buffer. - int held = _length & 0x0F; + int held = _length & 0x0F; - if (held != 0) - { - int remain = StripeSize - held; - - if (source.Length >= remain) - { - source.Slice(0, remain).CopyTo(_holdback.AsSpan(held)); - _state.ProcessStripe(_holdback); - - source = source.Slice(remain); - _length += remain; - } - else - { - source.CopyTo(_holdback.AsSpan(held)); - _length += source.Length; - return; - } - } + if (held != 0) + { + int remain = StripeSize - held; - while (source.Length >= StripeSize) + if (source.Length >= remain) { - _state.ProcessStripe(source); - source = source.Slice(StripeSize); - _length += StripeSize; - } + source.Slice(0, remain).CopyTo(_holdback.AsSpan(held)); + _state.ProcessStripe(_holdback); - if (source.Length > 0) + source = source.Slice(remain); + _length += remain; + } + else { - _holdback ??= new byte[StripeSize]; - source.CopyTo(_holdback); + source.CopyTo(_holdback.AsSpan(held)); _length += source.Length; + return; } } - /// - /// Writes the computed hash value to - /// without modifying accumulated state. - /// - protected override void GetCurrentHashCore(Span destination) + while (source.Length >= StripeSize) { - int remainingLength = _length & 0x0F; - ReadOnlySpan remaining = ReadOnlySpan.Empty; - - if (remainingLength > 0) - { - remaining = new ReadOnlySpan(_holdback, 0, remainingLength); - } - - uint acc = _state.Complete(_length, remaining); - BinaryPrimitives.WriteUInt32BigEndian(destination, acc); + _state.ProcessStripe(source); + source = source.Slice(StripeSize); + _length += StripeSize; } - /// - /// Computes the XxHash32 hash of the provided data. - /// - /// The data to hash. - /// The XxHash32 hash of the provided data. - /// - /// is . - /// - public static byte[] Hash(byte[] source) + if (source.Length > 0) { - if (source is null) - { - throw new ArgumentNullException(nameof(source)); - } - - return Hash(new ReadOnlySpan(source)); + _holdback ??= new byte[StripeSize]; + source.CopyTo(_holdback); + _length += source.Length; } + } - /// - /// Computes the XxHash32 hash of the provided data using the provided seed. - /// - /// The data to hash. - /// The seed value for this hash computation. - /// The XxHash32 hash of the provided data. - /// - /// is . - /// - public static byte[] Hash(byte[] source, int seed) - { - if (source is null) - { - throw new ArgumentNullException(nameof(source)); - } + /// + /// Writes the computed hash value to + /// without modifying accumulated state. + /// + protected override void GetCurrentHashCore(Span destination) + { + int remainingLength = _length & 0x0F; + ReadOnlySpan remaining = []; - return Hash(new ReadOnlySpan(source), seed); + if (remainingLength > 0) + { + remaining = new ReadOnlySpan(_holdback, 0, remainingLength); } - /// - /// Computes the XxHash32 hash of the provided data. - /// - /// The data to hash. - /// The seed value for this hash computation. The default is zero. - /// The XxHash32 hash of the provided data. - public static byte[] Hash(ReadOnlySpan source, int seed = 0) + uint acc = _state.Complete(_length, remaining); + BinaryPrimitives.WriteUInt32BigEndian(destination, acc); + } + + /// + /// Computes the XxHash32 hash of the provided data. + /// + /// The data to hash. + /// The XxHash32 hash of the provided data. + /// + /// is . + /// + public static byte[] Hash(byte[] source) + { + if (source is null) { - byte[] ret = new byte[HashSize]; - StaticHash(source, ret, seed); - return ret; + throw new ArgumentNullException(nameof(source)); } - /// - /// Attempts to compute the XxHash32 hash of the provided data into the provided destination. - /// - /// The data to hash. - /// The buffer that receives the computed hash value. - /// - /// On success, receives the number of bytes written to . - /// - /// The seed value for this hash computation. The default is zero. - /// - /// if is long enough to receive - /// the computed hash value (4 bytes); otherwise, . - /// - public static bool TryHash(ReadOnlySpan source, Span destination, out int bytesWritten, int seed = 0) - { - if (destination.Length < HashSize) - { - bytesWritten = 0; - return false; - } + return Hash(new ReadOnlySpan(source)); + } - bytesWritten = StaticHash(source, destination, seed); - return true; + /// + /// Computes the XxHash32 hash of the provided data using the provided seed. + /// + /// The data to hash. + /// The seed value for this hash computation. + /// The XxHash32 hash of the provided data. + /// + /// is . + /// + public static byte[] Hash(byte[] source, int seed) + { + if (source is null) + { + throw new ArgumentNullException(nameof(source)); } - /// - /// Computes the XxHash32 hash of the provided data into the provided destination. - /// - /// The data to hash. - /// The buffer that receives the computed hash value. - /// The seed value for this hash computation. The default is zero. - /// - /// The number of bytes written to . - /// - public static int Hash(ReadOnlySpan source, Span destination, int seed = 0) - { - if (destination.Length < HashSize) - throw new ArgumentException(@"destination too short", nameof(destination)); + return Hash(new ReadOnlySpan(source), seed); + } - return StaticHash(source, destination, seed); - } + /// + /// Computes the XxHash32 hash of the provided data. + /// + /// The data to hash. + /// The seed value for this hash computation. The default is zero. + /// The XxHash32 hash of the provided data. + public static byte[] Hash(ReadOnlySpan source, int seed = 0) + { + byte[] ret = new byte[HashSize]; + StaticHash(source, ret, seed); + return ret; + } - private static int StaticHash(ReadOnlySpan source, Span destination, int seed) + /// + /// Attempts to compute the XxHash32 hash of the provided data into the provided destination. + /// + /// The data to hash. + /// The buffer that receives the computed hash value. + /// + /// On success, receives the number of bytes written to . + /// + /// The seed value for this hash computation. The default is zero. + /// + /// if is long enough to receive + /// the computed hash value (4 bytes); otherwise, . + /// + public static bool TryHash(ReadOnlySpan source, Span destination, out int bytesWritten, int seed = 0) + { + if (destination.Length < HashSize) { - int totalLength = source.Length; - State state = new State((uint)seed); + bytesWritten = 0; + return false; + } - while (source.Length >= StripeSize) - { - state.ProcessStripe(source); - source = source.Slice(StripeSize); - } + bytesWritten = StaticHash(source, destination, seed); + return true; + } - uint val = state.Complete(totalLength, source); - BinaryPrimitives.WriteUInt32BigEndian(destination, val); - return HashSize; + /// + /// Computes the XxHash32 hash of the provided data into the provided destination. + /// + /// The data to hash. + /// The buffer that receives the computed hash value. + /// The seed value for this hash computation. The default is zero. + /// + /// The number of bytes written to . + /// + public static int Hash(ReadOnlySpan source, Span destination, int seed = 0) + { + if (destination.Length < HashSize) + throw new ArgumentException(@"destination too short", nameof(destination)); + + return StaticHash(source, destination, seed); + } + + private static int StaticHash(ReadOnlySpan source, Span destination, int seed) + { + int totalLength = source.Length; + State state = new State((uint)seed); + + while (source.Length >= StripeSize) + { + state.ProcessStripe(source); + source = source.Slice(StripeSize); } + + uint val = state.Complete(totalLength, source); + BinaryPrimitives.WriteUInt32BigEndian(destination, val); + return HashSize; } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/InvokableGenerator.cs b/src/Orleans.CodeGenerator/InvokableGenerator.cs index cd4db8407bf..e3994e96741 100644 --- a/src/Orleans.CodeGenerator/InvokableGenerator.cs +++ b/src/Orleans.CodeGenerator/InvokableGenerator.cs @@ -2,943 +2,917 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System; -using System.Collections.Generic; -using System.Linq; using Orleans.CodeGenerator.Diagnostics; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using System.Linq.Expressions; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Generates RPC stub objects called invokers. +/// +internal class InvokableGenerator(ProxyGenerationContext generationContext) { - /// - /// Generates RPC stub objects called invokers. - /// - internal class InvokableGenerator - { - private readonly CodeGenerator _codeGenerator; + private readonly ProxyGenerationContext _generationContext = generationContext; + + private LibraryTypes LibraryTypes => _generationContext.LibraryTypes; - public InvokableGenerator(CodeGenerator codeGenerator) + public GeneratedInvokableDescription Generate(InvokableMethodDescription invokableMethodInfo) + { + var method = invokableMethodInfo.Method; + var generatedClassName = GetSimpleClassName(invokableMethodInfo); + + var baseClassType = GetBaseClassType(invokableMethodInfo); + var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo); + var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions); + var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType); + var accessibility = GetAccessibility(method); + var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo, invokableMethodInfo.Key); + + List serializationHooks = new(); + if (baseClassType.GetAttributes(LibraryTypes.SerializationCallbacksAttribute, out var hookAttributes)) { - _codeGenerator = codeGenerator; + foreach (var hookAttribute in hookAttributes) + { + var hookType = (INamedTypeSymbol)hookAttribute.ConstructorArguments[0].Value!; + serializationHooks.Add(hookType); + } } - private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; + var targetField = fieldDescriptions.OfType().Single(); - public GeneratedInvokableDescription Generate(InvokableMethodDescription invokableMethodInfo) + var accessibilityKind = accessibility switch + { + Accessibility.Public => SyntaxKind.PublicKeyword, + _ => SyntaxKind.InternalKeyword, + }; + + var classDeclaration = GetClassDeclarationSyntax( + invokableMethodInfo, + generatedClassName, + baseClassType, + fieldDescriptions, + fields, + ctor, + compoundTypeAliases, + targetField, + accessibilityKind); + + string? returnValueInitializerMethod = null; + if (baseClassType.GetAttribute(LibraryTypes.ReturnValueProxyAttribute) is { ConstructorArguments: { Length: > 0 } attrArgs }) { - var method = invokableMethodInfo.Method; - var generatedClassName = GetSimpleClassName(invokableMethodInfo); + returnValueInitializerMethod = (string?)attrArgs[0].Value; + } - var baseClassType = GetBaseClassType(invokableMethodInfo); - var fieldDescriptions = GetFieldDescriptions(invokableMethodInfo); - var fields = GetFieldDeclarations(invokableMethodInfo, fieldDescriptions); - var (ctor, ctorArgs) = GenerateConstructor(generatedClassName, invokableMethodInfo, baseClassType); - var accessibility = GetAccessibility(method); - var compoundTypeAliases = GetCompoundTypeAliasAttributeArguments(invokableMethodInfo, invokableMethodInfo.Key); + while (baseClassType.HasAttribute(LibraryTypes.SerializerTransparentAttribute)) + { + baseClassType = baseClassType.BaseType!; + } - List serializationHooks = new(); - if (baseClassType.GetAttributes(LibraryTypes.SerializationCallbacksAttribute, out var hookAttributes)) + var invokerDescription = new GeneratedInvokableDescription( + invokableMethodInfo, + accessibility, + generatedClassName, + GeneratedCodeUtilities.GetGeneratedNamespaceName(invokableMethodInfo.ContainingInterface), + [.. fieldDescriptions.OfType()], + serializationHooks, + baseClassType, + ctorArgs, + compoundTypeAliases, + returnValueInitializerMethod, + classDeclaration); + return invokerDescription; + + static Accessibility GetAccessibility(IMethodSymbol methodSymbol) + { + Accessibility accessibility = methodSymbol.DeclaredAccessibility; + var t = methodSymbol.ContainingType; + while (t is not null) { - foreach (var hookAttribute in hookAttributes) + if ((int)t.DeclaredAccessibility < (int)accessibility) { - var hookType = (INamedTypeSymbol)hookAttribute.ConstructorArguments[0].Value; - serializationHooks.Add(hookType); + accessibility = t.DeclaredAccessibility; } - } - var targetField = fieldDescriptions.OfType().Single(); + t = t.ContainingType; + } - var accessibilityKind = accessibility switch - { - Accessibility.Public => SyntaxKind.PublicKeyword, - _ => SyntaxKind.InternalKeyword, - }; + return accessibility; + } + } - var classDeclaration = GetClassDeclarationSyntax( - invokableMethodInfo, - generatedClassName, - baseClassType, - fieldDescriptions, - fields, - ctor, - compoundTypeAliases, - targetField, - accessibilityKind); - - string returnValueInitializerMethod = null; - if (baseClassType.GetAttribute(LibraryTypes.ReturnValueProxyAttribute) is { ConstructorArguments: { Length: > 0 } attrArgs }) - { - returnValueInitializerMethod = (string)attrArgs[0].Value; - } + private ClassDeclarationSyntax GetClassDeclarationSyntax( + InvokableMethodDescription method, + string generatedClassName, + INamedTypeSymbol baseClassType, + List fieldDescriptions, + MemberDeclarationSyntax[] fields, + ConstructorDeclarationSyntax? ctor, + List compoundTypeAliases, + TargetFieldDescription targetField, + SyntaxKind accessibilityKind) + { + var classDeclaration = ClassDeclaration(generatedClassName) + .AddBaseListTypes(SimpleBaseType(baseClassType.ToTypeSyntax(method.TypeParameterSubstitutions))) + .AddModifiers(Token(accessibilityKind), Token(SyntaxKind.SealedKeyword)) + .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes()) + .AddMembers(fields); - while (baseClassType.HasAttribute(LibraryTypes.SerializerTransparentAttribute)) - { - baseClassType = baseClassType.BaseType; - } + foreach (var alias in compoundTypeAliases) + { + classDeclaration = classDeclaration.AddAttributeLists( + AttributeList(SingletonSeparatedList(GetCompoundTypeAliasAttribute(alias)))); + } - var invokerDescription = new GeneratedInvokableDescription( - invokableMethodInfo, - accessibility, - generatedClassName, - CodeGenerator.GetGeneratedNamespaceName(invokableMethodInfo.ContainingInterface), - fieldDescriptions.OfType().ToList(), - serializationHooks, - baseClassType, - ctorArgs, - compoundTypeAliases, - returnValueInitializerMethod, - classDeclaration); - return invokerDescription; - - static Accessibility GetAccessibility(IMethodSymbol methodSymbol) - { - Accessibility accessibility = methodSymbol.DeclaredAccessibility; - var t = methodSymbol.ContainingType; - while (t is not null) - { - if ((int)t.DeclaredAccessibility < (int)accessibility) - { - accessibility = t.DeclaredAccessibility; - } + if (ctor != null) + { + classDeclaration = classDeclaration.AddMembers(ctor); + } - t = t.ContainingType; - } + if (method.ResponseTimeoutTicks.HasValue) + { + classDeclaration = classDeclaration.AddMembers(GenerateResponseTimeoutPropertyMembers(method.ResponseTimeoutTicks.Value)); + } - return accessibility; - } + classDeclaration = AddOptionalMembers(classDeclaration, + GenerateGetArgumentCount(method), + GenerateGetMethodName(method), + GenerateGetInterfaceName(method), + GenerateGetActivityName(method), + GenerateGetInterfaceType(method), + GenerateGetMethod(), + GenerateSetTargetMethod(method, targetField, fieldDescriptions), + GenerateGetTargetMethod(targetField), + GenerateDisposeMethod(fieldDescriptions, baseClassType), + GenerateGetArgumentMethod(method, fieldDescriptions), + GenerateSetArgumentMethod(method, fieldDescriptions), + GenerateInvokeInnerMethod(method, fieldDescriptions, targetField), + GenerateGetCancellationTokenMethod(method, fieldDescriptions), + GenerateTryCancelMethod(method, fieldDescriptions), + GenerateIsCancellableProperty(method)); + + if (method.AllTypeParameters.Count > 0) + { + classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, method.AllTypeParameters); } - private ClassDeclarationSyntax GetClassDeclarationSyntax( - InvokableMethodDescription method, - string generatedClassName, - INamedTypeSymbol baseClassType, - List fieldDescriptions, - MemberDeclarationSyntax[] fields, - ConstructorDeclarationSyntax ctor, - List compoundTypeAliases, - TargetFieldDescription targetField, - SyntaxKind accessibilityKind) - { - var classDeclaration = ClassDeclaration(generatedClassName) - .AddBaseListTypes(SimpleBaseType(baseClassType.ToTypeSyntax(method.TypeParameterSubstitutions))) - .AddModifiers(Token(accessibilityKind), Token(SyntaxKind.SealedKeyword)) - .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) - .AddMembers(fields); - - foreach (var alias in compoundTypeAliases) - { - classDeclaration = classDeclaration.AddAttributeLists( - AttributeList(SingletonSeparatedList(GetCompoundTypeAliasAttribute(alias)))); - } + return classDeclaration; + } - if (ctor != null) - { - classDeclaration = classDeclaration.AddMembers(ctor); - } + private MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(long value) + { + var timespanField = FieldDeclaration( + VariableDeclaration( + LibraryTypes.TimeSpan.ToTypeSyntax(), + SingletonSeparatedList(VariableDeclarator("_responseTimeoutValue") + .WithInitializer(EqualsValueClause( + InvocationExpression( + IdentifierName("global::System.TimeSpan").Member("FromTicks"), + ArgumentList(SeparatedList( + [ + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value))) + ])))))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + + var responseTimeoutProperty = MethodDeclaration(NullableType(LibraryTypes.TimeSpan.ToTypeSyntax()), "GetDefaultResponseTimeout") + .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue"))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); + return [timespanField, responseTimeoutProperty]; + } - if (method.ResponseTimeoutTicks.HasValue) - { - classDeclaration = classDeclaration.AddMembers(GenerateResponseTimeoutPropertyMembers(method.ResponseTimeoutTicks.Value)); - } + private static ClassDeclarationSyntax AddOptionalMembers(ClassDeclarationSyntax decl, params MemberDeclarationSyntax?[] items) + => decl.WithMembers(decl.Members.AddRange(items.OfType())); - classDeclaration = AddOptionalMembers(classDeclaration, - GenerateGetArgumentCount(method), - GenerateGetMethodName(method), - GenerateGetInterfaceName(method), - GenerateGetActivityName(method), - GenerateGetInterfaceType(method), - GenerateGetMethod(), - GenerateSetTargetMethod(method, targetField, fieldDescriptions), - GenerateGetTargetMethod(targetField), - GenerateDisposeMethod(fieldDescriptions, baseClassType), - GenerateGetArgumentMethod(method, fieldDescriptions), - GenerateSetArgumentMethod(method, fieldDescriptions), - GenerateInvokeInnerMethod(method, fieldDescriptions, targetField), - GenerateGetCancellationTokenMethod(method, fieldDescriptions), - GenerateTryCancelMethod(method, fieldDescriptions), - GenerateIsCancellableProperty(method)); - - if (method.AllTypeParameters.Count > 0) + internal AttributeSyntax GetCompoundTypeAliasAttribute(CompoundTypeAliasComponent[] argValues) + { + var args = new AttributeArgumentSyntax[argValues.Length]; + for (var i = 0; i < argValues.Length; i++) + { + ExpressionSyntax value; + value = argValues[i].Value switch { - classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, method.AllTypeParameters); - } + string stringValue => LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(stringValue)), + ITypeSymbol typeValue => TypeOfExpression(typeValue.ToOpenTypeSyntax()), + _ => throw new InvalidOperationException($"Unsupported value") + }; - return classDeclaration; + args[i] = AttributeArgument(value); } - private MemberDeclarationSyntax[] GenerateResponseTimeoutPropertyMembers(long value) - { - var timespanField = FieldDeclaration( - VariableDeclaration( - LibraryTypes.TimeSpan.ToTypeSyntax(), - SingletonSeparatedList(VariableDeclarator("_responseTimeoutValue") - .WithInitializer(EqualsValueClause( - InvocationExpression( - IdentifierName("global::System.TimeSpan").Member("FromTicks"), - ArgumentList(SeparatedList(new[] - { - Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(value))) - })))))))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + return Attribute(LibraryTypes.CompoundTypeAliasAttribute.ToNameSyntax()).AddArgumentListArguments(args); + } - var responseTimeoutProperty = MethodDeclaration(NullableType(LibraryTypes.TimeSpan.ToTypeSyntax()), "GetDefaultResponseTimeout") - .WithExpressionBody(ArrowExpressionClause(IdentifierName("_responseTimeoutValue"))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); - return new MemberDeclarationSyntax[] { timespanField, responseTimeoutProperty }; + internal static List GetCompoundTypeAliasAttributeArguments(InvokableMethodDescription methodDescription, InvokableMethodId invokableId) + { + var result = new List(2); + var containingInterface = methodDescription.ContainingInterface; + if (methodDescription.HasAlias) + { + result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.MethodId)); } - private ClassDeclarationSyntax AddOptionalMembers(ClassDeclarationSyntax decl, params MemberDeclarationSyntax[] items) - => decl.WithMembers(decl.Members.AddRange(items.Where(i => i != null))); + result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.GeneratedMethodId)); + return result; + } - internal AttributeSyntax GetCompoundTypeAliasAttribute(CompoundTypeAliasComponent[] argValues) + public static CompoundTypeAliasComponent[] GetCompoundTypeAliasComponents( + InvokableMethodId invokableId, + INamedTypeSymbol containingInterface, + string methodId) + { + var proxyBase = invokableId.ProxyBase; + var proxyBaseComponents = proxyBase.CompositeAliasComponents; + var extensionArgCount = proxyBase.IsExtension ? 1 : 0; + var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + extensionArgCount + 2]; + alias[0] = new("inv"); + for (var i = 0; i < proxyBaseComponents.Length; i++) { - var args = new AttributeArgumentSyntax[argValues.Length]; - for (var i = 0; i < argValues.Length; i++) - { - ExpressionSyntax value; - value = argValues[i].Value switch - { - string stringValue => LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(stringValue)), - ITypeSymbol typeValue => TypeOfExpression(typeValue.ToOpenTypeSyntax()), - _ => throw new InvalidOperationException($"Unsupported value") - }; + alias[i + 1] = proxyBaseComponents[i]; + } - args[i] = AttributeArgument(value); - } + alias[1 + proxyBaseComponents.Length] = new(containingInterface); - return Attribute(LibraryTypes.CompoundTypeAliasAttribute.ToNameSyntax()).AddArgumentListArguments(args); + // For grain extensions, also explicitly include the method's containing type. + // This is to distinguish between different extension methods with the same id (eg, alias) but different containing types. + if (proxyBase.IsExtension) + { + alias[1 + proxyBaseComponents.Length + 1] = new(invokableId.Method.ContainingType); } - internal static List GetCompoundTypeAliasAttributeArguments(InvokableMethodDescription methodDescription, InvokableMethodId invokableId) + alias[1 + proxyBaseComponents.Length + extensionArgCount + 1] = new(methodId); + return alias; + } + + private static INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method) + { + var methodReturnType = method.Method.ReturnType; + if (methodReturnType is not INamedTypeSymbol namedMethodReturnType) { - var result = new List(2); - var containingInterface = methodDescription.ContainingInterface; - if (methodDescription.HasAlias) - { - result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.MethodId)); - } + throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method)); + } - result.Add(GetCompoundTypeAliasComponents(invokableId, containingInterface, methodDescription.GeneratedMethodId)); - return result; + if (method.InvokableBaseTypes.TryGetValue(namedMethodReturnType, out var baseClassType)) + { + return baseClassType; } - public static CompoundTypeAliasComponent[] GetCompoundTypeAliasComponents( - InvokableMethodId invokableId, - INamedTypeSymbol containingInterface, - string methodId) + if (namedMethodReturnType.ConstructedFrom is { IsGenericType: true, IsUnboundGenericType: false } constructedFrom) { - var proxyBase = invokableId.ProxyBase; - var proxyBaseComponents = proxyBase.CompositeAliasComponents; - var extensionArgCount = proxyBase.IsExtension ? 1 : 0; - var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + extensionArgCount + 2]; - alias[0] = new("inv"); - for (var i = 0; i < proxyBaseComponents.Length; i++) + var unbound = constructedFrom.ConstructUnboundGenericType(); + if (method.InvokableBaseTypes.TryGetValue(unbound, out baseClassType)) { - alias[i + 1] = proxyBaseComponents[i]; + return baseClassType.ConstructedFrom.Construct([.. namedMethodReturnType.TypeArguments]); } + } - alias[1 + proxyBaseComponents.Length] = new(containingInterface); - - // For grain extensions, also explicitly include the method's containing type. - // This is to distinguish between different extension methods with the same id (eg, alias) but different containing types. - if (proxyBase.IsExtension) - { - alias[1 + proxyBaseComponents.Length + 1] = new(invokableId.Method.ContainingType); - } + throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method)); + } - alias[1 + proxyBaseComponents.Length + extensionArgCount + 1] = new(methodId); - return alias; - } + private MemberDeclarationSyntax GenerateSetTargetMethod( + InvokableMethodDescription methodDescription, + TargetFieldDescription targetField, + List fieldDescriptions) + { + var holder = IdentifierName("holder"); + var holderParameter = holder.Identifier; - private INamedTypeSymbol GetBaseClassType(InvokableMethodDescription method) + var containingInterface = methodDescription.ContainingInterface; + var targetType = containingInterface.ToTypeSyntax(); + var isExtension = methodDescription.Key.ProxyBase.IsExtension; + var (name, args) = isExtension switch { - var methodReturnType = method.Method.ReturnType; - if (methodReturnType is not INamedTypeSymbol namedMethodReturnType) - { - throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method)); - } + true => ("GetComponent", SingletonSeparatedList(Argument(TypeOfExpression(targetType)))), + _ => ("GetTarget", SeparatedList()) + }; + var getTarget = CastExpression( + targetType, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + holder, + IdentifierName(name)), + ArgumentList(args))); - if (method.InvokableBaseTypes.TryGetValue(namedMethodReturnType, out var baseClassType)) - { - return baseClassType; - } + var member = + MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget") + .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax())))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); - if (namedMethodReturnType.ConstructedFrom is { IsGenericType: true, IsUnboundGenericType: false } constructedFrom) - { - var unbound = constructedFrom.ConstructUnboundGenericType(); - if (method.InvokableBaseTypes.TryGetValue(unbound, out baseClassType)) + var assignmentExpression = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(targetField.FieldName), getTarget); + if (!methodDescription.IsCancellable) + { + return member.WithExpressionBody(ArrowExpressionClause(assignmentExpression)).WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + } + else + { + var ctsField = fieldDescriptions.First(f => f is CancellationTokenSourceFieldDescription); + var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax(); + var ctField = fieldDescriptions.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType)); + return member.WithBody(Block( + new List() { - return baseClassType.ConstructedFrom.Construct(namedMethodReturnType.TypeArguments.ToArray()); - } - } - - throw new OrleansGeneratorDiagnosticAnalysisException(InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(method)); + ExpressionStatement(assignmentExpression), + ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctsField.FieldName), ImplicitObjectCreationExpression())), + ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctField.FieldName), IdentifierName(ctsField.FieldName).Member("Token"))) + })); } + } - private MemberDeclarationSyntax GenerateSetTargetMethod( - InvokableMethodDescription methodDescription, - TargetFieldDescription targetField, - List fieldDescriptions) - { - var holder = IdentifierName("holder"); - var holderParameter = holder.Identifier; + private static MethodDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField) + { + return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") + .WithParameterList(ParameterList()) + .WithExpressionBody(ArrowExpressionClause(IdentifierName(targetField.FieldName))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); + } - var containingInterface = methodDescription.ContainingInterface; - var targetType = containingInterface.ToTypeSyntax(); - var isExtension = methodDescription.Key.ProxyBase.IsExtension; - var (name, args) = isExtension switch - { - true => ("GetComponent", SingletonSeparatedList(Argument(TypeOfExpression(targetType)))), - _ => ("GetTarget", SeparatedList()) - }; - var getTarget = CastExpression( - targetType, - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - holder, - IdentifierName(name)), - ArgumentList(args))); - - var member = - MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetTarget") - .WithParameterList(ParameterList(SingletonSeparatedList(Parameter(holderParameter).WithType(LibraryTypes.ITargetHolder.ToTypeSyntax())))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); - - var assignmentExpression = AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(targetField.FieldName), getTarget); - if (!methodDescription.IsCancellable) - { - return member.WithExpressionBody(ArrowExpressionClause(assignmentExpression)).WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - } - else - { - var ctsField = fieldDescriptions.First(f => f is CancellationTokenSourceFieldDescription); - var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax(); - var ctField = fieldDescriptions.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType)); - return member.WithBody(Block( - new List() - { - ExpressionStatement(assignmentExpression), - ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctsField.FieldName), ImplicitObjectCreationExpression())), - ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, IdentifierName(ctField.FieldName), IdentifierName(ctsField.FieldName).Member("Token"))) - })); - } + private MemberDeclarationSyntax? GenerateGetCancellationTokenMethod(InvokableMethodDescription method, List fields) + { + if (!method.IsCancellable) + { + return null; } - private static MethodDeclarationSyntax GenerateGetTargetMethod(TargetFieldDescription targetField) + // Method to get the cancellationToken argument + // C#: CancellationToken GetCancellationToken() => + var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax(); + var cancellationTokenField = fields.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType)); + var member = MethodDeclaration(cancellationTokenType, "GetCancellationToken") + .WithExpressionBody(ArrowExpressionClause(cancellationTokenField.FieldName.ToIdentifierName())) + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + return member; + } + + private static MemberDeclarationSyntax? GenerateIsCancellableProperty(InvokableMethodDescription method) + { + if (!method.IsCancellable) { - return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetTarget") - .WithParameterList(ParameterList()) - .WithExpressionBody(ArrowExpressionClause(IdentifierName(targetField.FieldName))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); + return null; } - private MemberDeclarationSyntax GenerateGetCancellationTokenMethod(InvokableMethodDescription method, List fields) - { - if (!method.IsCancellable) - { - return null; - } + // Property to indicate if the invokable is cancellable + // C#: public override bool IsCancellable => true; + var member = PropertyDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "IsCancellable") + .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.TrueLiteralExpression))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); + return member; + } - // Method to get the cancellationToken argument - // C#: CancellationToken GetCancellationToken() => - var cancellationTokenType = LibraryTypes.CancellationToken.ToTypeSyntax(); - var cancellationTokenField = fields.First(f => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, f.FieldType)); - var member = MethodDeclaration(cancellationTokenType, "GetCancellationToken") - .WithExpressionBody(ArrowExpressionClause(cancellationTokenField.FieldName.ToIdentifierName())) - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - return member; + private static MemberDeclarationSyntax? GenerateTryCancelMethod(InvokableMethodDescription method, List fields) + { + if (!method.IsCancellable) + { + return null; } - private MemberDeclarationSyntax GenerateIsCancellableProperty(InvokableMethodDescription method) + // Method to set the CancellationToken argument. + // C#: + // TryCancel() + // { + // if (_cts is { } cts) + // { + // cts.Cancel(false); + // return true; + // } + // return false; + // } + var cancellationTokenField = fields.First(f => f is CancellationTokenSourceFieldDescription); + var member = MethodDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "TryCancel") + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)) + .WithBody(Block( + IfStatement( + IsPatternExpression( + IdentifierName(cancellationTokenField.FieldName), + RecursivePattern() + .WithPropertyPatternClause(PropertyPatternClause()) + .WithDesignation(SingleVariableDesignation(Identifier("cts")))), + Block( + ExpressionStatement(InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("cts"), IdentifierName("Cancel"))) + .WithArgumentList(ArgumentList(SeparatedList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))])))), + ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)))), + ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression)))); + return member; + } + + private static MemberDeclarationSyntax? GenerateGetArgumentMethod( + InvokableMethodDescription methodDescription, + List fields) + { + if (methodDescription.Method.Parameters.Length == 0) + return null; + + var index = IdentifierName("index"); + + var cases = new List(); + foreach (var field in fields) { - if (!method.IsCancellable) + if (field is not MethodParameterFieldDescription parameter) { - return null; + continue; } - // Property to indicate if the invokable is cancellable - // C#: public override bool IsCancellable => true; - var member = PropertyDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "IsCancellable") - .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.TrueLiteralExpression))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)); - return member; + // C#: case {index}: return {field} + var label = CaseSwitchLabel( + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal))); + cases.Add( + SwitchSection( + SingletonList(label), + new SyntaxList( + ReturnStatement( + IdentifierName(parameter.FieldName))))); } - private MemberDeclarationSyntax GenerateTryCancelMethod(InvokableMethodDescription method, List fields) - { - if (!method.IsCancellable) - { - return null; - } + // C#: default: return OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs}) + var throwHelperMethod = MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("OrleansGeneratedCodeHelper"), + IdentifierName("InvokableThrowArgumentOutOfRange")); + cases.Add( + SwitchSection( + SingletonList(DefaultSwitchLabel()), + new SyntaxList( + ReturnStatement( + InvocationExpression( + throwHelperMethod, + ArgumentList( + SeparatedList( + [ + Argument(index), + Argument( + LiteralExpression( + SyntaxKind.NumericLiteralExpression, + Literal( + Math.Max(0, methodDescription.Method.Parameters.Length - 1)))) + ]))))))); + var body = SwitchStatement(index, new SyntaxList(cases)); + return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetArgument") + .WithParameterList( + ParameterList( + SingletonSeparatedList( + Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))))) + .WithBody(Block(body)) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); + } - // Method to set the CancellationToken argument. - // C#: - // TryCancel() - // { - // if (_cts is { } cts) - // { - // cts.Cancel(false); - // return true; - // } - // return false; - // } - var cancellationTokenField = fields.First(f => f is CancellationTokenSourceFieldDescription); - var member = MethodDeclaration(PredefinedType(Token(SyntaxKind.BoolKeyword)), "TryCancel") - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword)) - .WithBody(Block( - IfStatement( - IsPatternExpression( - IdentifierName(cancellationTokenField.FieldName), - RecursivePattern() - .WithPropertyPatternClause(PropertyPatternClause()) - .WithDesignation(SingleVariableDesignation(Identifier("cts")))), - Block( - ExpressionStatement(InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("cts"), IdentifierName("Cancel"))) - .WithArgumentList(ArgumentList(SeparatedList([Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression))])))), - ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)))), - ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression)))); - return member; - } - - private MemberDeclarationSyntax GenerateGetArgumentMethod( - InvokableMethodDescription methodDescription, - List fields) - { - if (methodDescription.Method.Parameters.Length == 0) - return null; - - var index = IdentifierName("index"); - - var cases = new List(); - foreach (var field in fields) - { - if (field is not MethodParameterFieldDescription parameter) - { - continue; - } + private static MemberDeclarationSyntax? GenerateSetArgumentMethod( + InvokableMethodDescription methodDescription, + List fields) + { + if (methodDescription.Method.Parameters.Length == 0) + return null; + + var index = IdentifierName("index"); + var value = IdentifierName("value"); - // C#: case {index}: return {field} - var label = CaseSwitchLabel( - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal))); - cases.Add( - SwitchSection( - SingletonList(label), - new SyntaxList( - ReturnStatement( - IdentifierName(parameter.FieldName))))); + var cases = new List(); + foreach (var field in fields) + { + if (field is not MethodParameterFieldDescription parameter) + { + continue; } - // C#: default: return OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs}) - var throwHelperMethod = MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("OrleansGeneratedCodeHelper"), - IdentifierName("InvokableThrowArgumentOutOfRange")); + // C#: case {index}: {field} = (TField)value; return; + var label = CaseSwitchLabel( + LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal))); cases.Add( SwitchSection( - SingletonList(DefaultSwitchLabel()), + SingletonList(label), new SyntaxList( - ReturnStatement( + [ + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(parameter.FieldName), + CastExpression(parameter.FieldType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions), value))), + ReturnStatement() + ]))); + } + + // C#: default: OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs}) + var maxArgs = Math.Max(0, methodDescription.Method.Parameters.Length - 1); + var throwHelperMethod = MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("OrleansGeneratedCodeHelper"), + IdentifierName("InvokableThrowArgumentOutOfRange")); + cases.Add( + SwitchSection( + SingletonList(DefaultSwitchLabel()), + new SyntaxList( + [ + ExpressionStatement( InvocationExpression( throwHelperMethod, ArgumentList( SeparatedList( - new[] - { + [ Argument(index), Argument( LiteralExpression( SyntaxKind.NumericLiteralExpression, - Literal( - Math.Max(0, methodDescription.Method.Parameters.Length - 1)))) - }))))))); - var body = SwitchStatement(index, new SyntaxList(cases)); - return MethodDeclaration(PredefinedType(Token(SyntaxKind.ObjectKeyword)), "GetArgument") - .WithParameterList( - ParameterList( - SingletonSeparatedList( - Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))))) - .WithBody(Block(body)) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); - } - - private MemberDeclarationSyntax GenerateSetArgumentMethod( - InvokableMethodDescription methodDescription, - List fields) - { - if (methodDescription.Method.Parameters.Length == 0) - return null; - - var index = IdentifierName("index"); - var value = IdentifierName("value"); - - var cases = new List(); - foreach (var field in fields) - { - if (field is not MethodParameterFieldDescription parameter) - { - continue; - } - - // C#: case {index}: {field} = (TField)value; return; - var label = CaseSwitchLabel( - LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(parameter.ParameterOrdinal))); - cases.Add( - SwitchSection( - SingletonList(label), - new SyntaxList( - new StatementSyntax[] - { - ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(parameter.FieldName), - CastExpression(parameter.FieldType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions), value))), - ReturnStatement() - }))); - } + Literal(maxArgs))) + ])))), + ReturnStatement() + ]))); + var body = SwitchStatement(index, new SyntaxList(cases)); + return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetArgument") + .WithParameterList( + ParameterList( + SeparatedList( + [ + Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword))), + Parameter(Identifier("value")).WithType(PredefinedType(Token(SyntaxKind.ObjectKeyword))) + ] + ))) + .WithBody(Block(body)) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); + } - // C#: default: OrleansGeneratedCodeHelper.InvokableThrowArgumentOutOfRange(index, {maxArgs}) - var maxArgs = Math.Max(0, methodDescription.Method.Parameters.Length - 1); - var throwHelperMethod = MemberAccessExpression( + private static MemberDeclarationSyntax GenerateInvokeInnerMethod( + InvokableMethodDescription method, + List fields, + TargetFieldDescription target) + { + var resultTask = IdentifierName("resultTask"); + + // C# var resultTask = this.target.{Method}({params}); + var args = SeparatedList( + fields.OfType() + .OrderBy(p => p.ParameterOrdinal) + .Select(p => Argument(IdentifierName(p.FieldName)))); + ExpressionSyntax methodCall; + if (method.MethodTypeParameters.Count > 0) + { + methodCall = MemberAccessExpression( SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("OrleansGeneratedCodeHelper"), - IdentifierName("InvokableThrowArgumentOutOfRange")); - cases.Add( - SwitchSection( - SingletonList(DefaultSwitchLabel()), - new SyntaxList( - new StatementSyntax[] - { - ExpressionStatement( - InvocationExpression( - throwHelperMethod, - ArgumentList( - SeparatedList( - new[] - { - Argument(index), - Argument( - LiteralExpression( - SyntaxKind.NumericLiteralExpression, - Literal(maxArgs))) - })))), - ReturnStatement() - }))); - var body = SwitchStatement(index, new SyntaxList(cases)); - return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "SetArgument") - .WithParameterList( - ParameterList( - SeparatedList( - new[] - { - Parameter(Identifier("index")).WithType(PredefinedType(Token(SyntaxKind.IntKeyword))), - Parameter(Identifier("value")).WithType(PredefinedType(Token(SyntaxKind.ObjectKeyword))) - } - ))) - .WithBody(Block(body)) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))); - } - - private MemberDeclarationSyntax GenerateInvokeInnerMethod( - InvokableMethodDescription method, - List fields, - TargetFieldDescription target) - { - var resultTask = IdentifierName("resultTask"); - - // C# var resultTask = this.target.{Method}({params}); - var args = SeparatedList( - fields.OfType() - .OrderBy(p => p.ParameterOrdinal) - .Select(p => Argument(IdentifierName(p.FieldName)))); - ExpressionSyntax methodCall; - if (method.MethodTypeParameters.Count > 0) - { - methodCall = MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(target.FieldName), - GenericName( - Identifier(method.Method.Name), - TypeArgumentList( - SeparatedList( - method.MethodTypeParameters.Select(p => IdentifierName(p.Name)))))); - } - else - { - methodCall = IdentifierName(target.FieldName).Member(method.Method.Name); - } - - return MethodDeclaration(method.Method.ReturnType.ToTypeSyntax(method.TypeParameterSubstitutions), "InvokeInner") - .WithParameterList(ParameterList()) - .WithExpressionBody(ArrowExpressionClause(InvocationExpression(methodCall, ArgumentList(args)))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) - .WithModifiers(TokenList(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword))); + IdentifierName(target.FieldName), + GenericName( + Identifier(method.Method.Name), + TypeArgumentList( + SeparatedList( + method.MethodTypeParameters.Select(p => IdentifierName(p.Name)))))); } - - private MemberDeclarationSyntax GenerateDisposeMethod( - List fields, - INamedTypeSymbol baseClassType) + else { - var body = new List(); - foreach (var field in fields) - { - if (field is CancellationTokenSourceFieldDescription ctsField) - { - // C# - // _cts?.Dispose(); - body.Add( - ExpressionStatement( - ConditionalAccessExpression( - ctsField.FieldName.ToIdentifierName(), - InvocationExpression( - MemberBindingExpression(IdentifierName("Dispose")))))); - } + methodCall = IdentifierName(target.FieldName).Member(method.Method.Name); + } - if (field.IsInstanceField) - { - body.Add( - ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(field.FieldName), - LiteralExpression(SyntaxKind.DefaultLiteralExpression)))); - } - } + return MethodDeclaration(method.Method.ReturnType.ToTypeSyntax(method.TypeParameterSubstitutions), "InvokeInner") + .WithParameterList(ParameterList()) + .WithExpressionBody(ArrowExpressionClause(InvocationExpression(methodCall, ArgumentList(args)))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) + .WithModifiers(TokenList(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword))); + } - // C# base.Dispose(); - if (baseClassType is { } - && baseClassType.AllInterfaces.Any(i => i.SpecialType == SpecialType.System_IDisposable) - && baseClassType.GetAllMembers("Dispose").FirstOrDefault(m => !m.IsAbstract && m.DeclaredAccessibility != Accessibility.Private) is { }) - { - body.Add(ExpressionStatement(InvocationExpression(BaseExpression().Member("Dispose")).WithArgumentList(ArgumentList()))); + private static MemberDeclarationSyntax GenerateDisposeMethod( + List fields, + INamedTypeSymbol baseClassType) + { + var body = new List(); + foreach (var field in fields) + { + if (field is CancellationTokenSourceFieldDescription ctsField) + { + // C# + // _cts?.Dispose(); + body.Add( + ExpressionStatement( + ConditionalAccessExpression( + ctsField.FieldName.ToIdentifierName(), + InvocationExpression( + MemberBindingExpression(IdentifierName("Dispose")))))); } - return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Dispose") - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithBody(Block(body)); - } - - private MemberDeclarationSyntax GenerateGetArgumentCount(InvokableMethodDescription methodDescription) - => methodDescription.Method.Parameters.Length is var count and not 0 ? - MethodDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword)), "GetArgumentCount") - .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(count)))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) : null; - - private MemberDeclarationSyntax GenerateGetActivityName(InvokableMethodDescription methodDescription) - { - // This property is intended to contain a value suitable for use as an OpenTelemetry Span Name for RPC calls. - // Therefore, the interface name and method name components must not include periods or slashes. - // In order to avoid that, we omit the namespace from the interface name. - // See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/rpc.md - - var interfaceName = methodDescription.Method.ContainingType.ToDisplayName(methodDescription.TypeParameterSubstitutions, includeGlobalSpecifier: false, includeNamespace: false); - var methodName = methodDescription.Method.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var activityName = $"{interfaceName}/{methodName}"; - return MethodDeclaration(PredefinedType(Token(SyntaxKind.StringKeyword)), "GetActivityName") - .WithExpressionBody( - ArrowExpressionClause( - LiteralExpression( - SyntaxKind.NumericLiteralExpression, - Literal(activityName)))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - } - - private MemberDeclarationSyntax GenerateGetMethodName( - InvokableMethodDescription methodDescription) => - MethodDeclaration(PredefinedType(Token(SyntaxKind.StringKeyword)), "GetMethodName") - .WithExpressionBody( - ArrowExpressionClause( - LiteralExpression( - SyntaxKind.NumericLiteralExpression, - Literal(methodDescription.Method.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - - private MemberDeclarationSyntax GenerateGetInterfaceName( - InvokableMethodDescription methodDescription) => - MethodDeclaration(PredefinedType(Token(SyntaxKind.StringKeyword)), "GetInterfaceName") - .WithExpressionBody( - ArrowExpressionClause( - LiteralExpression( - SyntaxKind.NumericLiteralExpression, - Literal(methodDescription.Method.ContainingType.ToDisplayName(methodDescription.TypeParameterSubstitutions, includeGlobalSpecifier: false))))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - - private MemberDeclarationSyntax GenerateGetInterfaceType( - InvokableMethodDescription methodDescription) => - MethodDeclaration(LibraryTypes.Type.ToTypeSyntax(), "GetInterfaceType") - .WithExpressionBody( - ArrowExpressionClause( - TypeOfExpression(methodDescription.Method.ContainingType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions)))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - - private MemberDeclarationSyntax GenerateGetMethod() - => MethodDeclaration(LibraryTypes.MethodInfo.ToTypeSyntax(), "GetMethod") - .WithExpressionBody(ArrowExpressionClause(IdentifierName("MethodBackingField"))) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - - public static string GetSimpleClassName(InvokableMethodDescription method) - { - var genericArity = method.AllTypeParameters.Count; - var typeArgs = genericArity > 0 ? "_" + genericArity : string.Empty; - var proxyKey = method.ProxyBase.Key.GeneratedClassNameComponent; - return $"Invokable_{method.ContainingInterface.Name}_{proxyKey}_{method.GeneratedMethodId}{typeArgs}"; - } - - private MemberDeclarationSyntax[] GetFieldDeclarations( - InvokableMethodDescription method, - List fieldDescriptions) - { - return fieldDescriptions.Select(GetFieldDeclaration).ToArray(); - - MemberDeclarationSyntax GetFieldDeclaration(InvokerFieldDescription description) + if (field.IsInstanceField) { - FieldDeclarationSyntax field; - if (description is MethodInfoFieldDescription methodInfo) - { - var methodTypeArguments = GetTypesArray(method, method.MethodTypeParameters.Select(p => p.Parameter)); - var parameterTypes = GetTypesArray(method, method.Method.Parameters.Select(p => p.Type)); - - field = FieldDeclaration( - VariableDeclaration( - LibraryTypes.MethodInfo.ToTypeSyntax(), - SingletonSeparatedList(VariableDeclarator(description.FieldName) - .WithInitializer(EqualsValueClause( - InvocationExpression( - IdentifierName("OrleansGeneratedCodeHelper").Member("GetMethodInfoOrDefault"), - ArgumentList(SeparatedList(new[] - { - Argument(TypeOfExpression(method.Method.ContainingType.ToTypeSyntax(method.TypeParameterSubstitutions))), - Argument(method.Method.Name.GetLiteralExpression()), - Argument(methodTypeArguments), - Argument(parameterTypes), - })))))))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - } - else - { - field = FieldDeclaration( - VariableDeclaration( - description.FieldType.ToTypeSyntax(method.TypeParameterSubstitutions), - SingletonSeparatedList(VariableDeclarator(description.FieldName)))); - } - - switch (description) - { - case MethodParameterFieldDescription _: - field = field.AddModifiers(Token(SyntaxKind.PublicKeyword)); - break; - } - - return field; + body.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(field.FieldName), + LiteralExpression(SyntaxKind.DefaultLiteralExpression)))); } } - private ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnumerable typeSymbols) + // C# base.Dispose(); + if (baseClassType is { } + && baseClassType.AllInterfaces.Any(i => i.SpecialType == SpecialType.System_IDisposable) + && baseClassType.GetAllMembers("Dispose").FirstOrDefault(m => !m.IsAbstract && m.DeclaredAccessibility != Accessibility.Private) is { }) { - var types = typeSymbols.ToArray(); - return types.Length == 0 ? LiteralExpression(SyntaxKind.NullLiteralExpression) - : ImplicitArrayCreationExpression(InitializerExpression(SyntaxKind.ArrayInitializerExpression, SeparatedList( - types.Select(t => TypeOfExpression(t.ToTypeSyntax(method.TypeParameterSubstitutions)))))); + body.Add(ExpressionStatement(InvocationExpression(BaseExpression().Member("Dispose")).WithArgumentList(ArgumentList()))); } - private (ConstructorDeclarationSyntax Constructor, List ConstructorArguments) GenerateConstructor( - string simpleClassName, - InvokableMethodDescription method, - INamedTypeSymbol baseClassType) - { - var parameters = new List(); + return MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "Dispose") + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithBody(Block(body)); + } + + private static MemberDeclarationSyntax? GenerateGetArgumentCount(InvokableMethodDescription methodDescription) + => methodDescription.Method.Parameters.Length is var count and not 0 ? + MethodDeclaration(PredefinedType(Token(SyntaxKind.IntKeyword)), "GetArgumentCount") + .WithExpressionBody(ArrowExpressionClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(count)))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)) : null; + + private static MemberDeclarationSyntax GenerateGetActivityName(InvokableMethodDescription methodDescription) + { + // This property is intended to contain a value suitable for use as an OpenTelemetry Span Name for RPC calls. + // Therefore, the interface name and method name components must not include periods or slashes. + // In order to avoid that, we omit the namespace from the interface name. + // See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/trace/semantic_conventions/rpc.md + + var interfaceName = methodDescription.Method.ContainingType.ToDisplayName(methodDescription.TypeParameterSubstitutions, includeGlobalSpecifier: false, includeNamespace: false); + var methodName = methodDescription.Method.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var activityName = $"{interfaceName}/{methodName}"; + return MethodDeclaration(PredefinedType(Token(SyntaxKind.StringKeyword)), "GetActivityName") + .WithExpressionBody( + ArrowExpressionClause( + LiteralExpression( + SyntaxKind.NumericLiteralExpression, + Literal(activityName)))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + } + + private static MemberDeclarationSyntax GenerateGetMethodName( + InvokableMethodDescription methodDescription) => + MethodDeclaration(PredefinedType(Token(SyntaxKind.StringKeyword)), "GetMethodName") + .WithExpressionBody( + ArrowExpressionClause( + LiteralExpression( + SyntaxKind.NumericLiteralExpression, + Literal(methodDescription.Method.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + private static MemberDeclarationSyntax GenerateGetInterfaceName( + InvokableMethodDescription methodDescription) => + MethodDeclaration(PredefinedType(Token(SyntaxKind.StringKeyword)), "GetInterfaceName") + .WithExpressionBody( + ArrowExpressionClause( + LiteralExpression( + SyntaxKind.NumericLiteralExpression, + Literal(methodDescription.Method.ContainingType.ToDisplayName(methodDescription.TypeParameterSubstitutions, includeGlobalSpecifier: false))))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + private MemberDeclarationSyntax GenerateGetInterfaceType( + InvokableMethodDescription methodDescription) => + MethodDeclaration(LibraryTypes.Type.ToTypeSyntax(), "GetInterfaceType") + .WithExpressionBody( + ArrowExpressionClause( + TypeOfExpression(methodDescription.Method.ContainingType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions)))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + private MemberDeclarationSyntax GenerateGetMethod() + => MethodDeclaration(LibraryTypes.MethodInfo.ToTypeSyntax(), "GetMethod") + .WithExpressionBody(ArrowExpressionClause(IdentifierName("MethodBackingField"))) + .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.OverrideKeyword))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + public static string GetSimpleClassName(InvokableMethodDescription method) + { + var genericArity = method.AllTypeParameters.Count; + var typeArgs = genericArity > 0 ? "_" + genericArity : string.Empty; + var proxyKey = method.ProxyBase.Key.GeneratedClassNameComponent; + return $"Invokable_{method.ContainingInterface.Name}_{proxyKey}_{method.GeneratedMethodId}{typeArgs}"; + } - var body = new List(); + private MemberDeclarationSyntax[] GetFieldDeclarations( + InvokableMethodDescription method, + List fieldDescriptions) + { + return [.. fieldDescriptions.Select(GetFieldDeclaration)]; - List constructorArgumentTypes = new(); - List baseConstructorArguments = new(); - foreach (var constructor in baseClassType.GetAllMembers()) + MemberDeclarationSyntax GetFieldDeclaration(InvokerFieldDescription description) + { + FieldDeclarationSyntax field; + if (description is MethodInfoFieldDescription methodInfo) { - if (constructor.MethodKind != MethodKind.Constructor || constructor.DeclaredAccessibility == Accessibility.Private || constructor.IsImplicitlyDeclared) - { - continue; - } + var methodTypeArguments = GetTypesArray(method, method.MethodTypeParameters.Select(p => p.Parameter)); + var parameterTypes = GetTypesArray(method, method.Method.Parameters.Select(p => p.Type)); - if (constructor.HasAttribute(LibraryTypes.GeneratedActivatorConstructorAttribute)) - { - var index = 0; - foreach (var parameter in constructor.Parameters) - { - var identifier = $"base{index}"; - - var argumentType = parameter.Type.ToTypeSyntax(method.TypeParameterSubstitutions); - constructorArgumentTypes.Add(argumentType); - parameters.Add(Parameter(identifier.ToIdentifier()).WithType(argumentType)); - baseConstructorArguments.Add(Argument(identifier.ToIdentifierName())); - index++; - } - break; - } + field = FieldDeclaration( + VariableDeclaration( + LibraryTypes.MethodInfo.ToTypeSyntax(), + SingletonSeparatedList(VariableDeclarator(description.FieldName) + .WithInitializer(EqualsValueClause( + InvocationExpression( + IdentifierName("OrleansGeneratedCodeHelper").Member("GetMethodInfoOrDefault"), + ArgumentList(SeparatedList( + [ + Argument(TypeOfExpression(method.Method.ContainingType.ToTypeSyntax(method.TypeParameterSubstitutions))), + Argument(method.Method.Name.GetLiteralExpression()), + Argument(methodTypeArguments), + Argument(parameterTypes), + ])))))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + } + else + { + field = FieldDeclaration( + VariableDeclaration( + description.FieldType.ToTypeSyntax(method.TypeParameterSubstitutions), + SingletonSeparatedList(VariableDeclarator(description.FieldName)))); } - foreach (var (methodName, methodArgument) in method.CustomInitializerMethods) + switch (description) { - var argumentExpression = methodArgument.ToExpression(); - body.Add(ExpressionStatement(InvocationExpression(IdentifierName(methodName), ArgumentList(SeparatedList(new[] { Argument(argumentExpression) }))))); + case MethodParameterFieldDescription _: + field = field.AddModifiers(Token(SyntaxKind.PublicKeyword)); + break; } - if (body.Count == 0 && parameters.Count == 0) - return default; + return field; + } + } + + private static ExpressionSyntax GetTypesArray(InvokableMethodDescription method, IEnumerable typeSymbols) + { + var types = typeSymbols.ToArray(); + return types.Length == 0 ? LiteralExpression(SyntaxKind.NullLiteralExpression) + : ImplicitArrayCreationExpression(InitializerExpression(SyntaxKind.ArrayInitializerExpression, SeparatedList( + types.Select(t => TypeOfExpression(t.ToTypeSyntax(method.TypeParameterSubstitutions)))))); + } - var constructorDeclaration = ConstructorDeclaration(simpleClassName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters.ToArray()) - .WithInitializer( - ConstructorInitializer( - SyntaxKind.BaseConstructorInitializer, - ArgumentList(SeparatedList(baseConstructorArguments)))) - .AddBodyStatements(body.ToArray()); + private (ConstructorDeclarationSyntax? Constructor, List ConstructorArguments) GenerateConstructor( + string simpleClassName, + InvokableMethodDescription method, + INamedTypeSymbol baseClassType) + { + var parameters = new List(); - return (constructorDeclaration, constructorArgumentTypes); - } + var body = new List(); - private List GetFieldDescriptions(InvokableMethodDescription method) + List constructorArgumentTypes = new(); + List baseConstructorArguments = new(); + foreach (var constructor in baseClassType.GetAllMembers()) { - var fields = new List(); - uint fieldId = 0; - - foreach (var parameter in method.Method.Parameters) + if (constructor.MethodKind != MethodKind.Constructor || constructor.DeclaredAccessibility == Accessibility.Private || constructor.IsImplicitlyDeclared) { - var isSerializable = !SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type); - fields.Add(new MethodParameterFieldDescription(method.CodeGenerator, parameter, $"arg{fieldId}", fieldId, method.TypeParameterSubstitutions, isSerializable)); - fieldId++; + continue; } - fields.Add(new TargetFieldDescription(method.Method.ContainingType)); - fields.Add(new MethodInfoFieldDescription(LibraryTypes.MethodInfo, "MethodBackingField")); - - if (method.IsCancellable) + if (constructor.HasAttribute(LibraryTypes.GeneratedActivatorConstructorAttribute)) { - fields.Add(new CancellationTokenSourceFieldDescription(LibraryTypes)); - } + var index = 0; + foreach (var parameter in constructor.Parameters) + { + var identifier = $"base{index}"; - return fields; + var argumentType = parameter.Type.ToTypeSyntax(method.TypeParameterSubstitutions); + constructorArgumentTypes.Add(argumentType); + parameters.Add(Parameter(identifier.ToIdentifier()).WithType(argumentType)); + baseConstructorArguments.Add(Argument(identifier.ToIdentifierName())); + index++; + } + break; + } } - internal abstract class InvokerFieldDescription + foreach (var (methodName, methodArgument) in method.CustomInitializerMethods) { - protected InvokerFieldDescription(ITypeSymbol fieldType, string fieldName) - { - FieldType = fieldType; - FieldName = fieldName; - } - - public ITypeSymbol FieldType { get; } - public string FieldName { get; } - public abstract bool IsSerializable { get; } - public abstract bool IsInstanceField { get; } + var argumentExpression = methodArgument.ToExpression(); + body.Add(ExpressionStatement(InvocationExpression(IdentifierName(methodName), ArgumentList(SeparatedList([Argument(argumentExpression)]))))); } - internal sealed class TargetFieldDescription : InvokerFieldDescription - { - public TargetFieldDescription(ITypeSymbol fieldType) : base(fieldType, "_target") { } + if (body.Count == 0 && parameters.Count == 0) + return default; - public override bool IsSerializable => false; - public override bool IsInstanceField => true; - } + var constructorDeclaration = ConstructorDeclaration(simpleClassName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters([.. parameters]) + .WithInitializer( + ConstructorInitializer( + SyntaxKind.BaseConstructorInitializer, + ArgumentList(SeparatedList(baseConstructorArguments)))) + .AddBodyStatements([.. body]); - internal sealed class CancellationTokenSourceFieldDescription(LibraryTypes libraryTypes) : InvokerFieldDescription(libraryTypes.CancellationTokenSource, "_cts") + return (constructorDeclaration, constructorArgumentTypes); + } + + private List GetFieldDescriptions(InvokableMethodDescription method) + { + var fields = new List(); + uint fieldId = 0; + + foreach (var parameter in method.Method.Parameters) { - public override bool IsSerializable => false; - public override bool IsInstanceField => true; + var isSerializable = !SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type); + fields.Add(new MethodParameterFieldDescription(LibraryTypes, parameter, $"arg{fieldId}", fieldId, method.TypeParameterSubstitutions, isSerializable)); + fieldId++; } - internal sealed class CancellationTokenFieldDescription(LibraryTypes libraryTypes) : InvokerFieldDescription(libraryTypes.CancellationToken, "_ct") + fields.Add(new TargetFieldDescription(method.Method.ContainingType)); + fields.Add(new MethodInfoFieldDescription(LibraryTypes.MethodInfo, "MethodBackingField")); + + if (method.IsCancellable) { - public override bool IsSerializable => false; - public override bool IsInstanceField => true; + fields.Add(new CancellationTokenSourceFieldDescription(LibraryTypes)); } - internal class MethodParameterFieldDescription : InvokerFieldDescription, IMemberDescription + return fields; + } + + internal abstract class InvokerFieldDescription(ITypeSymbol fieldType, string fieldName) + { + public ITypeSymbol FieldType { get; } = fieldType; + public string FieldName { get; } = fieldName; + public abstract bool IsSerializable { get; } + public abstract bool IsInstanceField { get; } + } + + internal sealed class TargetFieldDescription(ITypeSymbol fieldType) : InvokerFieldDescription(fieldType, "_target") + { + public override bool IsSerializable => false; + public override bool IsInstanceField => true; + } + + internal sealed class CancellationTokenSourceFieldDescription(LibraryTypes libraryTypes) : InvokerFieldDescription(libraryTypes.CancellationTokenSource, "_cts") + { + public override bool IsSerializable => false; + public override bool IsInstanceField => true; + } + + internal sealed class CancellationTokenFieldDescription(LibraryTypes libraryTypes) : InvokerFieldDescription(libraryTypes.CancellationToken, "_ct") + { + public override bool IsSerializable => false; + public override bool IsInstanceField => true; + } + + internal class MethodParameterFieldDescription : InvokerFieldDescription, IMemberDescription + { + public MethodParameterFieldDescription( + LibraryTypes libraryTypes, + IParameterSymbol parameter, + string fieldName, + uint fieldId, + Dictionary typeParameterSubstitutions, + bool isSerializable) + : base(parameter.Type, fieldName) { - public MethodParameterFieldDescription( - CodeGenerator codeGenerator, - IParameterSymbol parameter, - string fieldName, - uint fieldId, - Dictionary typeParameterSubstitutions, - bool isSerializable) - : base(parameter.Type, fieldName) + TypeParameterSubstitutions = typeParameterSubstitutions; + FieldId = fieldId; + LibraryTypes = libraryTypes; + Parameter = parameter; + if (parameter.Type.TypeKind == TypeKind.Dynamic) { - TypeParameterSubstitutions = typeParameterSubstitutions; - FieldId = fieldId; - CodeGenerator = codeGenerator; - Parameter = parameter; - if (parameter.Type.TypeKind == TypeKind.Dynamic) - { - TypeSyntax = PredefinedType(Token(SyntaxKind.ObjectKeyword)); - TypeName = "dynamic"; - } - else - { - TypeName = Type.ToDisplayName(TypeParameterSubstitutions); - TypeSyntax = Type.ToTypeSyntax(TypeParameterSubstitutions); - } - - Symbol = parameter; - IsSerializable = isSerializable; + TypeSyntax = PredefinedType(Token(SyntaxKind.ObjectKeyword)); + TypeName = "dynamic"; } - - public CodeGenerator CodeGenerator { get; } - public ISymbol Symbol { get; } - public Dictionary TypeParameterSubstitutions { get; } - public int ParameterOrdinal => Parameter.Ordinal; - public uint FieldId { get; } - public ISymbol Member => Parameter; - public ITypeSymbol Type => FieldType; - public INamedTypeSymbol ContainingType => Parameter.ContainingType; - public TypeSyntax TypeSyntax { get; } - public IParameterSymbol Parameter { get; } - public override bool IsSerializable { get; } - public bool IsCopyable => true; - public override bool IsInstanceField => true; - - public string AssemblyName => Parameter.Type.ContainingAssembly.ToDisplayName(); - public string TypeName { get; } - - public string TypeNameIdentifier + else { - get - { - if (Type is ITypeParameterSymbol tp && TypeParameterSubstitutions.TryGetValue(tp, out var name)) - { - return name; - } - - return Type.GetValidIdentifier(); - } + TypeName = Type.ToDisplayName(TypeParameterSubstitutions); + TypeSyntax = Type.ToTypeSyntax(TypeParameterSubstitutions); } - public bool IsPrimaryConstructorParameter => false; - - public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(TypeParameterSubstitutions); + Symbol = parameter; + IsSerializable = isSerializable; } - internal sealed class MethodInfoFieldDescription : InvokerFieldDescription + public ISymbol Symbol { get; } + public LibraryTypes LibraryTypes { get; } + public Dictionary TypeParameterSubstitutions { get; } + public int ParameterOrdinal => Parameter.Ordinal; + public uint FieldId { get; } + public ISymbol Member => Parameter; + public ITypeSymbol Type => FieldType; + public INamedTypeSymbol ContainingType => Parameter.ContainingType; + public TypeSyntax TypeSyntax { get; } + public IParameterSymbol Parameter { get; } + public override bool IsSerializable { get; } + public bool IsCopyable => true; + public override bool IsInstanceField => true; + + public string AssemblyName => Parameter.Type.ContainingAssembly.ToDisplayName(); + public string TypeName { get; } + + public string TypeNameIdentifier { - public MethodInfoFieldDescription(ITypeSymbol fieldType, string fieldName) : base(fieldType, fieldName) { } + get + { + if (Type is ITypeParameterSymbol tp && TypeParameterSubstitutions.TryGetValue(tp, out var name)) + { + return name; + } - public override bool IsSerializable => false; - public override bool IsInstanceField => false; + return Type.GetValidIdentifier(); + } } + + public bool IsPrimaryConstructorParameter => false; + + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(TypeParameterSubstitutions); + } + + internal sealed class MethodInfoFieldDescription(ITypeSymbol fieldType, string fieldName) : InvokerFieldDescription(fieldType, fieldName) + { + public override bool IsSerializable => false; + public override bool IsInstanceField => false; } } diff --git a/src/Orleans.CodeGenerator/LibraryTypes.cs b/src/Orleans.CodeGenerator/LibraryTypes.cs index 9906407935c..4ca83ae54b9 100644 --- a/src/Orleans.CodeGenerator/LibraryTypes.cs +++ b/src/Orleans.CodeGenerator/LibraryTypes.cs @@ -1,459 +1,455 @@ -using System; using System.Collections.Concurrent; -using System.Collections.Generic; using System.Collections.Immutable; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Orleans.CodeGenerator.SyntaxGeneration; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal sealed class LibraryTypes { - internal sealed class LibraryTypes - { - private readonly ConcurrentDictionary _shallowCopyableTypes = new(SymbolEqualityComparer.Default); + private readonly ConcurrentDictionary _shallowCopyableTypes = new(SymbolEqualityComparer.Default); - public static LibraryTypes FromCompilation(Compilation compilation, CodeGeneratorOptions options) => new LibraryTypes(compilation, options); + public static LibraryTypes FromCompilation(Compilation compilation, CodeGeneratorOptions options) => new LibraryTypes(compilation, options); - private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) - { - Compilation = compilation; - ApplicationPartAttribute = Type("Orleans.ApplicationPartAttribute"); - Action_2 = Type("System.Action`2"); - TypeManifestProviderBase = Type("Orleans.Serialization.Configuration.TypeManifestProviderBase"); - Field = Type("Orleans.Serialization.WireProtocol.Field"); - FieldCodec_1 = Type("Orleans.Serialization.Codecs.IFieldCodec`1"); - AbstractTypeSerializer = Type("Orleans.Serialization.Serializers.AbstractTypeSerializer`1"); - DeepCopier_1 = Type("Orleans.Serialization.Cloning.IDeepCopier`1"); - ShallowCopier = Type("Orleans.Serialization.Cloning.ShallowCopier`1"); - CompoundTypeAliasAttribute = Type("Orleans.CompoundTypeAliasAttribute"); - CopyContext = Type("Orleans.Serialization.Cloning.CopyContext"); - MethodInfo = Type("System.Reflection.MethodInfo"); - Func_2 = Type("System.Func`2"); - GenerateMethodSerializersAttribute = Type("Orleans.GenerateMethodSerializersAttribute"); - GenerateSerializerAttribute = Type("Orleans.GenerateSerializerAttribute"); - SerializationCallbacksAttribute = Type("Orleans.SerializationCallbacksAttribute"); - IActivator_1 = Type("Orleans.Serialization.Activators.IActivator`1"); - IBufferWriter = Type("System.Buffers.IBufferWriter`1"); - IdAttributeType = Type(CodeGeneratorOptions.IdAttribute); - ConstructorAttributeTypes = CodeGeneratorOptions.ConstructorAttributes.Select(Type).ToArray(); - AliasAttribute = Type("Orleans.AliasAttribute"); - IInvokable = Type("Orleans.Serialization.Invocation.IInvokable"); - InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute"); - RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers"); - InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute"); - DefaultInvokableBaseTypeAttribute = Type("Orleans.DefaultInvokableBaseTypeAttribute"); - GenerateCodeForDeclaringAssemblyAttribute = Type("Orleans.GenerateCodeForDeclaringAssemblyAttribute"); - InvokableBaseTypeAttribute = Type("Orleans.InvokableBaseTypeAttribute"); - ReturnValueProxyAttribute = Type("Orleans.Invocation.ReturnValueProxyAttribute"); - RegisterSerializerAttribute = Type("Orleans.RegisterSerializerAttribute"); - ResponseTimeoutAttribute = Type("Orleans.ResponseTimeoutAttribute"); - GeneratedActivatorConstructorAttribute = Type("Orleans.GeneratedActivatorConstructorAttribute"); - SerializerTransparentAttribute = Type("Orleans.SerializerTransparentAttribute"); - RegisterActivatorAttribute = Type("Orleans.RegisterActivatorAttribute"); - RegisterConverterAttribute = Type("Orleans.RegisterConverterAttribute"); - RegisterCopierAttribute = Type("Orleans.RegisterCopierAttribute"); - UseActivatorAttribute = Type("Orleans.UseActivatorAttribute"); - SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute"); - OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute"); - ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder"); - TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute"); - NonSerializedAttribute = Type("System.NonSerializedAttribute"); - ObsoleteAttribute = Type("System.ObsoleteAttribute"); - BaseCodec_1 = Type("Orleans.Serialization.Serializers.IBaseCodec`1"); - BaseCopier_1 = Type("Orleans.Serialization.Cloning.IBaseCopier`1"); - ArrayCodec = Type("Orleans.Serialization.Codecs.ArrayCodec`1"); - ArrayCopier = Type("Orleans.Serialization.Codecs.ArrayCopier`1"); - Reader = Type("Orleans.Serialization.Buffers.Reader`1"); - TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions"); - Task = Type("System.Threading.Tasks.Task"); - Task_1 = Type("System.Threading.Tasks.Task`1"); - this.Type = Type("System.Type"); - _uri = Type("System.Uri"); - _int128 = TypeOrDefault("System.Int128"); - _uInt128 = TypeOrDefault("System.UInt128"); - _half = TypeOrDefault("System.Half"); - _dateOnly = TypeOrDefault("System.DateOnly"); - _dateTimeOffset = Type("System.DateTimeOffset"); - _bitVector32 = Type("System.Collections.Specialized.BitVector32"); - _compareInfo = Type("System.Globalization.CompareInfo"); - _cultureInfo = Type("System.Globalization.CultureInfo"); - _version = Type("System.Version"); - _timeOnly = TypeOrDefault("System.TimeOnly"); - Guid = Type("System.Guid"); - ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider"); - ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1"); - ValueTask = Type("System.Threading.Tasks.ValueTask"); - ValueTask_1 = Type("System.Threading.Tasks.ValueTask`1"); - ValueTypeGetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeGetter`2"); - ValueTypeSetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeSetter`2"); - Writer = Type("Orleans.Serialization.Buffers.Writer`1"); - FSharpSourceConstructFlagsOrDefault = TypeOrDefault("Microsoft.FSharp.Core.SourceConstructFlags"); - FSharpCompilationMappingAttributeOrDefault = TypeOrDefault("Microsoft.FSharp.Core.CompilationMappingAttribute"); - StaticCodecs = new List - { - new(compilation.GetSpecialType(SpecialType.System_Object), Type("Orleans.Serialization.Codecs.ObjectCodec")), - new(compilation.GetSpecialType(SpecialType.System_Boolean), Type("Orleans.Serialization.Codecs.BoolCodec")), - new(compilation.GetSpecialType(SpecialType.System_Char), Type("Orleans.Serialization.Codecs.CharCodec")), - new(compilation.GetSpecialType(SpecialType.System_Byte), Type("Orleans.Serialization.Codecs.ByteCodec")), - new(compilation.GetSpecialType(SpecialType.System_SByte), Type("Orleans.Serialization.Codecs.SByteCodec")), - new(compilation.GetSpecialType(SpecialType.System_Int16), Type("Orleans.Serialization.Codecs.Int16Codec")), - new(compilation.GetSpecialType(SpecialType.System_Int32), Type("Orleans.Serialization.Codecs.Int32Codec")), - new(compilation.GetSpecialType(SpecialType.System_Int64), Type("Orleans.Serialization.Codecs.Int64Codec")), - new(compilation.GetSpecialType(SpecialType.System_UInt16), Type("Orleans.Serialization.Codecs.UInt16Codec")), - new(compilation.GetSpecialType(SpecialType.System_UInt32), Type("Orleans.Serialization.Codecs.UInt32Codec")), - new(compilation.GetSpecialType(SpecialType.System_UInt64), Type("Orleans.Serialization.Codecs.UInt64Codec")), - new(compilation.GetSpecialType(SpecialType.System_String), Type("Orleans.Serialization.Codecs.StringCodec")), - new(compilation.CreateArrayTypeSymbol(compilation.GetSpecialType(SpecialType.System_Byte), 1), Type("Orleans.Serialization.Codecs.ByteArrayCodec")), - new(compilation.GetSpecialType(SpecialType.System_Single), Type("Orleans.Serialization.Codecs.FloatCodec")), - new(compilation.GetSpecialType(SpecialType.System_Double), Type("Orleans.Serialization.Codecs.DoubleCodec")), - new(compilation.GetSpecialType(SpecialType.System_Decimal), Type("Orleans.Serialization.Codecs.DecimalCodec")), - new(compilation.GetSpecialType(SpecialType.System_DateTime), Type("Orleans.Serialization.Codecs.DateTimeCodec")), - new(Type("System.TimeSpan"), Type("Orleans.Serialization.Codecs.TimeSpanCodec")), - new(Type("System.DateTimeOffset"), Type("Orleans.Serialization.Codecs.DateTimeOffsetCodec")), - new(TypeOrDefault("System.DateOnly"), TypeOrDefault("Orleans.Serialization.Codecs.DateOnlyCodec")), - new(TypeOrDefault("System.TimeOnly"), TypeOrDefault("Orleans.Serialization.Codecs.TimeOnlyCodec")), - new(Type("System.Guid"), Type("Orleans.Serialization.Codecs.GuidCodec")), - new(Type("System.Type"), Type("Orleans.Serialization.Codecs.TypeSerializerCodec")), - new(Type("System.ReadOnlyMemory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.ReadOnlyMemoryOfByteCodec")), - new(Type("System.Memory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.MemoryOfByteCodec")), - new(Type("System.Net.IPAddress"), Type("Orleans.Serialization.Codecs.IPAddressCodec")), - new(Type("System.Net.IPEndPoint"), Type("Orleans.Serialization.Codecs.IPEndPointCodec")), - new(TypeOrDefault("System.UInt128"), TypeOrDefault("Orleans.Serialization.Codecs.UInt128Codec")), - new(TypeOrDefault("System.Int128"), TypeOrDefault("Orleans.Serialization.Codecs.Int128Codec")), - new(TypeOrDefault("System.Half"), TypeOrDefault("Orleans.Serialization.Codecs.HalfCodec")), - new(Type("System.Uri"), Type("Orleans.Serialization.Codecs.UriCodec")), - }.Where(desc => desc.UnderlyingType is { } && desc.CodecType is { }).ToArray(); - WellKnownCodecs = new WellKnownCodecDescription[] - { - new(Type("System.Exception"), Type("Orleans.Serialization.ExceptionCodec")), - new(Type("System.Collections.Generic.Dictionary`2"), Type("Orleans.Serialization.Codecs.DictionaryCodec`2")), - new(Type("System.Collections.Generic.List`1"), Type("Orleans.Serialization.Codecs.ListCodec`1")), - new(Type("System.Collections.Generic.HashSet`1"), Type("Orleans.Serialization.Codecs.HashSetCodec`1")), - new(compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("Orleans.Serialization.Codecs.NullableCodec`1")), - }; - StaticCopiers = new WellKnownCopierDescription[] - { - new(compilation.GetSpecialType(SpecialType.System_Object), Type("Orleans.Serialization.Codecs.ObjectCopier")), - new(compilation.CreateArrayTypeSymbol(compilation.GetSpecialType(SpecialType.System_Byte), 1), Type("Orleans.Serialization.Codecs.ByteArrayCopier")), - new(Type("System.ReadOnlyMemory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.ReadOnlyMemoryOfByteCopier")), - new(Type("System.Memory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.MemoryOfByteCopier")), - }; - WellKnownCopiers = new WellKnownCopierDescription[] - { - new(Type("System.Exception"), Type("Orleans.Serialization.ExceptionCodec")), - new(Type("System.Collections.Generic.Dictionary`2"), Type("Orleans.Serialization.Codecs.DictionaryCopier`2")), - new(Type("System.Collections.Generic.List`1"), Type("Orleans.Serialization.Codecs.ListCopier`1")), - new(Type("System.Collections.Generic.HashSet`1"), Type("Orleans.Serialization.Codecs.HashSetCopier`1")), - new(compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("Orleans.Serialization.Codecs.NullableCopier`1")), - }; - Exception = Type("System.Exception"); - ImmutableAttribute = Type(CodeGeneratorOptions.ImmutableAttribute); - TimeSpan = Type("System.TimeSpan"); - _ipAddress = Type("System.Net.IPAddress"); - _ipEndPoint = Type("System.Net.IPEndPoint"); - CancellationToken = Type("System.Threading.CancellationToken"); - CancellationTokenSource = Type("System.Threading.CancellationTokenSource"); - _immutableContainerTypes = new[] + private LibraryTypes(Compilation compilation, CodeGeneratorOptions options) + { + Compilation = compilation; + ApplicationPartAttribute = Type("Orleans.ApplicationPartAttribute"); + Action_2 = Type("System.Action`2"); + TypeManifestProviderBase = Type("Orleans.Serialization.Configuration.TypeManifestProviderBase"); + Field = Type("Orleans.Serialization.WireProtocol.Field"); + FieldCodec_1 = Type("Orleans.Serialization.Codecs.IFieldCodec`1"); + AbstractTypeSerializer = Type("Orleans.Serialization.Serializers.AbstractTypeSerializer`1"); + DeepCopier_1 = Type("Orleans.Serialization.Cloning.IDeepCopier`1"); + ShallowCopier = Type("Orleans.Serialization.Cloning.ShallowCopier`1"); + CompoundTypeAliasAttribute = Type("Orleans.CompoundTypeAliasAttribute"); + CopyContext = Type("Orleans.Serialization.Cloning.CopyContext"); + MethodInfo = Type("System.Reflection.MethodInfo"); + Func_2 = Type("System.Func`2"); + GenerateMethodSerializersAttribute = Type("Orleans.GenerateMethodSerializersAttribute"); + GenerateSerializerAttribute = Type("Orleans.GenerateSerializerAttribute"); + SerializationCallbacksAttribute = Type("Orleans.SerializationCallbacksAttribute"); + IActivator_1 = Type("Orleans.Serialization.Activators.IActivator`1"); + IBufferWriter = Type("System.Buffers.IBufferWriter`1"); + IdAttributeType = Type(CodeGeneratorOptions.IdAttribute); + ConstructorAttributeTypes = [.. CodeGeneratorOptions.ConstructorAttributes.Select(Type)]; + AliasAttribute = Type("Orleans.AliasAttribute"); + IInvokable = Type("Orleans.Serialization.Invocation.IInvokable"); + InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute"); + RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers"); + InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute"); + DefaultInvokableBaseTypeAttribute = Type("Orleans.DefaultInvokableBaseTypeAttribute"); + GenerateCodeForDeclaringAssemblyAttribute = Type("Orleans.GenerateCodeForDeclaringAssemblyAttribute"); + InvokableBaseTypeAttribute = Type("Orleans.InvokableBaseTypeAttribute"); + ReturnValueProxyAttribute = Type("Orleans.Invocation.ReturnValueProxyAttribute"); + RegisterSerializerAttribute = Type("Orleans.RegisterSerializerAttribute"); + ResponseTimeoutAttribute = Type("Orleans.ResponseTimeoutAttribute"); + GeneratedActivatorConstructorAttribute = Type("Orleans.GeneratedActivatorConstructorAttribute"); + SerializerTransparentAttribute = Type("Orleans.SerializerTransparentAttribute"); + RegisterActivatorAttribute = Type("Orleans.RegisterActivatorAttribute"); + RegisterConverterAttribute = Type("Orleans.RegisterConverterAttribute"); + RegisterCopierAttribute = Type("Orleans.RegisterCopierAttribute"); + UseActivatorAttribute = Type("Orleans.UseActivatorAttribute"); + SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute"); + OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute"); + ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder"); + TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute"); + NonSerializedAttribute = Type("System.NonSerializedAttribute"); + ObsoleteAttribute = Type("System.ObsoleteAttribute"); + BaseCodec_1 = Type("Orleans.Serialization.Serializers.IBaseCodec`1"); + BaseCopier_1 = Type("Orleans.Serialization.Cloning.IBaseCopier`1"); + ArrayCodec = Type("Orleans.Serialization.Codecs.ArrayCodec`1"); + ArrayCopier = Type("Orleans.Serialization.Codecs.ArrayCopier`1"); + Reader = Type("Orleans.Serialization.Buffers.Reader`1"); + TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions"); + Task = Type("System.Threading.Tasks.Task"); + Task_1 = Type("System.Threading.Tasks.Task`1"); + this.Type = Type("System.Type"); + _uri = Type("System.Uri"); + _int128 = TypeOrDefault("System.Int128"); + _uInt128 = TypeOrDefault("System.UInt128"); + _half = TypeOrDefault("System.Half"); + _dateOnly = TypeOrDefault("System.DateOnly"); + _dateTimeOffset = Type("System.DateTimeOffset"); + _bitVector32 = Type("System.Collections.Specialized.BitVector32"); + _compareInfo = Type("System.Globalization.CompareInfo"); + _cultureInfo = Type("System.Globalization.CultureInfo"); + _version = Type("System.Version"); + _timeOnly = TypeOrDefault("System.TimeOnly"); + Guid = Type("System.Guid"); + ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider"); + ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1"); + ValueTask = Type("System.Threading.Tasks.ValueTask"); + ValueTask_1 = Type("System.Threading.Tasks.ValueTask`1"); + ValueTypeGetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeGetter`2"); + ValueTypeSetter_2 = Type("Orleans.Serialization.Utilities.ValueTypeSetter`2"); + Writer = Type("Orleans.Serialization.Buffers.Writer`1"); + FSharpSourceConstructFlagsOrDefault = TypeOrDefault("Microsoft.FSharp.Core.SourceConstructFlags"); + FSharpCompilationMappingAttributeOrDefault = TypeOrDefault("Microsoft.FSharp.Core.CompilationMappingAttribute"); + StaticCodecs = [.. new List { - compilation.GetSpecialType(SpecialType.System_Nullable_T), - Type("System.Tuple`1"), - Type("System.Tuple`2"), - Type("System.Tuple`3"), - Type("System.Tuple`4"), - Type("System.Tuple`5"), - Type("System.Tuple`6"), - Type("System.Tuple`7"), - Type("System.Tuple`8"), - Type("System.ValueTuple`1"), - Type("System.ValueTuple`2"), - Type("System.ValueTuple`3"), - Type("System.ValueTuple`4"), - Type("System.ValueTuple`5"), - Type("System.ValueTuple`6"), - Type("System.ValueTuple`7"), - Type("System.ValueTuple`8"), - Type("System.Collections.Immutable.ImmutableArray`1"), - Type("System.Collections.Immutable.ImmutableDictionary`2"), - Type("System.Collections.Immutable.ImmutableHashSet`1"), - Type("System.Collections.Immutable.ImmutableList`1"), - Type("System.Collections.Immutable.ImmutableQueue`1"), - Type("System.Collections.Immutable.ImmutableSortedDictionary`2"), - Type("System.Collections.Immutable.ImmutableSortedSet`1"), - Type("System.Collections.Immutable.ImmutableStack`1"), - }; - - LanguageVersion = (compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions)?.LanguageVersion; - - INamedTypeSymbol Type(string metadataName) + new(compilation.GetSpecialType(SpecialType.System_Object), Type("Orleans.Serialization.Codecs.ObjectCodec")), + new(compilation.GetSpecialType(SpecialType.System_Boolean), Type("Orleans.Serialization.Codecs.BoolCodec")), + new(compilation.GetSpecialType(SpecialType.System_Char), Type("Orleans.Serialization.Codecs.CharCodec")), + new(compilation.GetSpecialType(SpecialType.System_Byte), Type("Orleans.Serialization.Codecs.ByteCodec")), + new(compilation.GetSpecialType(SpecialType.System_SByte), Type("Orleans.Serialization.Codecs.SByteCodec")), + new(compilation.GetSpecialType(SpecialType.System_Int16), Type("Orleans.Serialization.Codecs.Int16Codec")), + new(compilation.GetSpecialType(SpecialType.System_Int32), Type("Orleans.Serialization.Codecs.Int32Codec")), + new(compilation.GetSpecialType(SpecialType.System_Int64), Type("Orleans.Serialization.Codecs.Int64Codec")), + new(compilation.GetSpecialType(SpecialType.System_UInt16), Type("Orleans.Serialization.Codecs.UInt16Codec")), + new(compilation.GetSpecialType(SpecialType.System_UInt32), Type("Orleans.Serialization.Codecs.UInt32Codec")), + new(compilation.GetSpecialType(SpecialType.System_UInt64), Type("Orleans.Serialization.Codecs.UInt64Codec")), + new(compilation.GetSpecialType(SpecialType.System_String), Type("Orleans.Serialization.Codecs.StringCodec")), + new(compilation.CreateArrayTypeSymbol(compilation.GetSpecialType(SpecialType.System_Byte), 1), Type("Orleans.Serialization.Codecs.ByteArrayCodec")), + new(compilation.GetSpecialType(SpecialType.System_Single), Type("Orleans.Serialization.Codecs.FloatCodec")), + new(compilation.GetSpecialType(SpecialType.System_Double), Type("Orleans.Serialization.Codecs.DoubleCodec")), + new(compilation.GetSpecialType(SpecialType.System_Decimal), Type("Orleans.Serialization.Codecs.DecimalCodec")), + new(compilation.GetSpecialType(SpecialType.System_DateTime), Type("Orleans.Serialization.Codecs.DateTimeCodec")), + new(Type("System.TimeSpan"), Type("Orleans.Serialization.Codecs.TimeSpanCodec")), + new(Type("System.DateTimeOffset"), Type("Orleans.Serialization.Codecs.DateTimeOffsetCodec")), + new(TypeOrDefault("System.DateOnly"), TypeOrDefault("Orleans.Serialization.Codecs.DateOnlyCodec")), + new(TypeOrDefault("System.TimeOnly"), TypeOrDefault("Orleans.Serialization.Codecs.TimeOnlyCodec")), + new(Type("System.Guid"), Type("Orleans.Serialization.Codecs.GuidCodec")), + new(Type("System.Type"), Type("Orleans.Serialization.Codecs.TypeSerializerCodec")), + new(Type("System.ReadOnlyMemory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.ReadOnlyMemoryOfByteCodec")), + new(Type("System.Memory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.MemoryOfByteCodec")), + new(Type("System.Net.IPAddress"), Type("Orleans.Serialization.Codecs.IPAddressCodec")), + new(Type("System.Net.IPEndPoint"), Type("Orleans.Serialization.Codecs.IPEndPointCodec")), + new(TypeOrDefault("System.UInt128"), TypeOrDefault("Orleans.Serialization.Codecs.UInt128Codec")), + new(TypeOrDefault("System.Int128"), TypeOrDefault("Orleans.Serialization.Codecs.Int128Codec")), + new(TypeOrDefault("System.Half"), TypeOrDefault("Orleans.Serialization.Codecs.HalfCodec")), + new(Type("System.Uri"), Type("Orleans.Serialization.Codecs.UriCodec")), + }.Where(desc => desc.UnderlyingType is { } && desc.CodecType is { })]; + WellKnownCodecs = + [ + new(Type("System.Exception"), Type("Orleans.Serialization.ExceptionCodec")), + new(Type("System.Collections.Generic.Dictionary`2"), Type("Orleans.Serialization.Codecs.DictionaryCodec`2")), + new(Type("System.Collections.Generic.List`1"), Type("Orleans.Serialization.Codecs.ListCodec`1")), + new(Type("System.Collections.Generic.HashSet`1"), Type("Orleans.Serialization.Codecs.HashSetCodec`1")), + new(compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("Orleans.Serialization.Codecs.NullableCodec`1")), + ]; + StaticCopiers = + [ + new(compilation.GetSpecialType(SpecialType.System_Object), Type("Orleans.Serialization.Codecs.ObjectCopier")), + new(compilation.CreateArrayTypeSymbol(compilation.GetSpecialType(SpecialType.System_Byte), 1), Type("Orleans.Serialization.Codecs.ByteArrayCopier")), + new(Type("System.ReadOnlyMemory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.ReadOnlyMemoryOfByteCopier")), + new(Type("System.Memory`1").Construct(compilation.GetSpecialType(SpecialType.System_Byte)), Type("Orleans.Serialization.Codecs.MemoryOfByteCopier")), + ]; + WellKnownCopiers = + [ + new(Type("System.Exception"), Type("Orleans.Serialization.ExceptionCodec")), + new(Type("System.Collections.Generic.Dictionary`2"), Type("Orleans.Serialization.Codecs.DictionaryCopier`2")), + new(Type("System.Collections.Generic.List`1"), Type("Orleans.Serialization.Codecs.ListCopier`1")), + new(Type("System.Collections.Generic.HashSet`1"), Type("Orleans.Serialization.Codecs.HashSetCopier`1")), + new(compilation.GetSpecialType(SpecialType.System_Nullable_T), Type("Orleans.Serialization.Codecs.NullableCopier`1")), + ]; + Exception = Type("System.Exception"); + ImmutableAttribute = Type(CodeGeneratorOptions.ImmutableAttribute); + TimeSpan = Type("System.TimeSpan"); + _ipAddress = Type("System.Net.IPAddress"); + _ipEndPoint = Type("System.Net.IPEndPoint"); + CancellationToken = Type("System.Threading.CancellationToken"); + CancellationTokenSource = Type("System.Threading.CancellationTokenSource"); + _immutableContainerTypes = + [ + compilation.GetSpecialType(SpecialType.System_Nullable_T), + Type("System.Tuple`1"), + Type("System.Tuple`2"), + Type("System.Tuple`3"), + Type("System.Tuple`4"), + Type("System.Tuple`5"), + Type("System.Tuple`6"), + Type("System.Tuple`7"), + Type("System.Tuple`8"), + Type("System.ValueTuple`1"), + Type("System.ValueTuple`2"), + Type("System.ValueTuple`3"), + Type("System.ValueTuple`4"), + Type("System.ValueTuple`5"), + Type("System.ValueTuple`6"), + Type("System.ValueTuple`7"), + Type("System.ValueTuple`8"), + Type("System.Collections.Immutable.ImmutableArray`1"), + Type("System.Collections.Immutable.ImmutableDictionary`2"), + Type("System.Collections.Immutable.ImmutableHashSet`1"), + Type("System.Collections.Immutable.ImmutableList`1"), + Type("System.Collections.Immutable.ImmutableQueue`1"), + Type("System.Collections.Immutable.ImmutableSortedDictionary`2"), + Type("System.Collections.Immutable.ImmutableSortedSet`1"), + Type("System.Collections.Immutable.ImmutableStack`1"), + ]; + + LanguageVersion = (compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions)?.LanguageVersion; + + INamedTypeSymbol Type(string metadataName) + { + var result = compilation.GetTypeByMetadataName(metadataName); + if (result is null) { - var result = compilation.GetTypeByMetadataName(metadataName); - if (result is null) - { - throw new InvalidOperationException("Cannot find type with metadata name " + metadataName); - } - - return result; + throw new InvalidOperationException("Cannot find type with metadata name " + metadataName); } - INamedTypeSymbol? TypeOrDefault(string metadataName) - { - var result = compilation.GetTypeByMetadataName(metadataName); - return result; - } + return result; } - public INamedTypeSymbol Action_2 { get; private set; } - public INamedTypeSymbol TypeManifestProviderBase { get; private set; } - public INamedTypeSymbol Field { get; private set; } - public INamedTypeSymbol DeepCopier_1 { get; private set; } - public INamedTypeSymbol ShallowCopier { get; private set; } - public INamedTypeSymbol FieldCodec_1 { get; private set; } - public INamedTypeSymbol AbstractTypeSerializer { get; private set; } - public INamedTypeSymbol Func_2 { get; private set; } - public INamedTypeSymbol CompoundTypeAliasAttribute { get; private set; } - public INamedTypeSymbol GenerateMethodSerializersAttribute { get; private set; } - public INamedTypeSymbol GenerateSerializerAttribute { get; private set; } - public INamedTypeSymbol IActivator_1 { get; private set; } - public INamedTypeSymbol IBufferWriter { get; private set; } - public INamedTypeSymbol IInvokable { get; private set; } - public INamedTypeSymbol ITargetHolder { get; private set; } - public INamedTypeSymbol TypeManifestProviderAttribute { get; private set; } - public INamedTypeSymbol NonSerializedAttribute { get; private set; } - public INamedTypeSymbol ObsoleteAttribute { get; private set; } - public INamedTypeSymbol BaseCodec_1 { get; private set; } - public INamedTypeSymbol BaseCopier_1 { get; private set; } - public INamedTypeSymbol ArrayCodec { get; private set; } - public INamedTypeSymbol ArrayCopier { get; private set; } - public INamedTypeSymbol Reader { get; private set; } - public INamedTypeSymbol TypeManifestOptions { get; private set; } - public INamedTypeSymbol Task { get; private set; } - public INamedTypeSymbol Task_1 { get; private set; } - public INamedTypeSymbol Type { get; private set; } - private INamedTypeSymbol _uri; - private INamedTypeSymbol? _dateOnly; - private INamedTypeSymbol _dateTimeOffset; - private INamedTypeSymbol? _timeOnly; - public INamedTypeSymbol MethodInfo { get; private set; } - public INamedTypeSymbol ICodecProvider { get; private set; } - public INamedTypeSymbol ValueSerializer { get; private set; } - public INamedTypeSymbol ValueTask { get; private set; } - public INamedTypeSymbol ValueTask_1 { get; private set; } - public INamedTypeSymbol ValueTypeGetter_2 { get; private set; } - public INamedTypeSymbol ValueTypeSetter_2 { get; private set; } - public INamedTypeSymbol Writer { get; private set; } - public INamedTypeSymbol IdAttributeType { get; private set; } - public INamedTypeSymbol[] ConstructorAttributeTypes { get; private set; } - public INamedTypeSymbol AliasAttribute { get; private set; } - public WellKnownCodecDescription[] StaticCodecs { get; private set; } - public WellKnownCodecDescription[] WellKnownCodecs { get; private set; } - public WellKnownCopierDescription[] StaticCopiers { get; private set; } - public WellKnownCopierDescription[] WellKnownCopiers { get; private set; } - public INamedTypeSymbol RegisterCopierAttribute { get; private set; } - public INamedTypeSymbol RegisterSerializerAttribute { get; private set; } - public INamedTypeSymbol ResponseTimeoutAttribute { get; private set; } - public INamedTypeSymbol RegisterConverterAttribute { get; private set; } - public INamedTypeSymbol RegisterActivatorAttribute { get; private set; } - public INamedTypeSymbol UseActivatorAttribute { get; private set; } - public INamedTypeSymbol SuppressReferenceTrackingAttribute { get; private set; } - public INamedTypeSymbol OmitDefaultMemberValuesAttribute { get; private set; } - public INamedTypeSymbol CopyContext { get; private set; } - public INamedTypeSymbol CancellationToken { get; private set; } - public INamedTypeSymbol CancellationTokenSource { get; } - public INamedTypeSymbol Guid { get; private set; } - public Compilation Compilation { get; private set; } - public INamedTypeSymbol TimeSpan { get; private set; } - private INamedTypeSymbol _ipAddress; - private INamedTypeSymbol _ipEndPoint; - private INamedTypeSymbol[] _immutableContainerTypes; - private INamedTypeSymbol _bitVector32; - private INamedTypeSymbol _compareInfo; - private INamedTypeSymbol _cultureInfo; - private INamedTypeSymbol _version; - private INamedTypeSymbol? _int128; - private INamedTypeSymbol? _uInt128; - private INamedTypeSymbol? _half; - private INamedTypeSymbol[]? _regularShallowCopyableTypes; - private INamedTypeSymbol[] RegularShallowCopyableType => _regularShallowCopyableTypes ??= new List + INamedTypeSymbol? TypeOrDefault(string metadataName) { - TimeSpan, - _dateOnly, - _timeOnly, - _dateTimeOffset, - Guid, - _bitVector32, - _compareInfo, - _cultureInfo, - _version, - _ipAddress, - _ipEndPoint, - CancellationToken, - Type, - _uri, - _uInt128, - _int128, - _half - }.Where(t => t is {}).ToArray()!; - - public INamedTypeSymbol ImmutableAttribute { get; private set; } - public INamedTypeSymbol Exception { get; private set; } - public INamedTypeSymbol ApplicationPartAttribute { get; private set; } - public INamedTypeSymbol InvokeMethodNameAttribute { get; private set; } - public INamedTypeSymbol InvokableCustomInitializerAttribute { get; private set; } - public INamedTypeSymbol InvokableBaseTypeAttribute { get; private set; } - public INamedTypeSymbol ReturnValueProxyAttribute { get; private set; } - public INamedTypeSymbol DefaultInvokableBaseTypeAttribute { get; private set; } - public INamedTypeSymbol GenerateCodeForDeclaringAssemblyAttribute { get; private set; } - public INamedTypeSymbol SerializationCallbacksAttribute { get; private set; } - public INamedTypeSymbol GeneratedActivatorConstructorAttribute { get; private set; } - public INamedTypeSymbol SerializerTransparentAttribute { get; private set; } - public INamedTypeSymbol? FSharpCompilationMappingAttributeOrDefault { get; private set; } - public INamedTypeSymbol? FSharpSourceConstructFlagsOrDefault { get; private set; } - public INamedTypeSymbol RuntimeHelpers { get; private set; } - - public LanguageVersion? LanguageVersion { get; private set; } - - public bool IsShallowCopyable(ITypeSymbol type) + var result = compilation.GetTypeByMetadataName(metadataName); + return result; + } + } + + public INamedTypeSymbol Action_2 { get; private set; } + public INamedTypeSymbol TypeManifestProviderBase { get; private set; } + public INamedTypeSymbol Field { get; private set; } + public INamedTypeSymbol DeepCopier_1 { get; private set; } + public INamedTypeSymbol ShallowCopier { get; private set; } + public INamedTypeSymbol FieldCodec_1 { get; private set; } + public INamedTypeSymbol AbstractTypeSerializer { get; private set; } + public INamedTypeSymbol Func_2 { get; private set; } + public INamedTypeSymbol CompoundTypeAliasAttribute { get; private set; } + public INamedTypeSymbol GenerateMethodSerializersAttribute { get; private set; } + public INamedTypeSymbol GenerateSerializerAttribute { get; private set; } + public INamedTypeSymbol IActivator_1 { get; private set; } + public INamedTypeSymbol IBufferWriter { get; private set; } + public INamedTypeSymbol IInvokable { get; private set; } + public INamedTypeSymbol ITargetHolder { get; private set; } + public INamedTypeSymbol TypeManifestProviderAttribute { get; private set; } + public INamedTypeSymbol NonSerializedAttribute { get; private set; } + public INamedTypeSymbol ObsoleteAttribute { get; private set; } + public INamedTypeSymbol BaseCodec_1 { get; private set; } + public INamedTypeSymbol BaseCopier_1 { get; private set; } + public INamedTypeSymbol ArrayCodec { get; private set; } + public INamedTypeSymbol ArrayCopier { get; private set; } + public INamedTypeSymbol Reader { get; private set; } + public INamedTypeSymbol TypeManifestOptions { get; private set; } + public INamedTypeSymbol Task { get; private set; } + public INamedTypeSymbol Task_1 { get; private set; } + public INamedTypeSymbol Type { get; private set; } + private INamedTypeSymbol _uri; + private INamedTypeSymbol? _dateOnly; + private INamedTypeSymbol _dateTimeOffset; + private INamedTypeSymbol? _timeOnly; + public INamedTypeSymbol MethodInfo { get; private set; } + public INamedTypeSymbol ICodecProvider { get; private set; } + public INamedTypeSymbol ValueSerializer { get; private set; } + public INamedTypeSymbol ValueTask { get; private set; } + public INamedTypeSymbol ValueTask_1 { get; private set; } + public INamedTypeSymbol ValueTypeGetter_2 { get; private set; } + public INamedTypeSymbol ValueTypeSetter_2 { get; private set; } + public INamedTypeSymbol Writer { get; private set; } + public INamedTypeSymbol IdAttributeType { get; private set; } + public INamedTypeSymbol[] ConstructorAttributeTypes { get; private set; } + public INamedTypeSymbol AliasAttribute { get; private set; } + public WellKnownCodecDescription[] StaticCodecs { get; private set; } + public WellKnownCodecDescription[] WellKnownCodecs { get; private set; } + public WellKnownCopierDescription[] StaticCopiers { get; private set; } + public WellKnownCopierDescription[] WellKnownCopiers { get; private set; } + public INamedTypeSymbol RegisterCopierAttribute { get; private set; } + public INamedTypeSymbol RegisterSerializerAttribute { get; private set; } + public INamedTypeSymbol ResponseTimeoutAttribute { get; private set; } + public INamedTypeSymbol RegisterConverterAttribute { get; private set; } + public INamedTypeSymbol RegisterActivatorAttribute { get; private set; } + public INamedTypeSymbol UseActivatorAttribute { get; private set; } + public INamedTypeSymbol SuppressReferenceTrackingAttribute { get; private set; } + public INamedTypeSymbol OmitDefaultMemberValuesAttribute { get; private set; } + public INamedTypeSymbol CopyContext { get; private set; } + public INamedTypeSymbol CancellationToken { get; private set; } + public INamedTypeSymbol CancellationTokenSource { get; } + public INamedTypeSymbol Guid { get; private set; } + public Compilation Compilation { get; private set; } + public INamedTypeSymbol TimeSpan { get; private set; } + private INamedTypeSymbol _ipAddress; + private INamedTypeSymbol _ipEndPoint; + private INamedTypeSymbol[] _immutableContainerTypes; + private INamedTypeSymbol _bitVector32; + private INamedTypeSymbol _compareInfo; + private INamedTypeSymbol _cultureInfo; + private INamedTypeSymbol _version; + private INamedTypeSymbol? _int128; + private INamedTypeSymbol? _uInt128; + private INamedTypeSymbol? _half; + + private INamedTypeSymbol[] RegularShallowCopyableType => field ??= new List + { + TimeSpan, + _dateOnly, + _timeOnly, + _dateTimeOffset, + Guid, + _bitVector32, + _compareInfo, + _cultureInfo, + _version, + _ipAddress, + _ipEndPoint, + CancellationToken, + Type, + _uri, + _uInt128, + _int128, + _half + }.OfType().ToArray(); + + public INamedTypeSymbol ImmutableAttribute { get; private set; } + public INamedTypeSymbol Exception { get; private set; } + public INamedTypeSymbol ApplicationPartAttribute { get; private set; } + public INamedTypeSymbol InvokeMethodNameAttribute { get; private set; } + public INamedTypeSymbol InvokableCustomInitializerAttribute { get; private set; } + public INamedTypeSymbol InvokableBaseTypeAttribute { get; private set; } + public INamedTypeSymbol ReturnValueProxyAttribute { get; private set; } + public INamedTypeSymbol DefaultInvokableBaseTypeAttribute { get; private set; } + public INamedTypeSymbol GenerateCodeForDeclaringAssemblyAttribute { get; private set; } + public INamedTypeSymbol SerializationCallbacksAttribute { get; private set; } + public INamedTypeSymbol GeneratedActivatorConstructorAttribute { get; private set; } + public INamedTypeSymbol SerializerTransparentAttribute { get; private set; } + public INamedTypeSymbol? FSharpCompilationMappingAttributeOrDefault { get; private set; } + public INamedTypeSymbol? FSharpSourceConstructFlagsOrDefault { get; private set; } + public INamedTypeSymbol RuntimeHelpers { get; private set; } + + public LanguageVersion? LanguageVersion { get; private set; } + + public bool IsShallowCopyable(ITypeSymbol type) + { + switch (type.SpecialType) { - switch (type.SpecialType) - { - case SpecialType.System_Boolean: - case SpecialType.System_Char: - case SpecialType.System_SByte: - case SpecialType.System_Byte: - case SpecialType.System_Int16: - case SpecialType.System_UInt16: - case SpecialType.System_Int32: - case SpecialType.System_UInt32: - case SpecialType.System_Int64: - case SpecialType.System_UInt64: - case SpecialType.System_Decimal: - case SpecialType.System_Single: - case SpecialType.System_Double: - case SpecialType.System_String: - case SpecialType.System_DateTime: - return true; - } + case SpecialType.System_Boolean: + case SpecialType.System_Char: + case SpecialType.System_SByte: + case SpecialType.System_Byte: + case SpecialType.System_Int16: + case SpecialType.System_UInt16: + case SpecialType.System_Int32: + case SpecialType.System_UInt32: + case SpecialType.System_Int64: + case SpecialType.System_UInt64: + case SpecialType.System_Decimal: + case SpecialType.System_Single: + case SpecialType.System_Double: + case SpecialType.System_String: + case SpecialType.System_DateTime: + return true; + } - if (_shallowCopyableTypes.TryGetValue(type, out var result)) - { - return result; - } + if (_shallowCopyableTypes.TryGetValue(type, out var result)) + { + return result; + } - foreach (var shallowCopyable in RegularShallowCopyableType) + foreach (var shallowCopyable in RegularShallowCopyableType) + { + if (SymbolEqualityComparer.Default.Equals(shallowCopyable, type)) { - if (SymbolEqualityComparer.Default.Equals(shallowCopyable, type)) - { - return _shallowCopyableTypes[type] = true; - } + return _shallowCopyableTypes[type] = true; } + } + + if (type.IsSealed && type.HasAttribute(ImmutableAttribute)) + { + return _shallowCopyableTypes[type] = true; + } - if (type.IsSealed && type.HasAttribute(ImmutableAttribute)) + if (type.HasBaseType(Exception)) + { + return _shallowCopyableTypes[type] = true; + } + + if (!(type is INamedTypeSymbol namedType)) + { + return _shallowCopyableTypes[type] = false; + } + + if (namedType.IsTupleType) + { + return _shallowCopyableTypes[type] = AreShallowCopyable(namedType.TupleElements); + } + else if (namedType.IsGenericType) + { + var def = namedType.ConstructedFrom; + foreach (var t in _immutableContainerTypes) { - return _shallowCopyableTypes[type] = true; + if (SymbolEqualityComparer.Default.Equals(t, def)) + return _shallowCopyableTypes[type] = AreShallowCopyable(namedType.TypeArguments); } - - if (type.HasBaseType(Exception)) + } + else + { + if (type.TypeKind == TypeKind.Enum) { return _shallowCopyableTypes[type] = true; } - if (!(type is INamedTypeSymbol namedType)) + if (type.TypeKind == TypeKind.Struct && !namedType.IsUnboundGenericType) { - return _shallowCopyableTypes[type] = false; + return _shallowCopyableTypes[type] = IsValueTypeFieldsShallowCopyable(type); } + } - if (namedType.IsTupleType) - { - return _shallowCopyableTypes[type] = AreShallowCopyable(namedType.TupleElements); - } - else if (namedType.IsGenericType) + return _shallowCopyableTypes[type] = false; + } + + private bool IsValueTypeFieldsShallowCopyable(ITypeSymbol type) + { + foreach (var field in type.GetDeclaredInstanceMembers()) + { + if (field.Type is not INamedTypeSymbol fieldType) { - var def = namedType.ConstructedFrom; - foreach (var t in _immutableContainerTypes) - { - if (SymbolEqualityComparer.Default.Equals(t, def)) - return _shallowCopyableTypes[type] = AreShallowCopyable(namedType.TypeArguments); - } + return false; } - else + + if (SymbolEqualityComparer.Default.Equals(type, fieldType)) { - if (type.TypeKind == TypeKind.Enum) - { - return _shallowCopyableTypes[type] = true; - } - - if (type.TypeKind == TypeKind.Struct && !namedType.IsUnboundGenericType) - { - return _shallowCopyableTypes[type] = IsValueTypeFieldsShallowCopyable(type); - } + return false; } - return _shallowCopyableTypes[type] = false; - } - - private bool IsValueTypeFieldsShallowCopyable(ITypeSymbol type) - { - foreach (var field in type.GetDeclaredInstanceMembers()) + if (!IsShallowCopyable(fieldType)) { - if (field.Type is not INamedTypeSymbol fieldType) - { - return false; - } - - if (SymbolEqualityComparer.Default.Equals(type, fieldType)) - { - return false; - } - - if (!IsShallowCopyable(fieldType)) - { - return false; - } + return false; } - - return true; } - private bool AreShallowCopyable(ImmutableArray types) - { - foreach (var t in types) - if (!IsShallowCopyable(t)) - return false; - - return true; - } + return true; + } - private bool AreShallowCopyable(ImmutableArray fields) - { - foreach (var f in fields) - if (!IsShallowCopyable(f.Type)) - return false; + private bool AreShallowCopyable(ImmutableArray types) + { + foreach (var t in types) + if (!IsShallowCopyable(t)) + return false; - return true; - } + return true; } - internal static class LibraryExtensions + private bool AreShallowCopyable(ImmutableArray fields) { - public static WellKnownCodecDescription? FindByUnderlyingType(this WellKnownCodecDescription[] values, ISymbol type) - { - foreach (var c in values) - if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, type)) - return c; + foreach (var f in fields) + if (!IsShallowCopyable(f.Type)) + return false; - return null; - } + return true; + } +} - public static WellKnownCopierDescription? FindByUnderlyingType(this WellKnownCopierDescription[] values, ISymbol type) - { - foreach (var c in values) - if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, type)) - return c; +internal static class LibraryExtensions +{ + public static WellKnownCodecDescription? FindByUnderlyingType(this WellKnownCodecDescription[] values, ISymbol type) + { + foreach (var c in values) + if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, type)) + return c; - return null; - } + return null; + } + + public static WellKnownCopierDescription? FindByUnderlyingType(this WellKnownCopierDescription[] values, ISymbol type) + { + foreach (var c in values) + if (SymbolEqualityComparer.Default.Equals(c.UnderlyingType, type)) + return c; - public static bool HasScopedKeyword(this LibraryTypes libraryTypes) => libraryTypes.LanguageVersion is null or >= LanguageVersion.CSharp11; + return null; } + + public static bool HasScopedKeyword(this LibraryTypes libraryTypes) => libraryTypes.LanguageVersion is null or >= LanguageVersion.CSharp11; } diff --git a/src/Orleans.CodeGenerator/MetadataGenerator.cs b/src/Orleans.CodeGenerator/MetadataGenerator.cs index a01df87cd10..f6ea25c4574 100644 --- a/src/Orleans.CodeGenerator/MetadataGenerator.cs +++ b/src/Orleans.CodeGenerator/MetadataGenerator.cs @@ -1,242 +1,658 @@ -using System.Collections.Generic; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Model; using Orleans.CodeGenerator.SyntaxGeneration; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using System; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class MetadataGenerator(MetadataAggregateModel metadataModel, string assemblyName) { - #nullable disable - internal class MetadataGenerator - { - private readonly CodeGenerator _codeGenerator; + private static readonly TypeSyntax TypeManifestOptionsType = ParseTypeName("global::Orleans.Serialization.Configuration.TypeManifestOptions"); + private static readonly TypeSyntax TypeManifestProviderBaseType = ParseTypeName("global::Orleans.Serialization.Configuration.TypeManifestProviderBase"); + + private readonly MetadataAggregateModel _metadataModel = metadataModel; + private readonly string _assemblyName = assemblyName ?? "Assembly"; - public MetadataGenerator(CodeGenerator codeGenerator) + public ClassDeclarationSyntax GenerateMetadata() + => GenerateIncrementalMetadata(); + + private ClassDeclarationSyntax GenerateIncrementalMetadata() + { + var configParam = "config".ToIdentifierName(); + var body = new List(); + var model = _metadataModel; + var orderedProxyInterfaces = GetOrderedProxyInterfaces(model.ProxyInterfaces); + var generatedInvokableActivatorMetadataNames = new HashSet( + model.GeneratedInvokableActivatorMetadataNames, + StringComparer.Ordinal); + var generatedInvokables = GetGeneratedInvokableMetadata(orderedProxyInterfaces, generatedInvokableActivatorMetadataNames); + var serializableRegistrations = GetOrderedSerializableRegistrations(model, generatedInvokables); + + var addSerializerMethod = configParam.Member("Serializers").Member("Add"); + foreach (var registration in serializableRegistrations) { - _codeGenerator = codeGenerator; + AddRegistration(body, addSerializerMethod, registration.SerializerTypeSyntax); } - private MetadataModel MetadataModel => _codeGenerator.MetadataModel; - - public ClassDeclarationSyntax GenerateMetadata() + foreach (var type in model.RegisteredCodecs.Where(static codec => codec.Kind == RegisteredCodecKind.Serializer)) { - var configParam = "config".ToIdentifierName(); - var addSerializerMethod = configParam.Member("Serializers").Member("Add"); - var addCopierMethod = configParam.Member("Copiers").Member("Add"); - var addConverterMethod = configParam.Member("Converters").Member("Add"); - var body = new List(); + AddRegistration(body, addSerializerMethod, GetOpenTypeSyntax(type.Type)); + } - foreach (var type in MetadataModel.SerializableTypes) + var addCopierMethod = configParam.Member("Copiers").Member("Add"); + foreach (var registration in serializableRegistrations) + { + if (registration.CopierTypeSyntax is not null) { - body.Add(ExpressionStatement(InvocationExpression(addSerializerMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(GetCodecTypeName(type)))))))); + AddRegistration(body, addCopierMethod, registration.CopierTypeSyntax); } + } - foreach (var type in MetadataModel.SerializableTypes) - { - if (type.IsEnumType) continue; + foreach (var type in model.RegisteredCodecs.Where(static codec => codec.Kind == RegisteredCodecKind.Copier)) + { + AddRegistration(body, addCopierMethod, GetOpenTypeSyntax(type.Type)); + } - if (!MetadataModel.DefaultCopiers.TryGetValue(type, out var typeName)) - typeName = GetCopierTypeName(type); + var addConverterMethod = configParam.Member("Converters").Member("Add"); + foreach (var type in model.RegisteredCodecs.Where(static codec => codec.Kind == RegisteredCodecKind.Converter)) + { + AddRegistration(body, addConverterMethod, GetOpenTypeSyntax(type.Type)); + } - body.Add(ExpressionStatement(InvocationExpression(addCopierMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(typeName))))))); - } + var addProxyMethod = configParam.Member("InterfaceProxies").Member("Add"); + foreach (var type in orderedProxyInterfaces) + { + AddRegistration(body, addProxyMethod, GetGeneratedProxyTypeSyntax(type)); + } - foreach (var type in MetadataModel.DetectedCopiers) - { - body.Add(ExpressionStatement(InvocationExpression(addCopierMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.ToOpenTypeSyntax()))))))); - } + var addInvokableInterfaceMethod = configParam.Member("Interfaces").Member("Add"); + foreach (var type in orderedProxyInterfaces.Select(static proxy => proxy.InterfaceType).Distinct()) + { + AddRegistration(body, addInvokableInterfaceMethod, GetOpenTypeSyntax(type)); + } + + var addInvokableInterfaceImplementationMethod = configParam.Member("InterfaceImplementations").Member("Add"); + foreach (var type in GetOrderedInterfaceImplementations(model.InterfaceImplementations)) + { + AddRegistration(body, addInvokableInterfaceImplementationMethod, GetOpenTypeSyntax(type.ImplementationType)); + } - foreach (var type in MetadataModel.DetectedSerializers) + var addActivatorMethod = configParam.Member("Activators").Member("Add"); + foreach (var registration in serializableRegistrations) + { + if (registration.ActivatorTypeSyntax is not null) { - body.Add(ExpressionStatement(InvocationExpression(addSerializerMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.ToOpenTypeSyntax()))))))); + AddRegistration(body, addActivatorMethod, registration.ActivatorTypeSyntax); } + } + + foreach (var type in model.RegisteredCodecs.Where(static codec => codec.Kind == RegisteredCodecKind.Activator)) + { + AddRegistration(body, addActivatorMethod, GetOpenTypeSyntax(type.Type)); + } + + var addWellKnownTypeIdMethod = configParam.Member("WellKnownTypeIds").Member("Add"); + foreach (var type in model.ReferenceAssemblyData.WellKnownTypeIds) + { + body.Add(ExpressionStatement(InvocationExpression(addWellKnownTypeIdMethod, + ArgumentList(SeparatedList( + [ + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(type.Id))), + Argument(CreateTypeOfExpression(type.Type)), + ]))))); + } + + var addTypeAliasMethod = configParam.Member("WellKnownTypeAliases").Member("Add"); + foreach (var type in model.ReferenceAssemblyData.TypeAliases) + { + body.Add(ExpressionStatement(InvocationExpression(addTypeAliasMethod, + ArgumentList(SeparatedList( + [ + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(type.Alias))), + Argument(CreateTypeOfExpression(type.Type)), + ]))))); + } - foreach (var type in MetadataModel.DetectedConverters) + AddCompoundTypeAliases(configParam, body, generatedInvokables); + return CreateMetadataClass(body, configParam); + } + + private ClassDeclarationSyntax CreateMetadataClass(List body, IdentifierNameSyntax configParam) + { + var configureMethod = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "ConfigureInner") + .AddModifiers(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword)) + .AddParameterListParameters( + Parameter(configParam.Identifier).WithType(TypeManifestOptionsType)) + .AddBodyStatements([.. body]); + + return ClassDeclaration("Metadata_" + SyntaxGeneration.Identifier.SanitizeIdentifierName(_assemblyName)) + .AddBaseListTypes(SimpleBaseType(TypeManifestProviderBaseType)) + .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword)) + .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes()) + .AddMembers(configureMethod); + } + + private void AddCompoundTypeAliases( + IdentifierNameSyntax configParam, + List body, + ImmutableArray generatedInvokables) + { + var aliases = _metadataModel.ReferenceAssemblyData.CompoundTypeAliases + .OrderBy(static entry => entry.Components.Length) + .ThenBy(static entry => entry.TargetType.SyntaxString, StringComparer.Ordinal) + .ToImmutableArray(); + + var generatedAliases = generatedInvokables + .SelectMany(static invokable => invokable.Aliases) + .ToImmutableArray(); + + var aliasTree = CompoundTypeAliasEmissionTree.Create(); + foreach (var alias in aliases.Concat(generatedAliases)) + { + aliasTree.Add(alias.Components, alias.TargetType); + } + + var nodeId = 0; + AddCompoundTypeAliases(body, configParam.Member("CompoundTypeAliases"), aliasTree, ref nodeId); + } + + private void AddCompoundTypeAliases( + List body, + ExpressionSyntax tree, + CompoundTypeAliasEmissionTree aliases, + ref int nodeId) + { + ExpressionSyntax node; + if (!aliases.HasKey) + { + node = tree; + } + else + { + var nodeName = IdentifierName($"n{++nodeId}"); + node = nodeName; + var addArguments = new List(2) { CreateCompoundAliasArgument(aliases.Key) }; + if (aliases.HasValue) { - body.Add(ExpressionStatement(InvocationExpression(addConverterMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.ToOpenTypeSyntax()))))))); + addArguments.Add(Argument(CreateTypeOfExpression(aliases.Value))); } - var addProxyMethod = configParam.Member("InterfaceProxies").Member("Add"); - foreach (var type in MetadataModel.GeneratedProxies) + if (aliases.Children is { Count: > 0 }) { - body.Add(ExpressionStatement(InvocationExpression(addProxyMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.TypeSyntax))))))); + body.Add(LocalDeclarationStatement(VariableDeclaration( + ParseTypeName("var"), + SingletonSeparatedList(VariableDeclarator(nodeName.Identifier).WithInitializer(EqualsValueClause(InvocationExpression( + tree.Member("Add"), + ArgumentList(SeparatedList(addArguments))))))))); } - - var addInvokableInterfaceMethod = configParam.Member("Interfaces").Member("Add"); - foreach (var type in MetadataModel.InvokableInterfaces.Values) + else { - body.Add(ExpressionStatement(InvocationExpression(addInvokableInterfaceMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.InterfaceType.ToOpenTypeSyntax()))))))); + body.Add(ExpressionStatement(InvocationExpression(tree.Member("Add"), ArgumentList(SeparatedList(addArguments))))); } + } - var addInvokableInterfaceImplementationMethod = configParam.Member("InterfaceImplementations").Member("Add"); - foreach (var type in MetadataModel.InvokableInterfaceImplementations) + if (aliases.Children is { Count: > 0 }) + { + foreach (var child in aliases.Children.Values) { - body.Add(ExpressionStatement(InvocationExpression(addInvokableInterfaceImplementationMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.ToOpenTypeSyntax()))))))); + AddCompoundTypeAliases(body, node, child, ref nodeId); } + } + } + + private static ArgumentSyntax CreateCompoundAliasArgument(CompoundAliasComponentModel component) + => component.IsType + ? Argument(CreateTypeOfExpression(component.TypeValue)) + : Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(component.StringValue!))); + + private static ImmutableArray GetOrderedProxyInterfaces(ImmutableArray proxyInterfaces) + => OrderBySourceLocation(proxyInterfaces, static proxy => proxy.SourceLocation, static proxy => proxy.InterfaceType.SyntaxString); - var addActivatorMethod = configParam.Member("Activators").Member("Add"); - foreach (var type in MetadataModel.ActivatableTypes) + private static ImmutableArray GetOrderedInterfaceImplementations(ImmutableArray interfaceImplementations) + => OrderBySourceLocation(interfaceImplementations, static implementation => implementation.SourceLocation, static implementation => implementation.ImplementationType.SyntaxString); + + private static ImmutableArray OrderBySourceLocation( + ImmutableArray entries, + Func locationSelector, + Func sortKeySelector) + { + return [.. entries + .Select(entry => { - body.Add(ExpressionStatement(InvocationExpression(addActivatorMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(GetActivatorTypeName(type)))))))); - } + var sourceLocation = locationSelector(entry); + return ( + Entry: entry, + sourceLocation.SourceOrderGroup, + sourceLocation.FilePath, + sourceLocation.Position, + SortKey: sortKeySelector(entry)); + }) + .OrderBy(static entry => entry.SourceOrderGroup) + .ThenBy(static entry => entry.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.Position) + .ThenBy(static entry => entry.SortKey, StringComparer.Ordinal) + .Select(static entry => entry.Entry)]; + } + + private ImmutableArray GetGeneratedInvokableMetadata( + ImmutableArray proxyInterfaces, + HashSet generatedInvokableActivatorMetadataNames) + { + var result = ImmutableArray.CreateBuilder(); + var seen = new HashSet(StringComparer.Ordinal); - foreach (var type in MetadataModel.DetectedActivators) + foreach (var proxy in proxyInterfaces) + { + foreach (var method in proxy.Methods) { - body.Add(ExpressionStatement(InvocationExpression(addActivatorMethod, - ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(type.ToOpenTypeSyntax()))))))); + var metadata = CreateGeneratedInvokableMetadata(proxy, method, generatedInvokableActivatorMetadataNames); + var key = metadata.TypeSyntax.ToString(); + if (seen.Add(key)) + { + result.Add(metadata); + } } + } + + return result.ToImmutable(); + } + + private static ImmutableArray GetOrderedSerializableRegistrations( + MetadataAggregateModel model, + ImmutableArray generatedInvokables) + { + var registrations = new List(model.SerializableTypes.Length + generatedInvokables.Length); - var addWellKnownTypeIdMethod = configParam.Member("WellKnownTypeIds").Member("Add"); - foreach (var type in MetadataModel.WellKnownTypeIds) + foreach (var type in model.SerializableTypes) + { + TypeSyntax? copierType = null; + if (!type.IsEnumType) { - body.Add(ExpressionStatement(InvocationExpression(addWellKnownTypeIdMethod, - ArgumentList(SeparatedList(new[] { Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(type.Id))), Argument(TypeOfExpression(type.Type)) }))))); + copierType = TryGetDefaultCopierType(type.TypeSyntax, model.DefaultCopiers) + ?? GetCopierTypeName(type.GeneratedNamespace, type.Name, type.TypeParameters.Length); } - var addTypeAliasMethod = configParam.Member("WellKnownTypeAliases").Member("Add"); - foreach (var type in MetadataModel.TypeAliases) + var activatorType = ShouldGenerateActivator(type) + ? GetActivatorTypeName(type.GeneratedNamespace, type.Name, type.TypeParameters.Length) + : null; + + registrations.Add(new SerializableMetadataRegistration( + sourceType: type.TypeSyntax, + sortKey: type.TypeSyntax.SyntaxString, + serializerTypeSyntax: GetCodecTypeName(type.GeneratedNamespace, type.Name, type.TypeParameters.Length), + copierTypeSyntax: copierType, + activatorTypeSyntax: activatorType, + sourceLocation: type.SourceLocation)); + } + + foreach (var type in generatedInvokables) + { + registrations.Add(new SerializableMetadataRegistration( + sourceType: type.SourceType, + sortKey: type.TypeSyntax.ToString(), + serializerTypeSyntax: type.CodecTypeSyntax, + copierTypeSyntax: type.CopierTypeSyntax, + activatorTypeSyntax: type.ActivatorTypeSyntax, + sourceLocation: type.SourceLocation)); + } + + return [.. registrations + .Select(registration => { - body.Add(ExpressionStatement(InvocationExpression(addTypeAliasMethod, - ArgumentList(SeparatedList(new[] { Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(type.Alias))), Argument(TypeOfExpression(type.Type)) }))))); - } + return ( + Registration: registration, + registration.SourceLocation.SourceOrderGroup, + registration.SourceLocation.FilePath, + registration.SourceLocation.Position); + }) + .OrderBy(static entry => entry.SourceOrderGroup) + .ThenBy(static entry => entry.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.Position) + .ThenBy(static entry => entry.Registration.SortKey, StringComparer.Ordinal) + .Select(static entry => entry.Registration)]; + } - AddCompoundTypeAliases(configParam, body); + private static GeneratedInvokableMetadata CreateGeneratedInvokableMetadata( + ProxyInterfaceModel proxy, + MethodModel method, + HashSet generatedInvokableActivatorMetadataNames) + { + var generatedNamespace = method.ContainingInterfaceGeneratedNamespace; + var name = GetGeneratedInvokableClassName(proxy, method); + var genericArity = method.ContainingInterfaceTypeParameterCount + method.TypeParameters.Length; + var typeSyntax = CreateGeneratedTypeSyntax(generatedNamespace, name, genericArity); + var targetType = new TypeRef(typeSyntax.ToString()); + var aliases = CreateGeneratedInvokableAliases(proxy, method, targetType); + var codecTypeSyntax = GetCodecTypeName(generatedNamespace, name, genericArity); + var copierTypeSyntax = GetCopierTypeName(generatedNamespace, name, genericArity); + var metadataName = GetGeneratedInvokableMetadataName(generatedNamespace, name, genericArity); + var activatorTypeSyntax = generatedInvokableActivatorMetadataNames.Contains(metadataName) + ? GetActivatorTypeName(generatedNamespace, name, genericArity) + : null; + + return new GeneratedInvokableMetadata( + typeSyntax, + method.ContainingInterfaceType, + aliases, + codecTypeSyntax, + copierTypeSyntax, + activatorTypeSyntax, + proxy.SourceLocation); + } + + private static string GetGeneratedInvokableMetadataName(string generatedNamespace, string name, int genericArity) + => genericArity == 0 + ? $"{generatedNamespace}.{name}" + : $"{generatedNamespace}.{name}`{genericArity}"; - var configType = _codeGenerator.LibraryTypes.TypeManifestOptions; - var configureMethod = MethodDeclaration(PredefinedType(Token(SyntaxKind.VoidKeyword)), "ConfigureInner") - .AddModifiers(Token(SyntaxKind.ProtectedKeyword), Token(SyntaxKind.OverrideKeyword)) - .AddParameterListParameters( - Parameter(configParam.Identifier).WithType(configType.ToTypeSyntax())) - .AddBodyStatements(body.ToArray()); + private static string GetGeneratedInvokableClassName(ProxyInterfaceModel proxy, MethodModel method) + { + var genericArity = method.ContainingInterfaceTypeParameterCount + method.TypeParameters.Length; + var typeArgs = genericArity > 0 ? "_" + genericArity : string.Empty; + return $"Invokable_{method.ContainingInterfaceName}_{proxy.ProxyBase.GeneratedClassNameComponent}_{method.GeneratedMethodId}{typeArgs}"; + } - var interfaceType = _codeGenerator.LibraryTypes.TypeManifestProviderBase; - return ClassDeclaration("Metadata_" + SyntaxGeneration.Identifier.SanitizeIdentifierName(_codeGenerator.Compilation.AssemblyName)) - .AddBaseListTypes(SimpleBaseType(interfaceType.ToTypeSyntax())) - .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword)) - .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) - .AddMembers(configureMethod); + private static ImmutableArray CreateGeneratedInvokableAliases( + ProxyInterfaceModel proxy, + MethodModel method, + TypeRef targetType) + { + var result = ImmutableArray.CreateBuilder(2); + if (!string.Equals(method.MethodId, method.GeneratedMethodId, StringComparison.Ordinal)) + { + result.Add(CreateGeneratedInvokableAlias(proxy, method, method.MethodId, targetType)); } - private void AddCompoundTypeAliases(IdentifierNameSyntax configParam, List body) + result.Add(CreateGeneratedInvokableAlias(proxy, method, method.GeneratedMethodId, targetType)); + return result.ToImmutable(); + } + + private static CompoundTypeAliasModel CreateGeneratedInvokableAlias( + ProxyInterfaceModel proxy, + MethodModel method, + string methodId, + TypeRef targetType) + { + var components = ImmutableArray.CreateBuilder(proxy.ProxyBase.IsExtension ? 6 : 4); + components.Add(new CompoundAliasComponentModel("inv")); + components.Add(new CompoundAliasComponentModel(proxy.ProxyBase.ProxyBaseType)); + if (proxy.ProxyBase.IsExtension) { - // The goal is to emit a tree describing all of the generated invokers in the form: - // ("inv", typeof(ProxyBaseType), typeof(ContainingInterface), "") - // The first step is to collate the invokers into tree to ease the process of generating a tree in code. - var nodeId = 0; - AddCompoundTypeAliases(body, configParam.Member("CompoundTypeAliases"), MetadataModel.CompoundTypeAliases); - void AddCompoundTypeAliases(List body, ExpressionSyntax tree, CompoundTypeAliasTree aliases) - { - ExpressionSyntax node; + components.Add(new CompoundAliasComponentModel("Ext")); + } - if (aliases.Key.IsDefault) - { - // At the root node, do not create a new node, just enumerate over the child nodes. - node = tree; - } - else - { - var nodeName = IdentifierName($"n{++nodeId}"); - node = nodeName; - var valueExpression = aliases.Value switch - { - { } type => Argument(TypeOfExpression(type)), - _ => null - }; + components.Add(new CompoundAliasComponentModel(proxy.ProxyBase.IsExtension ? proxy.InterfaceType : method.ContainingInterfaceType)); - // Get the arguments for the Add call - var addArguments = aliases.Key switch - { - { IsType: true } typeKey => valueExpression switch - { - // Call the two-argument Add overload to add a key and value. - { } argument => new[] { Argument(TypeOfExpression(typeKey.TypeValue.ToOpenTypeSyntax())), argument }, - - // Call the one-argument Add overload to add only a key. - _ => new[] { Argument(TypeOfExpression(typeKey.TypeValue.ToOpenTypeSyntax())) }, - }, - { IsString: true } stringKey => valueExpression switch - { - // Call the two-argument Add overload to add a key and value. - { } argument => new[] { Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(stringKey.StringValue))), argument }, - - // Call the one-argument Add overload to add only a key. - _ => new[] { Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(stringKey.StringValue))) }, - }, - _ => throw new InvalidOperationException("Unexpected alias key") - }; - - if (aliases.Children is { Count: > 0 }) - { - // C#: var {newTree.Identifier} = {tree}.Add({addArguments}); - body.Add(LocalDeclarationStatement(VariableDeclaration( - ParseTypeName("var"), - SingletonSeparatedList(VariableDeclarator(nodeName.Identifier).WithInitializer(EqualsValueClause(InvocationExpression( - tree.Member("Add"), - ArgumentList(SeparatedList(addArguments))))))))); - } - else - { - // Do not emit a variable. - // C#: {tree}.Add({addArguments}); - body.Add(ExpressionStatement(InvocationExpression(tree.Member("Add"), ArgumentList(SeparatedList(addArguments))))); - } - } + if (proxy.ProxyBase.IsExtension) + { + components.Add(new CompoundAliasComponentModel(method.OriginalContainingInterfaceType)); + } - if (aliases.Children is { Count: > 0 }) - { - foreach (var child in aliases.Children.Values) - { - AddCompoundTypeAliases(body, node, child); - } - } + components.Add(new CompoundAliasComponentModel(methodId)); + return new CompoundTypeAliasModel(components.MoveToImmutable(), targetType); + } + + private static TypeSyntax GetOpenTypeSyntax(TypeRef typeRef) + { + var syntax = typeRef.SyntaxString.Trim(); + if (syntax.StartsWith("global::", StringComparison.Ordinal)) + { + syntax = syntax.Substring("global::".Length); + } + + var typeSyntax = syntax switch + { + "bool" or "System.Boolean" => ParseName("bool"), + "byte" or "System.Byte" => ParseName("byte"), + "sbyte" or "System.SByte" => ParseName("sbyte"), + "short" or "System.Int16" => ParseName("short"), + "ushort" or "System.UInt16" => ParseName("ushort"), + "int" or "System.Int32" => ParseName("int"), + "uint" or "System.UInt32" => ParseName("uint"), + "long" or "System.Int64" => ParseName("long"), + "ulong" or "System.UInt64" => ParseName("ulong"), + "float" or "System.Single" => ParseName("float"), + "double" or "System.Double" => ParseName("double"), + "decimal" or "System.Decimal" => ParseName("decimal"), + "char" or "System.Char" => ParseName("char"), + "string" or "System.String" => ParseName("string"), + "object" or "System.Object" => ParseName("object"), + _ => typeRef.ToTypeSyntax(), + }; + + return (TypeSyntax)OpenGenericTypeSyntaxRewriter.Instance.Visit(typeSyntax); + } + + private static TypeOfExpressionSyntax CreateTypeOfExpression(TypeRef typeRef) + { + var result = TypeOfExpression(GetOpenTypeSyntax(typeRef)); + return IsPredefinedTypeRef(typeRef) + ? result + .WithOpenParenToken(Token(TriviaList(Space), SyntaxKind.OpenParenToken, TriviaList(Space))) + .WithCloseParenToken(Token(TriviaList(Space), SyntaxKind.CloseParenToken, TriviaList(Space))) + : result; + } + + private static bool IsPredefinedTypeRef(TypeRef typeRef) + { + var syntax = typeRef.SyntaxString.Trim(); + if (syntax.StartsWith("global::", StringComparison.Ordinal)) + { + syntax = syntax.Substring("global::".Length); + } + + return syntax is "bool" or "System.Boolean" + or "byte" or "System.Byte" + or "sbyte" or "System.SByte" + or "short" or "System.Int16" + or "ushort" or "System.UInt16" + or "int" or "System.Int32" + or "uint" or "System.UInt32" + or "long" or "System.Int64" + or "ulong" or "System.UInt64" + or "float" or "System.Single" + or "double" or "System.Double" + or "decimal" or "System.Decimal" + or "char" or "System.Char" + or "string" or "System.String" + or "object" or "System.Object"; + } + + private sealed class OpenGenericTypeSyntaxRewriter : CSharpSyntaxRewriter + { + public static readonly OpenGenericTypeSyntaxRewriter Instance = new(); + + public override SyntaxNode? VisitGenericName(GenericNameSyntax node) + { + var visited = (GenericNameSyntax)base.VisitGenericName(node)!; + var argumentCount = visited.TypeArgumentList.Arguments.Count; + return visited.WithTypeArgumentList(TypeArgumentList(SeparatedList( + Enumerable.Range(0, argumentCount).Select(static _ => OmittedTypeArgument())))); + } + } + + private static TypeSyntax GetGeneratedProxyTypeSyntax(ProxyInterfaceModel proxy) + { + var genericArity = Math.Max(proxy.TypeParameters.Length, CountGenericArguments(proxy.InterfaceType)); + return CreateGeneratedTypeSyntax(proxy.GeneratedNamespace, ProxyGenerator.GetSimpleClassName(proxy.Name), genericArity); + } + + private static int CountGenericArguments(TypeRef typeRef) + { + var typeSyntax = typeRef.ToTypeSyntax(); + var count = 0; + foreach (var genericName in typeSyntax.DescendantNodesAndSelf().OfType()) + { + count += genericName.TypeArgumentList.Arguments.Count; + } + + return count; + } + + private static TypeSyntax? TryGetDefaultCopierType(TypeRef originalType, ImmutableArray defaultCopiers) + { + foreach (var copier in defaultCopiers) + { + if (copier.OriginalType.Equals(originalType)) + { + return copier.CopierType.ToTypeSyntax(); } } - public static TypeSyntax GetCodecTypeName(ISerializableTypeDescription type) + return null; + } + + private static void AddRegistration(List body, ExpressionSyntax addMethod, TypeSyntax typeSyntax) + { + body.Add(ExpressionStatement(InvocationExpression(addMethod, + ArgumentList(SingletonSeparatedList(Argument(TypeOfExpression(typeSyntax))))))); + } + + private static bool ShouldGenerateActivator(SerializableTypeModel type) + => !type.IsAbstractType + && !type.IsEnumType + && (!type.IsValueType && type.IsEmptyConstructable && !type.UseActivator || type.HasActivatorConstructor); + + public static TypeSyntax GetCodecTypeName(ISerializableTypeDescription type) + => GetCodecTypeName(type.GeneratedNamespace, type.Name, type.TypeParameters.Count); + + public static TypeSyntax GetCodecTypeName(string generatedNamespace, string name, int genericArity) + => CreateGeneratedTypeSyntax(generatedNamespace, SerializerGenerator.GetSimpleClassName(name), genericArity); + + public static TypeSyntax GetCopierTypeName(ISerializableTypeDescription type) + => GetCopierTypeName(type.GeneratedNamespace, type.Name, type.TypeParameters.Count); + + public static TypeSyntax GetCopierTypeName(string generatedNamespace, string name, int genericArity) + => CreateGeneratedTypeSyntax(generatedNamespace, CopierGenerator.GetSimpleClassName(name), genericArity); + + public static TypeSyntax GetActivatorTypeName(ISerializableTypeDescription type) + => GetActivatorTypeName(type.GeneratedNamespace, type.Name, type.TypeParameters.Count); + + public static TypeSyntax GetActivatorTypeName(string generatedNamespace, string name, int genericArity) + => CreateGeneratedTypeSyntax(generatedNamespace, ActivatorGenerator.GetSimpleClassName(name), genericArity); + + private static TypeSyntax CreateGeneratedTypeSyntax(string generatedNamespace, string simpleName, int genericArity) + { + var name = genericArity > 0 ? $"{simpleName}<{new string(',', genericArity - 1)}>" : simpleName; + return ParseTypeName($"{generatedNamespace}.{name}"); + } + + private readonly struct GeneratedInvokableMetadata( + TypeSyntax typeSyntax, + TypeRef sourceType, + ImmutableArray aliases, + TypeSyntax codecTypeSyntax, + TypeSyntax? copierTypeSyntax, + TypeSyntax? activatorTypeSyntax, + SourceLocationModel sourceLocation) + { + public TypeSyntax TypeSyntax { get; } = typeSyntax; + public TypeRef SourceType { get; } = sourceType; + public ImmutableArray Aliases { get; } = aliases; + public TypeSyntax CodecTypeSyntax { get; } = codecTypeSyntax; + public TypeSyntax? CopierTypeSyntax { get; } = copierTypeSyntax; + public TypeSyntax? ActivatorTypeSyntax { get; } = activatorTypeSyntax; + public SourceLocationModel SourceLocation { get; } = sourceLocation; + } + + private sealed class CompoundTypeAliasEmissionTree + { + private CompoundTypeAliasEmissionTree(CompoundAliasComponentModel key, TypeRef value, bool hasKey, bool hasValue) + { + Key = key; + Value = value; + HasKey = hasKey; + HasValue = hasValue; + } + + public static CompoundTypeAliasEmissionTree Create() => new(default, TypeRef.Empty, hasKey: false, hasValue: false); + + public CompoundAliasComponentModel Key { get; } + public bool HasKey { get; } + public bool HasValue { get; private set; } + public TypeRef Value { get; private set; } + public Dictionary? Children { get; private set; } + + public void Add(ImmutableArray keys, TypeRef value) => Add(keys.AsSpan(), value); + + public void Add(ReadOnlySpan keys, TypeRef value) { - var genericArity = type.TypeParameters.Count; - var name = SerializerGenerator.GetSimpleClassName(type); - if (genericArity > 0) + if (keys.Length == 0) { - name = $"{name}<{new string(',', genericArity - 1)}>"; + throw new InvalidOperationException("No valid key specified."); } - return ParseTypeName(type.GeneratedNamespace + "." + name); + var key = keys[0]; + if (keys.Length == 1) + { + AddInternal(key, value, hasValue: true); + } + else + { + var childNode = GetChildOrDefault(key) ?? AddInternal(key, TypeRef.Empty, hasValue: false); + childNode.Add(keys.Slice(1), value); + } } - public static TypeSyntax GetCopierTypeName(ISerializableTypeDescription type) + private CompoundTypeAliasEmissionTree? GetChildOrDefault(CompoundAliasComponentModel key) { - var genericArity = type.TypeParameters.Count; - var name = CopierGenerator.GetSimpleClassName(type); - if (genericArity > 0) + TryGetChild(key, out var result); + return result; + } + + private bool TryGetChild(CompoundAliasComponentModel key, out CompoundTypeAliasEmissionTree? result) + { + if (Children is { } children) { - name = $"{name}<{new string(',', genericArity - 1)}>"; + return children.TryGetValue(key, out result); } - return ParseTypeName(type.GeneratedNamespace + "." + name); + result = default; + return false; } - public static TypeSyntax GetActivatorTypeName(ISerializableTypeDescription type) + private CompoundTypeAliasEmissionTree AddInternal(CompoundAliasComponentModel key, TypeRef value, bool hasValue) { - var genericArity = type.TypeParameters.Count; - var name = ActivatorGenerator.GetSimpleClassName(type); - if (genericArity > 0) + Children ??= new(); + + if (Children.TryGetValue(key, out var existing)) { - name = $"{name}<{new string(',', genericArity - 1)}>"; + if (hasValue) + { + if (existing.HasValue && !existing.Value.Equals(value)) + { + throw new ArgumentException($"A key with the value '{key}' already exists."); + } + + existing.Value = value; + existing.HasValue = true; + } + + return existing; } - return ParseTypeName(type.GeneratedNamespace + "." + name); + var child = new CompoundTypeAliasEmissionTree(key, value, hasKey: true, hasValue: hasValue); + Children.Add(key, child); + return child; } + } + private readonly struct SerializableMetadataRegistration( + TypeRef sourceType, + string sortKey, + TypeSyntax serializerTypeSyntax, + TypeSyntax? copierTypeSyntax, + TypeSyntax? activatorTypeSyntax, + SourceLocationModel sourceLocation) + { + public TypeRef SourceType { get; } = sourceType; + public string SortKey { get; } = sortKey; + public TypeSyntax SerializerTypeSyntax { get; } = serializerTypeSyntax; + public TypeSyntax? CopierTypeSyntax { get; } = copierTypeSyntax; + public TypeSyntax? ActivatorTypeSyntax { get; } = activatorTypeSyntax; + public SourceLocationModel SourceLocation { get; } = sourceLocation; } } diff --git a/src/Orleans.CodeGenerator/MetadataSourceOutputGenerator.cs b/src/Orleans.CodeGenerator/MetadataSourceOutputGenerator.cs new file mode 100644 index 00000000000..b25c7bc0bfd --- /dev/null +++ b/src/Orleans.CodeGenerator/MetadataSourceOutputGenerator.cs @@ -0,0 +1,64 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Model; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal static class MetadataSourceOutputGenerator +{ + internal static SourceOutputResult CreateMetadataSourceOutput( + MetadataAggregateModel metadataModel, + SourceGeneratorOptions options) + { + try + { + SourceGeneratorOptionsParser.AttachDebuggerIfRequested(options); + var metadataGenerator = new MetadataGenerator(metadataModel, metadataModel.AssemblyName); + var metadataClass = metadataGenerator.GenerateMetadata(); + var metadataNamespace = $"{GeneratedCodeUtilities.CodeGeneratorName}.{Identifier.SanitizeIdentifierName(metadataModel.AssemblyName ?? "Assembly")}"; + var namespacedMembers = new Dictionary>(StringComparer.Ordinal); + GeneratedSourceOutput.AddMember(namespacedMembers, metadataNamespace, metadataClass); + var assemblyAttributes = CreateAssemblyAttributes( + metadataModel.ReferenceAssemblyData.ApplicationParts, + metadataNamespace, + metadataClass.Identifier.Text); + + var assemblyName = metadataModel.AssemblyName ?? "assembly"; + return SourceOutputResult.FromSource( + new GeneratedSourceEntry( + GeneratedSourceOutput.CreateMetadataHintName(assemblyName), + GeneratedSourceOutput.CreateSourceString(GeneratedSourceOutput.CreateCompilationUnit(namespacedMembers, assemblyAttributes)))); + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return SourceOutputResult.FromDiagnostic(analysisException.Diagnostic); + } + } + + internal static SyntaxList CreateAssemblyAttributes( + IEnumerable applicationParts, + string metadataNamespace, + string metadataClassName) + { + var assemblyAttributes = ApplicationPartAttributeGenerator.GenerateSyntax( + SyntaxFactory.ParseName("global::Orleans.ApplicationPartAttribute"), + applicationParts); + var metadataAttribute = SyntaxFactory.AttributeList() + .WithTarget(SyntaxFactory.AttributeTargetSpecifier(SyntaxFactory.Token(SyntaxKind.AssemblyKeyword))) + .WithAttributes( + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.Attribute(SyntaxFactory.ParseName("global::Orleans.Serialization.Configuration.TypeManifestProviderAttribute")) + .AddArgumentListArguments( + SyntaxFactory.AttributeArgument( + SyntaxFactory.TypeOfExpression( + SyntaxFactory.QualifiedName( + SyntaxFactory.ParseName(metadataNamespace), + SyntaxFactory.IdentifierName(metadataClassName))))))); + assemblyAttributes.Add(metadataAttribute); + + return SyntaxFactory.List(assemblyAttributes); + } +} + diff --git a/src/Orleans.CodeGenerator/Model/FieldDescription.cs b/src/Orleans.CodeGenerator/Model/FieldDescription.cs index ef977c16db0..a74ff86c5bd 100644 --- a/src/Orleans.CodeGenerator/Model/FieldDescription.cs +++ b/src/Orleans.CodeGenerator/Model/FieldDescription.cs @@ -4,47 +4,46 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class FieldDescription : IFieldDescription { - internal class FieldDescription : IFieldDescription + public FieldDescription(uint fieldId, bool isPrimaryConstructorParameter, IFieldSymbol member) { - public FieldDescription(uint fieldId, bool isPrimaryConstructorParameter, IFieldSymbol member) - { - FieldId = fieldId; - IsPrimaryConstructorParameter = isPrimaryConstructorParameter; - Field = member; - Type = member.Type; - ContainingType = member.ContainingType; + FieldId = fieldId; + IsPrimaryConstructorParameter = isPrimaryConstructorParameter; + Field = member; + Type = member.Type; + ContainingType = member.ContainingType; - if (Type.TypeKind == TypeKind.Dynamic) - { - TypeSyntax = PredefinedType(Token(SyntaxKind.ObjectKeyword)); - } - else - { - TypeSyntax = Type.ToTypeSyntax(); - } + if (Type.TypeKind == TypeKind.Dynamic) + { + TypeSyntax = PredefinedType(Token(SyntaxKind.ObjectKeyword)); + } + else + { + TypeSyntax = Type.ToTypeSyntax(); } + } - public ISymbol Symbol => Field; - public IFieldSymbol Field { get; } - public uint FieldId { get; } - public ITypeSymbol Type { get; } - public INamedTypeSymbol ContainingType { get; } - public TypeSyntax TypeSyntax { get; } + public ISymbol Symbol => Field; + public IFieldSymbol Field { get; } + public uint FieldId { get; } + public ITypeSymbol Type { get; } + public INamedTypeSymbol ContainingType { get; } + public TypeSyntax TypeSyntax { get; } - public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); - public string TypeName => Type.ToDisplayName(); - public string TypeNameIdentifier => Type.GetValidIdentifier(); - public bool IsPrimaryConstructorParameter { get; set; } - public bool IsSerializable => true; - public bool IsCopyable => true; + public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); + public string TypeName => Type.ToDisplayName(); + public string TypeNameIdentifier => Type.GetValidIdentifier(); + public bool IsPrimaryConstructorParameter { get; set; } + public bool IsSerializable => true; + public bool IsCopyable => true; - public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); - } + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); +} - internal interface IFieldDescription : IMemberDescription - { - IFieldSymbol Field { get; } - } +internal interface IFieldDescription : IMemberDescription +{ + IFieldSymbol Field { get; } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs b/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs index c335016e13a..b0eb53f81ed 100644 --- a/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs +++ b/src/Orleans.CodeGenerator/Model/GeneratedInvokableDescription.cs @@ -1,117 +1,109 @@ -using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Orleans.CodeGenerator.SyntaxGeneration; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +[DebuggerDisplay("{MethodDescription}")] +internal sealed class GeneratedInvokableDescription : ISerializableTypeDescription { - #nullable disable - [DebuggerDisplay("{MethodDescription}")] - internal sealed class GeneratedInvokableDescription : ISerializableTypeDescription + public GeneratedInvokableDescription( + InvokableMethodDescription methodDescription, + Accessibility accessibility, + string generatedClassName, + string generatedNamespaceName, + List members, + List serializationHooks, + INamedTypeSymbol baseType, + List constructorArguments, + List compoundTypeAliases, + string? returnValueInitializerMethod, + ClassDeclarationSyntax classDeclarationSyntax) { - private TypeSyntax _openTypeSyntax; - private TypeSyntax _typeSyntax; - private TypeSyntax _baseTypeSyntax; - - public GeneratedInvokableDescription( - InvokableMethodDescription methodDescription, - Accessibility accessibility, - string generatedClassName, - string generatedNamespaceName, - List members, - List serializationHooks, - INamedTypeSymbol baseType, - List constructorArguments, - List compoundTypeAliases, - string returnValueInitializerMethod, - ClassDeclarationSyntax classDeclarationSyntax) + if (methodDescription.AllTypeParameters.Count == 0) { - if (methodDescription.AllTypeParameters.Count == 0) - { - MetadataName = $"{generatedNamespaceName}.{generatedClassName}"; - } - else - { - MetadataName = $"{generatedNamespaceName}.{generatedClassName}`{methodDescription.AllTypeParameters.Count}"; - } - - BaseType = baseType; - Name = generatedClassName; - GeneratedNamespace = generatedNamespaceName; - Members = members; - MethodDescription = methodDescription; - Accessibility = accessibility; - SerializationHooks = serializationHooks; - ActivatorConstructorParameters = constructorArguments; - CompoundTypeAliases = compoundTypeAliases; - ReturnValueInitializerMethod = returnValueInitializerMethod; - ClassDeclarationSyntax = classDeclarationSyntax; + MetadataName = $"{generatedNamespaceName}.{generatedClassName}"; } + else + { + MetadataName = $"{generatedNamespaceName}.{generatedClassName}`{methodDescription.AllTypeParameters.Count}"; + } + + BaseType = baseType; + Name = generatedClassName; + GeneratedNamespace = generatedNamespaceName; + Members = members; + MethodDescription = methodDescription; + Accessibility = accessibility; + SerializationHooks = serializationHooks; + ActivatorConstructorParameters = constructorArguments; + CompoundTypeAliases = compoundTypeAliases; + ReturnValueInitializerMethod = returnValueInitializerMethod; + ClassDeclarationSyntax = classDeclarationSyntax; + } - public Accessibility Accessibility { get; } - public TypeSyntax TypeSyntax => _typeSyntax ??= CreateTypeSyntax(); - public TypeSyntax OpenTypeSyntax => _openTypeSyntax ??= CreateOpenTypeSyntax(); - public bool HasComplexBaseType => BaseType is { SpecialType: not SpecialType.System_Object }; - public bool IncludePrimaryConstructorParameters => false; - public INamedTypeSymbol BaseType { get; } - public TypeSyntax BaseTypeSyntax => _baseTypeSyntax ??= BaseType.ToTypeSyntax(MethodDescription.TypeParameterSubstitutions); - public string Namespace => GeneratedNamespace; - public string GeneratedNamespace { get; } - public string Name { get; } - public bool IsValueType => false; - public bool IsSealedType => true; - public bool IsAbstractType => false; - public bool IsEnumType => false; - public bool IsGenericType => TypeParameters.Count > 0; - public List Members { get; } - public Compilation Compilation => MethodDescription.CodeGenerator.Compilation; - public bool IsEmptyConstructable => ActivatorConstructorParameters is not { Count: > 0 }; - public bool UseActivator => ActivatorConstructorParameters is { Count: > 0 }; - public bool TrackReferences => false; - public bool OmitDefaultMemberValues => false; - public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters => MethodDescription.AllTypeParameters; - public List SerializationHooks { get; } - public bool IsShallowCopyable => false; - public bool IsUnsealedImmutable => false; - public bool IsImmutable => false; - public bool IsExceptionType => false; - public List ActivatorConstructorParameters { get; } - public bool HasActivatorConstructor => UseActivator; - public List CompoundTypeAliases { get; } - public ClassDeclarationSyntax ClassDeclarationSyntax { get; } - public string ReturnValueInitializerMethod { get; } + public Accessibility Accessibility { get; } + public TypeSyntax TypeSyntax => field ??= CreateTypeSyntax(); + public TypeSyntax OpenTypeSyntax => field ??= CreateOpenTypeSyntax(); + public bool HasComplexBaseType => BaseType is { SpecialType: not SpecialType.System_Object }; + public bool IncludePrimaryConstructorParameters => false; + public INamedTypeSymbol BaseType { get; } + public TypeSyntax BaseTypeSyntax => field ??= BaseType.ToTypeSyntax(MethodDescription.TypeParameterSubstitutions); + public string Namespace => GeneratedNamespace; + public string GeneratedNamespace { get; } + public string Name { get; } + public bool IsValueType => false; + public bool IsSealedType => true; + public bool IsAbstractType => false; + public bool IsEnumType => false; + public bool IsGenericType => TypeParameters.Count > 0; + public List Members { get; } + public Compilation Compilation => MethodDescription.GenerationContext.Compilation; + public bool IsEmptyConstructable => ActivatorConstructorParameters is not { Count: > 0 }; + public bool UseActivator => ActivatorConstructorParameters is { Count: > 0 }; + public bool TrackReferences => false; + public bool OmitDefaultMemberValues => false; + public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters => MethodDescription.AllTypeParameters; + public List SerializationHooks { get; } + public bool IsShallowCopyable => false; + public bool IsUnsealedImmutable => false; + public bool IsImmutable => false; + public bool IsExceptionType => false; + public List ActivatorConstructorParameters { get; } + public bool HasActivatorConstructor => UseActivator; + public List CompoundTypeAliases { get; } + public ClassDeclarationSyntax ClassDeclarationSyntax { get; } + public string? ReturnValueInitializerMethod { get; } - public InvokableMethodDescription MethodDescription { get; } - public string MetadataName { get; } + public InvokableMethodDescription MethodDescription { get; } + public string MetadataName { get; } - public ExpressionSyntax GetObjectCreationExpression() => ObjectCreationExpression(TypeSyntax, ArgumentList(), null); + public ExpressionSyntax GetObjectCreationExpression() => ObjectCreationExpression(TypeSyntax, ArgumentList(), null); - private TypeSyntax CreateTypeSyntax() + private TypeSyntax CreateTypeSyntax() + { + var simpleName = InvokableGenerator.GetSimpleClassName(MethodDescription); + var subs = MethodDescription.TypeParameterSubstitutions; + return (TypeParameters, Namespace) switch { - var simpleName = InvokableGenerator.GetSimpleClassName(MethodDescription); - var subs = MethodDescription.TypeParameterSubstitutions; - return (TypeParameters, Namespace) switch - { - ({ Count: > 0 }, { Length: > 0 }) => QualifiedName(ParseName(Namespace), GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter])))))), - ({ Count: > 0 }, _) => GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter]))))), - (_, { Length: > 0 }) => QualifiedName(ParseName(Namespace), IdentifierName(simpleName)), - _ => IdentifierName(simpleName), - }; - } + ({ Count: > 0 }, { Length: > 0 }) => QualifiedName(ParseName(Namespace), GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter])))))), + ({ Count: > 0 }, _) => GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter]))))), + (_, { Length: > 0 }) => QualifiedName(ParseName(Namespace), IdentifierName(simpleName)), + _ => IdentifierName(simpleName), + }; + } - private TypeSyntax CreateOpenTypeSyntax() + private TypeSyntax CreateOpenTypeSyntax() + { + var simpleName = InvokableGenerator.GetSimpleClassName(MethodDescription); + return (TypeParameters, Namespace) switch { - var simpleName = InvokableGenerator.GetSimpleClassName(MethodDescription); - return (TypeParameters, Namespace) switch - { - ({ Count: > 0 }, { Length: > 0 }) => QualifiedName(ParseName(Namespace), GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => OmittedTypeArgument()))))), - ({ Count: > 0 }, _) => GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => OmittedTypeArgument())))), - (_, { Length: > 0 }) => QualifiedName(ParseName(Namespace), IdentifierName(simpleName)), - _ => IdentifierName(simpleName), - }; - } + ({ Count: > 0 }, { Length: > 0 }) => QualifiedName(ParseName(Namespace), GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => OmittedTypeArgument()))))), + ({ Count: > 0 }, _) => GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => OmittedTypeArgument())))), + (_, { Length: > 0 }) => QualifiedName(ParseName(Namespace), IdentifierName(simpleName)), + _ => IdentifierName(simpleName), + }; } } diff --git a/src/Orleans.CodeGenerator/Model/GeneratedProxyDescription.cs b/src/Orleans.CodeGenerator/Model/GeneratedProxyDescription.cs index 4ea72a3d5e9..32e85f8b81d 100644 --- a/src/Orleans.CodeGenerator/Model/GeneratedProxyDescription.cs +++ b/src/Orleans.CodeGenerator/Model/GeneratedProxyDescription.cs @@ -1,43 +1,41 @@ using Orleans.CodeGenerator.SyntaxGeneration; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System.Linq; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class GeneratedProxyDescription { - internal class GeneratedProxyDescription + public GeneratedProxyDescription(ProxyInterfaceDescription interfaceDescription, string generatedClassName) { - public GeneratedProxyDescription(ProxyInterfaceDescription interfaceDescription, string generatedClassName) + InterfaceDescription = interfaceDescription; + GeneratedClassName = generatedClassName; + TypeSyntax = GetProxyTypeName(interfaceDescription); + if (InterfaceDescription.TypeParameters.Count == 0) + { + MetadataName = $"{InterfaceDescription.GeneratedNamespace}.{GeneratedClassName}"; + } + else { - InterfaceDescription = interfaceDescription; - GeneratedClassName = generatedClassName; - TypeSyntax = GetProxyTypeName(interfaceDescription); - if (InterfaceDescription.TypeParameters.Count == 0) - { - MetadataName = $"{InterfaceDescription.GeneratedNamespace}.{GeneratedClassName}"; - } - else - { - MetadataName = $"{InterfaceDescription.GeneratedNamespace}.{GeneratedClassName}`{InterfaceDescription.TypeParameters.Count}"; - } + MetadataName = $"{InterfaceDescription.GeneratedNamespace}.{GeneratedClassName}`{InterfaceDescription.TypeParameters.Count}"; } + } - public TypeSyntax TypeSyntax { get; } - public ProxyInterfaceDescription InterfaceDescription { get; } - public string GeneratedClassName { get; } - public string MetadataName { get; } + public TypeSyntax TypeSyntax { get; } + public ProxyInterfaceDescription InterfaceDescription { get; } + public string GeneratedClassName { get; } + public string MetadataName { get; } - private static TypeSyntax GetProxyTypeName(ProxyInterfaceDescription interfaceDescription) + private static TypeSyntax GetProxyTypeName(ProxyInterfaceDescription interfaceDescription) + { + var interfaceType = interfaceDescription.InterfaceType; + var genericArity = interfaceType.GetAllTypeParameters().Count(); + var name = ProxyGenerator.GetSimpleClassName(interfaceDescription); + if (genericArity > 0) { - var interfaceType = interfaceDescription.InterfaceType; - var genericArity = interfaceType.GetAllTypeParameters().Count(); - var name = ProxyGenerator.GetSimpleClassName(interfaceDescription); - if (genericArity > 0) - { - name += $"<{new string(',', genericArity - 1)}>"; - } - - return ParseTypeName(interfaceDescription.GeneratedNamespace + "." + name); + name += $"<{new string(',', genericArity - 1)}>"; } + + return ParseTypeName(interfaceDescription.GeneratedNamespace + "." + name); } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/ICodecDescription.cs b/src/Orleans.CodeGenerator/Model/ICodecDescription.cs index 431f3a7b8fa..2b81a5d85df 100644 --- a/src/Orleans.CodeGenerator/Model/ICodecDescription.cs +++ b/src/Orleans.CodeGenerator/Model/ICodecDescription.cs @@ -1,9 +1,8 @@ using Microsoft.CodeAnalysis; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal interface ICopierDescription { - internal interface ICopierDescription - { - ITypeSymbol UnderlyingType { get; } - } + ITypeSymbol UnderlyingType { get; } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/IMemberDescription.cs b/src/Orleans.CodeGenerator/Model/IMemberDescription.cs index fbc6b7a59c5..c0caf849658 100644 --- a/src/Orleans.CodeGenerator/Model/IMemberDescription.cs +++ b/src/Orleans.CodeGenerator/Model/IMemberDescription.cs @@ -1,51 +1,48 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System; -using System.Collections.Generic; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal interface IMemberDescription { - internal interface IMemberDescription - { - uint FieldId { get; } - ISymbol Symbol { get; } - ITypeSymbol Type { get; } - INamedTypeSymbol ContainingType { get; } - string AssemblyName { get; } - string TypeName { get; } - TypeSyntax TypeSyntax { get; } - string TypeNameIdentifier { get; } - TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol); - bool IsPrimaryConstructorParameter { get; } - bool IsSerializable { get; } - bool IsCopyable { get; } - } + uint FieldId { get; } + ISymbol Symbol { get; } + ITypeSymbol Type { get; } + INamedTypeSymbol ContainingType { get; } + string AssemblyName { get; } + string TypeName { get; } + TypeSyntax TypeSyntax { get; } + string TypeNameIdentifier { get; } + TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol); + bool IsPrimaryConstructorParameter { get; } + bool IsSerializable { get; } + bool IsCopyable { get; } +} - internal sealed class MemberDescriptionTypeComparer : IEqualityComparer - { - public static MemberDescriptionTypeComparer Default { get; } = new MemberDescriptionTypeComparer(); +internal sealed class MemberDescriptionTypeComparer : IEqualityComparer +{ + public static MemberDescriptionTypeComparer Default { get; } = new MemberDescriptionTypeComparer(); - public bool Equals(IMemberDescription x, IMemberDescription y) + public bool Equals(IMemberDescription x, IMemberDescription y) + { + if (ReferenceEquals(x, y)) { - if (ReferenceEquals(x, y)) - { - return true; - } - - if (ReferenceEquals(x, null) || ReferenceEquals(y, null)) - { - return false; - } - - return string.Equals(x.TypeName, y.TypeName) && string.Equals(x.AssemblyName, y.AssemblyName); + return true; } - public int GetHashCode(IMemberDescription obj) + if (ReferenceEquals(x, null) || ReferenceEquals(y, null)) { - int hashCode = -499943048; - hashCode = hashCode * -1521134295 + StringComparer.Ordinal.GetHashCode(obj.TypeName); - hashCode = hashCode * -1521134295 + StringComparer.Ordinal.GetHashCode(obj.AssemblyName); - return hashCode; + return false; } + + return string.Equals(x.TypeName, y.TypeName) && string.Equals(x.AssemblyName, y.AssemblyName); + } + + public int GetHashCode(IMemberDescription obj) + { + int hashCode = -499943048; + hashCode = hashCode * -1521134295 + StringComparer.Ordinal.GetHashCode(obj.TypeName); + hashCode = hashCode * -1521134295 + StringComparer.Ordinal.GetHashCode(obj.AssemblyName); + return hashCode; } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/ISerializableTypeDescription.cs b/src/Orleans.CodeGenerator/Model/ISerializableTypeDescription.cs index b71a9d37890..677281e9acc 100644 --- a/src/Orleans.CodeGenerator/Model/ISerializableTypeDescription.cs +++ b/src/Orleans.CodeGenerator/Model/ISerializableTypeDescription.cs @@ -1,39 +1,37 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System.Collections.Generic; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal interface ISerializableTypeDescription { - internal interface ISerializableTypeDescription - { - Accessibility Accessibility { get; } - TypeSyntax TypeSyntax { get; } - bool HasComplexBaseType { get; } - bool IncludePrimaryConstructorParameters { get; } - INamedTypeSymbol BaseType { get; } - TypeSyntax BaseTypeSyntax { get; } - string Namespace { get; } - string GeneratedNamespace { get; } - string Name { get; } - bool IsValueType { get; } - bool IsSealedType { get; } - bool IsAbstractType { get; } - bool IsEnumType { get; } - bool IsGenericType { get; } - List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } - List Members { get; } - Compilation Compilation { get; } - bool UseActivator { get; } - bool IsEmptyConstructable { get; } - bool HasActivatorConstructor { get; } - bool TrackReferences { get; } - bool OmitDefaultMemberValues { get; } - ExpressionSyntax GetObjectCreationExpression(); - List SerializationHooks { get; } - bool IsShallowCopyable { get; } - bool IsUnsealedImmutable { get; } - bool IsImmutable { get; } - bool IsExceptionType { get; } - List ActivatorConstructorParameters { get; } - } -} \ No newline at end of file + Accessibility Accessibility { get; } + TypeSyntax TypeSyntax { get; } + bool HasComplexBaseType { get; } + bool IncludePrimaryConstructorParameters { get; } + INamedTypeSymbol BaseType { get; } + TypeSyntax BaseTypeSyntax { get; } + string Namespace { get; } + string GeneratedNamespace { get; } + string Name { get; } + bool IsValueType { get; } + bool IsSealedType { get; } + bool IsAbstractType { get; } + bool IsEnumType { get; } + bool IsGenericType { get; } + List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } + List Members { get; } + Compilation Compilation { get; } + bool UseActivator { get; } + bool IsEmptyConstructable { get; } + bool HasActivatorConstructor { get; } + bool TrackReferences { get; } + bool OmitDefaultMemberValues { get; } + ExpressionSyntax GetObjectCreationExpression(); + List SerializationHooks { get; } + bool IsShallowCopyable { get; } + bool IsUnsealedImmutable { get; } + bool IsImmutable { get; } + bool IsExceptionType { get; } + List ActivatorConstructorParameters { get; } +} diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs index 86ad724897c..4a72f3a198a 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs @@ -1,221 +1,221 @@ using Microsoft.CodeAnalysis; using Orleans.CodeGenerator.SyntaxGeneration; -using System; -using System.Collections.Generic; using System.Collections.Immutable; using System.Globalization; -using System.Linq; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Describes an invokable method. +/// This is a method on the original interface which defined it. +/// By contrast, describes a method on an interface which a proxy is being generated for, having type argument substitutions, etc. +/// +internal sealed class InvokableMethodDescription : IEquatable { - #nullable disable - /// - /// Describes an invokable method. - /// This is a method on the original interface which defined it. - /// By contrast, describes a method on an interface which a proxy is being generated for, having type argument substitutions, etc. - /// - internal sealed class InvokableMethodDescription : IEquatable + public static InvokableMethodDescription Create(InvokableMethodId method, INamedTypeSymbol containingType) => new(method, containingType); + + private InvokableMethodDescription(InvokableMethodId invokableId, INamedTypeSymbol containingType) { - public static InvokableMethodDescription Create(InvokableMethodId method, INamedTypeSymbol containingType) => new(method, containingType); + Key = invokableId; + ContainingInterface = containingType; + GeneratedMethodId = GeneratedCodeUtilities.CreateHashedMethodId(Method); + MethodId = GenerationContext.GetId(Method)?.ToString(CultureInfo.InvariantCulture) ?? GenerationContext.GetAlias(Method) ?? GeneratedMethodId; - private InvokableMethodDescription(InvokableMethodId invokableId, INamedTypeSymbol containingType) - { - Key = invokableId; - ContainingInterface = containingType; - GeneratedMethodId = CodeGenerator.CreateHashedMethodId(Method); - MethodId = CodeGenerator.GetId(Method)?.ToString(CultureInfo.InvariantCulture) ?? CodeGenerator.GetAlias(Method) ?? GeneratedMethodId; + MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); + // Set defaults from the interface type. + var invokableBaseTypes = new Dictionary(SymbolEqualityComparer.Default); + foreach (var pair in ProxyBase.InvokableBaseTypes) + { + invokableBaseTypes[pair.Key] = pair.Value; + } - // Set defaults from the interface type. - var invokableBaseTypes = new Dictionary(SymbolEqualityComparer.Default); - foreach (var pair in ProxyBase.InvokableBaseTypes) + InvokableBaseTypes = invokableBaseTypes; + foreach (var methodAttr in Method.GetAttributes()) + { + if (methodAttr.AttributeClass is not { } attributeClass) { - invokableBaseTypes[pair.Key] = pair.Value; + continue; } - InvokableBaseTypes = invokableBaseTypes; - foreach (var methodAttr in Method.GetAttributes()) + if (attributeClass.GetAttributes(GenerationContext.LibraryTypes.InvokableBaseTypeAttribute, out var attrs)) { - if (methodAttr.AttributeClass.GetAttributes(CodeGenerator.LibraryTypes.InvokableBaseTypeAttribute, out var attrs)) + foreach (var attr in attrs) { - foreach (var attr in attrs) + var ctorArgs = attr.ConstructorArguments; + var proxyBaseType = (INamedTypeSymbol)ctorArgs[0].Value!; + var returnType = (INamedTypeSymbol)ctorArgs[1].Value!; + var invokableBaseType = (INamedTypeSymbol)ctorArgs[2].Value!; + if (!SymbolEqualityComparer.Default.Equals(ProxyBase.ProxyBaseType, proxyBaseType)) { - var ctorArgs = attr.ConstructorArguments; - var proxyBaseType = (INamedTypeSymbol)ctorArgs[0].Value; - var returnType = (INamedTypeSymbol)ctorArgs[1].Value; - var invokableBaseType = (INamedTypeSymbol)ctorArgs[2].Value; - if (!SymbolEqualityComparer.Default.Equals(ProxyBase.ProxyBaseType, proxyBaseType)) - { - // This attribute does not apply to this particular invoker, since it is for a different proxy base type. - continue; - } - - invokableBaseTypes[returnType] = invokableBaseType; + // This attribute does not apply to this particular invoker, since it is for a different proxy base type. + continue; } + + invokableBaseTypes[returnType] = invokableBaseType; } + } - if (methodAttr.AttributeClass.GetAttributes(CodeGenerator.LibraryTypes.InvokableCustomInitializerAttribute, out attrs)) + if (attributeClass.GetAttributes(GenerationContext.LibraryTypes.InvokableCustomInitializerAttribute, out attrs)) + { + foreach (var attr in attrs) { - foreach (var attr in attrs) - { - var methodName = (string)attr.ConstructorArguments[0].Value; + var methodName = (string)attr.ConstructorArguments[0].Value!; - TypedConstant methodArgument; - if (attr.ConstructorArguments.Length == 2) + TypedConstant methodArgument; + if (attr.ConstructorArguments.Length == 2) + { + // Take the value from the attribute directly. + methodArgument = attr.ConstructorArguments[1]; + } + else + { + // Take the value from the attribute which this attribute is attached to. + if (TryGetNamedArgument(attr.NamedArguments, "AttributeArgumentName", out var argNameArg) + && TryGetNamedArgument(methodAttr.NamedArguments, (string?)argNameArg.Value, out var namedArgument)) { - // Take the value from the attribute directly. - methodArgument = attr.ConstructorArguments[1]; + methodArgument = namedArgument; } else { - // Take the value from the attribute which this attribute is attached to. - if (TryGetNamedArgument(attr.NamedArguments, "AttributeArgumentName", out var argNameArg) - && TryGetNamedArgument(methodAttr.NamedArguments, (string)argNameArg.Value, out var namedArgument)) + var index = 0; + if (TryGetNamedArgument(attr.NamedArguments, "AttributeArgumentIndex", out var indexArg)) { - methodArgument = namedArgument; + index = (int)indexArg.Value!; } - else - { - var index = 0; - if (TryGetNamedArgument(attr.NamedArguments, "AttributeArgumentIndex", out var indexArg)) - { - index = (int)indexArg.Value; - } - methodArgument = methodAttr.ConstructorArguments[index]; - } + methodArgument = methodAttr.ConstructorArguments[index]; } - - CustomInitializerMethods.Add((methodName, methodArgument)); } - } - if (SymbolEqualityComparer.Default.Equals(methodAttr.AttributeClass, CodeGenerator.LibraryTypes.ResponseTimeoutAttribute)) - { - ResponseTimeoutTicks = TimeSpan.Parse((string)methodAttr.ConstructorArguments[0].Value).Ticks; + CustomInitializerMethods.Add((methodName, methodArgument)); } } - AllTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - - var names = new HashSet(StringComparer.Ordinal); - foreach (var typeParameter in ContainingInterface.GetAllTypeParameters()) + if (SymbolEqualityComparer.Default.Equals(attributeClass, GenerationContext.LibraryTypes.ResponseTimeoutAttribute)) { - var tpName = GetTypeParameterName(names, typeParameter); - AllTypeParameters.Add((tpName, typeParameter)); + ResponseTimeoutTicks = TimeSpan.Parse((string)methodAttr.ConstructorArguments[0].Value!).Ticks; } + } - foreach (var typeParameter in Method.TypeParameters) - { - var tpName = GetTypeParameterName(names, typeParameter); - AllTypeParameters.Add((tpName, typeParameter)); - MethodTypeParameters.Add((tpName, typeParameter)); - } + AllTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); + MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - TypeParameterSubstitutions = new(SymbolEqualityComparer.Default); - foreach (var (name, parameter) in AllTypeParameters) - { - TypeParameterSubstitutions[parameter] = name; - } + var names = new HashSet(StringComparer.Ordinal); + foreach (var typeParameter in ContainingInterface.GetAllTypeParameters()) + { + var tpName = GetTypeParameterName(names, typeParameter); + AllTypeParameters.Add((tpName, typeParameter)); + } - static string GetTypeParameterName(HashSet names, ITypeParameterSymbol typeParameter) - { - var count = 0; - var result = typeParameter.Name; - while (names.Contains(result)) - { - result = $"{typeParameter.Name}_{++count}"; - } + foreach (var typeParameter in Method.TypeParameters) + { + var tpName = GetTypeParameterName(names, typeParameter); + AllTypeParameters.Add((tpName, typeParameter)); + MethodTypeParameters.Add((tpName, typeParameter)); + } + + TypeParameterSubstitutions = new(SymbolEqualityComparer.Default); + foreach (var (name, parameter) in AllTypeParameters) + { + TypeParameterSubstitutions[parameter] = name; + } - names.Add(result); - return result.EscapeIdentifier(); + static string GetTypeParameterName(HashSet names, ITypeParameterSymbol typeParameter) + { + var count = 0; + var result = typeParameter.Name; + while (names.Contains(result)) + { + result = $"{typeParameter.Name}_{++count}"; } - static bool TryGetNamedArgument(ImmutableArray> arguments, string name, out TypedConstant value) + names.Add(result); + return result.EscapeIdentifier(); + } + + static bool TryGetNamedArgument(ImmutableArray> arguments, string? name, out TypedConstant value) + { + foreach (var arg in arguments) { - foreach (var arg in arguments) + if (string.Equals(arg.Key, name, StringComparison.Ordinal)) { - if (string.Equals(arg.Key, name, StringComparison.Ordinal)) - { - value = arg.Value; - return true; - } + value = arg.Value; + return true; } - - value = default; - return false; } - } - /// - /// Gets the source generator. - /// - public CodeGenerator CodeGenerator => ProxyBase.CodeGenerator; - - /// - /// Gets the method identifier. - /// - public InvokableMethodId Key { get; } - - /// - /// Gets the proxy base information for the method (eg, GrainReference, whether it is an extension). - /// - public InvokableMethodProxyBase ProxyBase => Key.ProxyBase; - - /// - /// Gets the method symbol. - /// - public IMethodSymbol Method => Key.Method; - - /// - /// Gets the dictionary of invokable base types. This indicates what invokable base type (eg, ValueTaskRequest) should be used for a given return type (eg, ValueTask). - /// - public IReadOnlyDictionary InvokableBaseTypes { get; } - - /// - /// Gets the response timeout ticks, if set. - /// - public long? ResponseTimeoutTicks { get; } - - /// - /// Gets the list of custom initializer method names and their corresponding argument. - /// - public List<(string MethodName, TypedConstant MethodArgument)> CustomInitializerMethods { get; } = new(); - - /// - /// Gets the generated method identifier. - /// - public string GeneratedMethodId { get; } - - /// - /// Gets the method identifier. - /// - public string MethodId { get; } - - public List<(string Name, ITypeParameterSymbol Parameter)> AllTypeParameters { get; } - public List<(string Name, ITypeParameterSymbol Parameter)> MethodTypeParameters { get; } - public Dictionary TypeParameterSubstitutions { get; } - - /// - /// Gets a value indicating whether this method has an alias. - /// - public bool HasAlias => !string.Equals(MethodId, GeneratedMethodId, StringComparison.Ordinal); - - /// - /// Gets the interface which this type is contained in. - /// - public INamedTypeSymbol ContainingInterface { get; } - - /// - /// Gets a value indicating whether this method is cancellable. - /// - public bool IsCancellable => Method.Parameters.Any(parameterSymbol => SymbolEqualityComparer.Default.Equals(CodeGenerator.LibraryTypes.CancellationToken, parameterSymbol.Type)); - - public bool Equals(InvokableMethodDescription other) => Key.Equals(other.Key); - public override bool Equals(object obj) => obj is InvokableMethodDescription imd && Equals(imd); - public override int GetHashCode() => Key.GetHashCode(); - public override string ToString() => $"{ProxyBase}/{ContainingInterface.Name}/{Method.Name}"; + value = default; + return false; + } } + + /// + /// Gets the proxy generation context. + /// + public ProxyGenerationContext GenerationContext => ProxyBase.GenerationContext; + + /// + /// Gets the method identifier. + /// + public InvokableMethodId Key { get; } + + /// + /// Gets the proxy base information for the method (eg, GrainReference, whether it is an extension). + /// + public InvokableMethodProxyBase ProxyBase => Key.ProxyBase; + + /// + /// Gets the method symbol. + /// + public IMethodSymbol Method => Key.Method; + + /// + /// Gets the dictionary of invokable base types. This indicates what invokable base type (eg, ValueTaskRequest) should be used for a given return type (eg, ValueTask). + /// + public IReadOnlyDictionary InvokableBaseTypes { get; } + + /// + /// Gets the response timeout ticks, if set. + /// + public long? ResponseTimeoutTicks { get; } + + /// + /// Gets the list of custom initializer method names and their corresponding argument. + /// + public List<(string MethodName, TypedConstant MethodArgument)> CustomInitializerMethods { get; } = new(); + + /// + /// Gets the generated method identifier. + /// + public string GeneratedMethodId { get; } + + /// + /// Gets the method identifier. + /// + public string MethodId { get; } + + public List<(string Name, ITypeParameterSymbol Parameter)> AllTypeParameters { get; } + public List<(string Name, ITypeParameterSymbol Parameter)> MethodTypeParameters { get; } + public Dictionary TypeParameterSubstitutions { get; } + + /// + /// Gets a value indicating whether this method has an alias. + /// + public bool HasAlias => !string.Equals(MethodId, GeneratedMethodId, StringComparison.Ordinal); + + /// + /// Gets the interface which this type is contained in. + /// + public INamedTypeSymbol ContainingInterface { get; } + + /// + /// Gets a value indicating whether this method is cancellable. + /// + public bool IsCancellable => Method.Parameters.Any(parameterSymbol => SymbolEqualityComparer.Default.Equals(GenerationContext.LibraryTypes.CancellationToken, parameterSymbol.Type)); + + public bool Equals(InvokableMethodDescription? other) => other is not null && Key.Equals(other.Key); + public override bool Equals(object? obj) => obj is InvokableMethodDescription imd && Equals(imd); + public override int GetHashCode() => Key.GetHashCode(); + public override string ToString() => $"{ProxyBase}/{ContainingInterface.Name}/{Method.Name}"; } diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs index 4c3c2c08db2..3be266fd4f4 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodId.cs @@ -1,44 +1,42 @@ using Microsoft.CodeAnalysis; -using System; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Identifies an invokable method. +/// +internal readonly struct InvokableMethodId(InvokableMethodProxyBase proxyBaseInfo, INamedTypeSymbol interfaceType, IMethodSymbol method) : IEquatable { /// - /// Identifies an invokable method. + /// Gets the proxy base information for the method (eg, GrainReference, whether it is an extension). /// - internal readonly struct InvokableMethodId(InvokableMethodProxyBase proxyBaseInfo, INamedTypeSymbol interfaceType, IMethodSymbol method) : IEquatable - { - /// - /// Gets the proxy base information for the method (eg, GrainReference, whether it is an extension). - /// - public InvokableMethodProxyBase ProxyBase { get; } = proxyBaseInfo; + public InvokableMethodProxyBase ProxyBase { get; } = proxyBaseInfo; - /// - /// Gets the method symbol. - /// - public IMethodSymbol Method { get; } = method; + /// + /// Gets the method symbol. + /// + public IMethodSymbol Method { get; } = method; - /// - /// Gets the containing interface symbol. - /// - public INamedTypeSymbol InterfaceType { get; } = interfaceType; + /// + /// Gets the containing interface symbol. + /// + public INamedTypeSymbol InterfaceType { get; } = interfaceType; - public bool Equals(InvokableMethodId other) => - ProxyBase.Equals(other.ProxyBase) - && SymbolEqualityComparer.Default.Equals(Method, other.Method) - && SymbolEqualityComparer.Default.Equals(InterfaceType, other.InterfaceType); + public bool Equals(InvokableMethodId other) => + ProxyBase.Equals(other.ProxyBase) + && SymbolEqualityComparer.Default.Equals(Method, other.Method) + && SymbolEqualityComparer.Default.Equals(InterfaceType, other.InterfaceType); - public override bool Equals(object obj) => obj is InvokableMethodId imd && Equals(imd); - public override int GetHashCode() + public override bool Equals(object obj) => obj is InvokableMethodId imd && Equals(imd); + public override int GetHashCode() + { + unchecked { - unchecked - { - return ProxyBase.GetHashCode() - * 17 ^ SymbolEqualityComparer.Default.GetHashCode(Method) - * 17 ^ SymbolEqualityComparer.Default.GetHashCode(InterfaceType); - } + return ProxyBase.GetHashCode() + * 17 ^ SymbolEqualityComparer.Default.GetHashCode(Method) + * 17 ^ SymbolEqualityComparer.Default.GetHashCode(InterfaceType); } - - public override string ToString() => $"{ProxyBase}/{InterfaceType.Name}/{Method.Name}"; } + + public override string ToString() => $"{ProxyBase}/{InterfaceType.Name}/{Method.Name}"; } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBase.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBase.cs index 8b8d8d61688..abc1ef0cabe 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBase.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBase.cs @@ -1,55 +1,46 @@ using Microsoft.CodeAnalysis; -using System; -using System.Collections.Generic; using System.Collections.Immutable; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Describes the proxy base for an invokable method, including whether the proxy is a grain reference or extension, and what invokable base types should be used for a given return type. +/// +internal sealed class InvokableMethodProxyBase(ProxyGenerationContext generationContext, InvokableMethodProxyBaseId descriptor, Dictionary invokableBaseTypes) : IEquatable { + + /// + /// Gets the proxy generation context. + /// + public ProxyGenerationContext GenerationContext { get; } = generationContext; + + /// + /// Gets the proxy base id. + /// + public InvokableMethodProxyBaseId Key { get; } = descriptor; + + /// + /// Gets the proxy base type, eg GrainReference. + /// + public INamedTypeSymbol ProxyBaseType => Key.ProxyBaseType; + + /// + /// Gets a value indicating whether this descriptor represents an extension. + /// + public bool IsExtension => Key.IsExtension; + /// - /// Describes the proxy base for an invokable method, including whether the proxy is a grain reference or extension, and what invokable base types should be used for a given return type. + /// Gets the components of the compound type alias used to refer to this proxy base. /// - internal sealed class InvokableMethodProxyBase : IEquatable - { - public InvokableMethodProxyBase(CodeGenerator codeGenerator, InvokableMethodProxyBaseId descriptor, Dictionary invokableBaseTypes) - { - CodeGenerator = codeGenerator; - Key = descriptor; - InvokableBaseTypes = invokableBaseTypes ?? throw new ArgumentNullException(nameof(invokableBaseTypes)); - } - - /// - /// Gets the source generator. - /// - public CodeGenerator CodeGenerator { get; } - - /// - /// Gets the proxy base id. - /// - public InvokableMethodProxyBaseId Key { get; } - - /// - /// Gets the proxy base type, eg GrainReference. - /// - public INamedTypeSymbol ProxyBaseType => Key.ProxyBaseType; - - /// - /// Gets a value indicating whether this descriptor represents an extension. - /// - public bool IsExtension => Key.IsExtension; - - /// - /// Gets the components of the compound type alias used to refer to this proxy base. - /// - public ImmutableArray CompositeAliasComponents => Key.CompositeAliasComponents; - - /// - /// Gets the dictionary of invokable base types. This indicates what invokable base type (eg, ValueTaskRequest) should be used for a given return type (eg, ValueTask). - /// - public IReadOnlyDictionary InvokableBaseTypes { get; } - - public bool Equals(InvokableMethodProxyBase other) => other is not null && Key.Equals(other.Key); - public override bool Equals(object obj) => obj is InvokableMethodProxyBase other && Equals(other); - public override int GetHashCode() => Key.GetHashCode(); - public override string ToString() => Key.ToString(); - } -} \ No newline at end of file + public ImmutableArray CompositeAliasComponents => Key.CompositeAliasComponents; + + /// + /// Gets the dictionary of invokable base types. This indicates what invokable base type (eg, ValueTaskRequest) should be used for a given return type (eg, ValueTask). + /// + public IReadOnlyDictionary InvokableBaseTypes { get; } = invokableBaseTypes ?? throw new ArgumentNullException(nameof(invokableBaseTypes)); + + public bool Equals(InvokableMethodProxyBase other) => other is not null && Key.Equals(other.Key); + public override bool Equals(object obj) => obj is InvokableMethodProxyBase other && Equals(other); + public override int GetHashCode() => Key.GetHashCode(); + public override string ToString() => Key.ToString(); +} diff --git a/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBaseId.cs b/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBaseId.cs index 6360d7c4d7e..12ddf504b40 100644 --- a/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBaseId.cs +++ b/src/Orleans.CodeGenerator/Model/InvokableMethodProxyBaseId.cs @@ -1,59 +1,57 @@ using Microsoft.CodeAnalysis; -using System; using System.Collections.Immutable; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Identifies a proxy base, including whether the proxy is a grain reference or extension. +/// +internal readonly struct InvokableMethodProxyBaseId : IEquatable { - /// - /// Identifies a proxy base, including whether the proxy is a grain reference or extension. - /// - internal readonly struct InvokableMethodProxyBaseId : IEquatable + public InvokableMethodProxyBaseId(INamedTypeSymbol type, bool isExtension) { - public InvokableMethodProxyBaseId(INamedTypeSymbol type, bool isExtension) + if (!SymbolEqualityComparer.Default.Equals(type, type.OriginalDefinition)) { - if (!SymbolEqualityComparer.Default.Equals(type, type.OriginalDefinition)) - { - throw new ArgumentException("Type must be an original definition. This is a code generator bug."); - } - - ProxyBaseType = type; - IsExtension = isExtension; - - if (IsExtension) - { - CompositeAliasComponents = ImmutableArray.Create(new CompoundTypeAliasComponent[] { new(ProxyBaseType), new("Ext") }); - GeneratedClassNameComponent = $"{ProxyBaseType.Name}_Ext"; - } - else - { - CompositeAliasComponents = ImmutableArray.Create(new CompoundTypeAliasComponent[] { new(ProxyBaseType) }); - GeneratedClassNameComponent = ProxyBaseType.Name; - } + throw new ArgumentException("Type must be an original definition. This is a code generator bug."); } - /// - /// Gets the proxy base type, eg GrainReference. - /// - public INamedTypeSymbol ProxyBaseType { get; } - - /// - /// Gets a value indicating whether this descriptor represents an extension. - /// - public bool IsExtension { get; } - - /// - /// Gets the components of the compound type alias used to refer to this proxy base. - /// - public ImmutableArray CompositeAliasComponents { get; } - - /// - /// Gets a string used to distinguish this proxy base from others in generated class names. - /// - public string GeneratedClassNameComponent { get; } - - public bool Equals(InvokableMethodProxyBaseId other) => SymbolEqualityComparer.Default.Equals(ProxyBaseType, other.ProxyBaseType) && IsExtension == other.IsExtension; - public override bool Equals(object obj) => obj is InvokableMethodProxyBaseId other && Equals(other); - public override int GetHashCode() => IsExtension.GetHashCode() * 17 ^ SymbolEqualityComparer.Default.GetHashCode(ProxyBaseType); - public override string ToString() => GeneratedClassNameComponent; + ProxyBaseType = type; + IsExtension = isExtension; + + if (IsExtension) + { + CompositeAliasComponents = [new(ProxyBaseType), new("Ext")]; + GeneratedClassNameComponent = $"{ProxyBaseType.Name}_Ext"; + } + else + { + CompositeAliasComponents = [new(ProxyBaseType)]; + GeneratedClassNameComponent = ProxyBaseType.Name; + } } + + /// + /// Gets the proxy base type, eg GrainReference. + /// + public INamedTypeSymbol ProxyBaseType { get; } + + /// + /// Gets a value indicating whether this descriptor represents an extension. + /// + public bool IsExtension { get; } + + /// + /// Gets the components of the compound type alias used to refer to this proxy base. + /// + public ImmutableArray CompositeAliasComponents { get; } + + /// + /// Gets a string used to distinguish this proxy base from others in generated class names. + /// + public string GeneratedClassNameComponent { get; } + + public bool Equals(InvokableMethodProxyBaseId other) => SymbolEqualityComparer.Default.Equals(ProxyBaseType, other.ProxyBaseType) && IsExtension == other.IsExtension; + public override bool Equals(object obj) => obj is InvokableMethodProxyBaseId other && Equals(other); + public override int GetHashCode() => IsExtension.GetHashCode() * 17 ^ SymbolEqualityComparer.Default.GetHashCode(ProxyBaseType); + public override string ToString() => GeneratedClassNameComponent; } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/MemberModel.cs b/src/Orleans.CodeGenerator/Model/MemberModel.cs new file mode 100644 index 00000000000..c3875a7797b --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/MemberModel.cs @@ -0,0 +1,119 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes the kind of a serializable member. +/// +internal enum MemberKind : byte +{ + Field, + Property +} + +/// +/// Describes the accessibility strategy for getting/setting a member value during serialization. +/// +internal enum AccessStrategy : byte +{ + /// + /// The member can be accessed directly (public field or property with accessible getter/setter). + /// + Direct, + + /// + /// The member requires a generated delegate-based accessor (FieldAccessor utility). + /// + GeneratedAccessor, + + /// + /// The member requires an UnsafeAccessor-based accessor. + /// + UnsafeAccessor +} + +/// +/// Describes the strategy for constructing an instance of a serializable type during deserialization. +/// +internal enum ObjectCreationStrategy : byte +{ + /// + /// Use default(T) for value types. + /// + Default, + + /// + /// Use new T() — type has an accessible parameterless constructor. + /// + NewExpression, + + /// + /// Use RuntimeHelpers.GetUninitializedObject(typeof(T)). + /// + GetUninitializedObject +} + +/// +/// Describes a serializable field or property member in a . +/// Contains all data needed for serializer, copier, and activator generation without holding ISymbol references. +/// +internal sealed record class MemberModel +{ + public MemberModel( + uint fieldId, + string name, + TypeRef type, + TypeRef containingType, + string assemblyName, + string typeNameIdentifier, + bool isPrimaryConstructorParameter, + bool isSerializable, + bool isCopyable, + MemberKind kind, + AccessStrategy getterStrategy, + AccessStrategy setterStrategy, + bool isObsolete, + bool hasImmutableAttribute, + bool isShallowCopyable, + bool isValueType, + bool containingTypeIsValueType, + string? backingPropertyName) + { + FieldId = fieldId; + Name = name; + Type = type; + ContainingType = containingType; + AssemblyName = assemblyName; + TypeNameIdentifier = typeNameIdentifier; + IsPrimaryConstructorParameter = isPrimaryConstructorParameter; + IsSerializable = isSerializable; + IsCopyable = isCopyable; + Kind = kind; + GetterStrategy = getterStrategy; + SetterStrategy = setterStrategy; + IsObsolete = isObsolete; + HasImmutableAttribute = hasImmutableAttribute; + IsShallowCopyable = isShallowCopyable; + IsValueType = isValueType; + ContainingTypeIsValueType = containingTypeIsValueType; + BackingPropertyName = backingPropertyName; + } + + public uint FieldId { get; } + public string Name { get; } + public TypeRef Type { get; } + public TypeRef ContainingType { get; } + public string AssemblyName { get; } + public string TypeNameIdentifier { get; } + public bool IsPrimaryConstructorParameter { get; } + public bool IsSerializable { get; } + public bool IsCopyable { get; } + public MemberKind Kind { get; } + public AccessStrategy GetterStrategy { get; } + public AccessStrategy SetterStrategy { get; } + public bool IsObsolete { get; } + public bool HasImmutableAttribute { get; } + public bool IsShallowCopyable { get; } + public bool IsValueType { get; } + public bool ContainingTypeIsValueType { get; } + public string? BackingPropertyName { get; } + +} diff --git a/src/Orleans.CodeGenerator/Model/MetadataAggregateModel.cs b/src/Orleans.CodeGenerator/Model/MetadataAggregateModel.cs new file mode 100644 index 00000000000..5fe6b32ddd8 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/MetadataAggregateModel.cs @@ -0,0 +1,24 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Combined model aggregating all pipeline outputs for metadata class generation. +/// This is the input to the final RegisterSourceOutput that produces the +/// Metadata_{AssemblyName} class and assembly-level attributes. +/// +internal sealed record class MetadataAggregateModel( + string AssemblyName, + EquatableArray SerializableTypes, + EquatableArray ProxyInterfaces, + EquatableArray RegisteredCodecs, + ReferenceAssemblyModel ReferenceAssemblyData, + EquatableArray ActivatableTypes, + EquatableArray GeneratedProxyTypes, + EquatableArray InvokableInterfaces, + EquatableArray GeneratedInvokableActivatorMetadataNames, + EquatableArray InterfaceImplementations, + EquatableArray DefaultCopiers); + +/// +/// Describes a default copier mapping (type → copier type) for shallow-copyable types. +/// +internal readonly record struct DefaultCopierModel(TypeRef OriginalType, TypeRef CopierType); diff --git a/src/Orleans.CodeGenerator/Model/MetadataAggregateModelBuilder.cs b/src/Orleans.CodeGenerator/Model/MetadataAggregateModelBuilder.cs new file mode 100644 index 00000000000..f9a9871b6a8 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/MetadataAggregateModelBuilder.cs @@ -0,0 +1,339 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal static class MetadataAggregateModelBuilder +{ + /// + /// Creates a from the collected pipeline outputs. + /// This provides a single equality checkpoint so that downstream generation can be + /// skipped when no upstream pipeline has changed. + /// + internal static MetadataAggregateModel CreateMetadataAggregate( + string assemblyName, + ImmutableArray serializableTypes, + ImmutableArray proxyInterfaces, + ReferenceAssemblyModel refData) + => CreateMetadataAggregate( + assemblyName, + serializableTypes, + proxyInterfaces.IsDefaultOrEmpty + ? [] + : proxyInterfaces.Select(static proxy => new ProxyOutputModel(proxy, EquatableArray.Empty, useDeclaredInvokableFallback: false)).ToImmutableArray(), + refData); + + internal static MetadataAggregateModel CreateMetadataAggregate( + string assemblyName, + ImmutableArray serializableTypes, + ImmutableArray proxyOutputs, + ReferenceAssemblyModel refData) + { + var normalizedReferenceData = NormalizeReferenceAssemblyData(refData); + var normalizedSerializableTypes = MergeSerializableTypes(serializableTypes, normalizedReferenceData.ReferencedSerializableTypes); + var sourceProxyInterfaces = proxyOutputs.IsDefaultOrEmpty + ? [] + : proxyOutputs.Select(static output => output.ProxyInterface).ToImmutableArray(); + var normalizedProxyInterfaces = MergeProxyInterfaces(sourceProxyInterfaces, normalizedReferenceData.ReferencedProxyInterfaces); + var activatableTypes = GetActivatableTypes(normalizedSerializableTypes); + var generatedProxyTypes = GetGeneratedProxyTypes(normalizedProxyInterfaces); + var invokableInterfaces = GetInvokableInterfaces(normalizedProxyInterfaces); + var generatedInvokableActivatorMetadataNames = GetGeneratedInvokableActivatorMetadataNames(proxyOutputs); + var defaultCopiers = GetDefaultCopiers(normalizedSerializableTypes); + + return new MetadataAggregateModel( + AssemblyName: assemblyName, + SerializableTypes: normalizedSerializableTypes, + ProxyInterfaces: normalizedProxyInterfaces, + RegisteredCodecs: normalizedReferenceData.RegisteredCodecs, + ReferenceAssemblyData: normalizedReferenceData, + ActivatableTypes: activatableTypes, + GeneratedProxyTypes: generatedProxyTypes, + InvokableInterfaces: invokableInterfaces, + GeneratedInvokableActivatorMetadataNames: generatedInvokableActivatorMetadataNames, + InterfaceImplementations: normalizedReferenceData.InterfaceImplementations, + DefaultCopiers: defaultCopiers); + } + + private static ReferenceAssemblyModel NormalizeReferenceAssemblyData(ReferenceAssemblyModel referenceData) + { + var applicationParts = referenceData.ApplicationParts + .Distinct() + .ToImmutableArray(); + + var wellKnownTypeIds = referenceData.WellKnownTypeIds + .Distinct() + .OrderBy(static entry => entry.Type.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.Id) + .ToImmutableArray(); + + var typeAliases = referenceData.TypeAliases + .Distinct() + .OrderBy(static entry => entry.Type.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.Alias, StringComparer.Ordinal) + .ToImmutableArray(); + + var compoundTypeAliases = referenceData.CompoundTypeAliases + .Distinct() + .OrderBy(static entry => ReferenceAssemblyModelExtractor.GetCompoundTypeAliasOrderKey(entry), StringComparer.Ordinal) + .ThenBy(static entry => entry.TargetType.SyntaxString, StringComparer.Ordinal) + .ToImmutableArray(); + + var referencedSerializableTypes = DeduplicateSerializableTypes(referenceData.ReferencedSerializableTypes); + + var referencedProxyInterfaces = DeduplicateProxyInterfaces(referenceData.ReferencedProxyInterfaces); + + var registeredCodecs = referenceData.RegisteredCodecs + .Distinct() + .OrderBy(static entry => entry.Type.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.Kind) + .ToImmutableArray(); + + var interfaceImplementations = referenceData.InterfaceImplementations + .Distinct() + .OrderBy(static entry => entry.ImplementationType.SyntaxString, StringComparer.Ordinal) + .ToImmutableArray(); + + return new ReferenceAssemblyModel( + AssemblyName: referenceData.AssemblyName ?? string.Empty, + ApplicationParts: applicationParts, + WellKnownTypeIds: wellKnownTypeIds, + TypeAliases: typeAliases, + CompoundTypeAliases: compoundTypeAliases, + ReferencedSerializableTypes: referencedSerializableTypes, + ReferencedProxyInterfaces: referencedProxyInterfaces, + RegisteredCodecs: registeredCodecs, + InterfaceImplementations: interfaceImplementations); + } + + internal static ImmutableArray MergeSerializableTypes( + ImmutableArray source, + ImmutableArray referenced) + { + var merged = source.IsDefault ? [] : source; + if (!referenced.IsDefaultOrEmpty) + { + merged = merged.AddRange(referenced); + } + + return DeduplicateSerializableTypes(merged); + } + + internal static ImmutableArray MergeProxyInterfaces( + ImmutableArray source, + ImmutableArray referenced) + { + var merged = source.IsDefault ? [] : source; + if (!referenced.IsDefaultOrEmpty) + { + merged = merged.AddRange(referenced); + } + + return DeduplicateProxyInterfaces(merged); + } + + internal static ImmutableArray DeduplicateSerializableTypes( + ImmutableArray entries) + { + if (entries.IsDefaultOrEmpty) + { + return []; + } + + var selected = new Dictionary(StringComparer.Ordinal); + foreach (var entry in entries + .Where(static entry => entry is not null) + .OrderBy(static entry => entry.SourceLocation.SourceOrderGroup) + .ThenBy(static entry => entry.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.SourceLocation.Position) + .ThenBy(static entry => entry.TypeSyntax.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static entry => entry.GeneratedNamespace, StringComparer.Ordinal) + .ThenBy(static entry => entry.Name, StringComparer.Ordinal)) + { + var key = CreateTypeDedupeKey( + entry.MetadataIdentity, + entry.TypeSyntax.SyntaxString, + entry.GeneratedNamespace, + entry.Name); + if (!selected.ContainsKey(key)) + { + selected.Add(key, entry); + } + } + + return [.. OrderSerializableTypeModels(selected.Values)]; + } + + internal static ImmutableArray DeduplicateProxyInterfaces( + ImmutableArray entries) + { + if (entries.IsDefaultOrEmpty) + { + return []; + } + + var selected = new Dictionary(StringComparer.Ordinal); + foreach (var entry in entries + .Where(static entry => entry is not null) + .OrderBy(static entry => entry.SourceLocation.SourceOrderGroup) + .ThenBy(static entry => entry.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.SourceLocation.Position) + .ThenBy(static entry => entry.InterfaceType.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static entry => entry.GeneratedNamespace, StringComparer.Ordinal) + .ThenBy(static entry => entry.Name, StringComparer.Ordinal)) + { + var key = CreateTypeDedupeKey( + entry.MetadataIdentity, + entry.InterfaceType.SyntaxString, + entry.GeneratedNamespace, + entry.Name); + if (!selected.ContainsKey(key)) + { + selected.Add(key, entry); + } + } + + return [.. OrderProxyInterfaceModels(selected.Values)]; + } + + internal static ImmutableArray NormalizeSerializableTypeModels( + ImmutableArray entries) + { + if (entries.IsDefaultOrEmpty) + { + return []; + } + + return [.. OrderSerializableTypeModels(entries.Where(static entry => entry is not null))]; + } + + internal static ImmutableArray NormalizeProxyInterfaceModels( + ImmutableArray entries) + { + if (entries.IsDefaultOrEmpty) + { + return []; + } + + return [.. OrderProxyInterfaceModels(entries.Where(static entry => entry is not null))]; + } + + private static string CreateTypeDedupeKey( + TypeMetadataIdentity metadataIdentity, + string typeSyntax, + string generatedNamespace, + string name) + { + if (!metadataIdentity.IsEmpty) + { + return string.Join( + "|", + "M", + metadataIdentity.AssemblyIdentity ?? string.Empty, + metadataIdentity.AssemblyName ?? string.Empty, + metadataIdentity.MetadataName ?? string.Empty); + } + + return string.Join( + "|", + "S", + typeSyntax ?? string.Empty, + generatedNamespace ?? string.Empty, + name ?? string.Empty); + } + + internal static IOrderedEnumerable OrderSerializableTypeModels(IEnumerable entries) + => entries + .OrderBy(static entry => entry.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static entry => entry.TypeSyntax.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.SourceLocation.SourceOrderGroup) + .ThenBy(static entry => entry.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.SourceLocation.Position) + .ThenBy(static entry => entry.GeneratedNamespace, StringComparer.Ordinal) + .ThenBy(static entry => entry.Name, StringComparer.Ordinal); + + internal static IOrderedEnumerable OrderProxyInterfaceModels(IEnumerable entries) + => entries + .OrderBy(static entry => entry.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static entry => entry.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static entry => entry.InterfaceType.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.SourceLocation.SourceOrderGroup) + .ThenBy(static entry => entry.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.SourceLocation.Position) + .ThenBy(static entry => entry.GeneratedNamespace, StringComparer.Ordinal) + .ThenBy(static entry => entry.Name, StringComparer.Ordinal); + + private static ImmutableArray GetActivatableTypes(ImmutableArray serializableTypes) + => [.. serializableTypes + .Where(static type => ShouldGenerateActivator(type)) + .Select(static type => type.TypeSyntax) + .Distinct() + .OrderBy(static type => type.SyntaxString, StringComparer.Ordinal)]; + + private static ImmutableArray GetGeneratedProxyTypes(ImmutableArray proxyInterfaces) + => [.. proxyInterfaces + .Select(static proxy => CreateGeneratedTypeRef( + proxy.GeneratedNamespace, + ProxyGenerator.GetSimpleClassName(proxy.Name), + proxy.TypeParameters.Length)) + .Distinct() + .OrderBy(static type => type.SyntaxString, StringComparer.Ordinal)]; + + private static ImmutableArray GetInvokableInterfaces(ImmutableArray proxyInterfaces) + => [.. proxyInterfaces + .Select(static proxy => proxy.InterfaceType) + .Distinct() + .OrderBy(static type => type.SyntaxString, StringComparer.Ordinal)]; + + private static ImmutableArray GetGeneratedInvokableActivatorMetadataNames(ImmutableArray proxyOutputs) + { + if (proxyOutputs.IsDefaultOrEmpty) + { + return []; + } + + return [.. proxyOutputs + .SelectMany(static output => output.OwnedInvokableActivatorMetadataNames) + .Distinct(StringComparer.Ordinal) + .OrderBy(static metadataName => metadataName, StringComparer.Ordinal)]; + } + + private static ImmutableArray GetDefaultCopiers(ImmutableArray serializableTypes) + => [.. serializableTypes + .Where(static type => type.IsShallowCopyable && !type.IsGenericType) + .Select(static type => new DefaultCopierModel( + type.TypeSyntax, + new TypeRef($"global::Orleans.Serialization.Cloning.ShallowCopier<{type.TypeSyntax.SyntaxString}>"))) + .Distinct() + .OrderBy(static entry => entry.OriginalType.SyntaxString, StringComparer.Ordinal)]; + + private static bool ShouldGenerateActivator(SerializableTypeModel type) + { + return !type.IsAbstractType + && !type.IsEnumType + && (!type.IsValueType && type.IsEmptyConstructable && !type.UseActivator || type.HasActivatorConstructor); + } + + private static TypeRef CreateGeneratedTypeRef(string generatedNamespace, string simpleName, int genericArity) + { + var qualifiedName = string.IsNullOrWhiteSpace(generatedNamespace) + ? simpleName + : $"{generatedNamespace}.{simpleName}"; + + return genericArity > 0 + ? new TypeRef($"{qualifiedName}<{new string(',', genericArity - 1)}>") + : new TypeRef(qualifiedName); + } +} + + diff --git a/src/Orleans.CodeGenerator/Model/MetadataModel.cs b/src/Orleans.CodeGenerator/Model/MetadataModel.cs index 133d8b21d63..0c79c6782db 100644 --- a/src/Orleans.CodeGenerator/Model/MetadataModel.cs +++ b/src/Orleans.CodeGenerator/Model/MetadataModel.cs @@ -1,211 +1,70 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System; -using System.Collections.Generic; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class MetadataModel +{ + public Dictionary InvokableInterfaces { get; } = new(SymbolEqualityComparer.Default); + public Dictionary GeneratedInvokables { get; } = new(); + public Dictionary DefaultCopiers { get; } = new(); + internal Dictionary> ProxyBaseTypeInvokableBaseTypes { get; } = new (SymbolEqualityComparer.Default); +} +internal readonly struct CompoundTypeAliasComponent : IEquatable { -#nullable disable - internal class MetadataModel + private readonly Either _value; + public CompoundTypeAliasComponent(string value) => _value = new Either(value); + public CompoundTypeAliasComponent(ITypeSymbol value) => _value = new Either(value); + public static CompoundTypeAliasComponent Default => new(); + public bool IsDefault => _value.RawValue is null; + public bool IsString => _value.IsLeft; + public string? StringValue => _value.LeftValue; + public bool IsType => _value.IsRight; + public ITypeSymbol? TypeValue => _value.RightValue; + public object? Value => _value.RawValue; + + public bool Equals(CompoundTypeAliasComponent other) => (Value, other.Value) switch { - public List SerializableTypes { get; } = new(1024); - public Dictionary InvokableInterfaces { get; } = new(SymbolEqualityComparer.Default); - public List InvokableInterfaceImplementations { get; } = new(1024); - public Dictionary GeneratedInvokables { get; } = new(); - public List GeneratedProxies { get; } = new(1024); - public List ActivatableTypes { get; } = new(1024); - public List DetectedSerializers { get; } = new(); - public List DetectedActivators { get; } = new(); - public Dictionary DefaultCopiers { get; } = new(); - public List DetectedCopiers { get; } = new(); - public List DetectedConverters { get; } = new(); - public List<(TypeSyntax Type, string Alias)> TypeAliases { get; } = new(1024); - public CompoundTypeAliasTree CompoundTypeAliases { get; } = CompoundTypeAliasTree.Create(); - public List<(TypeSyntax Type, uint Id)> WellKnownTypeIds { get; } = new(1024); - public HashSet ApplicationParts { get; } = new(); - internal Dictionary> ProxyBaseTypeInvokableBaseTypes { get; } = new (SymbolEqualityComparer.Default); - } - - /// - /// Represents a compound type aliases as a prefix tree. - /// - internal sealed class CompoundTypeAliasTree + (null, null) => true, + (string stringValue, string otherStringValue) => string.Equals(stringValue, otherStringValue), + (ITypeSymbol typeValue, ITypeSymbol otherTypeValue) => SymbolEqualityComparer.Default.Equals(typeValue, otherTypeValue), + _ => false, + }; + public override bool Equals(object? obj) => obj is CompoundTypeAliasComponent other && Equals(other); + public override int GetHashCode() => _value.RawValue switch { - private Dictionary _children; - - /// - /// Initializes a new instance of the class. - /// - private CompoundTypeAliasTree(CompoundTypeAliasComponent key, TypeSyntax value) - { - Key = key; - Value = value; - } - - /// - /// Gets the key for this node. - /// - public CompoundTypeAliasComponent Key { get; } - - /// - /// Gets the value for this node. - /// - public TypeSyntax Value { get; private set; } - - /// - /// Creates a new tree with a root node which has no key or value. - /// - public static CompoundTypeAliasTree Create() => new(default, default); - - public Dictionary Children => _children; - - internal CompoundTypeAliasTree GetChildOrDefault(object key) - { - TryGetChild(key, out var result); - return result; - } - - internal bool TryGetChild(object key, out CompoundTypeAliasTree result) - { - if (_children is { } children) - { - return children.TryGetValue(key, out result); - } - - result = default; - return false; - } - - public void Add(CompoundTypeAliasComponent[] key, TypeSyntax value) - { - Add(key.AsSpan(), value); - } - - public void Add(ReadOnlySpan keys, TypeSyntax value) - { - if (keys.Length == 0) - { - throw new InvalidOperationException("No valid key specified."); - } - - var key = keys[0]; - if (keys.Length == 1) - { - AddInternal(key, value); - } - else - { - var childNode = GetChildOrDefault(key) ?? AddInternal(key); - childNode.Add(keys.Slice(1), value); - } - } + string stringValue => stringValue.GetHashCode(), + ITypeSymbol type => SymbolEqualityComparer.Default.GetHashCode(type), + _ => throw new InvalidOperationException($"Unsupported type {_value.RawValue}") + }; - /// - /// Adds a node to the tree. - /// - /// The key for the new node. - public CompoundTypeAliasTree Add(ITypeSymbol key) => AddInternal(new CompoundTypeAliasComponent(key)); - - /// - /// Adds a node to the tree. - /// - /// The key for the new node. - public CompoundTypeAliasTree Add(string key) => AddInternal(new CompoundTypeAliasComponent(key)); - - /// - /// Adds a node to the tree. - /// - /// The key for the new node. - /// The value for the new node. - public CompoundTypeAliasTree Add(string key, TypeSyntax value) => AddInternal(new CompoundTypeAliasComponent(key), value); - - /// - /// Adds a node to the tree. - /// - /// The key for the new node. - /// The value for the new node. - public CompoundTypeAliasTree Add(ITypeSymbol key, TypeSyntax value) => AddInternal(new CompoundTypeAliasComponent(key), value); - - private CompoundTypeAliasTree AddInternal(CompoundTypeAliasComponent key) => AddInternal(key, default); - private CompoundTypeAliasTree AddInternal(CompoundTypeAliasComponent key, TypeSyntax value) - { - _children ??= new(); - - if (_children.TryGetValue(key, out var existing)) - { - if (value is not null && existing.Value is not null) - { - throw new ArgumentException($"A key with the value '{key}' already exists. Existing value: '{existing.Value}', new value: '{value}'"); - } - - existing.Value = value; - return existing; - } - else - { - return _children[key] = new CompoundTypeAliasTree(key, value); - } - } - } - - internal readonly struct CompoundTypeAliasComponent : IEquatable + internal readonly struct EqualityComparer : IEqualityComparer { - private readonly Either _value; - public CompoundTypeAliasComponent(string value) => _value = new Either(value); - public CompoundTypeAliasComponent(ITypeSymbol value) => _value = new Either(value); - public static CompoundTypeAliasComponent Default => new(); - public bool IsDefault => _value.RawValue is null; - public bool IsString => _value.IsLeft; - public string StringValue => _value.LeftValue; - public bool IsType => _value.IsRight; - public ITypeSymbol TypeValue => _value.RightValue; - public object Value => _value.RawValue; - - public bool Equals(CompoundTypeAliasComponent other) => (Value, other.Value) switch - { - (null, null) => true, - (string stringValue, string otherStringValue) => string.Equals(stringValue, otherStringValue), - (ITypeSymbol typeValue, ITypeSymbol otherTypeValue) => SymbolEqualityComparer.Default.Equals(typeValue, otherTypeValue), - _ => false, - }; - public override bool Equals(object obj) => obj is CompoundTypeAliasComponent other && Equals(other); - public override int GetHashCode() => _value.RawValue switch - { - string stringValue => stringValue.GetHashCode(), - ITypeSymbol type => SymbolEqualityComparer.Default.GetHashCode(type), - _ => throw new InvalidOperationException($"Unsupported type {_value.RawValue}") - }; + public static EqualityComparer Default => default; + public bool Equals(CompoundTypeAliasComponent x, CompoundTypeAliasComponent y) => x.Equals(y); + public int GetHashCode(CompoundTypeAliasComponent obj) => obj.GetHashCode(); + } - internal readonly struct EqualityComparer : IEqualityComparer - { - public static EqualityComparer Default => default; - public bool Equals(CompoundTypeAliasComponent x, CompoundTypeAliasComponent y) => x.Equals(y); - public int GetHashCode(CompoundTypeAliasComponent obj) => obj.GetHashCode(); - } + public override string ToString() => _value.RawValue?.ToString() ?? string.Empty; +} - public override string ToString() => _value.RawValue?.ToString(); +internal readonly struct Either where T : class where U : class +{ + public Either(T value) + { + RawValue = value; + IsLeft = true; } - internal readonly struct Either where T : class where U : class + public Either(U value) { - private readonly bool _isLeft; - private readonly object _value; - public Either(T value) - { - _value = value; - _isLeft = true; - } - - public Either(U value) - { - _value = value; - _isLeft = false; - } - - public bool IsLeft => _isLeft; - public bool IsRight => !IsLeft; - public T LeftValue => (T)_value; - public U RightValue => (U)_value; - public object RawValue => _value; + RawValue = value; + IsLeft = false; } + + public bool IsLeft { get; } + public bool IsRight => !IsLeft; + public T? LeftValue => (T?)RawValue; + public U? RightValue => (U?)RawValue; + public object? RawValue { get; } } diff --git a/src/Orleans.CodeGenerator/Model/MethodModel.cs b/src/Orleans.CodeGenerator/Model/MethodModel.cs new file mode 100644 index 00000000000..46b9dc7f203 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/MethodModel.cs @@ -0,0 +1,33 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes a method parameter for invokable/proxy generation. +/// +internal readonly record struct MethodParameterModel(string Name, TypeRef Type, int Ordinal, bool IsCancellationToken); + +/// +/// Describes a method on a proxy interface for invokable generation. +/// +internal sealed record class MethodModel( + string Name, + TypeRef ReturnType, + EquatableArray Parameters, + EquatableArray TypeParameters, + TypeRef ContainingInterfaceType, + TypeRef OriginalContainingInterfaceType, + string ContainingInterfaceName, + string ContainingInterfaceGeneratedNamespace, + int ContainingInterfaceTypeParameterCount, + string GeneratedMethodId, + string MethodId, + long? ResponseTimeoutTicks, + EquatableArray CustomInitializerMethods, + bool IsCancellable) +{ + public bool HasAlias => !string.Equals(MethodId, GeneratedMethodId, StringComparison.Ordinal); +} + +/// +/// Describes a custom initializer method associated with an invokable method's attribute. +/// +internal readonly record struct CustomInitializerModel(string MethodName, string ArgumentValue); diff --git a/src/Orleans.CodeGenerator/Model/MethodSignatureComparer.cs b/src/Orleans.CodeGenerator/Model/MethodSignatureComparer.cs index 61592e54c38..18fc6079a6b 100644 --- a/src/Orleans.CodeGenerator/Model/MethodSignatureComparer.cs +++ b/src/Orleans.CodeGenerator/Model/MethodSignatureComparer.cs @@ -1,114 +1,111 @@ -using Microsoft.CodeAnalysis; -using System; -using System.Collections.Generic; +using Microsoft.CodeAnalysis; -namespace Orleans.CodeGenerator -{ - internal sealed class MethodSignatureComparer : IEqualityComparer, IComparer +namespace Orleans.CodeGenerator; + +internal sealed class MethodSignatureComparer : IEqualityComparer, IComparer + { + public static MethodSignatureComparer Default { get; } = new(); + + private MethodSignatureComparer() { - public static MethodSignatureComparer Default { get; } = new(); + } - private MethodSignatureComparer() + public bool Equals(IMethodSymbol x, IMethodSymbol y) + { + if (!string.Equals(x.Name, y.Name, StringComparison.Ordinal)) { + return false; } - public bool Equals(IMethodSymbol x, IMethodSymbol y) + if (x.TypeArguments.Length != y.TypeArguments.Length) { - if (!string.Equals(x.Name, y.Name, StringComparison.Ordinal)) - { - return false; - } + return false; + } - if (x.TypeArguments.Length != y.TypeArguments.Length) + for (var i = 0; i < x.TypeArguments.Length; i++) + { + if (!SymbolEqualityComparer.Default.Equals(x.TypeArguments[i], y.TypeArguments[i])) { return false; } + } - for (var i = 0; i < x.TypeArguments.Length; i++) - { - if (!SymbolEqualityComparer.Default.Equals(x.TypeArguments[i], y.TypeArguments[i])) - { - return false; - } - } + if (x.Parameters.Length != y.Parameters.Length) + { + return false; + } - if (x.Parameters.Length != y.Parameters.Length) + for (var i = 0; i < x.Parameters.Length; i++) + { + if (!SymbolEqualityComparer.Default.Equals(x.Parameters[i].Type, y.Parameters[i].Type)) { return false; } + } - for (var i = 0; i < x.Parameters.Length; i++) - { - if (!SymbolEqualityComparer.Default.Equals(x.Parameters[i].Type, y.Parameters[i].Type)) - { - return false; - } - } + return true; + } - return true; - } + public int GetHashCode(IMethodSymbol obj) + { + int hashCode = -499943048; + hashCode = hashCode * -1521134295 + StringComparer.Ordinal.GetHashCode(obj.Name); - public int GetHashCode(IMethodSymbol obj) + foreach (var arg in obj.TypeArguments) { - int hashCode = -499943048; - hashCode = hashCode * -1521134295 + StringComparer.Ordinal.GetHashCode(obj.Name); + hashCode = hashCode * -1521134295 + SymbolEqualityComparer.Default.GetHashCode(arg); + } - foreach (var arg in obj.TypeArguments) - { - hashCode = hashCode * -1521134295 + SymbolEqualityComparer.Default.GetHashCode(arg); - } + foreach (var parameter in obj.Parameters) + { + hashCode = hashCode * -1521134295 + SymbolEqualityComparer.Default.GetHashCode(parameter.Type); + } - foreach (var parameter in obj.Parameters) - { - hashCode = hashCode * -1521134295 + SymbolEqualityComparer.Default.GetHashCode(parameter.Type); - } + return hashCode; + } - return hashCode; + public int Compare(IMethodSymbol x, IMethodSymbol y) + { + var result = StringComparer.Ordinal.Compare(x.Name, y.Name); + if (result != 0) + { + return result; } - public int Compare(IMethodSymbol x, IMethodSymbol y) + result = x.TypeArguments.Length.CompareTo(y.TypeArguments.Length); + if (result != 0) { - var result = StringComparer.Ordinal.Compare(x.Name, y.Name); - if (result != 0) - { - return result; - } + return result; + } - result = x.TypeArguments.Length.CompareTo(y.TypeArguments.Length); + for (var i = 0; i < x.TypeArguments.Length; i++) + { + var xh = SymbolEqualityComparer.Default.GetHashCode(x.TypeArguments[i]); + var yh = SymbolEqualityComparer.Default.GetHashCode(y.TypeArguments[i]); + result = xh.CompareTo(yh); if (result != 0) { return result; } + } - for (var i = 0; i < x.TypeArguments.Length; i++) - { - var xh = SymbolEqualityComparer.Default.GetHashCode(x.TypeArguments[i]); - var yh = SymbolEqualityComparer.Default.GetHashCode(y.TypeArguments[i]); - result = xh.CompareTo(yh); - if (result != 0) - { - return result; - } - } + result = x.Parameters.Length.CompareTo(y.Parameters.Length); + if (result != 0) + { + return result; + } - result = x.Parameters.Length.CompareTo(y.Parameters.Length); + for (var i = 0; i < x.Parameters.Length; i++) + { + var xh = SymbolEqualityComparer.Default.GetHashCode(x.Parameters[i].Type); + var yh = SymbolEqualityComparer.Default.GetHashCode(y.Parameters[i].Type); + result = xh.CompareTo(yh); if (result != 0) { return result; } - - for (var i = 0; i < x.Parameters.Length; i++) - { - var xh = SymbolEqualityComparer.Default.GetHashCode(x.Parameters[i].Type); - var yh = SymbolEqualityComparer.Default.GetHashCode(y.Parameters[i].Type); - result = xh.CompareTo(yh); - if (result != 0) - { - return result; - } - } - - return 0; } + + return 0; } -} \ No newline at end of file + } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/ModelExtractor.cs b/src/Orleans.CodeGenerator/Model/ModelExtractor.cs new file mode 100644 index 00000000000..b5951c2d8e5 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/ModelExtractor.cs @@ -0,0 +1,100 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +/// +/// Extracts and other incremental pipeline models +/// from Roslyn symbols, producing value-type representations suitable for pipeline caching. +/// +internal static class ModelExtractor +{ + public static SerializableTypeModel ExtractSerializableTypeModel( + ISerializableTypeDescription description, + SourceLocationModel sourceLocation = default) + => SerializableTypeModelExtractor.ExtractSerializableTypeModel(description, sourceLocation); + + public static ReferenceAssemblyModel ExtractReferenceAssemblyData( + Compilation compilation, + CodeGeneratorOptions options, + CancellationToken cancellationToken) + => ReferenceAssemblyModelExtractor.ExtractReferenceAssemblyData(compilation, options, cancellationToken); + + internal static ReferenceAssemblyModel ExtractReferenceAssemblyData( + Compilation compilation, + CodeGeneratorOptions options, + CancellationToken cancellationToken, + out ImmutableArray diagnostics) + => ReferenceAssemblyModelExtractor.ExtractReferenceAssemblyData(compilation, options, cancellationToken, out diagnostics); + + public static MetadataAggregateModel CreateMetadataAggregate( + string assemblyName, + ImmutableArray serializableTypes, + ImmutableArray proxyInterfaces, + ReferenceAssemblyModel refData) + => MetadataAggregateModelBuilder.CreateMetadataAggregate(assemblyName, serializableTypes, proxyInterfaces, refData); + + public static MetadataAggregateModel CreateMetadataAggregate( + string assemblyName, + ImmutableArray serializableTypes, + ImmutableArray proxyOutputs, + ReferenceAssemblyModel refData) + => MetadataAggregateModelBuilder.CreateMetadataAggregate(assemblyName, serializableTypes, proxyOutputs, refData); + + internal static ImmutableArray MergeSerializableTypes( + ImmutableArray source, + ImmutableArray referenced) + => MetadataAggregateModelBuilder.MergeSerializableTypes(source, referenced); + + internal static ImmutableArray MergeProxyInterfaces( + ImmutableArray source, + ImmutableArray referenced) + => MetadataAggregateModelBuilder.MergeProxyInterfaces(source, referenced); + + internal static ImmutableArray DeduplicateSerializableTypes( + ImmutableArray entries) + => MetadataAggregateModelBuilder.DeduplicateSerializableTypes(entries); + + internal static ImmutableArray DeduplicateProxyInterfaces( + ImmutableArray entries) + => MetadataAggregateModelBuilder.DeduplicateProxyInterfaces(entries); + + internal static ImmutableArray NormalizeSerializableTypeModels( + ImmutableArray entries) + => MetadataAggregateModelBuilder.NormalizeSerializableTypeModels(entries); + + internal static ImmutableArray NormalizeProxyInterfaceModels( + ImmutableArray entries) + => MetadataAggregateModelBuilder.NormalizeProxyInterfaceModels(entries); + + public static RegisteredCodecModel ExtractRegisteredCodec(INamedTypeSymbol symbol, RegisteredCodecKind kind) + => ReferenceAssemblyModelExtractor.ExtractRegisteredCodec(symbol, kind); + + internal static SerializableTypeModel? TryExtractSerializableTypeModel( + INamedTypeSymbol typeSymbol, + Compilation compilation, + LibraryTypes libraryTypes, + CodeGeneratorOptions options, + bool throwOnFailure = false) + => SerializableTypeModelExtractor.TryExtractSerializableTypeModel(typeSymbol, compilation, libraryTypes, options, throwOnFailure); + + public static ProxyInterfaceModel? ExtractProxyInterfaceFromAttributeContext( + GeneratorAttributeSyntaxContext context, + CancellationToken cancellationToken) + => ProxyInterfaceModelExtractor.ExtractProxyInterfaceFromAttributeContext(context, cancellationToken); + + public static ProxyInterfaceModel? ExtractProxyInterfaceModel( + INamedTypeSymbol typeSymbol, + Compilation compilation, + CancellationToken cancellationToken) + => ProxyInterfaceModelExtractor.ExtractProxyInterfaceModel(typeSymbol, compilation, cancellationToken); + + public static ProxyInterfaceModel? ExtractInheritedProxyInterfaceFromSyntaxContext( + GeneratorSyntaxContext context, + CancellationToken cancellationToken) + => ProxyInterfaceModelExtractor.ExtractInheritedProxyInterfaceFromSyntaxContext(context, cancellationToken); + + internal static SourceLocationModel GetSourceLocation(ISymbol? symbol) + => SymbolSourceLocationExtractor.GetSourceLocation(symbol); +} diff --git a/src/Orleans.CodeGenerator/Model/PropertyDescription.cs b/src/Orleans.CodeGenerator/Model/PropertyDescription.cs index 3ab0d55c7de..ffe2ab783d4 100644 --- a/src/Orleans.CodeGenerator/Model/PropertyDescription.cs +++ b/src/Orleans.CodeGenerator/Model/PropertyDescription.cs @@ -4,45 +4,44 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal interface IPropertyDescription : IMemberDescription { - internal interface IPropertyDescription : IMemberDescription - { - } +} - internal class PropertyDescription : IPropertyDescription +internal class PropertyDescription : IPropertyDescription +{ + public PropertyDescription(uint fieldId, bool isPrimaryConstructorParameter, IPropertySymbol property) { - public PropertyDescription(uint fieldId, bool isPrimaryConstructorParameter, IPropertySymbol property) + FieldId = fieldId; + IsPrimaryConstructorParameter = isPrimaryConstructorParameter; + Property = property; + + if (Type.TypeKind == TypeKind.Dynamic) { - FieldId = fieldId; - IsPrimaryConstructorParameter = isPrimaryConstructorParameter; - Property = property; - - if (Type.TypeKind == TypeKind.Dynamic) - { - TypeSyntax = PredefinedType(Token(SyntaxKind.ObjectKeyword)); - } - else - { - TypeSyntax = Type.ToTypeSyntax(); - } + TypeSyntax = PredefinedType(Token(SyntaxKind.ObjectKeyword)); } + else + { + TypeSyntax = Type.ToTypeSyntax(); + } + } - public uint FieldId { get; } - public ISymbol Symbol => Property; - public ITypeSymbol Type => Property.Type; - public INamedTypeSymbol ContainingType => Property.ContainingType; - public IPropertySymbol Property { get; } + public uint FieldId { get; } + public ISymbol Symbol => Property; + public ITypeSymbol Type => Property.Type; + public INamedTypeSymbol ContainingType => Property.ContainingType; + public IPropertySymbol Property { get; } - public TypeSyntax TypeSyntax { get; } + public TypeSyntax TypeSyntax { get; } - public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); - public string TypeName => Type.ToDisplayName(); - public string TypeNameIdentifier => Type.GetValidIdentifier(); - public bool IsPrimaryConstructorParameter { get; set; } - public bool IsSerializable => true; - public bool IsCopyable => true; + public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); + public string TypeName => Type.ToDisplayName(); + public string TypeNameIdentifier => Type.GetValidIdentifier(); + public bool IsPrimaryConstructorParameter { get; set; } + public bool IsSerializable => true; + public bool IsCopyable => true; - public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); - } + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/Model/ProxyInterfaceDescription.cs b/src/Orleans.CodeGenerator/Model/ProxyInterfaceDescription.cs index ba9c3264f4b..e959cd420ae 100644 --- a/src/Orleans.CodeGenerator/Model/ProxyInterfaceDescription.cs +++ b/src/Orleans.CodeGenerator/Model/ProxyInterfaceDescription.cs @@ -1,279 +1,272 @@ using Orleans.CodeGenerator.SyntaxGeneration; using Microsoft.CodeAnalysis; -using System; -using System.Collections.Generic; using Orleans.CodeGenerator.Diagnostics; -using System.Linq; using System.Diagnostics; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +[DebuggerDisplay("{InterfaceType} (proxy base {ProxyBaseType})")] +internal class ProxyInterfaceDescription : IEquatable { - [DebuggerDisplay("{InterfaceType} (proxy base {ProxyBaseType})")] - internal class ProxyInterfaceDescription : IEquatable + private static readonly char[] FilteredNameChars = ['`', '.']; + + public ProxyInterfaceDescription( + ProxyGenerationContext generationContext, + INamedTypeSymbol proxyBaseType, + INamedTypeSymbol interfaceType) { - private static readonly char[] FilteredNameChars = new char[] { '`', '.' }; - private List _methods; + ValidateBaseClass(generationContext.LibraryTypes, proxyBaseType); - public ProxyInterfaceDescription( - CodeGenerator codeGenerator, - INamedTypeSymbol proxyBaseType, - INamedTypeSymbol interfaceType) + var prop = interfaceType.GetAllMembers().FirstOrDefault(); + if (prop is { }) { - ValidateBaseClass(codeGenerator.LibraryTypes, proxyBaseType); - - var prop = interfaceType.GetAllMembers().FirstOrDefault(); - if (prop is { }) - { - throw new OrleansGeneratorDiagnosticAnalysisException(RpcInterfacePropertyDiagnostic.CreateDiagnostic(interfaceType, prop)); - } + throw new OrleansGeneratorDiagnosticAnalysisException(RpcInterfacePropertyDiagnostic.CreateDiagnostic(interfaceType, prop)); + } - CodeGenerator = codeGenerator; - InterfaceType = interfaceType; - Name = codeGenerator.GetAlias(interfaceType) ?? interfaceType.Name; - ProxyBaseType = proxyBaseType; + GenerationContext = generationContext; + InterfaceType = interfaceType; + Name = generationContext.GetAlias(interfaceType) ?? interfaceType.Name; + ProxyBaseType = proxyBaseType; - // If the name is a user-defined name which specified a generic arity, strip the arity backtick now - if (Name.IndexOfAny(FilteredNameChars) >= 0) + // If the name is a user-defined name which specified a generic arity, strip the arity backtick now + if (Name.IndexOfAny(FilteredNameChars) >= 0) + { + foreach (var c in FilteredNameChars) { - foreach (var c in FilteredNameChars) - { - Name = Name.Replace(c, '_'); - } + Name = Name.Replace(c, '_'); } + } - GeneratedNamespace = InterfaceType.GetNamespaceAndNesting() switch - { - { Length: > 0 } ns => $"{CodeGenerator.CodeGeneratorName}.{ns}", - _ => CodeGenerator.CodeGeneratorName - }; + GeneratedNamespace = InterfaceType.GetNamespaceAndNesting() switch + { + { Length: > 0 } ns => $"{GeneratedCodeUtilities.CodeGeneratorName}.{ns}", + _ => GeneratedCodeUtilities.CodeGeneratorName + }; - var names = new HashSet(StringComparer.Ordinal); - TypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); + var names = new HashSet(StringComparer.Ordinal); + TypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - foreach (var tp in interfaceType.GetAllTypeParameters()) + foreach (var tp in interfaceType.GetAllTypeParameters()) + { + var tpName = GetTypeParameterName(names, tp); + TypeParameters.Add((tpName, tp)); + } + + static string GetTypeParameterName(HashSet names, ITypeParameterSymbol tp) + { + var count = 0; + var result = tp.Name; + while (names.Contains(result)) { - var tpName = GetTypeParameterName(names, tp); - TypeParameters.Add((tpName, tp)); + result = $"{tp.Name}_{++count}"; } - static string GetTypeParameterName(HashSet names, ITypeParameterSymbol tp) + names.Add(result); + return result; + } + } + + public ProxyGenerationContext GenerationContext { get; } + + private List GetMethods() + { + var result = new List(); + foreach (var iface in GetAllInterfaces(InterfaceType)) + { + foreach (var method in iface.GetDeclaredInstanceMembers()) { - var count = 0; - var result = tp.Name; - while (names.Contains(result)) + if (method.MethodKind == MethodKind.ExplicitInterfaceImplementation) { - result = $"{tp.Name}_{++count}"; + // Explicit implementations can be ignored when generating a proxy. + // Proxies must implement every method explicitly to ensure faithful reproduction of the interface behavior. + // At the calling side, the explicit implementation will be called if it was not overridden by a derived type. + continue; } - names.Add(result); - return result; + var methodDescription = GenerationContext.GetProxyMethodDescription(InterfaceType, method: method); + result.Add(methodDescription); } } - public CodeGenerator CodeGenerator { get; } + return result; - private List GetMethods() + static IEnumerable GetAllInterfaces(INamedTypeSymbol s) { - var result = new List(); - foreach (var iface in GetAllInterfaces(InterfaceType)) + if (s.TypeKind == TypeKind.Interface) { - foreach (var method in iface.GetDeclaredInstanceMembers()) - { - if (method.MethodKind == MethodKind.ExplicitInterfaceImplementation) - { - // Explicit implementations can be ignored when generating a proxy. - // Proxies must implement every method explicitly to ensure faithful reproduction of the interface behavior. - // At the calling side, the explicit implementation will be called if it was not overridden by a derived type. - continue; - } - - var methodDescription = CodeGenerator.GetProxyMethodDescription(InterfaceType, method: method); - result.Add(methodDescription); - } + yield return s; } - return result; + foreach (var i in s.AllInterfaces) - static IEnumerable GetAllInterfaces(INamedTypeSymbol s) { - if (s.TypeKind == TypeKind.Interface) - { - yield return s; - } - - foreach (var i in s.AllInterfaces) - - { - yield return i; - } + yield return i; } } + } - public string Name { get; } - public INamedTypeSymbol InterfaceType { get; } - public List Methods => _methods ??= GetMethods(); - public SemanticModel SemanticModel { get; } - public string GeneratedNamespace { get; } - public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } - public INamedTypeSymbol ProxyBaseType { get; } + public string Name { get; } + public INamedTypeSymbol InterfaceType { get; } + public List Methods => field ??= GetMethods(); + public string GeneratedNamespace { get; } + public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } + public INamedTypeSymbol ProxyBaseType { get; } - private static void ValidateBaseClass(LibraryTypes l, INamedTypeSymbol baseClass) - { - ValidateGenericInvokeAsync(l, baseClass); - ValidateNonGenericInvokeAsync(l, baseClass); + private static void ValidateBaseClass(LibraryTypes l, INamedTypeSymbol baseClass) + { + ValidateGenericInvokeAsync(l, baseClass); + ValidateNonGenericInvokeAsync(l, baseClass); - static void ValidateGenericInvokeAsync(LibraryTypes l, INamedTypeSymbol baseClass) + static void ValidateGenericInvokeAsync(LibraryTypes l, INamedTypeSymbol baseClass) + { + var found = false; + string? complaint = null; + ISymbol? complaintMember = null; + foreach (var member in baseClass.GetMembers("InvokeAsync")) { - var found = false; - string complaint = null; - ISymbol complaintMember = null; - foreach (var member in baseClass.GetMembers("InvokeAsync")) + if (member is not IMethodSymbol method) { - if (member is not IMethodSymbol method) - { - complaintMember = member; - complaint = "not a method"; - continue; - } + complaintMember = member; + complaint = "not a method"; + continue; + } - if (method.TypeParameters.Length != 1) - { - complaintMember = member; - complaint = "incorrect number of type parameters (expected one type parameter)"; - continue; - } + if (method.TypeParameters.Length != 1) + { + complaintMember = member; + complaint = "incorrect number of type parameters (expected one type parameter)"; + continue; + } - if (method.Parameters.Length != 1) - { - complaintMember = member; - complaint = $"missing parameter (expected a parameter of type {l.IInvokable.ToDisplayString()})"; - continue; - } + if (method.Parameters.Length != 1) + { + complaintMember = member; + complaint = $"missing parameter (expected a parameter of type {l.IInvokable.ToDisplayString()})"; + continue; + } - var paramType = method.Parameters[0].Type; - if (!SymbolEqualityComparer.Default.Equals(paramType, l.IInvokable)) + var paramType = method.Parameters[0].Type; + if (!SymbolEqualityComparer.Default.Equals(paramType, l.IInvokable)) + { + var implementsIInvokable = false; + foreach (var @interface in paramType.AllInterfaces) { - var implementsIInvokable = false; - foreach (var @interface in paramType.AllInterfaces) - { - if (SymbolEqualityComparer.Default.Equals(@interface, l.IInvokable)) - { - implementsIInvokable = true; - break; - } - } - - if (!implementsIInvokable) + if (SymbolEqualityComparer.Default.Equals(@interface, l.IInvokable)) { - complaintMember = member; - complaint = $"incorrect parameter type (found {paramType}, expected {l.IInvokable} or a type which implements {l.IInvokable})"; - continue; + implementsIInvokable = true; + break; } } - var expectedReturnType = l.ValueTask_1.Construct(method.TypeParameters[0]); - if (!SymbolEqualityComparer.Default.Equals(method.ReturnType, expectedReturnType)) + if (!implementsIInvokable) { complaintMember = member; - complaint = $"incorrect return type (found: {method.ReturnType.ToDisplayString()}, expected {expectedReturnType.ToDisplayString()})"; + complaint = $"incorrect parameter type (found {paramType}, expected {l.IInvokable} or a type which implements {l.IInvokable})"; continue; } - - found = true; } - if (!found) + var expectedReturnType = l.ValueTask_1.Construct(method.TypeParameters[0]); + if (!SymbolEqualityComparer.Default.Equals(method.ReturnType, expectedReturnType)) { - var notFoundMessage = $"Proxy base class {baseClass} does not contain a definition for ValueTask InvokeAsync(IInvokable)"; - var locationMember = complaintMember ?? baseClass; - var complaintMessage = complaint switch - { - { Length: > 0 } => $"{notFoundMessage}. Complaint: {complaint} for symbol: {complaintMember.ToDisplayString()}", - _ => notFoundMessage, - }; - var diagnostic = IncorrectProxyBaseClassSpecificationDiagnostic.CreateDiagnostic(baseClass, locationMember.Locations.First(), complaintMessage); - throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); + complaintMember = member; + complaint = $"incorrect return type (found: {method.ReturnType.ToDisplayString()}, expected {expectedReturnType.ToDisplayString()})"; + continue; } + + found = true; } - - static void ValidateNonGenericInvokeAsync(LibraryTypes l, INamedTypeSymbol baseClass) + + if (!found) { - var found = false; - string complaint = null; - ISymbol complaintMember = null; - foreach (var member in baseClass.GetMembers("InvokeAsync")) + var notFoundMessage = $"Proxy base class {baseClass} does not contain a definition for ValueTask InvokeAsync(IInvokable)"; + var locationMember = complaintMember ?? baseClass; + var complaintMessage = complaint switch { - if (member is not IMethodSymbol method) - { - complaintMember = member; - complaint = "not a method"; - continue; - } + { Length: > 0 } when complaintMember is not null => $"{notFoundMessage}. Complaint: {complaint} for symbol: {complaintMember.ToDisplayString()}", + _ => notFoundMessage, + }; + var diagnostic = IncorrectProxyBaseClassSpecificationDiagnostic.CreateDiagnostic(baseClass, locationMember.Locations.First(), complaintMessage); + throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); + } + } + + static void ValidateNonGenericInvokeAsync(LibraryTypes l, INamedTypeSymbol baseClass) + { + var found = false; + string? complaint = null; + ISymbol? complaintMember = null; + foreach (var member in baseClass.GetMembers("InvokeAsync")) + { + if (member is not IMethodSymbol method) + { + complaintMember = member; + complaint = "not a method"; + continue; + } - if (method.TypeParameters.Length != 0) - { - complaintMember = member; - complaint = "incorrect number of type parameters (expected zero)"; - continue; - } + if (method.TypeParameters.Length != 0) + { + complaintMember = member; + complaint = "incorrect number of type parameters (expected zero)"; + continue; + } - if (method.Parameters.Length != 1) - { - complaintMember = member; - complaint = $"missing parameter (expected a parameter of type {l.IInvokable.ToDisplayString()})"; - continue; - } + if (method.Parameters.Length != 1) + { + complaintMember = member; + complaint = $"missing parameter (expected a parameter of type {l.IInvokable.ToDisplayString()})"; + continue; + } - var paramType = method.Parameters[0].Type; - if (!SymbolEqualityComparer.Default.Equals(paramType, l.IInvokable)) + var paramType = method.Parameters[0].Type; + if (!SymbolEqualityComparer.Default.Equals(paramType, l.IInvokable)) + { + var implementsIInvokable = false; + foreach (var @interface in paramType.AllInterfaces) { - var implementsIInvokable = false; - foreach (var @interface in paramType.AllInterfaces) + if (SymbolEqualityComparer.Default.Equals(@interface, l.IInvokable)) { - if (SymbolEqualityComparer.Default.Equals(@interface, l.IInvokable)) - { - implementsIInvokable = true; - break; - } - } - - if (!implementsIInvokable) - { - complaintMember = member; - complaint = $"incorrect parameter type (found {method.Parameters[0].Type}, expected {l.IInvokable})"; - continue; + implementsIInvokable = true; + break; } } - if (!SymbolEqualityComparer.Default.Equals(method.ReturnType, l.ValueTask)) + if (!implementsIInvokable) { complaintMember = member; - complaint = $"incorrect return type (found: {method.ReturnType.ToDisplayString()}, expected {l.ValueTask.ToDisplayString()})"; + complaint = $"incorrect parameter type (found {method.Parameters[0].Type}, expected {l.IInvokable})"; continue; } - - found = true; } - if (!found) + if (!SymbolEqualityComparer.Default.Equals(method.ReturnType, l.ValueTask)) { - var notFoundMessage = $"Proxy base class {baseClass} does not contain a definition for ValueTask InvokeAsync(IInvokable)"; - var locationMember = complaintMember ?? baseClass; - var complaintMessage = complaint switch - { - { Length: > 0 } => $"{notFoundMessage}. Complaint: {complaint} for symbol: {complaintMember.ToDisplayString()}", - _ => notFoundMessage, - }; - var diagnostic = IncorrectProxyBaseClassSpecificationDiagnostic.CreateDiagnostic(baseClass, locationMember.Locations.First(), complaintMessage); - throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); + complaintMember = member; + complaint = $"incorrect return type (found: {method.ReturnType.ToDisplayString()}, expected {l.ValueTask.ToDisplayString()})"; + continue; } + + found = true; } - } - public bool Equals(ProxyInterfaceDescription other) => SymbolEqualityComparer.Default.Equals(InterfaceType, other.InterfaceType) && SymbolEqualityComparer.Default.Equals(ProxyBaseType, other.ProxyBaseType); - public override bool Equals(object obj) => obj is ProxyInterfaceDescription other && Equals(other); - public override int GetHashCode() => SymbolEqualityComparer.Default.GetHashCode(InterfaceType) * 17 ^ SymbolEqualityComparer.Default.GetHashCode(ProxyBaseType); - public override string ToString() => $"Type: {InterfaceType}, ProxyBaseType: {ProxyBaseType}"; + if (!found) + { + var notFoundMessage = $"Proxy base class {baseClass} does not contain a definition for ValueTask InvokeAsync(IInvokable)"; + var locationMember = complaintMember ?? baseClass; + var complaintMessage = complaint switch + { + { Length: > 0 } when complaintMember is not null => $"{notFoundMessage}. Complaint: {complaint} for symbol: {complaintMember.ToDisplayString()}", + _ => notFoundMessage, + }; + var diagnostic = IncorrectProxyBaseClassSpecificationDiagnostic.CreateDiagnostic(baseClass, locationMember.Locations.First(), complaintMessage); + throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); + } + } } + + public bool Equals(ProxyInterfaceDescription? other) => other is not null && SymbolEqualityComparer.Default.Equals(InterfaceType, other.InterfaceType) && SymbolEqualityComparer.Default.Equals(ProxyBaseType, other.ProxyBaseType); + public override bool Equals(object? obj) => obj is ProxyInterfaceDescription other && Equals(other); + public override int GetHashCode() => SymbolEqualityComparer.Default.GetHashCode(InterfaceType) * 17 ^ SymbolEqualityComparer.Default.GetHashCode(ProxyBaseType); + public override string ToString() => $"Type: {InterfaceType}, ProxyBaseType: {ProxyBaseType}"; } diff --git a/src/Orleans.CodeGenerator/Model/ProxyInterfaceModel.cs b/src/Orleans.CodeGenerator/Model/ProxyInterfaceModel.cs new file mode 100644 index 00000000000..55920dd969b --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/ProxyInterfaceModel.cs @@ -0,0 +1,28 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes a mapping from a return type to an invokable base type (e.g., ValueTask → ValueTaskRequest). +/// +internal readonly record struct InvokableBaseTypeMapping(TypeRef ReturnType, TypeRef InvokableBaseType); + +/// +/// Describes a proxy base type used for RPC proxy generation. +/// +internal sealed record class ProxyBaseModel( + TypeRef ProxyBaseType, + bool IsExtension, + string GeneratedClassNameComponent, + EquatableArray InvokableBaseTypes); + +/// +/// Describes a [GenerateMethodSerializers]-annotated interface for incremental proxy/invokable generation. +/// +internal sealed record class ProxyInterfaceModel( + TypeRef InterfaceType, + string Name, + string GeneratedNamespace, + EquatableArray TypeParameters, + ProxyBaseModel ProxyBase, + EquatableArray Methods, + SourceLocationModel SourceLocation = default, + TypeMetadataIdentity MetadataIdentity = default); diff --git a/src/Orleans.CodeGenerator/Model/ProxyInterfaceModelExtractor.cs b/src/Orleans.CodeGenerator/Model/ProxyInterfaceModelExtractor.cs new file mode 100644 index 00000000000..4739853f193 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/ProxyInterfaceModelExtractor.cs @@ -0,0 +1,526 @@ +using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Model; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal static class ProxyInterfaceModelExtractor +{ + /// + /// Extracts a from a + /// provided by the ForAttributeWithMetadataName incremental pipeline step for + /// [GenerateMethodSerializers]-annotated interfaces. + /// + internal static ProxyInterfaceModel? ExtractProxyInterfaceFromAttributeContext( + GeneratorAttributeSyntaxContext context, + CancellationToken cancellationToken) + { + if (context.TargetSymbol is not INamedTypeSymbol typeSymbol || typeSymbol.TypeKind != TypeKind.Interface) + { + return null; + } + + return ExtractProxyInterfaceModel(typeSymbol, context.SemanticModel.Compilation, context.Attributes, cancellationToken); + } + + internal static ProxyInterfaceModel? ExtractProxyInterfaceModel( + INamedTypeSymbol typeSymbol, + Compilation compilation, + CancellationToken cancellationToken) + { + if (typeSymbol is null || typeSymbol.TypeKind != TypeKind.Interface) + { + return null; + } + + return ExtractProxyInterfaceModel(typeSymbol, compilation, [], cancellationToken); + } + + internal static ProxyInterfaceModel? ExtractInheritedProxyInterfaceFromSyntaxContext( + GeneratorSyntaxContext context, + CancellationToken cancellationToken) + { + if (context.Node is not InterfaceDeclarationSyntax interfaceDeclaration) + { + return null; + } + + var compilation = context.SemanticModel.Compilation; + if (context.SemanticModel.GetDeclaredSymbol(interfaceDeclaration, cancellationToken) is not INamedTypeSymbol typeSymbol + || typeSymbol.TypeKind != TypeKind.Interface) + { + return null; + } + + var options = new CodeGeneratorOptions(); + var libraryTypes = LibraryTypes.FromCompilation(compilation, options); + if (typeSymbol.GetAttributes(libraryTypes.GenerateMethodSerializersAttribute, out var directAttributes, inherited: false) + && directAttributes.Any(static attribute => TryGetProxyBaseInfo(attribute, out _, out _))) + { + return null; + } + + return ExtractProxyInterfaceModel(typeSymbol, compilation, [], cancellationToken); + } + + private static ProxyInterfaceModel? ExtractProxyInterfaceModel( + INamedTypeSymbol typeSymbol, + Compilation compilation, + ImmutableArray candidateAttributes, + CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var options = new CodeGeneratorOptions(); + var libraryTypes = LibraryTypes.FromCompilation(compilation, options); + + if (!TryGetGenerateMethodSerializersAttribute(typeSymbol, candidateAttributes, libraryTypes, out var attribute) + || !TryGetProxyBaseInfo(attribute, out var proxyBaseTypeSymbol, out var isExtension)) + { + return null; + } + + var proxyBaseType = proxyBaseTypeSymbol.OriginalDefinition; + var invokableBaseTypes = ExtractInvokableBaseTypeMappings(proxyBaseType, libraryTypes, cancellationToken); + var generatedClassNameComponent = isExtension ? $"{proxyBaseType.Name}_Ext" : proxyBaseType.Name; + var proxyBase = new ProxyBaseModel( + new TypeRef(proxyBaseType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + isExtension, + generatedClassNameComponent, + invokableBaseTypes); + + var name = GetProxyInterfaceName(typeSymbol, libraryTypes); + var typeParameters = ExtractInterfaceTypeParameters(typeSymbol); + var methods = ExtractInterfaceMethods(typeSymbol, libraryTypes, isExtension, cancellationToken); + var generatedNamespace = GeneratedCodeUtilities.GetGeneratedNamespaceName(typeSymbol); + + return new ProxyInterfaceModel( + new TypeRef(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + name, + generatedNamespace, + typeParameters, + proxyBase, + methods, + SourceLocation: SymbolSourceLocationExtractor.GetSourceLocation(typeSymbol), + MetadataIdentity: TypeMetadataIdentity.Create(typeSymbol)); + } + + private static string GetProxyInterfaceName(INamedTypeSymbol typeSymbol, LibraryTypes libraryTypes) + { + var alias = typeSymbol.GetAttribute(libraryTypes.AliasAttribute); + var name = alias is { ConstructorArguments.Length: > 0 } + && alias.ConstructorArguments[0].Value is string aliasName + && !string.IsNullOrWhiteSpace(aliasName) + ? aliasName + : typeSymbol.Name; + + if (name.IndexOfAny(['`', '.']) >= 0) + { + name = name.Replace('`', '_').Replace('.', '_'); + } + + return name; + } + + + private static bool TryGetGenerateMethodSerializersAttribute( + INamedTypeSymbol typeSymbol, + ImmutableArray candidateAttributes, + LibraryTypes libraryTypes, + [NotNullWhen(true)] out AttributeData? attribute) + { + foreach (var candidate in candidateAttributes) + { + if (!TryGetProxyBaseInfo(candidate, out _, out _)) + { + continue; + } + + attribute = candidate; + return true; + } + + if (typeSymbol.GetAttributes(libraryTypes.GenerateMethodSerializersAttribute, out var inheritedAttributes, inherited: true)) + { + foreach (var inheritedAttribute in inheritedAttributes) + { + if (!TryGetProxyBaseInfo(inheritedAttribute, out _, out _)) + { + continue; + } + + attribute = inheritedAttribute; + return true; + } + } + + attribute = null; + return false; + } + + private static bool TryGetProxyBaseInfo(AttributeData? attribute, [NotNullWhen(true)] out INamedTypeSymbol? proxyBaseTypeSymbol, out bool isExtension) + { + proxyBaseTypeSymbol = null; + isExtension = false; + + if (attribute is null + || attribute.ConstructorArguments.Length < 1 + || attribute.ConstructorArguments[0].Value is not INamedTypeSymbol proxyBaseType) + { + return false; + } + + proxyBaseTypeSymbol = proxyBaseType; + if (attribute.ConstructorArguments.Length > 1 && attribute.ConstructorArguments[1].Value is bool extension) + { + isExtension = extension; + } + + return true; + } + + private static ImmutableArray ExtractInvokableBaseTypeMappings( + INamedTypeSymbol proxyBaseType, + LibraryTypes libraryTypes, + CancellationToken cancellationToken) + { + if (!proxyBaseType.GetAttributes(libraryTypes.DefaultInvokableBaseTypeAttribute, out var invokableBaseTypeAttributes)) + { + return []; + } + + var mappings = new Dictionary(StringComparer.Ordinal); + foreach (var attr in invokableBaseTypeAttributes) + { + cancellationToken.ThrowIfCancellationRequested(); + + var ctorArgs = attr.ConstructorArguments; + if (ctorArgs.Length < 2 + || ctorArgs[0].Value is not INamedTypeSymbol returnType + || ctorArgs[1].Value is not INamedTypeSymbol invokableBaseType) + { + continue; + } + + var returnTypeRef = new TypeRef(returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + var invokableBaseTypeRef = new TypeRef(invokableBaseType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + mappings[returnTypeRef.SyntaxString] = new InvokableBaseTypeMapping(returnTypeRef, invokableBaseTypeRef); + } + + if (mappings.Count == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(mappings.Count); + foreach (var mapping in mappings.OrderBy(static m => m.Key, StringComparer.Ordinal)) + { + builder.Add(mapping.Value); + } + + return builder.MoveToImmutable(); + } + + private static ImmutableArray ExtractInterfaceTypeParameters(INamedTypeSymbol typeSymbol) + { + if (typeSymbol.TypeParameters.Length == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(typeSymbol.TypeParameters.Length); + foreach (var tp in typeSymbol.TypeParameters) + { + builder.Add(new TypeParameterModel(tp.Name, tp.Name, tp.Ordinal)); + } + + return builder.MoveToImmutable(); + } + + private static ImmutableArray ExtractInterfaceMethods( + INamedTypeSymbol interfaceType, + LibraryTypes libraryTypes, + bool isExtension, + CancellationToken cancellationToken) + { + var methods = new SortedDictionary(StringComparer.Ordinal); + + foreach (var iface in GetAllInterfaces(interfaceType)) + { + cancellationToken.ThrowIfCancellationRequested(); + + foreach (var member in iface.GetDeclaredInstanceMembers()) + { + if (member.MethodKind == MethodKind.ExplicitInterfaceImplementation) + { + continue; + } + + var originalMethod = member.OriginalDefinition; + var methodIdentity = GetMethodIdentity( + isExtension ? interfaceType : originalMethod.ContainingType, + originalMethod); + if (methods.ContainsKey(methodIdentity)) + { + continue; + } + + var containingInterface = isExtension ? interfaceType : originalMethod.ContainingType; + var methodModel = ExtractMethodModel(member, originalMethod, containingInterface, libraryTypes); + if (methodModel is not null) + { + methods.Add(methodIdentity, methodModel); + } + } + } + + if (methods.Count == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(methods.Count); + foreach (var method in methods.Values) + { + builder.Add(method); + } + + return builder.MoveToImmutable(); + } + + private static IEnumerable GetAllInterfaces(INamedTypeSymbol symbol) + { + if (symbol.TypeKind == TypeKind.Interface) + { + yield return symbol; + } + + foreach (var iface in symbol.AllInterfaces) + { + yield return iface; + } + } + + private static string GetMethodIdentity(INamedTypeSymbol containingType, IMethodSymbol method) + { + var builder = new StringBuilder(); + builder.Append(containingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + builder.Append("::"); + builder.Append(method.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + builder.Append('.'); + builder.Append(method.Name); + builder.Append('`'); + builder.Append(method.Arity); + builder.Append('('); + + for (var i = 0; i < method.Parameters.Length; i++) + { + if (i > 0) + { + builder.Append(','); + } + + builder.Append(method.Parameters[i].Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + + builder.Append(')'); + builder.Append("->"); + builder.Append(method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + return builder.ToString(); + } + + private static MethodModel ExtractMethodModel( + IMethodSymbol method, + IMethodSymbol originalMethod, + INamedTypeSymbol containingInterface, + LibraryTypes libraryTypes) + { + var generatedMethodId = GeneratedCodeUtilities.CreateHashedMethodId(originalMethod); + + // Determine method ID: explicit ID → alias → generated hash + string methodId; + var idValue = GeneratedCodeUtilities.GetId(libraryTypes, originalMethod); + if (idValue.HasValue) + { + methodId = idValue.Value.ToString(CultureInfo.InvariantCulture); + } + else + { + var aliasAttr = originalMethod.GetAttribute(libraryTypes.AliasAttribute); + methodId = aliasAttr is not null && aliasAttr.ConstructorArguments.Length > 0 + ? (string?)aliasAttr.ConstructorArguments[0].Value ?? generatedMethodId + : generatedMethodId; + } + + var parameters = ExtractMethodParameters(method, libraryTypes); + var typeParameters = ExtractMethodTypeParameters(method); + var returnType = new TypeRef(method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + + // Response timeout + long? responseTimeoutTicks = null; + var timeoutAttr = originalMethod.GetAttribute(libraryTypes.ResponseTimeoutAttribute); + if (timeoutAttr is not null + && timeoutAttr.ConstructorArguments.Length > 0 + && timeoutAttr.ConstructorArguments[0].Value is string timeoutStr) + { + if (TimeSpan.TryParse(timeoutStr, out var timeout)) + { + responseTimeoutTicks = timeout.Ticks; + } + } + + var customInitializers = ExtractCustomInitializers(originalMethod, libraryTypes); + + var isCancellable = false; + foreach (var param in method.Parameters) + { + if (SymbolEqualityComparer.Default.Equals(libraryTypes.CancellationToken, param.Type)) + { + isCancellable = true; + break; + } + } + + return new MethodModel( + method.Name, + returnType, + parameters, + typeParameters, + new TypeRef(containingInterface.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + new TypeRef(originalMethod.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + containingInterface.Name, + GeneratedCodeUtilities.GetGeneratedNamespaceName(containingInterface), + containingInterface.GetAllTypeParameters().Count(), + generatedMethodId, + methodId, + responseTimeoutTicks, + customInitializers, + isCancellable); + } + + private static ImmutableArray ExtractMethodParameters( + IMethodSymbol method, + LibraryTypes libraryTypes) + { + if (method.Parameters.Length == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(method.Parameters.Length); + foreach (var param in method.Parameters) + { + var isCancellationToken = SymbolEqualityComparer.Default.Equals(libraryTypes.CancellationToken, param.Type); + builder.Add(new MethodParameterModel( + param.Name, + new TypeRef(param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + param.Ordinal, + isCancellationToken)); + } + + return builder.MoveToImmutable(); + } + + private static ImmutableArray ExtractMethodTypeParameters(IMethodSymbol method) + { + if (method.TypeParameters.Length == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(method.TypeParameters.Length); + foreach (var tp in method.TypeParameters) + { + builder.Add(new TypeParameterModel(tp.Name, tp.Name, tp.Ordinal)); + } + + return builder.MoveToImmutable(); + } + + private static ImmutableArray ExtractCustomInitializers( + IMethodSymbol method, + LibraryTypes libraryTypes) + { + ImmutableArray.Builder? builder = null; + + foreach (var methodAttr in method.GetAttributes()) + { + if (methodAttr.AttributeClass is null) + { + continue; + } + + if (methodAttr.AttributeClass.GetAttributes(libraryTypes.InvokableCustomInitializerAttribute, out var attrs)) + { + foreach (var attr in attrs) + { + if (attr.ConstructorArguments.Length == 0 || attr.ConstructorArguments[0].Value is not string methodName) + { + continue; + } + + string? argumentValue = null; + + if (attr.ConstructorArguments.Length == 2) + { + argumentValue = attr.ConstructorArguments[1].Value?.ToString(); + } + else + { + if (TryGetNamedArgument(attr.NamedArguments, "AttributeArgumentName", out var argNameArg) + && argNameArg.Value is string attributeArgumentName + && TryGetNamedArgument(methodAttr.NamedArguments, attributeArgumentName, out var namedArgument)) + { + argumentValue = namedArgument.Value?.ToString(); + } + else + { + var index = 0; + if (TryGetNamedArgument(attr.NamedArguments, "AttributeArgumentIndex", out var indexArg)) + { + index = indexArg.Value is int value ? value : index; + } + + if (methodAttr.ConstructorArguments.Length > index) + { + argumentValue = methodAttr.ConstructorArguments[index].Value?.ToString(); + } + } + } + + builder ??= ImmutableArray.CreateBuilder(); + builder.Add(new CustomInitializerModel(methodName ?? string.Empty, argumentValue ?? string.Empty)); + } + } + } + + return builder is not null + ? builder.ToImmutable() + : []; + } + + private static bool TryGetNamedArgument( + ImmutableArray> arguments, + string name, + out TypedConstant value) + { + foreach (var arg in arguments) + { + if (string.Equals(arg.Key, name, StringComparison.Ordinal)) + { + value = arg.Value; + return true; + } + } + + value = default; + return false; + } +} + + diff --git a/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs b/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs index a1abb2a33ac..cfb79b69f13 100644 --- a/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs +++ b/src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs @@ -1,179 +1,171 @@ using Orleans.CodeGenerator.SyntaxGeneration; using Microsoft.CodeAnalysis; -using System; -using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Collections.Immutable; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using Microsoft.CodeAnalysis.CSharp; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Describes an invokable method on a proxy interface. +/// +[DebuggerDisplay("{Method} (from {ProxyInterface})")] +internal class ProxyMethodDescription : IEquatable { - /// - /// Describes an invokable method on a proxy interface. - /// - [DebuggerDisplay("{Method} (from {ProxyInterface})")] - internal class ProxyMethodDescription : IEquatable + private readonly GeneratedInvokableDescription _originalInvokable; + public static ProxyMethodDescription Create( + ProxyInterfaceDescription proxyInterface, + GeneratedInvokableDescription generatedInvokable, + IMethodSymbol method) + => new(proxyInterface, generatedInvokable, method); + + private ProxyMethodDescription(ProxyInterfaceDescription proxyInterface, GeneratedInvokableDescription generatedInvokable, IMethodSymbol method) { - private readonly GeneratedInvokableDescription _originalInvokable; - public static ProxyMethodDescription Create( - ProxyInterfaceDescription proxyInterface, - GeneratedInvokableDescription generatedInvokable, - IMethodSymbol method) - => new(proxyInterface, generatedInvokable, method); - - private ProxyMethodDescription(ProxyInterfaceDescription proxyInterface, GeneratedInvokableDescription generatedInvokable, IMethodSymbol method) - { - _originalInvokable = generatedInvokable; - Method = method; - ProxyInterface = proxyInterface; + _originalInvokable = generatedInvokable; + Method = method; + ProxyInterface = proxyInterface; - TypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); + TypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); + MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>(); - TypeParametersWithArguments = Method.ContainingType.GetAllTypeParameters().Zip(method.ContainingType.GetAllTypeArguments(), (a, b) => (a, b)).ToImmutableArray(); - var names = new HashSet(StringComparer.Ordinal); - foreach (var (typeParameter, typeArgument) in TypeParametersWithArguments) - { - var tpName = GetTypeParameterName(names, typeParameter); - TypeParameters.Add((tpName, typeParameter)); - } + TypeParametersWithArguments = [.. Method.ContainingType.GetAllTypeParameters().Zip(method.ContainingType.GetAllTypeArguments(), (a, b) => (a, b))]; + var names = new HashSet(StringComparer.Ordinal); + foreach (var (typeParameter, typeArgument) in TypeParametersWithArguments) + { + var tpName = GetTypeParameterName(names, typeParameter); + TypeParameters.Add((tpName, typeParameter)); + } - foreach (var typeParameter in Method.TypeParameters) - { - var tpName = GetTypeParameterName(names, typeParameter); - TypeParameters.Add((tpName, typeParameter)); - MethodTypeParameters.Add((tpName, typeParameter)); - } + foreach (var typeParameter in Method.TypeParameters) + { + var tpName = GetTypeParameterName(names, typeParameter); + TypeParameters.Add((tpName, typeParameter)); + MethodTypeParameters.Add((tpName, typeParameter)); + } - TypeParameterSubstitutions = new(SymbolEqualityComparer.Default); - foreach (var (name, parameter) in TypeParameters) - { - TypeParameterSubstitutions[parameter] = name; - } + TypeParameterSubstitutions = new(SymbolEqualityComparer.Default); + foreach (var (name, parameter) in TypeParameters) + { + TypeParameterSubstitutions[parameter] = name; + } - foreach (var (parameter, arg) in TypeParametersWithArguments) - { - TypeParameterSubstitutions[parameter] = arg.ToDisplayName(); - } + foreach (var (parameter, arg) in TypeParametersWithArguments) + { + TypeParameterSubstitutions[parameter] = arg.ToDisplayName(); + } - GeneratedInvokable = new ConstructedGeneratedInvokableDescription(generatedInvokable, this); - static string GetTypeParameterName(HashSet names, ITypeParameterSymbol typeParameter) + GeneratedInvokable = new ConstructedGeneratedInvokableDescription(generatedInvokable, this); + static string GetTypeParameterName(HashSet names, ITypeParameterSymbol typeParameter) + { + var count = 0; + var result = typeParameter.Name; + while (names.Contains(result)) { - var count = 0; - var result = typeParameter.Name; - while (names.Contains(result)) - { - result = $"{typeParameter.Name}_{++count}"; - } - - names.Add(result); - return result.EscapeIdentifier(); + result = $"{typeParameter.Name}_{++count}"; } + + names.Add(result); + return result.EscapeIdentifier(); } + } - public CodeGenerator CodeGenerator => InvokableMethod.CodeGenerator; - public InvokableMethodDescription InvokableMethod => _originalInvokable.MethodDescription; - public ConstructedGeneratedInvokableDescription GeneratedInvokable { get; } - public ProxyInterfaceDescription ProxyInterface { get; } - - public IMethodSymbol Method { get; } - public InvokableMethodId InvokableId { get; } - public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } - public List<(string Name, ITypeParameterSymbol Parameter)> MethodTypeParameters { get; } - public ImmutableArray<(ITypeParameterSymbol Parameter, ITypeSymbol Argument)> TypeParametersWithArguments { get; } - public Dictionary TypeParameterSubstitutions { get; } - - /// - /// Mapping of method return types to invokable base type. The code generator will create a derived type with the method arguments as fields. - /// - public IReadOnlyDictionary InvokableBaseTypes => InvokableMethod.InvokableBaseTypes; - public InvokableMethodId InvokableKey => InvokableMethod.Key; - public List<(string, TypedConstant)> CustomInitializerMethods => InvokableMethod.CustomInitializerMethods; - public string GeneratedMethodId => InvokableMethod.GeneratedMethodId; - public string MethodId => InvokableMethod.MethodId; - public bool HasAlias => InvokableMethod.HasAlias; - public long? ResponseTimeoutTicks => InvokableMethod.ResponseTimeoutTicks; - - public override int GetHashCode() => ProxyInterface.GetHashCode() * 17 ^ InvokableMethod.GetHashCode(); - public bool Equals(ProxyMethodDescription other) => other is not null && InvokableMethod.Key.Equals(other.InvokableKey) && ProxyInterface.Equals(other.ProxyInterface); - public override bool Equals(object other) => other is ProxyMethodDescription imd && Equals(imd); - - internal sealed class ConstructedGeneratedInvokableDescription : ISerializableTypeDescription - { - private TypeSyntax _typeSyntax; - private TypeSyntax _baseTypeSyntax; - private readonly GeneratedInvokableDescription _invokableDescription; - private readonly ProxyMethodDescription _proxyMethod; + public ProxyGenerationContext GenerationContext => InvokableMethod.GenerationContext; + public InvokableMethodDescription InvokableMethod => _originalInvokable.MethodDescription; + public ConstructedGeneratedInvokableDescription GeneratedInvokable { get; } + public ProxyInterfaceDescription ProxyInterface { get; } + + public IMethodSymbol Method { get; } + public InvokableMethodId InvokableId { get; } + public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } + public List<(string Name, ITypeParameterSymbol Parameter)> MethodTypeParameters { get; } + public ImmutableArray<(ITypeParameterSymbol Parameter, ITypeSymbol Argument)> TypeParametersWithArguments { get; } + public Dictionary TypeParameterSubstitutions { get; } - public ConstructedGeneratedInvokableDescription(GeneratedInvokableDescription invokableDescription, ProxyMethodDescription proxyMethod) + /// + /// Mapping of method return types to invokable base type. The code generator will create a derived type with the method arguments as fields. + /// + public IReadOnlyDictionary InvokableBaseTypes => InvokableMethod.InvokableBaseTypes; + public InvokableMethodId InvokableKey => InvokableMethod.Key; + public List<(string, TypedConstant)> CustomInitializerMethods => InvokableMethod.CustomInitializerMethods; + public string GeneratedMethodId => InvokableMethod.GeneratedMethodId; + public string MethodId => InvokableMethod.MethodId; + public bool HasAlias => InvokableMethod.HasAlias; + public long? ResponseTimeoutTicks => InvokableMethod.ResponseTimeoutTicks; + + public override int GetHashCode() => ProxyInterface.GetHashCode() * 17 ^ InvokableMethod.GetHashCode(); + public bool Equals(ProxyMethodDescription other) => other is not null && InvokableMethod.Key.Equals(other.InvokableKey) && ProxyInterface.Equals(other.ProxyInterface); + public override bool Equals(object other) => other is ProxyMethodDescription imd && Equals(imd); + + internal sealed class ConstructedGeneratedInvokableDescription : ISerializableTypeDescription + { + private readonly GeneratedInvokableDescription _invokableDescription; + private readonly ProxyMethodDescription _proxyMethod; + + public ConstructedGeneratedInvokableDescription(GeneratedInvokableDescription invokableDescription, ProxyMethodDescription proxyMethod) + { + _invokableDescription = invokableDescription; + _proxyMethod = proxyMethod; + Members = new List(invokableDescription.Members.Count); + var proxyMethodParameters = proxyMethod.Method.Parameters; + foreach (var member in invokableDescription.Members.OfType()) { - _invokableDescription = invokableDescription; - _proxyMethod = proxyMethod; - Members = new List(invokableDescription.Members.Count); - var proxyMethodParameters = proxyMethod.Method.Parameters; - foreach (var member in invokableDescription.Members.OfType()) - { - Members.Add(new InvokableGenerator.MethodParameterFieldDescription( - proxyMethod.CodeGenerator, - proxyMethodParameters[member.ParameterOrdinal], - member.FieldName, - member.FieldId, - proxyMethod.TypeParameterSubstitutions, - member.IsSerializable)); - } + Members.Add(new InvokableGenerator.MethodParameterFieldDescription( + member.LibraryTypes, + proxyMethodParameters[member.ParameterOrdinal], + member.FieldName, + member.FieldId, + proxyMethod.TypeParameterSubstitutions, + member.IsSerializable)); } + } - public Accessibility Accessibility => _invokableDescription.Accessibility; - public TypeSyntax TypeSyntax => _typeSyntax ??= CreateTypeSyntax(); - public TypeSyntax OpenTypeSyntax => _invokableDescription.OpenTypeSyntax; - public bool HasComplexBaseType => BaseType is { SpecialType: not SpecialType.System_Object }; - public bool IncludePrimaryConstructorParameters => false; - public INamedTypeSymbol BaseType => _invokableDescription.BaseType; - public TypeSyntax BaseTypeSyntax => _baseTypeSyntax ??= BaseType.ToTypeSyntax(_proxyMethod.TypeParameterSubstitutions); - public string Namespace => GeneratedNamespace; - public string GeneratedNamespace => _invokableDescription.GeneratedNamespace; - public string Name => _invokableDescription.Name; - public bool IsValueType => _invokableDescription.IsValueType; - public bool IsSealedType => _invokableDescription.IsSealedType; - public bool IsAbstractType => _invokableDescription.IsAbstractType; - public bool IsEnumType => _invokableDescription.IsEnumType; - public bool IsGenericType => TypeParameters.Count > 0; - public List Members { get; } - public Compilation Compilation => MethodDescription.CodeGenerator.Compilation; - public bool IsEmptyConstructable => ActivatorConstructorParameters is not { Count: > 0 }; - public bool UseActivator => ActivatorConstructorParameters is { Count: > 0 }; - public bool TrackReferences => _invokableDescription.TrackReferences; - public bool OmitDefaultMemberValues => _invokableDescription.OmitDefaultMemberValues; - public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters => _proxyMethod.TypeParameters; - public List SerializationHooks => _invokableDescription.SerializationHooks; - public bool IsShallowCopyable => _invokableDescription.IsShallowCopyable; - public bool IsUnsealedImmutable => _invokableDescription.IsUnsealedImmutable; - public bool IsImmutable => _invokableDescription.IsImmutable; - public bool IsExceptionType => _invokableDescription.IsExceptionType; - public List ActivatorConstructorParameters => _invokableDescription.ActivatorConstructorParameters; - public bool HasActivatorConstructor => UseActivator; - public string ReturnValueInitializerMethod => _invokableDescription.ReturnValueInitializerMethod; - - public InvokableMethodDescription MethodDescription => _invokableDescription.MethodDescription; - - public ExpressionSyntax GetObjectCreationExpression() => ObjectCreationExpression(TypeSyntax, ArgumentList(), null); - - private TypeSyntax CreateTypeSyntax() + public Accessibility Accessibility => _invokableDescription.Accessibility; + public TypeSyntax TypeSyntax => field ??= CreateTypeSyntax(); + public TypeSyntax OpenTypeSyntax => _invokableDescription.OpenTypeSyntax; + public bool HasComplexBaseType => BaseType is { SpecialType: not SpecialType.System_Object }; + public bool IncludePrimaryConstructorParameters => false; + public INamedTypeSymbol BaseType => _invokableDescription.BaseType; + public TypeSyntax BaseTypeSyntax => field ??= BaseType.ToTypeSyntax(_proxyMethod.TypeParameterSubstitutions); + public string Namespace => GeneratedNamespace; + public string GeneratedNamespace => _invokableDescription.GeneratedNamespace; + public string Name => _invokableDescription.Name; + public bool IsValueType => _invokableDescription.IsValueType; + public bool IsSealedType => _invokableDescription.IsSealedType; + public bool IsAbstractType => _invokableDescription.IsAbstractType; + public bool IsEnumType => _invokableDescription.IsEnumType; + public bool IsGenericType => TypeParameters.Count > 0; + public List Members { get; } + public Compilation Compilation => MethodDescription.GenerationContext.Compilation; + public bool IsEmptyConstructable => ActivatorConstructorParameters is not { Count: > 0 }; + public bool UseActivator => ActivatorConstructorParameters is { Count: > 0 }; + public bool TrackReferences => _invokableDescription.TrackReferences; + public bool OmitDefaultMemberValues => _invokableDescription.OmitDefaultMemberValues; + public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters => _proxyMethod.TypeParameters; + public List SerializationHooks => _invokableDescription.SerializationHooks; + public bool IsShallowCopyable => _invokableDescription.IsShallowCopyable; + public bool IsUnsealedImmutable => _invokableDescription.IsUnsealedImmutable; + public bool IsImmutable => _invokableDescription.IsImmutable; + public bool IsExceptionType => _invokableDescription.IsExceptionType; + public List ActivatorConstructorParameters => _invokableDescription.ActivatorConstructorParameters; + public bool HasActivatorConstructor => UseActivator; + public string? ReturnValueInitializerMethod => _invokableDescription.ReturnValueInitializerMethod; + + public InvokableMethodDescription MethodDescription => _invokableDescription.MethodDescription; + + public ExpressionSyntax GetObjectCreationExpression() => ObjectCreationExpression(TypeSyntax, ArgumentList(), null); + + private TypeSyntax CreateTypeSyntax() + { + var simpleName = InvokableGenerator.GetSimpleClassName(MethodDescription); + var subs = _proxyMethod.TypeParameterSubstitutions; + return (TypeParameters, Namespace) switch { - var simpleName = InvokableGenerator.GetSimpleClassName(MethodDescription); - var subs = _proxyMethod.TypeParameterSubstitutions; - return (TypeParameters, Namespace) switch - { - ({ Count: > 0 }, { Length: > 0 }) => QualifiedName(ParseName(Namespace), GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter])))))), - ({ Count: > 0 }, _) => GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter]))))), - (_, { Length: > 0 }) => QualifiedName(ParseName(Namespace), IdentifierName(simpleName)), - _ => IdentifierName(simpleName), - }; - } + ({ Count: > 0 }, { Length: > 0 }) => QualifiedName(ParseName(Namespace), GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter])))))), + ({ Count: > 0 }, _) => GenericName(Identifier(simpleName), TypeArgumentList(SeparatedList(TypeParameters.Select(p => IdentifierName(subs[p.Parameter]))))), + (_, { Length: > 0 }) => QualifiedName(ParseName(Namespace), IdentifierName(simpleName)), + _ => IdentifierName(simpleName), + }; } } } diff --git a/src/Orleans.CodeGenerator/Model/ProxyOutputModel.cs b/src/Orleans.CodeGenerator/Model/ProxyOutputModel.cs new file mode 100644 index 00000000000..3f126feae4e --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/ProxyOutputModel.cs @@ -0,0 +1,23 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes the proxy output for a single interface, including the invokable metadata names owned by that file. +/// +internal sealed record class ProxyOutputModel( + ProxyInterfaceModel ProxyInterface, + EquatableArray OwnedInvokableMetadataNames, + EquatableArray OwnedInvokableActivatorMetadataNames, + bool UseDeclaredInvokableFallback) +{ + public ProxyOutputModel( + ProxyInterfaceModel proxyInterface, + EquatableArray ownedInvokableMetadataNames, + bool useDeclaredInvokableFallback) + : this( + proxyInterface, + ownedInvokableMetadataNames, + EquatableArray.Empty, + useDeclaredInvokableFallback) + { + } +} diff --git a/src/Orleans.CodeGenerator/Model/ReferenceAssemblyModel.cs b/src/Orleans.CodeGenerator/Model/ReferenceAssemblyModel.cs new file mode 100644 index 00000000000..6d7afd561e5 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/ReferenceAssemblyModel.cs @@ -0,0 +1,73 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes a well-known type ID mapping. +/// +internal readonly record struct WellKnownTypeIdModel(TypeRef Type, uint Id); + +/// +/// Describes a type alias mapping. +/// +internal readonly record struct TypeAliasModel(TypeRef Type, string Alias); + +/// +/// A single component in a compound type alias path. +/// +internal readonly record struct CompoundAliasComponentModel +{ + public CompoundAliasComponentModel(string stringValue) + { + StringValue = stringValue; + TypeValue = TypeRef.Empty; + IsType = false; + } + + public CompoundAliasComponentModel(TypeRef typeValue) + { + StringValue = null; + TypeValue = typeValue; + IsType = true; + } + + public bool IsString => !IsType && StringValue is not null; + public bool IsType { get; } + public string? StringValue { get; } + public TypeRef TypeValue { get; } + +} + +/// +/// Describes a compound type alias entry (a path of components mapping to a type). +/// +internal readonly record struct CompoundTypeAliasModel(EquatableArray Components, TypeRef TargetType); + +/// +/// Describes an interface implementation (a concrete type implementing an invokable interface). +/// +internal readonly record struct InterfaceImplementationModel +{ + public InterfaceImplementationModel(TypeRef implementationType, SourceLocationModel sourceLocation = default) + { + ImplementationType = implementationType; + SourceLocation = sourceLocation; + } + + public TypeRef ImplementationType { get; } + public SourceLocationModel SourceLocation { get; } +} + +/// +/// Aggregated data extracted from referenced assemblies via [GenerateCodeForDeclaringAssembly] +/// and [ApplicationPart] attributes. This model is produced by a CompilationProvider-based +/// pipeline and cached via structural equality. +/// +internal sealed record class ReferenceAssemblyModel( + string AssemblyName, + EquatableArray ApplicationParts, + EquatableArray WellKnownTypeIds, + EquatableArray TypeAliases, + EquatableArray CompoundTypeAliases, + EquatableArray ReferencedSerializableTypes, + EquatableArray ReferencedProxyInterfaces, + EquatableArray RegisteredCodecs, + EquatableArray InterfaceImplementations); diff --git a/src/Orleans.CodeGenerator/Model/ReferenceAssemblyModelExtractor.cs b/src/Orleans.CodeGenerator/Model/ReferenceAssemblyModelExtractor.cs new file mode 100644 index 00000000000..f48012d5d98 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/ReferenceAssemblyModelExtractor.cs @@ -0,0 +1,323 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Model; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal static class ReferenceAssemblyModelExtractor +{ + /// + /// Extracts reference-assembly metadata from the compilation using the provided code generation options. + /// This isolates reference-assembly scanning into a cacheable pipeline step so that + /// downstream work can be skipped when references don't change. + /// + internal static ReferenceAssemblyModel ExtractReferenceAssemblyData( + Compilation compilation, + CodeGeneratorOptions options, + CancellationToken cancellationToken) + => ExtractReferenceAssemblyData(compilation, options, cancellationToken, out _); + + internal static ReferenceAssemblyModel ExtractReferenceAssemblyData( + Compilation compilation, + CodeGeneratorOptions options, + CancellationToken cancellationToken, + out ImmutableArray diagnostics) + { + var libraryTypes = LibraryTypes.FromCompilation(compilation, options); + + var applicationParts = new List(); + var applicationPartSet = new HashSet(StringComparer.Ordinal); + AddApplicationPart(compilation.Assembly.MetadataName); + + var assembliesToExamine = new HashSet(SymbolEqualityComparer.Default); + ComputeAssembliesToExamine( + compilation.Assembly, + assembliesToExamine, + libraryTypes.GenerateCodeForDeclaringAssemblyAttribute, + cancellationToken); + + var wellKnownTypeIds = new HashSet(); + var typeAliases = new HashSet(); + var compoundTypeAliases = new HashSet(); + var referencedSerializableTypes = new HashSet(); + var referencedProxyInterfaces = new HashSet(); + var registeredCodecs = new HashSet(); + var interfaceImplementations = new HashSet(); + var diagnosticBuilder = ImmutableArray.CreateBuilder(); + + foreach (var reference in compilation.References) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (compilation.GetAssemblyOrModuleSymbol(reference) is not IAssemblySymbol asm) + { + continue; + } + + if (!asm.GetAttributes(libraryTypes.ApplicationPartAttribute, out var attrs)) + { + continue; + } + + AddApplicationPart(asm.MetadataName); + foreach (var attr in attrs) + { + if (attr.ConstructorArguments.Length > 0 + && attr.ConstructorArguments[0].Value is string partName) + { + AddApplicationPart(partName); + } + } + } + + foreach (var asm in assembliesToExamine) + { + cancellationToken.ThrowIfCancellationRequested(); + + var reportSerializableTypeDiagnostics = !SymbolEqualityComparer.Default.Equals(asm, compilation.Assembly); + foreach (var symbol in asm.GetDeclaredTypes()) + { + cancellationToken.ThrowIfCancellationRequested(); + + try + { + if (SerializableTypeModelExtractor.TryExtractSerializableTypeModel(symbol, compilation, libraryTypes, options, reportSerializableTypeDiagnostics) is { } serializableTypeModel) + { + referencedSerializableTypes.Add(serializableTypeModel); + } + } + catch (OrleansGeneratorDiagnosticAnalysisException exception) when (reportSerializableTypeDiagnostics) + { + diagnosticBuilder.Add(exception.Diagnostic); + } + + if (ProxyInterfaceModelExtractor.ExtractProxyInterfaceModel(symbol, compilation, cancellationToken) is { } proxyInterfaceModel) + { + referencedProxyInterfaces.Add(proxyInterfaceModel); + } + + var typeRef = new TypeRef(symbol.ToOpenTypeSyntax().ToString()); + if (GeneratedCodeUtilities.GetId(libraryTypes, symbol) is uint wellKnownTypeId) + { + wellKnownTypeIds.Add(new WellKnownTypeIdModel(typeRef, wellKnownTypeId)); + } + + if (symbol.GetAttribute(libraryTypes.AliasAttribute) is { ConstructorArguments.Length: > 0 } aliasAttr + && aliasAttr.ConstructorArguments[0].Value is string alias) + { + typeAliases.Add(new TypeAliasModel(typeRef, alias)); + } + + if (TryExtractCompoundTypeAlias(symbol, libraryTypes.CompoundTypeAliasAttribute, out var components)) + { + compoundTypeAliases.Add(new CompoundTypeAliasModel(components, typeRef)); + } + + if ((symbol.TypeKind == TypeKind.Class || symbol.TypeKind == TypeKind.Struct) + && !symbol.IsAbstract + && (symbol.DeclaredAccessibility == Accessibility.Public || symbol.DeclaredAccessibility == Accessibility.Internal)) + { + if (symbol.HasAttribute(libraryTypes.RegisterSerializerAttribute)) + { + registeredCodecs.Add(new RegisteredCodecModel(typeRef, RegisteredCodecKind.Serializer)); + } + + if (symbol.HasAttribute(libraryTypes.RegisterCopierAttribute)) + { + registeredCodecs.Add(new RegisteredCodecModel(typeRef, RegisteredCodecKind.Copier)); + } + + if (symbol.HasAttribute(libraryTypes.RegisterActivatorAttribute)) + { + registeredCodecs.Add(new RegisteredCodecModel(typeRef, RegisteredCodecKind.Activator)); + } + + if (symbol.HasAttribute(libraryTypes.RegisterConverterAttribute)) + { + registeredCodecs.Add(new RegisteredCodecModel(typeRef, RegisteredCodecKind.Converter)); + } + + foreach (var iface in symbol.AllInterfaces) + { + if (iface.GetAttribute(libraryTypes.GenerateMethodSerializersAttribute, inherited: true) is not null) + { + interfaceImplementations.Add(new InterfaceImplementationModel(typeRef, SymbolSourceLocationExtractor.GetSourceLocation(symbol))); + break; + } + } + } + } + } + + var orderedApplicationParts = applicationParts.ToImmutableArray(); + + var sortedWellKnownTypeIds = wellKnownTypeIds + .OrderBy(static entry => entry.Type.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.Id) + .ToImmutableArray(); + + var sortedTypeAliases = typeAliases + .OrderBy(static entry => entry.Type.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.Alias, StringComparer.Ordinal) + .ToImmutableArray(); + + var sortedCompoundTypeAliases = compoundTypeAliases + .OrderBy(static entry => GetCompoundTypeAliasOrderKey(entry), StringComparer.Ordinal) + .ThenBy(static entry => entry.TargetType.SyntaxString, StringComparer.Ordinal) + .ToImmutableArray(); + + var sortedReferencedSerializableTypes = MetadataAggregateModelBuilder.OrderSerializableTypeModels(referencedSerializableTypes) + .ToImmutableArray(); + + var sortedReferencedProxyInterfaces = MetadataAggregateModelBuilder.OrderProxyInterfaceModels(referencedProxyInterfaces) + .ToImmutableArray(); + + var sortedRegisteredCodecs = registeredCodecs + .OrderBy(static entry => entry.Type.SyntaxString, StringComparer.Ordinal) + .ThenBy(static entry => entry.Kind) + .ToImmutableArray(); + + var sortedInterfaceImplementations = interfaceImplementations + .OrderBy(static entry => entry.ImplementationType.SyntaxString, StringComparer.Ordinal) + .ToImmutableArray(); + + diagnostics = diagnosticBuilder.ToImmutable(); + + return new ReferenceAssemblyModel( + AssemblyName: compilation.AssemblyName ?? string.Empty, + ApplicationParts: orderedApplicationParts, + WellKnownTypeIds: sortedWellKnownTypeIds, + TypeAliases: sortedTypeAliases, + CompoundTypeAliases: sortedCompoundTypeAliases, + ReferencedSerializableTypes: sortedReferencedSerializableTypes, + ReferencedProxyInterfaces: sortedReferencedProxyInterfaces, + RegisteredCodecs: sortedRegisteredCodecs, + InterfaceImplementations: sortedInterfaceImplementations); + + void AddApplicationPart(string applicationPart) + { + if (applicationPartSet.Add(applicationPart)) + { + applicationParts.Add(applicationPart); + } + } + } + + private static void ComputeAssembliesToExamine( + IAssemblySymbol asm, + HashSet expandedAssemblies, + INamedTypeSymbol generateCodeForDeclaringAssemblyAttribute, + CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (!expandedAssemblies.Add(asm)) + { + return; + } + + if (!asm.GetAttributes(generateCodeForDeclaringAssemblyAttribute, out var attrs)) + { + return; + } + + foreach (var attr in attrs) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (attr.ConstructorArguments.Length != 1) + { + continue; + } + + var argument = attr.ConstructorArguments[0]; + if (argument.Kind != TypedConstantKind.Type || argument.Value is not ITypeSymbol type) + { + continue; + } + + var declaringAssembly = type.OriginalDefinition.ContainingAssembly; + if (declaringAssembly is null) + { + continue; + } + + ComputeAssembliesToExamine( + declaringAssembly, + expandedAssemblies, + generateCodeForDeclaringAssemblyAttribute, + cancellationToken); + } + } + + private static bool TryExtractCompoundTypeAlias( + INamedTypeSymbol symbol, + INamedTypeSymbol compoundTypeAliasAttribute, + out ImmutableArray components) + { + var attr = symbol.GetAttribute(compoundTypeAliasAttribute); + if (attr is null) + { + components = []; + return false; + } + + var allArgs = attr.ConstructorArguments; + var attributeName = attr.AttributeClass?.Name ?? "unknown"; + var constructorArguments = string.Join(", ", allArgs.Select(static argument => argument.ToString())); + if (allArgs.Length != 1 || allArgs[0].Values.Length == 0) + { + throw new ArgumentException($"Unsupported arguments in attribute [{attributeName}({constructorArguments})]"); + } + + var args = allArgs[0].Values; + var result = ImmutableArray.CreateBuilder(args.Length); + for (var i = 0; i < args.Length; i++) + { + var arg = args[i]; + if (arg.IsNull) + { + throw new ArgumentNullException($"Unsupported null argument in attribute [{attributeName}({constructorArguments})]"); + } + + result.Add(arg.Value switch + { + ITypeSymbol type => new CompoundAliasComponentModel(new TypeRef(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))), + string str => new CompoundAliasComponentModel(str), + _ => throw new ArgumentException($"Unrecognized argument type for argument {arg} in attribute [{attributeName}({constructorArguments})]"), + }); + } + + components = result.MoveToImmutable(); + return true; + } + + internal static string GetCompoundTypeAliasOrderKey(CompoundTypeAliasModel entry) + { + if (entry.Components.Length == 0) + { + return string.Empty; + } + + return string.Join( + "\u001F", + entry.Components.Select(static component => component.IsString + ? $"S:{component.StringValue ?? string.Empty}" + : component.IsType + ? $"T:{component.TypeValue.SyntaxString}" + : string.Empty)); + } + + /// + /// Extracts a from a symbol with one of the Register* attributes. + /// + internal static RegisteredCodecModel ExtractRegisteredCodec(INamedTypeSymbol symbol, RegisteredCodecKind kind) + { + return new RegisteredCodecModel( + new TypeRef(symbol.ToOpenTypeSyntax().ToString()), + kind); + } +} + + diff --git a/src/Orleans.CodeGenerator/Model/RegisteredCodecModel.cs b/src/Orleans.CodeGenerator/Model/RegisteredCodecModel.cs new file mode 100644 index 00000000000..112372ac1e2 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/RegisteredCodecModel.cs @@ -0,0 +1,18 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// The kind of manually-registered codec/provider. +/// +internal enum RegisteredCodecKind : byte +{ + Serializer, + Copier, + Activator, + Converter +} + +/// +/// Describes a type annotated with [RegisterSerializer], [RegisterCopier], +/// [RegisterActivator], or [RegisterConverter]. +/// +internal readonly record struct RegisteredCodecModel(TypeRef Type, RegisteredCodecKind Kind); diff --git a/src/Orleans.CodeGenerator/Model/SerializableTypeDescription.cs b/src/Orleans.CodeGenerator/Model/SerializableTypeDescription.cs index 530ab8f5787..f6d5defea5a 100644 --- a/src/Orleans.CodeGenerator/Model/SerializableTypeDescription.cs +++ b/src/Orleans.CodeGenerator/Model/SerializableTypeDescription.cs @@ -1,251 +1,244 @@ using Orleans.CodeGenerator.SyntaxGeneration; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System; -using System.Collections.Generic; -using System.Linq; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class SerializableTypeDescription : ISerializableTypeDescription { - internal class SerializableTypeDescription : ISerializableTypeDescription - { - private readonly LibraryTypes _libraryTypes; - private TypeSyntax _typeSyntax; - private INamedTypeSymbol _baseType; - private TypeSyntax _baseTypeSyntax; + private readonly LibraryTypes _libraryTypes; - public SerializableTypeDescription(Compilation compilation, INamedTypeSymbol type, bool supportsPrimaryConstructorParameters, IEnumerable members, LibraryTypes libraryTypes) + public SerializableTypeDescription(Compilation compilation, INamedTypeSymbol type, bool supportsPrimaryConstructorParameters, IEnumerable members, LibraryTypes libraryTypes) + { + Type = type; + IncludePrimaryConstructorParameters = supportsPrimaryConstructorParameters; + Members = [.. members]; + Compilation = compilation; + _libraryTypes = libraryTypes; + + var t = type; + Accessibility accessibility = t.DeclaredAccessibility; + while (t is not null) { - Type = type; - IncludePrimaryConstructorParameters = supportsPrimaryConstructorParameters; - Members = members.ToList(); - Compilation = compilation; - _libraryTypes = libraryTypes; - - var t = type; - Accessibility accessibility = t.DeclaredAccessibility; - while (t is not null) + if ((int)t.DeclaredAccessibility < (int)accessibility) { - if ((int)t.DeclaredAccessibility < (int)accessibility) - { - accessibility = t.DeclaredAccessibility; - } - - t = t.ContainingType; + accessibility = t.DeclaredAccessibility; } - Accessibility = accessibility; - TypeParameters = new(); - var names = new HashSet(StringComparer.Ordinal); - foreach (var tp in type.GetAllTypeParameters()) - { - var tpName = GetTypeParameterName(names, tp); - TypeParameters.Add((tpName, tp)); - } + t = t.ContainingType; + } - SerializationHooks = new(); - if (type.GetAttributes(libraryTypes.SerializationCallbacksAttribute, out var hookAttributes)) + Accessibility = accessibility; + TypeParameters = new(); + var names = new HashSet(StringComparer.Ordinal); + foreach (var tp in type.GetAllTypeParameters()) + { + var tpName = GetTypeParameterName(names, tp); + TypeParameters.Add((tpName, tp)); + } + + SerializationHooks = new(); + if (type.GetAttributes(libraryTypes.SerializationCallbacksAttribute, out var hookAttributes)) + { + foreach (var hookAttribute in hookAttributes) { - foreach (var hookAttribute in hookAttributes) - { - var hookType = (INamedTypeSymbol)hookAttribute.ConstructorArguments[0].Value; - SerializationHooks.Add(hookType); - } + var hookType = (INamedTypeSymbol)hookAttribute.ConstructorArguments[0].Value!; + SerializationHooks.Add(hookType); } + } + + ActivatorConstructorParameters = []; + if (TryGetActivatorConstructor(type, _libraryTypes, out var constructorParameters)) + { + HasActivatorConstructor = true; + ActivatorConstructorParameters = constructorParameters; + } - if (TryGetActivatorConstructor(type, _libraryTypes, out var constructorParameters)) + static bool TryGetActivatorConstructor(INamedTypeSymbol type, LibraryTypes libraryTypes, [System.Diagnostics.CodeAnalysis.NotNullWhen(true)] out List? parameters) + { + parameters = null; + if (type.IsAbstract) { - HasActivatorConstructor = true; - ActivatorConstructorParameters = constructorParameters; + return false; } - static bool TryGetActivatorConstructor(INamedTypeSymbol type, LibraryTypes libraryTypes, out List parameters) + foreach (var constructor in type.GetAllMembers()) { - parameters = null; - if (type.IsAbstract) + if (constructor.MethodKind != MethodKind.Constructor || constructor.DeclaredAccessibility == Accessibility.Private || constructor.IsImplicitlyDeclared) { - return false; + continue; } - foreach (var constructor in type.GetAllMembers()) + if (constructor.HasAttribute(libraryTypes.GeneratedActivatorConstructorAttribute)) { - if (constructor.MethodKind != MethodKind.Constructor || constructor.DeclaredAccessibility == Accessibility.Private || constructor.IsImplicitlyDeclared) + foreach (var parameter in constructor.Parameters) { - continue; + var argumentType = parameter.Type.ToTypeSyntax(); + (parameters ??= new()).Add(argumentType); } - if (constructor.HasAttribute(libraryTypes.GeneratedActivatorConstructorAttribute)) - { - foreach (var parameter in constructor.Parameters) - { - var argumentType = parameter.Type.ToTypeSyntax(); - (parameters ??= new()).Add(argumentType); - } - - break; - } + break; } - - return parameters is not null; } - static string GetTypeParameterName(HashSet names, ITypeParameterSymbol tp) - { - var count = 0; - var result = tp.Name; - while (names.Contains(result)) - { - result = $"{tp.Name}_{++count}"; - } + return parameters is not null; + } - names.Add(result); - return result.EscapeIdentifier(); + static string GetTypeParameterName(HashSet names, ITypeParameterSymbol tp) + { + var count = 0; + var result = tp.Name; + while (names.Contains(result)) + { + result = $"{tp.Name}_{++count}"; } + + names.Add(result); + return result.EscapeIdentifier(); } + } - private INamedTypeSymbol Type { get; } + public INamedTypeSymbol Type { get; } - public Accessibility Accessibility { get; } + public Accessibility Accessibility { get; } - public TypeSyntax TypeSyntax => _typeSyntax ??= Type.ToTypeSyntax(); + public TypeSyntax TypeSyntax => field ??= Type.ToTypeSyntax(); - public TypeSyntax BaseTypeSyntax => _baseTypeSyntax ??= BaseType.ToTypeSyntax(); + public TypeSyntax BaseTypeSyntax => field ??= BaseType.ToTypeSyntax(); - public bool HasComplexBaseType => !IsValueType && BaseType is { SpecialType: not SpecialType.System_Object }; + public bool HasComplexBaseType => !IsValueType && BaseType is { SpecialType: not SpecialType.System_Object }; - public bool IncludePrimaryConstructorParameters { get; } + public bool IncludePrimaryConstructorParameters { get; } - public INamedTypeSymbol BaseType => _baseType ??= GetEffectiveBaseType(); + public INamedTypeSymbol BaseType => field ??= GetEffectiveBaseType(); - private INamedTypeSymbol GetEffectiveBaseType() - { - var type = Type.EnumUnderlyingType ?? Type.BaseType; - while (type != null && type.HasAttribute(_libraryTypes.SerializerTransparentAttribute)) - type = type.BaseType; - return type; - } + private INamedTypeSymbol GetEffectiveBaseType() + { + var type = Type.EnumUnderlyingType ?? Type.BaseType; + while (type != null && type.HasAttribute(_libraryTypes.SerializerTransparentAttribute)) + type = type.BaseType; + return type!; + } - public string Namespace => Type.GetNamespaceAndNesting(); + public string Namespace => Type.GetNamespaceAndNesting(); - public string GeneratedNamespace => Namespace switch - { - { Length: > 0 } ns => $"{CodeGenerator.CodeGeneratorName}.{ns}", - _ => CodeGenerator.CodeGeneratorName - }; + public string GeneratedNamespace => Namespace switch + { + { Length: > 0 } ns => $"{GeneratedCodeUtilities.CodeGeneratorName}.{ns}", + _ => GeneratedCodeUtilities.CodeGeneratorName + }; - public string Name => Type.Name; + public string Name => Type.Name; - public bool IsValueType => Type.IsValueType; - public bool IsSealedType => Type.IsSealed; - public bool IsAbstractType => Type.IsAbstract; - public bool IsEnumType => Type.EnumUnderlyingType != null; + public bool IsValueType => Type.IsValueType; + public bool IsSealedType => Type.IsSealed; + public bool IsAbstractType => Type.IsAbstract; + public bool IsEnumType => Type.EnumUnderlyingType != null; - public bool IsGenericType => Type.IsGenericType; + public bool IsGenericType => Type.IsGenericType; - public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } + public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; } - public List Members { get; } - public Compilation Compilation { get; } - public List ActivatorConstructorParameters { get; } + public List Members { get; } + public Compilation Compilation { get; } + public List ActivatorConstructorParameters { get; } - public bool IsEmptyConstructable + public bool IsEmptyConstructable + { + get { - get + if (Type.Constructors.Length == 0) { - if (Type.Constructors.Length == 0) - { - return true; - } + return true; + } - // Types which have required members are not empty constructable for Orleans, at least not yet. - var t = Type; - while (t != null) + // Types which have required members are not empty constructable for Orleans, at least not yet. + var t = Type; + while (t != null) + { + foreach (var member in t.GetMembers()) { - foreach (var member in t.GetMembers()) + if (member is IPropertySymbol { IsRequired: true } or IFieldSymbol { IsRequired: true }) { - if (member is IPropertySymbol { IsRequired: true } or IFieldSymbol { IsRequired: true }) - { - return false; - } + return false; } - - t = t.BaseType; } - foreach (var ctor in Type.Constructors) - { - if (ctor.Parameters.Length != 0) - { - continue; - } + t = t.BaseType; + } - switch (ctor.DeclaredAccessibility) - { - case Accessibility.Public: - return true; - } + foreach (var ctor in Type.Constructors) + { + if (ctor.Parameters.Length != 0) + { + continue; } - return false; + switch (ctor.DeclaredAccessibility) + { + case Accessibility.Public: + return true; + } } + + return false; } + } - public bool HasActivatorConstructor { get; } + public bool HasActivatorConstructor { get; } - public bool UseActivator => Type.HasAttribute(_libraryTypes.UseActivatorAttribute) || !IsEmptyConstructable || HasActivatorConstructor; + public bool UseActivator => Type.HasAttribute(_libraryTypes.UseActivatorAttribute) || !IsEmptyConstructable || HasActivatorConstructor; - public bool TrackReferences => !IsValueType && !IsExceptionType && !Type.HasAttribute(_libraryTypes.SuppressReferenceTrackingAttribute); - public bool OmitDefaultMemberValues => Type.HasAttribute(_libraryTypes.OmitDefaultMemberValuesAttribute); + public bool TrackReferences => !IsValueType && !IsExceptionType && !Type.HasAttribute(_libraryTypes.SuppressReferenceTrackingAttribute); + public bool OmitDefaultMemberValues => Type.HasAttribute(_libraryTypes.OmitDefaultMemberValuesAttribute); - public List SerializationHooks { get; } + public List SerializationHooks { get; } - public bool IsShallowCopyable => IsEnumType || !Type.HasBaseType(_libraryTypes.Exception) && _libraryTypes.IsShallowCopyable(Type); + public bool IsShallowCopyable => IsEnumType || !Type.HasBaseType(_libraryTypes.Exception) && _libraryTypes.IsShallowCopyable(Type); - public bool IsUnsealedImmutable => !Type.IsSealed && IsImmutable; + public bool IsUnsealedImmutable => !Type.IsSealed && IsImmutable; - public bool IsImmutable => Type.HasAttribute(_libraryTypes.ImmutableAttribute); + public bool IsImmutable => Type.HasAttribute(_libraryTypes.ImmutableAttribute); - public bool IsExceptionType => Type.HasBaseType(_libraryTypes.Exception); + public bool IsExceptionType => Type.HasBaseType(_libraryTypes.Exception); - public ExpressionSyntax GetObjectCreationExpression() + public ExpressionSyntax GetObjectCreationExpression() + { + if (IsValueType) { - if (IsValueType) - { - return DefaultExpression(TypeSyntax); - } + return DefaultExpression(TypeSyntax); + } - var instanceConstructors = Type.InstanceConstructors; - var isConstructible = false; - if (!instanceConstructors.IsDefaultOrEmpty) + var instanceConstructors = Type.InstanceConstructors; + var isConstructible = false; + if (!instanceConstructors.IsDefaultOrEmpty) + { + foreach (var ctor in instanceConstructors) { - foreach (var ctor in instanceConstructors) + if (ctor.Parameters.IsDefaultOrEmpty) { - if (ctor.Parameters.IsDefaultOrEmpty) + if (ctor.IsImplicitlyDeclared || ctor.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal) { - if (ctor.IsImplicitlyDeclared || ctor.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal) - { - isConstructible = true; - } - - break; + isConstructible = true; } + + break; } } + } - if (isConstructible) - { - return ObjectCreationExpression(TypeSyntax, ArgumentList(), null); - } - else - { - return CastExpression( - TypeSyntax, - InvocationExpression(_libraryTypes.RuntimeHelpers.ToTypeSyntax().Member("GetUninitializedObject")) - .AddArgumentListArguments( - Argument(TypeOfExpression(TypeSyntax)))); - } + if (isConstructible) + { + return ObjectCreationExpression(TypeSyntax, ArgumentList(), null); + } + else + { + return CastExpression( + TypeSyntax, + InvocationExpression(_libraryTypes.RuntimeHelpers.ToTypeSyntax().Member("GetUninitializedObject")) + .AddArgumentListArguments( + Argument(TypeOfExpression(TypeSyntax)))); } } } diff --git a/src/Orleans.CodeGenerator/Model/SerializableTypeModel.cs b/src/Orleans.CodeGenerator/Model/SerializableTypeModel.cs new file mode 100644 index 00000000000..f2ecff3f8a2 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/SerializableTypeModel.cs @@ -0,0 +1,38 @@ +using Microsoft.CodeAnalysis; + +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes a [GenerateSerializer]-annotated type for incremental pipeline caching and generation. +/// Contains all data needed to generate a serializer, copier, and activator without holding ISymbol references. +/// +internal sealed record class SerializableTypeModel( + Accessibility Accessibility, + TypeRef TypeSyntax, + bool HasComplexBaseType, + bool IncludePrimaryConstructorParameters, + TypeRef BaseTypeSyntax, + string Namespace, + string GeneratedNamespace, + string Name, + bool IsValueType, + bool IsSealedType, + bool IsAbstractType, + bool IsEnumType, + bool IsGenericType, + EquatableArray TypeParameters, + EquatableArray Members, + bool UseActivator, + bool IsEmptyConstructable, + bool HasActivatorConstructor, + bool TrackReferences, + bool OmitDefaultMemberValues, + EquatableArray SerializationHooks, + bool IsShallowCopyable, + bool IsUnsealedImmutable, + bool IsImmutable, + bool IsExceptionType, + EquatableArray ActivatorConstructorParameters, + ObjectCreationStrategy CreationStrategy, + SourceLocationModel SourceLocation = default, + TypeMetadataIdentity MetadataIdentity = default); diff --git a/src/Orleans.CodeGenerator/Model/SerializableTypeModelExtractor.cs b/src/Orleans.CodeGenerator/Model/SerializableTypeModelExtractor.cs new file mode 100644 index 00000000000..a432cb8aed8 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/SerializableTypeModelExtractor.cs @@ -0,0 +1,491 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Diagnostics; +using Orleans.CodeGenerator.Model; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal static class SerializableTypeModelExtractor +{ + /// + /// Extracts a from an . + /// Converts symbol-based descriptions into equatable value models for incremental pipeline caching. + /// + internal static SerializableTypeModel ExtractSerializableTypeModel( + ISerializableTypeDescription description, + SourceLocationModel sourceLocation = default) + { + var typeParameters = ExtractTypeParameters(description.TypeParameters); + var members = ExtractMembers(description.Members); + var serializationHooks = ExtractTypeRefs(description.SerializationHooks); + var activatorCtorParams = ExtractTypeRefSyntaxList(description.ActivatorConstructorParameters); + var creationStrategy = DetermineCreationStrategy(description); + + return new SerializableTypeModel( + Accessibility: description.Accessibility, + TypeSyntax: new TypeRef(description.TypeSyntax.ToString()), + HasComplexBaseType: description.HasComplexBaseType, + IncludePrimaryConstructorParameters: description.IncludePrimaryConstructorParameters, + BaseTypeSyntax: description.HasComplexBaseType ? new TypeRef(description.BaseTypeSyntax.ToString()) : TypeRef.Empty, + Namespace: description.Namespace ?? string.Empty, + GeneratedNamespace: description.GeneratedNamespace ?? string.Empty, + Name: description.Name ?? string.Empty, + IsValueType: description.IsValueType, + IsSealedType: description.IsSealedType, + IsAbstractType: description.IsAbstractType, + IsEnumType: description.IsEnumType, + IsGenericType: description.IsGenericType, + TypeParameters: typeParameters, + Members: members, + UseActivator: description.UseActivator, + IsEmptyConstructable: description.IsEmptyConstructable, + HasActivatorConstructor: description.HasActivatorConstructor, + TrackReferences: description.TrackReferences, + OmitDefaultMemberValues: description.OmitDefaultMemberValues, + SerializationHooks: serializationHooks, + IsShallowCopyable: description.IsShallowCopyable, + IsUnsealedImmutable: description.IsUnsealedImmutable, + IsImmutable: description.IsImmutable, + IsExceptionType: description.IsExceptionType, + ActivatorConstructorParameters: activatorCtorParams, + CreationStrategy: creationStrategy, + SourceLocation: sourceLocation, + MetadataIdentity: description is SerializableTypeDescription serializableDescription + ? TypeMetadataIdentity.Create(serializableDescription.Type) + : TypeMetadataIdentity.Empty); + } + + private static ImmutableArray ExtractTypeParameters( + List<(string Name, ITypeParameterSymbol Parameter)> typeParameters) + { + if (typeParameters is null || typeParameters.Count == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(typeParameters.Count); + for (var i = 0; i < typeParameters.Count; i++) + { + var (name, param) = typeParameters[i]; + builder.Add(new TypeParameterModel(name, param.Name, param.Ordinal)); + } + return builder.MoveToImmutable(); + } + + private static ImmutableArray ExtractMembers(List members) + { + if (members is null || members.Count == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(members.Count); + foreach (var member in members) + { + builder.Add(ExtractMember(member)); + } + return builder.MoveToImmutable(); + } + + private static MemberModel ExtractMember(IMemberDescription member) + { + var kind = member is IFieldDescription ? MemberKind.Field : MemberKind.Property; + var symbol = member.Symbol; + var containingType = member.ContainingType; + + // Determine getter/setter accessibility strategies + var getterStrategy = DetermineGetterStrategy(member); + var setterStrategy = DetermineSetterStrategy(member); + + // Determine if member has immutable attribute + var hasImmutableAttribute = false; + if (symbol is IPropertySymbol prop) + { + hasImmutableAttribute = prop.GetAttributes().Any(a => + a.AttributeClass?.Name == "ImmutableAttribute" + && a.AttributeClass.ContainingNamespace?.Name == "Orleans"); + } + if (!hasImmutableAttribute) + { + hasImmutableAttribute = symbol.GetAttributes().Any(a => + a.AttributeClass?.Name == "ImmutableAttribute" + && a.AttributeClass.ContainingNamespace?.Name == "Orleans"); + } + + // Determine if obsolete + var isObsolete = symbol.GetAttributes().Any(a => + a.AttributeClass?.Name == "ObsoleteAttribute" + && a.AttributeClass.ContainingNamespace?.Name == "System"); + + // Backing property name + string? backingPropertyName = null; + if (member is IFieldDescription fieldDesc) + { + var backingProp = PropertyUtility.GetMatchingProperty(fieldDesc.Field); + if (backingProp is not null) + { + backingPropertyName = backingProp.Name; + } + } + + return new MemberModel( + fieldId: member.FieldId, + name: symbol.Name, + type: new TypeRef(member.TypeSyntax.ToString()), + containingType: containingType is not null ? new TypeRef(containingType.ToTypeSyntax().ToString()) : TypeRef.Empty, + assemblyName: member.AssemblyName ?? string.Empty, + typeNameIdentifier: member.TypeNameIdentifier ?? string.Empty, + isPrimaryConstructorParameter: member.IsPrimaryConstructorParameter, + isSerializable: member.IsSerializable, + isCopyable: member.IsCopyable, + kind: kind, + getterStrategy: getterStrategy, + setterStrategy: setterStrategy, + isObsolete: isObsolete, + hasImmutableAttribute: hasImmutableAttribute, + isShallowCopyable: false, // Will be resolved later with LibraryTypes + isValueType: member.Type?.IsValueType ?? false, + containingTypeIsValueType: containingType?.IsValueType ?? false, + backingPropertyName: backingPropertyName); + } + + private static AccessStrategy DetermineGetterStrategy(IMemberDescription member) + { + if (member is IFieldDescription fieldDesc) + { + // Direct access if field is accessible + return AccessStrategy.Direct; + } + + if (member.Symbol is IPropertySymbol prop && prop.GetMethod is not null) + { + return AccessStrategy.Direct; + } + + return AccessStrategy.GeneratedAccessor; + } + + private static AccessStrategy DetermineSetterStrategy(IMemberDescription member) + { + if (member is IFieldDescription fieldDesc) + { + if (!fieldDesc.Field.IsReadOnly) + { + return AccessStrategy.Direct; + } + return AccessStrategy.GeneratedAccessor; + } + + if (member.Symbol is IPropertySymbol prop) + { + if (prop.SetMethod is not null && !prop.SetMethod.IsInitOnly) + { + return AccessStrategy.Direct; + } + if (member.IsPrimaryConstructorParameter) + { + return AccessStrategy.UnsafeAccessor; + } + return AccessStrategy.GeneratedAccessor; + } + + return AccessStrategy.GeneratedAccessor; + } + + private static ImmutableArray ExtractTypeRefs(List? symbols) + { + if (symbols is null || symbols.Count == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(symbols.Count); + foreach (var s in symbols) + { + builder.Add(new TypeRef(s.ToTypeSyntax().ToString())); + } + return builder.MoveToImmutable(); + } + + private static ImmutableArray ExtractTypeRefSyntaxList( + List? syntaxList) + { + if (syntaxList is null || syntaxList.Count == 0) + { + return []; + } + + var builder = ImmutableArray.CreateBuilder(syntaxList.Count); + foreach (var ts in syntaxList) + { + builder.Add(new TypeRef(ts.ToString())); + } + return builder.MoveToImmutable(); + } + + private static ObjectCreationStrategy DetermineCreationStrategy(ISerializableTypeDescription description) + { + if (description.IsValueType) + { + return ObjectCreationStrategy.Default; + } + + // Check if we can determine from the existing expression + var expr = description.GetObjectCreationExpression(); + if (expr is Microsoft.CodeAnalysis.CSharp.Syntax.DefaultExpressionSyntax) + { + return ObjectCreationStrategy.Default; + } + + if (expr is Microsoft.CodeAnalysis.CSharp.Syntax.ObjectCreationExpressionSyntax) + { + return ObjectCreationStrategy.NewExpression; + } + + return ObjectCreationStrategy.GetUninitializedObject; + } + + internal static SerializableTypeModel? TryExtractSerializableTypeModel( + INamedTypeSymbol typeSymbol, + Compilation compilation, + LibraryTypes libraryTypes, + CodeGeneratorOptions options, + bool throwOnFailure = false) + { + if (typeSymbol is null) + { + return null; + } + + if (FSharpUtilities.IsUnionCase(libraryTypes, typeSymbol, out var sumType)) + { + if (!sumType.HasAttribute(libraryTypes.GenerateSerializerAttribute)) + { + return null; + } + + if (throwOnFailure && HasReferenceAssemblyAttribute(sumType.ContainingAssembly)) + { + throw new OrleansGeneratorDiagnosticAnalysisException( + ReferenceAssemblyWithGenerateSerializerDiagnostic.CreateDiagnostic(sumType, Location.None)); + } + + if (!compilation.IsSymbolAccessibleWithin(sumType, compilation.Assembly)) + { + if (throwOnFailure) + { + throw new OrleansGeneratorDiagnosticAnalysisException( + InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(sumType, Location.None)); + } + + return null; + } + + var fsharpUnionCaseDescription = new FSharpUtilities.FSharpUnionCaseTypeDescription(compilation, typeSymbol, libraryTypes); + return ExtractSerializableTypeModel(fsharpUnionCaseDescription, SymbolSourceLocationExtractor.GetSourceLocation(typeSymbol)); + } + + if (!typeSymbol.HasAttribute(libraryTypes.GenerateSerializerAttribute)) + { + return null; + } + + if (throwOnFailure && HasReferenceAssemblyAttribute(typeSymbol.ContainingAssembly)) + { + throw new OrleansGeneratorDiagnosticAnalysisException( + ReferenceAssemblyWithGenerateSerializerDiagnostic.CreateDiagnostic(typeSymbol, Location.None)); + } + + if (!compilation.IsSymbolAccessibleWithin(typeSymbol, compilation.Assembly)) + { + if (throwOnFailure) + { + throw new OrleansGeneratorDiagnosticAnalysisException( + InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(typeSymbol, Location.None)); + } + + return null; + } + + if (FSharpUtilities.IsRecord(libraryTypes, typeSymbol)) + { + var fsharpDescription = new FSharpUtilities.FSharpRecordTypeDescription(compilation, typeSymbol, libraryTypes); + return ExtractSerializableTypeModel(fsharpDescription, SymbolSourceLocationExtractor.GetSourceLocation(typeSymbol)); + } + + var includePrimaryCtorParams = GetIncludePrimaryConstructorParameters(typeSymbol, libraryTypes); + var ctorParams = ResolveConstructorParameters(typeSymbol, includePrimaryCtorParams, libraryTypes); + var implicitFieldIdStrategy = (options.GenerateFieldIds, GetFieldIdsOptionFromType(typeSymbol, libraryTypes)) switch + { + (_, GenerateFieldIds.PublicProperties) => GenerateFieldIds.PublicProperties, + (GenerateFieldIds.PublicProperties, _) => GenerateFieldIds.PublicProperties, + _ => GenerateFieldIds.None, + }; + var helper = new FieldIdAssignmentHelper(typeSymbol, ctorParams, implicitFieldIdStrategy, libraryTypes); + if (!helper.IsValidForSerialization) + { + if (throwOnFailure) + { + throw new OrleansGeneratorDiagnosticAnalysisException( + CanNotGenerateImplicitFieldIdsDiagnostic.CreateDiagnostic(typeSymbol, helper.FailureReason!, Location.None)); + } + + return null; + } + + var members = CollectDataMembers(helper); + var description = new SerializableTypeDescription(compilation, typeSymbol, includePrimaryCtorParams, members, libraryTypes); + return ExtractSerializableTypeModel(description, SymbolSourceLocationExtractor.GetSourceLocation(typeSymbol)); + } + + private static bool HasReferenceAssemblyAttribute(IAssemblySymbol assembly) + { + return assembly?.GetAttributes().Any(attributeData => attributeData.AttributeClass is + { + Name: "ReferenceAssemblyAttribute", + ContainingNamespace: + { + Name: "CompilerServices", + ContainingNamespace: + { + Name: "Runtime", + ContainingNamespace: + { + Name: "System", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + }) == true; + } + + private static bool GetIncludePrimaryConstructorParameters(INamedTypeSymbol typeSymbol, LibraryTypes libraryTypes) + { + var attribute = typeSymbol.GetAttribute(libraryTypes.GenerateSerializerAttribute); + if (attribute is not null) + { + foreach (var namedArgument in attribute.NamedArguments) + { + if (namedArgument.Key == "IncludePrimaryConstructorParameters" + && namedArgument.Value.Kind == TypedConstantKind.Primitive + && namedArgument.Value.Value is bool b) + { + return b; + } + } + } + + // Default to true for records + if (typeSymbol.IsRecord) + { + return true; + } + + // Detect primary constructor via compiler-generated properties + var properties = typeSymbol.GetMembers().OfType().ToImmutableArray(); + return typeSymbol.GetMembers() + .OfType() + .Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0) + .Any(ctor => ctor.Parameters.All(prm => + properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated()))); + } + + private static ImmutableArray ResolveConstructorParameters( + INamedTypeSymbol typeSymbol, + bool includePrimaryCtorParams, + LibraryTypes libraryTypes) + { + if (!includePrimaryCtorParams) + { + return []; + } + + if (typeSymbol.IsRecord) + { + // Primary constructor is declared before the copy constructor for records + var potentialPrimaryConstructor = typeSymbol.Constructors[0]; + if (!potentialPrimaryConstructor.IsImplicitlyDeclared && !potentialPrimaryConstructor.IsCompilerGenerated()) + { + return potentialPrimaryConstructor.Parameters; + } + } + else + { + var annotatedConstructors = typeSymbol.Constructors + .Where(ctor => ctor.HasAnyAttribute(libraryTypes.ConstructorAttributeTypes)) + .ToList(); + if (annotatedConstructors.Count == 1) + { + return annotatedConstructors[0].Parameters; + } + + // Fallback: detect primary constructor via compiler-generated properties + var properties = typeSymbol.GetMembers().OfType().ToImmutableArray(); + var primaryConstructor = typeSymbol.GetMembers() + .OfType() + .Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0) + .FirstOrDefault(ctor => ctor.Parameters.All(prm => + properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated()))); + + if (primaryConstructor is not null) + { + return primaryConstructor.Parameters; + } + } + + return []; + } + + private static GenerateFieldIds GetFieldIdsOptionFromType(INamedTypeSymbol typeSymbol, LibraryTypes libraryTypes) + { + var attribute = typeSymbol.GetAttribute(libraryTypes.GenerateSerializerAttribute); + if (attribute is null) + { + return GenerateFieldIds.None; + } + + foreach (var namedArgument in attribute.NamedArguments) + { + if (namedArgument.Key == "GenerateFieldIds") + { + var value = namedArgument.Value.Value; + return value is null ? GenerateFieldIds.None : (GenerateFieldIds)(int)value; + } + } + + return GenerateFieldIds.None; + } + + private static IEnumerable CollectDataMembers(FieldIdAssignmentHelper fieldIdAssignmentHelper) + { + var members = new Dictionary<(uint, bool), IMemberDescription>(); + + foreach (var member in fieldIdAssignmentHelper.Members) + { + if (!fieldIdAssignmentHelper.TryGetSymbolKey(member, out var key)) + { + continue; + } + + var (id, isConstructorParameter) = key; + + if (member is IPropertySymbol property && !members.ContainsKey((id, isConstructorParameter))) + { + members[(id, isConstructorParameter)] = new PropertyDescription(id, isConstructorParameter, property); + } + + if (member is IFieldSymbol field) + { + if (!members.TryGetValue((id, isConstructorParameter), out var existing) || existing is IPropertyDescription) + { + members[(id, isConstructorParameter)] = new FieldDescription(id, isConstructorParameter, field); + } + } + } + + return members.Values; + } +} + + diff --git a/src/Orleans.CodeGenerator/Model/SourceLocationModel.cs b/src/Orleans.CodeGenerator/Model/SourceLocationModel.cs new file mode 100644 index 00000000000..9f336217a14 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/SourceLocationModel.cs @@ -0,0 +1,15 @@ +namespace Orleans.CodeGenerator.Model; + +internal readonly record struct SourceLocationModel +{ + public SourceLocationModel(int sourceOrderGroup, string filePath, int position) + { + SourceOrderGroup = sourceOrderGroup; + FilePath = filePath ?? string.Empty; + Position = position; + } + + public int SourceOrderGroup { get; } + public string FilePath { get; } + public int Position { get; } +} diff --git a/src/Orleans.CodeGenerator/Model/StructuralEquality.cs b/src/Orleans.CodeGenerator/Model/StructuralEquality.cs new file mode 100644 index 00000000000..adb84f08d7d --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/StructuralEquality.cs @@ -0,0 +1,95 @@ +using System.Collections.Immutable; + +namespace Orleans.CodeGenerator.Model; + +internal readonly struct EquatableArray(ImmutableArray values) : IEquatable>, IReadOnlyList +{ + private readonly ImmutableArray _values = StructuralEquality.Normalize(values); + + public static EquatableArray Empty { get; } = new([]); + + public ImmutableArray Values => StructuralEquality.Normalize(_values); + + public int Count => Values.Length; + + public int Length => Values.Length; + + public bool IsDefault => false; + + public bool IsEmpty => Values.IsEmpty; + + public bool IsDefaultOrEmpty => Values.IsEmpty; + + public T this[int index] => Values[index]; + + public ImmutableArray.Enumerator GetEnumerator() => Values.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)Values).GetEnumerator(); + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + => ((System.Collections.IEnumerable)Values).GetEnumerator(); + + public bool Equals(EquatableArray other) => StructuralEquality.SequenceEqual(Values, other.Values); + + public override bool Equals(object? obj) => obj is EquatableArray other && Equals(other); + + public override int GetHashCode() => StructuralEquality.GetSequenceHashCode(Values); + + public static implicit operator EquatableArray(ImmutableArray values) => new(values); + + public static implicit operator ImmutableArray(EquatableArray values) => values.Values; + + public override string ToString() => Values.ToString(); +} + +internal static class StructuralEquality +{ + public static ImmutableArray Normalize(ImmutableArray values) + => values.IsDefault ? [] : values; + + public static bool SequenceEqual(ImmutableArray left, ImmutableArray right) + => Normalize(left).SequenceEqual(Normalize(right)); + + public static int GetSequenceHashCode(ImmutableArray values) + { + var normalizedValues = Normalize(values); + var comparer = EqualityComparer.Default; + + unchecked + { + var hash = 17; + foreach (var item in normalizedValues) + { + hash = Combine(hash, item is null ? 0 : comparer.GetHashCode(item)); + } + + return hash; + } + } + + public static int GetHashCode(string? value) + => StringComparer.Ordinal.GetHashCode(value ?? string.Empty); + + public static int Combine(int hash, int value) + { + unchecked + { + return hash * 31 + value; + } + } +} + +internal sealed class ImmutableArrayComparer : IEqualityComparer> +{ + public static ImmutableArrayComparer Instance { get; } = new(); + + private ImmutableArrayComparer() + { + } + + public bool Equals(ImmutableArray left, ImmutableArray right) + => StructuralEquality.SequenceEqual(left, right); + + public int GetHashCode(ImmutableArray values) + => StructuralEquality.GetSequenceHashCode(values); +} diff --git a/src/Orleans.CodeGenerator/Model/SymbolSourceLocationExtractor.cs b/src/Orleans.CodeGenerator/Model/SymbolSourceLocationExtractor.cs new file mode 100644 index 00000000000..4cb940d782a --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/SymbolSourceLocationExtractor.cs @@ -0,0 +1,18 @@ +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal static class SymbolSourceLocationExtractor +{ + internal static SourceLocationModel GetSourceLocation(ISymbol? symbol) + { + var sourceLocation = symbol?.Locations.FirstOrDefault(static location => location.IsInSource); + return sourceLocation is null + ? new SourceLocationModel(sourceOrderGroup: 1, filePath: string.Empty, position: int.MaxValue) + : new SourceLocationModel( + sourceOrderGroup: 0, + filePath: sourceLocation.SourceTree?.FilePath ?? string.Empty, + position: sourceLocation.SourceSpan.Start); + } +} diff --git a/src/Orleans.CodeGenerator/Model/TypeMetadataIdentity.cs b/src/Orleans.CodeGenerator/Model/TypeMetadataIdentity.cs new file mode 100644 index 00000000000..a23e00db211 --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/TypeMetadataIdentity.cs @@ -0,0 +1,67 @@ +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Orleans.CodeGenerator.Model; + +/// +/// Identifies a type by its Roslyn metadata name and containing assembly. +/// +internal readonly record struct TypeMetadataIdentity +{ + public TypeMetadataIdentity(string metadataName, string assemblyName, string assemblyIdentity) + { + MetadataName = metadataName ?? string.Empty; + AssemblyName = assemblyName ?? string.Empty; + AssemblyIdentity = assemblyIdentity ?? string.Empty; + } + + public string MetadataName { get; } + public string AssemblyName { get; } + public string AssemblyIdentity { get; } + public bool IsEmpty => string.IsNullOrEmpty(MetadataName); + + public static TypeMetadataIdentity Empty { get; } = new TypeMetadataIdentity( + metadataName: string.Empty, + assemblyName: string.Empty, + assemblyIdentity: string.Empty); + + public static TypeMetadataIdentity Create(INamedTypeSymbol symbol) + { + if (symbol is null) + { + return Empty; + } + + var originalDefinition = symbol.OriginalDefinition; + var assembly = originalDefinition.ContainingAssembly; + return new TypeMetadataIdentity( + GetMetadataName(originalDefinition), + assembly?.Identity.Name ?? string.Empty, + assembly?.Identity.GetDisplayName() ?? string.Empty); + } + + private static string GetMetadataName(INamedTypeSymbol symbol) + { + var builder = new StringBuilder(); + var ns = symbol.ContainingNamespace; + if (ns is not null && !ns.IsGlobalNamespace) + { + builder.Append(ns.ToDisplayString()); + builder.Append('.'); + } + + AppendMetadataName(builder, symbol); + return builder.ToString(); + + static void AppendMetadataName(StringBuilder builder, INamedTypeSymbol current) + { + if (current.ContainingType is { } containingType) + { + AppendMetadataName(builder, containingType); + builder.Append('+'); + } + + builder.Append(current.MetadataName); + } + } +} diff --git a/src/Orleans.CodeGenerator/Model/TypeParameterModel.cs b/src/Orleans.CodeGenerator/Model/TypeParameterModel.cs new file mode 100644 index 00000000000..eef244e96cd --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/TypeParameterModel.cs @@ -0,0 +1,6 @@ +namespace Orleans.CodeGenerator.Model; + +/// +/// Describes a type parameter in a serializable or proxy type. +/// +internal readonly record struct TypeParameterModel(string Name, string OriginalName, int Ordinal); diff --git a/src/Orleans.CodeGenerator/Model/TypeRef.cs b/src/Orleans.CodeGenerator/Model/TypeRef.cs new file mode 100644 index 00000000000..fe68b865b1c --- /dev/null +++ b/src/Orleans.CodeGenerator/Model/TypeRef.cs @@ -0,0 +1,26 @@ +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Orleans.CodeGenerator.Model; + +/// +/// A value-type reference to a type, stored as a string representation of its syntax. +/// Used in incremental pipeline models to avoid holding ITypeSymbol references. +/// +internal readonly record struct TypeRef +{ + public TypeRef(string syntaxString) => SyntaxString = syntaxString ?? string.Empty; + + public string SyntaxString { get; } + + /// + /// Reconstructs a from the stored string. + /// + public TypeSyntax ToTypeSyntax() => SyntaxFactory.ParseTypeName(SyntaxString); + + public override string ToString() => SyntaxString; + + public static TypeRef Empty { get; } = new TypeRef(string.Empty); + + public bool IsEmpty => string.IsNullOrEmpty(SyntaxString); +} diff --git a/src/Orleans.CodeGenerator/Model/WellKnownCodecDescription.cs b/src/Orleans.CodeGenerator/Model/WellKnownCodecDescription.cs index 82c46a12237..28d8d7cc8cd 100644 --- a/src/Orleans.CodeGenerator/Model/WellKnownCodecDescription.cs +++ b/src/Orleans.CodeGenerator/Model/WellKnownCodecDescription.cs @@ -1,30 +1,16 @@ using Microsoft.CodeAnalysis; -#nullable disable -namespace Orleans.CodeGenerator -{ - internal sealed class WellKnownCodecDescription - { - public WellKnownCodecDescription(ITypeSymbol underlyingType, INamedTypeSymbol codecType) - { - UnderlyingType = underlyingType; - CodecType = codecType; - } - - public readonly ITypeSymbol UnderlyingType; - public readonly INamedTypeSymbol CodecType; - } +namespace Orleans.CodeGenerator; - internal sealed class WellKnownCopierDescription : ICopierDescription - { - public WellKnownCopierDescription(ITypeSymbol underlyingType, INamedTypeSymbol codecType) - { - UnderlyingType = underlyingType; - CopierType = codecType; - } +internal sealed class WellKnownCodecDescription(ITypeSymbol? underlyingType, INamedTypeSymbol? codecType) +{ + public readonly ITypeSymbol UnderlyingType = underlyingType!; + public readonly INamedTypeSymbol CodecType = codecType!; +} - public ITypeSymbol UnderlyingType { get; } +internal sealed class WellKnownCopierDescription(ITypeSymbol underlyingType, INamedTypeSymbol codecType) : ICopierDescription +{ + public ITypeSymbol UnderlyingType { get; } = underlyingType; - public INamedTypeSymbol CopierType { get; } - } + public INamedTypeSymbol CopierType { get; } = codecType; } diff --git a/src/Orleans.CodeGenerator/OrleansGeneratorDiagnosticAnalysisException.cs b/src/Orleans.CodeGenerator/OrleansGeneratorDiagnosticAnalysisException.cs index 74300230f53..ae6ae3f01b3 100644 --- a/src/Orleans.CodeGenerator/OrleansGeneratorDiagnosticAnalysisException.cs +++ b/src/Orleans.CodeGenerator/OrleansGeneratorDiagnosticAnalysisException.cs @@ -1,15 +1,8 @@ using Microsoft.CodeAnalysis; -using System; -namespace Orleans.CodeGenerator -{ - public class OrleansGeneratorDiagnosticAnalysisException : Exception - { - public OrleansGeneratorDiagnosticAnalysisException(Diagnostic diagnostic) : base(diagnostic.GetMessage()) - { - Diagnostic = diagnostic; - } +namespace Orleans.CodeGenerator; - public Diagnostic Diagnostic { get; } - } +public class OrleansGeneratorDiagnosticAnalysisException(Diagnostic diagnostic) : Exception(diagnostic.GetMessage()) +{ + public Diagnostic Diagnostic { get; } = diagnostic; } diff --git a/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs b/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs index 39dfbd5a039..dc04c7df9b4 100644 --- a/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs +++ b/src/Orleans.CodeGenerator/OrleansSourceGenerator.cs @@ -1,90 +1,212 @@ -using System; -using System.Diagnostics; -using System.Linq; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; -using Orleans.CodeGenerator.Diagnostics; using Orleans.CodeGenerator.Model; -#pragma warning disable RS1035 // Do not use APIs banned for analyzers -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +[Generator] +public sealed class OrleansSerializationSourceGenerator : IIncrementalGenerator { - [Generator] - public class OrleansSerializationSourceGenerator : ISourceGenerator + internal const string GeneratorOptionsTrackingName = "Orleans.GeneratorOptions"; + internal const string AssemblyNameTrackingName = "Orleans.AssemblyName"; + internal const string SerializableTypeResultsTrackingName = "Orleans.SerializableTypeResults"; + internal const string CollectedSerializableTypesTrackingName = "Orleans.CollectedSerializableTypes"; + internal const string DirectProxyInterfacesTrackingName = "Orleans.DirectProxyInterfaces"; + internal const string InheritedProxyInterfacesTrackingName = "Orleans.InheritedProxyInterfaces"; + internal const string CollectedProxyInterfacesTrackingName = "Orleans.CollectedProxyInterfaces"; + internal const string PreparedProxyOutputsTrackingName = "Orleans.PreparedProxyOutputs"; + internal const string ReferenceAssemblyDataTrackingName = "Orleans.ReferenceAssemblyData"; + internal const string MetadataAggregateTrackingName = "Orleans.MetadataAggregate"; + internal const string SerializerOutputsTrackingName = "Orleans.SerializerOutputs"; + internal const string ReferencedSerializerOutputsTrackingName = "Orleans.ReferencedSerializerOutputs"; + internal const string ProxyOutputsTrackingName = "Orleans.ProxyOutputs"; + internal const string MetadataOutputsTrackingName = "Orleans.MetadataOutputs"; + + + public void Initialize(IncrementalGeneratorInitializationContext context) { - public void Execute(GeneratorExecutionContext context) + var generatorOptions = context.AnalyzerConfigOptionsProvider + .Select(static (provider, _) => SourceGeneratorOptionsParser.ParseOptions(provider.GlobalOptions)) + .WithTrackingName(GeneratorOptionsTrackingName); + var compilationProvider = context.CompilationProvider; + var assemblyNameProvider = compilationProvider + .Select(static (compilation, _) => compilation.AssemblyName ?? "assembly") + .WithTrackingName(AssemblyNameTrackingName); + + // Incremental discovery of [GenerateSerializer] types + var serializableTypeContexts = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Orleans.GenerateSerializerAttribute", + predicate: static (node, _) => node is TypeDeclarationSyntax or EnumDeclarationSyntax, + transform: static (ctx, _) => ctx); + + var serializableTypeResults = serializableTypeContexts + .Combine(generatorOptions) + .Select(static (input, ct) => SerializableSourceOutputGenerator.CreateSerializableTypeResult( + input.Left, + input.Right, + ct)) + .WithTrackingName(SerializableTypeResultsTrackingName); + + var collectedSerializableTypeResults = serializableTypeResults + .Collect() + .Select(static (input, _) => GeneratedSourceOutput.DeduplicateSerializableTypeResults(input)) + .WithComparer(ImmutableArrayComparer.Instance); + + var collectedTypes = collectedSerializableTypeResults + .Select(static (input, _) => GeneratedSourceOutput.GetSerializableTypeModels(input)) + .WithComparer(ImmutableArrayComparer.Instance) + .WithTrackingName(CollectedSerializableTypesTrackingName); + + context.RegisterSourceOutput(collectedSerializableTypeResults.SelectMany(static (input, _) => input), static (productionContext, result) => { - try + if (result.Diagnostic is { } diagnostic) { - var processName = Process.GetCurrentProcess().ProcessName.ToLowerInvariant(); - if (processName.Contains("devenv") || processName.Contains("servicehub")) - { - return; - } + productionContext.ReportDiagnostic(diagnostic); + } + }); - if (!Debugger.IsAttached && - context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.orleans_designtimebuild", out var isDesignTimeBuild) - && string.Equals("true", isDesignTimeBuild, StringComparison.OrdinalIgnoreCase)) - { - return; - } + // Attribute-driven discovery of [GenerateMethodSerializers] interfaces, plus a + // constrained syntax provider for interfaces which inherit the attribute from a base interface. + var directProxyInterfaces = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Orleans.GenerateMethodSerializersAttribute", + predicate: static (node, _) => node is InterfaceDeclarationSyntax, + transform: static (ctx, ct) => ModelExtractor.ExtractProxyInterfaceFromAttributeContext(ctx, ct)) + .Where(static model => model is not null) + .Select(static (model, _) => model!) + .WithTrackingName(DirectProxyInterfacesTrackingName); - if (context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.orleans_attachdebugger", out var attachDebuggerOption) - && string.Equals("true", attachDebuggerOption, StringComparison.OrdinalIgnoreCase)) - { - Debugger.Launch(); - } + var inheritedProxyInterfaces = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => node is InterfaceDeclarationSyntax { BaseList: not null }, + transform: static (ctx, ct) => ModelExtractor.ExtractInheritedProxyInterfaceFromSyntaxContext(ctx, ct)) + .Where(static model => model is not null) + .Select(static (model, _) => model!) + .Collect() + .Select(static (input, _) => ModelExtractor.NormalizeProxyInterfaceModels(input)) + .WithComparer(ImmutableArrayComparer.Instance) + .WithTrackingName(InheritedProxyInterfacesTrackingName); - var options = new CodeGeneratorOptions(); + var collectedDirectProxyInterfaces = directProxyInterfaces + .Collect() + .Select(static (input, _) => ModelExtractor.NormalizeProxyInterfaceModels(input)) + .WithComparer(ImmutableArrayComparer.Instance); - if (context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.orleans_generatefieldids", out var generateFieldIds) && generateFieldIds is { Length: > 0 }) - { - if (Enum.TryParse(generateFieldIds, out GenerateFieldIds fieldIdOption)) - { - options.GenerateFieldIds = fieldIdOption; - } - } + var collectedProxies = collectedDirectProxyInterfaces + .Combine(inheritedProxyInterfaces) + .Select(static (input, _) => ModelExtractor.MergeProxyInterfaces(input.Left, input.Right)) + .WithComparer(ImmutableArrayComparer.Instance) + .WithTrackingName(CollectedProxyInterfacesTrackingName); - if (context.AnalyzerConfigOptions.GlobalOptions.TryGetValue("build_property.orleansgeneratecompatibilityinvokers", out var generateCompatInvokersValue) - && bool.TryParse(generateCompatInvokersValue, out var genCompatInvokers)) - { - options.GenerateCompatibilityInvokers = genCompatInvokers; - } + var preparedProxyOutputs = collectedProxies + .Combine(compilationProvider) + .Combine(generatorOptions) + .Select(static (input, ct) => ProxySourceOutputGenerator.CreateProxyOutputPreparation(input.Left.Right, input.Left.Left, input.Right, ct)) + .WithTrackingName(PreparedProxyOutputsTrackingName); - var codeGenerator = new CodeGenerator(context.Compilation, options); - var syntax = codeGenerator.GenerateCode(context.CancellationToken); - var sourceString = syntax.NormalizeWhitespace().ToFullString(); - var sourceText = SourceText.From(sourceString, Encoding.UTF8); - context.AddSource($"{context.Compilation.AssemblyName ?? "assembly"}.orleans.g.cs", sourceText); - } - catch (Exception exception) + context.RegisterSourceOutput(preparedProxyOutputs, static (productionContext, input) => + { + if (input.Diagnostic is { } diagnostic) { - if (!HandleException(context, exception)) - { - throw; - } + productionContext.ReportDiagnostic(diagnostic); } + }); + + // Extract reference assembly data (application parts, well-known type IDs, aliases) + var refAssemblyDataResults = compilationProvider + .Combine(generatorOptions) + .Select(static (input, ct) => ReferenceAssemblyDataProvider.CreateReferenceAssemblyDataResult( + input.Left, + input.Right, + ct)) + .WithTrackingName(ReferenceAssemblyDataTrackingName); - static bool HandleException(GeneratorExecutionContext context, Exception exception) + context.RegisterSourceOutput(refAssemblyDataResults, static (productionContext, result) => + { + if (!result.Diagnostics.IsDefaultOrEmpty) { - if (exception is OrleansGeneratorDiagnosticAnalysisException analysisException) + foreach (var diagnostic in result.Diagnostics) { - context.ReportDiagnostic(analysisException.Diagnostic); - return true; + productionContext.ReportDiagnostic(diagnostic); } - - context.ReportDiagnostic(UnhandledCodeGenerationExceptionDiagnostic.CreateDiagnostic(exception)); - Console.WriteLine(exception); - Console.WriteLine(exception.StackTrace); - return false; } - } + }); + + var refAssemblyData = refAssemblyDataResults + .Select(static (result, _) => result.Model); + + var preparedProxyOutputModels = preparedProxyOutputs + .Select(static (result, _) => result.ProxyOutputModels) + .WithComparer(ImmutableArrayComparer.Instance); + + // Combine source/reference models before metadata generation. + var metadataAggregate = collectedTypes + .Combine(preparedProxyOutputModels) + .Combine(refAssemblyData) + .Select(static (input, ct) => ModelExtractor.CreateMetadataAggregate( + input.Right.AssemblyName, + input.Left.Left, + input.Left.Right, + input.Right)) + .WithTrackingName(MetadataAggregateTrackingName); + + var serializerOutputs = collectedTypes + .Combine(compilationProvider) + .Combine(generatorOptions) + .Select(static (input, ct) => SerializableSourceOutputGenerator.CreateSerializableSourceOutputs( + input.Left.Right, + input.Left.Left, + input.Right, + ct)) + .WithComparer(ImmutableArrayComparer.Instance) + .WithTrackingName(SerializerOutputsTrackingName); + + context.RegisterSourceOutput(serializerOutputs.SelectMany(static (input, _) => input), static (productionContext, input) => + { + GeneratedSourceOutput.EmitSourceOutputResult(productionContext, input); + }); + + var referencedSerializerOutputs = refAssemblyData + .Combine(compilationProvider) + .Combine(generatorOptions) + .Select(static (input, ct) => SerializableSourceOutputGenerator.CreateReferencedSerializableSourceOutputs( + input.Left.Right, + input.Left.Left, + input.Right, + ct)) + .WithComparer(ImmutableArrayComparer.Instance) + .WithTrackingName(ReferencedSerializerOutputsTrackingName); + + context.RegisterSourceOutput(referencedSerializerOutputs.SelectMany(static (input, _) => input), static (productionContext, input) => + { + GeneratedSourceOutput.EmitSourceOutputResult(productionContext, input); + }); + + var proxyOutputs = preparedProxyOutputs + .SelectMany(static (result, _) => result.SourceOutputs) + .WithTrackingName(ProxyOutputsTrackingName); + + context.RegisterSourceOutput(proxyOutputs, static (productionContext, input) => + { + GeneratedSourceOutput.EmitSourceOutputResult(productionContext, input); + }); + + context.RegisterSourceOutput(assemblyNameProvider, static (productionContext, assemblyName) => + { + productionContext.AddSource($"{assemblyName}.orleans.g.cs", SourceText.From(string.Empty, Encoding.UTF8)); + }); + + var metadataOutputs = metadataAggregate + .Combine(generatorOptions) + .Select(static (input, _) => MetadataSourceOutputGenerator.CreateMetadataSourceOutput(input.Left, input.Right)) + .WithTrackingName(MetadataOutputsTrackingName); - public void Initialize(GeneratorInitializationContext context) + context.RegisterSourceOutput(metadataOutputs, static (productionContext, input) => { - } + GeneratedSourceOutput.EmitSourceOutputResult(productionContext, input); + }); } } -#pragma warning restore RS1035 // Do not use APIs banned for analyzers diff --git a/src/Orleans.CodeGenerator/Properties/IsExternalInit.cs b/src/Orleans.CodeGenerator/Properties/IsExternalInit.cs new file mode 100644 index 00000000000..f53b57d1a8a --- /dev/null +++ b/src/Orleans.CodeGenerator/Properties/IsExternalInit.cs @@ -0,0 +1,3 @@ +namespace System.Runtime.CompilerServices; + +internal static class IsExternalInit {} diff --git a/src/Orleans.CodeGenerator/Properties/NullableAttributes.cs b/src/Orleans.CodeGenerator/Properties/NullableAttributes.cs new file mode 100644 index 00000000000..84b1a4a33da --- /dev/null +++ b/src/Orleans.CodeGenerator/Properties/NullableAttributes.cs @@ -0,0 +1,9 @@ +#if NETSTANDARD2_0 +namespace System.Diagnostics.CodeAnalysis; + +[global::System.AttributeUsage(global::System.AttributeTargets.Parameter, Inherited = false)] +internal sealed class NotNullWhenAttribute(bool returnValue) : global::System.Attribute +{ + public bool ReturnValue { get; } = returnValue; +} +#endif diff --git a/src/Orleans.CodeGenerator/PropertyUtility.cs b/src/Orleans.CodeGenerator/PropertyUtility.cs index d9d3b334dd5..51e26acf6f1 100644 --- a/src/Orleans.CodeGenerator/PropertyUtility.cs +++ b/src/Orleans.CodeGenerator/PropertyUtility.cs @@ -1,76 +1,69 @@ using Microsoft.CodeAnalysis; -using System; -using System.Collections.Generic; -using System.Linq; using System.Text.RegularExpressions; -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +public static class PropertyUtility { - public static class PropertyUtility + private static readonly Regex PropertyMatchRegex = new("^<([^>]+)>.*$", RegexOptions.Compiled); + + public static IPropertySymbol? GetMatchingProperty(IFieldSymbol field) { - private static readonly Regex PropertyMatchRegex = new("^<([^>]+)>.*$", RegexOptions.Compiled); + if (field.ContainingType is null) + return null; + return GetMatchingProperty(field, field.ContainingType.GetMembers()); + } - public static IPropertySymbol? GetMatchingProperty(IFieldSymbol field) - { - if (field.ContainingType is null) - return null; - return GetMatchingProperty(field, field.ContainingType.GetMembers()); - } + public static bool IsCompilerGenerated(this ISymbol? symbol) + => symbol?.GetAttributes().Any(a => a.AttributeClass?.Name == "CompilerGeneratedAttribute") == true; - public static bool IsCompilerGenerated(this ISymbol? symbol) - => symbol?.GetAttributes().Any(a => a.AttributeClass?.Name == "CompilerGeneratedAttribute") == true; + public static bool IsCompilerGenerated(this IPropertySymbol? property) + => property?.GetMethod.IsCompilerGenerated() == true && property.SetMethod.IsCompilerGenerated(); - public static bool IsCompilerGenerated(this IPropertySymbol? property) - => property?.GetMethod.IsCompilerGenerated() == true && property.SetMethod.IsCompilerGenerated(); + public static IParameterSymbol? GetMatchingPrimaryConstructorParameter(IPropertySymbol property, IEnumerable constructorParameters) + { + return constructorParameters.FirstOrDefault(p => + string.Equals(GetCanonicalName(p.Name), GetCanonicalName(property.Name), StringComparison.Ordinal) && + SymbolEqualityComparer.Default.Equals(p.Type, property.Type)); + } - public static IParameterSymbol? GetMatchingPrimaryConstructorParameter(IPropertySymbol property, IEnumerable constructorParameters) + public static IPropertySymbol? GetMatchingProperty(IFieldSymbol field, IEnumerable memberSymbols) + { + var propertyName = PropertyMatchRegex.Match(field.Name); + if (!propertyName.Success) { - if (!property.IsCompilerGenerated()) - return null; - - return constructorParameters.FirstOrDefault(p => - string.Equals(p.Name, property.Name, StringComparison.Ordinal) && - SymbolEqualityComparer.Default.Equals(p.Type, property.Type)); + return null; } - public static IPropertySymbol? GetMatchingProperty(IFieldSymbol field, IEnumerable memberSymbols) - { - var propertyName = PropertyMatchRegex.Match(field.Name); - if (!propertyName.Success) - { - return null; - } - - var name = propertyName.Groups[1].Value; - var candidates = memberSymbols.OfType() - .Where(property => string.Equals(name, property.Name, StringComparison.Ordinal) - && SymbolEqualityComparer.Default.Equals(field.Type, property.Type)).ToArray(); - return candidates.Length == 1 ? candidates[0] : null; - } + var name = propertyName.Groups[1].Value; + var candidates = memberSymbols.OfType() + .Where(property => string.Equals(name, property.Name, StringComparison.Ordinal) + && SymbolEqualityComparer.Default.Equals(field.Type, property.Type)).ToArray(); + return candidates.Length == 1 ? candidates[0] : null; + } - public static IFieldSymbol? GetMatchingField(IPropertySymbol property) - { - if (property.ContainingType is null) - return null; - return GetMatchingField(property, property.ContainingType.GetMembers()); - } + public static IFieldSymbol? GetMatchingField(IPropertySymbol property) + { + if (property.ContainingType is null) + return null; + return GetMatchingField(property, property.ContainingType.GetMembers()); + } - public static IFieldSymbol? GetMatchingField(IPropertySymbol property, IEnumerable memberSymbols) - { - var backingFieldName = $"<{property.Name}>k__BackingField"; - var candidates = (from field in memberSymbols.OfType() - where SymbolEqualityComparer.Default.Equals(field.Type, property.Type) - where field.Name == backingFieldName || GetCanonicalName(field.Name) == GetCanonicalName(property.Name) - select field).ToArray(); - return candidates.Length == 1 ? candidates[0] : null; - } + public static IFieldSymbol? GetMatchingField(IPropertySymbol property, IEnumerable memberSymbols) + { + var backingFieldName = $"<{property.Name}>k__BackingField"; + var candidates = (from field in memberSymbols.OfType() + where SymbolEqualityComparer.Default.Equals(field.Type, property.Type) + where field.Name == backingFieldName || GetCanonicalName(field.Name) == GetCanonicalName(property.Name) + select field).ToArray(); + return candidates.Length == 1 ? candidates[0] : null; + } - public static string GetCanonicalName(string name) - { - name = name.TrimStart('_'); - if (name.Length > 0 && char.IsUpper(name[0])) - name = $"{char.ToLowerInvariant(name[0])}{name.Substring(1)}"; - return name; - } + public static string GetCanonicalName(string name) + { + name = name.TrimStart('_'); + if (name.Length > 0 && char.IsUpper(name[0])) + name = $"{char.ToLowerInvariant(name[0])}{name.Substring(1)}"; + return name; } -} \ No newline at end of file +} diff --git a/src/Orleans.CodeGenerator/ProxyGenerationContext.cs b/src/Orleans.CodeGenerator/ProxyGenerationContext.cs new file mode 100644 index 00000000000..f25858567fe --- /dev/null +++ b/src/Orleans.CodeGenerator/ProxyGenerationContext.cs @@ -0,0 +1,241 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal sealed class ProxyGenerationContext : IGeneratorServices +{ + private readonly Dictionary> _namespacedMembers = new(); + private readonly Dictionary _invokableMethodDescriptions = new(); + private readonly HashSet _visitedInterfaces = new(SymbolEqualityComparer.Default); + + internal ProxyGenerationContext(Compilation compilation, CodeGeneratorOptions options) + : this(compilation, options, LibraryTypes.FromCompilation(compilation, options)) + { + } + + internal ProxyGenerationContext(Compilation compilation, CodeGeneratorOptions options, LibraryTypes libraryTypes) + { + Compilation = compilation ?? throw new ArgumentNullException(nameof(compilation)); + Options = options ?? throw new ArgumentNullException(nameof(options)); + LibraryTypes = libraryTypes ?? throw new ArgumentNullException(nameof(libraryTypes)); + MetadataModel = new MetadataModel(); + ProxyGenerator = new ProxyGenerator(this, new CopierGenerator(this)); + InvokableGenerator = new InvokableGenerator(this); + } + + public Compilation Compilation { get; } + public CodeGeneratorOptions Options { get; } + internal LibraryTypes LibraryTypes { get; } + LibraryTypes IGeneratorServices.LibraryTypes => LibraryTypes; + internal MetadataModel MetadataModel { get; } + internal ProxyGenerator ProxyGenerator { get; } + internal InvokableGenerator InvokableGenerator { get; } + + internal void AddMember(string ns, MemberDeclarationSyntax member) + { + if (!_namespacedMembers.TryGetValue(ns, out var existing)) + { + existing = _namespacedMembers[ns] = new List(); + } + + existing.Add(member); + } + + internal IEnumerable<(string Namespace, MemberDeclarationSyntax Member)> GetEmittedMembers() + { + foreach (var entry in _namespacedMembers) + { + foreach (var member in entry.Value) + { + yield return (entry.Key, member); + } + } + } + + internal uint? GetId(ISymbol memberSymbol) => GeneratedCodeUtilities.GetId(LibraryTypes, memberSymbol); + + internal string? GetAlias(ISymbol symbol) => GeneratedCodeUtilities.GetAlias(LibraryTypes, symbol); + + internal void VisitInterface(INamedTypeSymbol interfaceType) + { + // Get or generate an invokable for the original method definition. + if (!SymbolEqualityComparer.Default.Equals(interfaceType, interfaceType.OriginalDefinition)) + { + interfaceType = interfaceType.OriginalDefinition; + } + + if (!_visitedInterfaces.Add(interfaceType)) + { + return; + } + + foreach (var proxyBase in GetProxyBases(interfaceType)) + { + _ = GetInvokableInterfaceDescription(proxyBase.ProxyBaseType, interfaceType); + } + } + + internal bool TryGetInvokableInterfaceDescription(INamedTypeSymbol interfaceType, [NotNullWhen(true)] out ProxyInterfaceDescription? result) + { + if (!TryGetProxyBaseDescription(interfaceType, out var description)) + { + result = null; + return false; + } + + result = GetInvokableInterfaceDescription(description.ProxyBaseType, interfaceType); + return true; + } + + private readonly Dictionary> _interfaceProxyBases = new(SymbolEqualityComparer.Default); + internal List GetProxyBases(INamedTypeSymbol interfaceType) + { + if (_interfaceProxyBases.TryGetValue(interfaceType, out var result)) + { + return result; + } + + result = new List(); + if (interfaceType.GetAttributes(LibraryTypes.GenerateMethodSerializersAttribute, out var attributes, inherited: true)) + { + foreach (var attribute in attributes) + { + var proxyBase = GetProxyBaseDescription(attribute); + if (!result.Contains(proxyBase)) + { + result.Add(proxyBase); + } + } + } + + _interfaceProxyBases[interfaceType] = result; + return result; + } + + internal bool TryGetProxyBaseDescription(INamedTypeSymbol interfaceType, [NotNullWhen(true)] out InvokableMethodProxyBase? result) + { + var attribute = interfaceType.GetAttribute(LibraryTypes.GenerateMethodSerializersAttribute, inherited: true); + if (attribute == null) + { + result = null; + return false; + } + + result = GetProxyBaseDescription(attribute); + return true; + } + + private InvokableMethodProxyBase GetProxyBaseDescription(AttributeData attribute) + { + var proxyBaseType = ((INamedTypeSymbol)attribute.ConstructorArguments[0].Value!).OriginalDefinition; + var isExtension = (bool)attribute.ConstructorArguments[1].Value!; + var invokableBaseTypes = GetInvokableBaseTypes(proxyBaseType); + var descriptor = new InvokableMethodProxyBaseId(proxyBaseType, isExtension); + var description = new InvokableMethodProxyBase(this, descriptor, invokableBaseTypes); + return description; + + Dictionary GetInvokableBaseTypes(INamedTypeSymbol baseClass) + { + // Set the base invokable types which are used if attributes on individual methods do not override them. + if (!MetadataModel.ProxyBaseTypeInvokableBaseTypes.TryGetValue(baseClass, out var invokableBaseTypes)) + { + invokableBaseTypes = new Dictionary(SymbolEqualityComparer.Default); + if (baseClass.GetAttributes(LibraryTypes.DefaultInvokableBaseTypeAttribute, out var invokableBaseTypeAttributes)) + { + foreach (var attr in invokableBaseTypeAttributes) + { + var ctorArgs = attr.ConstructorArguments; + var returnType = (INamedTypeSymbol)ctorArgs[0].Value!; + var invokableBaseType = (INamedTypeSymbol)ctorArgs[1].Value!; + invokableBaseTypes[returnType] = invokableBaseType; + } + } + + MetadataModel.ProxyBaseTypeInvokableBaseTypes[baseClass] = invokableBaseTypes; + } + + return invokableBaseTypes; + } + } + + internal InvokableMethodProxyBase GetProxyBase(INamedTypeSymbol interfaceType) + { + if (!TryGetProxyBaseDescription(interfaceType, out var result)) + { + throw new InvalidOperationException($"Cannot get proxy base description for a type which does not have or inherit [{nameof(LibraryTypes.GenerateMethodSerializersAttribute)}]"); + } + + return result; + } + + private ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSymbol proxyBaseType, INamedTypeSymbol interfaceType) + { + var originalInterface = interfaceType.OriginalDefinition; + if (MetadataModel.InvokableInterfaces.TryGetValue(originalInterface, out var description)) + { + return description; + } + + description = new ProxyInterfaceDescription(this, proxyBaseType, originalInterface); + MetadataModel.InvokableInterfaces.Add(originalInterface, description); + + // Generate a proxy. + var (generatedClass, proxyDescription) = ProxyGenerator.Generate(description); + + // Emit the generated proxy + if (Compilation.GetTypeByMetadataName(proxyDescription.MetadataName) == null) + { + AddMember(proxyDescription.InterfaceDescription.GeneratedNamespace, generatedClass); + } + + return description; + } + + internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method) + { + var originalMethod = method.OriginalDefinition; + var proxyBaseInfo = GetProxyBase(interfaceType); + + // For extensions, we want to ensure that the containing type is always the extension. + // This ensures that we will always know which 'component' to get in our SetTarget method. + // If the type is not an extension, use the original method definition's containing type. + // This is the interface where the type was originally defined. + var containingType = proxyBaseInfo.IsExtension ? interfaceType : originalMethod.ContainingType; + + var invokableId = new InvokableMethodId(proxyBaseInfo, containingType, originalMethod); + var interfaceDescription = GetInvokableInterfaceDescription(invokableId.ProxyBase.ProxyBaseType, interfaceType); + + // Get or generate an invokable for the original method definition. + if (!MetadataModel.GeneratedInvokables.TryGetValue(invokableId, out var generatedInvokable)) + { + if (!_invokableMethodDescriptions.TryGetValue(invokableId, out var methodDescription)) + { + methodDescription = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId, containingType); + } + + generatedInvokable = MetadataModel.GeneratedInvokables[invokableId] = InvokableGenerator.Generate(methodDescription); + + if (Compilation.GetTypeByMetadataName(generatedInvokable.MetadataName) == null) + { + // Emit the generated code on-demand. + AddMember(generatedInvokable.GeneratedNamespace, generatedInvokable.ClassDeclarationSyntax); + } + } + + var proxyMethodDescription = ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method); + + // For backwards compatibility, generate invokers for the specific implementation types as well, where they differ. + if (Options.GenerateCompatibilityInvokers && !SymbolEqualityComparer.Default.Equals(method.OriginalDefinition.ContainingType, interfaceType)) + { + var compatInvokableId = new InvokableMethodId(proxyBaseInfo, interfaceType, method); + var compatMethodDescription = InvokableMethodDescription.Create(compatInvokableId, interfaceType); + var compatInvokable = InvokableGenerator.Generate(compatMethodDescription); + AddMember(compatInvokable.GeneratedNamespace, compatInvokable.ClassDeclarationSyntax); + } + + return proxyMethodDescription; + } +} diff --git a/src/Orleans.CodeGenerator/ProxyGenerator.cs b/src/Orleans.CodeGenerator/ProxyGenerator.cs index d9727ceb58f..aa9da3ce56d 100644 --- a/src/Orleans.CodeGenerator/ProxyGenerator.cs +++ b/src/Orleans.CodeGenerator/ProxyGenerator.cs @@ -1,6 +1,3 @@ -using System; -using System.Collections.Generic; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -11,424 +8,421 @@ using static Orleans.CodeGenerator.InvokableGenerator; using static Orleans.CodeGenerator.SerializerGenerator; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +/// +/// Generates RPC stub objects called invokers. +/// +internal class ProxyGenerator(IGeneratorServices generatorServices, CopierGenerator copierGenerator) { - /// - /// Generates RPC stub objects called invokers. - /// - internal class ProxyGenerator - { - private const string CopyContextPoolMemberName = "CopyContextPool"; - private const string CodecProviderMemberName = "CodecProvider"; - private readonly CodeGenerator _codeGenerator; + private const string CopyContextPoolMemberName = "CopyContextPool"; + private const string CodecProviderMemberName = "CodecProvider"; + private readonly IGeneratorServices _generatorServices = generatorServices; + private readonly CopierGenerator _copierGenerator = copierGenerator; + + private LibraryTypes LibraryTypes => _generatorServices.LibraryTypes; - public ProxyGenerator(CodeGenerator codeGenerator) + public (ClassDeclarationSyntax, GeneratedProxyDescription) Generate(ProxyInterfaceDescription interfaceDescription) + { + var generatedClassName = GetSimpleClassName(interfaceDescription); + + var fieldDescriptions = GetFieldDescriptions(interfaceDescription); + var fieldDeclarations = GetFieldDeclarations(fieldDescriptions); + var proxyMethods = CreateProxyMethods(fieldDescriptions, interfaceDescription); + + var ctors = GenerateConstructors(generatedClassName, fieldDescriptions, interfaceDescription.ProxyBaseType); + + var classDeclaration = ClassDeclaration(generatedClassName) + .AddBaseListTypes( + SimpleBaseType(interfaceDescription.ProxyBaseType.ToTypeSyntax()), + SimpleBaseType(interfaceDescription.InterfaceType.ToTypeSyntax())) + .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword)) + .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes()) + .AddMembers(fieldDeclarations) + .AddMembers(ctors) + .AddMembers(proxyMethods); + + var typeParameters = interfaceDescription.TypeParameters; + if (typeParameters.Count > 0) { - _codeGenerator = codeGenerator; + classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, typeParameters); } - private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; + return (classDeclaration, new GeneratedProxyDescription(interfaceDescription, generatedClassName)); + } + + public static string GetSimpleClassName(ProxyInterfaceDescription interfaceDescription) + => GetSimpleClassName(interfaceDescription.Name); - public (ClassDeclarationSyntax, GeneratedProxyDescription) Generate(ProxyInterfaceDescription interfaceDescription) - { - var generatedClassName = GetSimpleClassName(interfaceDescription); - - var fieldDescriptions = GetFieldDescriptions(interfaceDescription); - var fieldDeclarations = GetFieldDeclarations(fieldDescriptions); - var proxyMethods = CreateProxyMethods(fieldDescriptions, interfaceDescription); - - var ctors = GenerateConstructors(generatedClassName, fieldDescriptions, interfaceDescription.ProxyBaseType); - - var classDeclaration = ClassDeclaration(generatedClassName) - .AddBaseListTypes( - SimpleBaseType(interfaceDescription.ProxyBaseType.ToTypeSyntax()), - SimpleBaseType(interfaceDescription.InterfaceType.ToTypeSyntax())) - .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.SealedKeyword)) - .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) - .AddMembers(fieldDeclarations) - .AddMembers(ctors) - .AddMembers(proxyMethods); - - var typeParameters = interfaceDescription.TypeParameters; - if (typeParameters.Count > 0) - { - classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, typeParameters); - } + public static string GetSimpleClassName(string name) + => $"Proxy_{SyntaxGeneration.Identifier.SanitizeIdentifierName(name)}"; - return (classDeclaration, new GeneratedProxyDescription(interfaceDescription, generatedClassName)); + private List GetFieldDescriptions( + ProxyInterfaceDescription interfaceDescription) + { + var fields = new List(); + + // Add a copier field for any method parameter which does not have a static codec. + var paramCopiers = interfaceDescription.Methods + .Where(method => method.MethodTypeParameters.Count == 0) + .SelectMany(method => method.GeneratedInvokable.Members); + _copierGenerator.GetCopierFieldDescriptions(paramCopiers, fields); + return fields; + } + + private static MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions) + { + return [.. fieldDescriptions.Select(GetFieldDeclaration)]; + + static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description) + { + return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName)))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); } + } - public static string GetSimpleClassName(ProxyInterfaceDescription interfaceDescription) - => $"Proxy_{SyntaxGeneration.Identifier.SanitizeIdentifierName(interfaceDescription.Name)}"; - - private List GetFieldDescriptions( - ProxyInterfaceDescription interfaceDescription) + private MemberDeclarationSyntax[] CreateProxyMethods( + List fieldDescriptions, + ProxyInterfaceDescription interfaceDescription) + { + var res = new List(); + foreach (var methodDescription in interfaceDescription.Methods) { - var fields = new List(); - - // Add a copier field for any method parameter which does not have a static codec. - var paramCopiers = interfaceDescription.Methods - .Where(method => method.MethodTypeParameters.Count == 0) - .SelectMany(method => method.GeneratedInvokable.Members); - _codeGenerator.CopierGenerator.GetCopierFieldDescriptions(paramCopiers, fields); - return fields; + res.Add(CreateProxyMethod(methodDescription)); } + return [.. res]; - private MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions) + MethodDeclarationSyntax CreateProxyMethod(ProxyMethodDescription methodDescription) { - return fieldDescriptions.Select(GetFieldDeclaration).ToArray(); + var (isAsync, body) = CreateAsyncProxyMethodBody(fieldDescriptions, methodDescription); + var method = methodDescription.Method; + var declaration = MethodDeclaration(method.ReturnType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions), method.Name.EscapeIdentifier()) + .AddParameterListParameters([.. method.Parameters.Select((p, i) => GetParameterSyntax(i, p, methodDescription.TypeParameterSubstitutions))]) + .WithBody(body); + + if (isAsync) + { + declaration = declaration.WithModifiers(TokenList(Token(SyntaxKind.AsyncKeyword))); + } + + var explicitInterfaceSpecifier = ExplicitInterfaceSpecifier(methodDescription.Method.ContainingType.ToNameSyntax()); + declaration = declaration.WithExplicitInterfaceSpecifier(explicitInterfaceSpecifier); - static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description) + if (methodDescription.MethodTypeParameters.Count > 0) { - return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName)))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + declaration = declaration.WithTypeParameterList( + TypeParameterList(SeparatedList(methodDescription.MethodTypeParameters.Select(tp => TypeParameter(tp.Name))))); } + + return declaration; } + } - private MemberDeclarationSyntax[] CreateProxyMethods( - List fieldDescriptions, - ProxyInterfaceDescription interfaceDescription) + private (bool IsAsync, BlockSyntax body) CreateAsyncProxyMethodBody( + List fieldDescriptions, + ProxyMethodDescription methodDescription) + { + var statements = new List(); + var requestVar = IdentifierName("request"); + var methodSymbol = methodDescription.Method; + var invokable = methodDescription.GeneratedInvokable; + ExpressionSyntax createRequestExpr = (!invokable.IsEmptyConstructable || invokable.UseActivator) switch { - var res = new List(); - foreach (var methodDescription in interfaceDescription.Methods) + true => InvocationExpression(ThisExpression().Member("GetInvokable", invokable.TypeSyntax)) + .WithArgumentList(ArgumentList(SeparatedList())), + _ => ObjectCreationExpression(invokable.TypeSyntax).WithArgumentList(ArgumentList()) + }; + + statements.Add( + LocalDeclarationStatement( + VariableDeclaration( + ParseTypeName("var"), + SingletonSeparatedList( + VariableDeclarator( + Identifier("request")) + .WithInitializer( + EqualsValueClause(createRequestExpr)))))); + + var codecs = fieldDescriptions.OfType() + .Concat(_generatorServices.LibraryTypes.StaticCopiers) + .ToList(); + + // Set request object fields from method parameters. + var parameterIndex = 0; + var parameters = invokable.Members.OfType().Select(member => new SerializableMethodMember(member)); + ExpressionSyntax copyContextPool = BaseExpression().Member(CopyContextPoolMemberName); + ExpressionSyntax copyContextVariable = IdentifierName("copyContext"); + var hasCopyContext = false; + foreach (var parameter in parameters) + { + // Only create a copy context as needed. + if (!hasCopyContext && !parameter.IsShallowCopyable) { - res.Add(CreateProxyMethod(methodDescription)); + // C#: using var copyContext = base.CopyContext.GetContext(); + statements.Add( + LocalDeclarationStatement( + VariableDeclaration( + ParseTypeName("var"), + SingletonSeparatedList( + VariableDeclarator(Identifier("copyContext")).WithInitializer( + EqualsValueClause(InvocationExpression( + copyContextPool.Member("GetContext"), + ArgumentList())))))).WithUsingKeyword(Token(SyntaxKind.UsingKeyword))); + hasCopyContext = true; } - return res.ToArray(); - MethodDeclarationSyntax CreateProxyMethod(ProxyMethodDescription methodDescription) - { - var (isAsync, body) = CreateAsyncProxyMethodBody(fieldDescriptions, methodDescription); - var method = methodDescription.Method; - var declaration = MethodDeclaration(method.ReturnType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions), method.Name.EscapeIdentifier()) - .AddParameterListParameters(method.Parameters.Select((p, i) => GetParameterSyntax(i, p, methodDescription.TypeParameterSubstitutions)).ToArray()) - .WithBody(body); + var valueExpression = _copierGenerator.GenerateMemberCopy( + fieldDescriptions, + IdentifierName($"arg{parameterIndex}"), + copyContextVariable, + codecs, + parameter); - if (isAsync) - { - declaration = declaration.WithModifiers(TokenList(Token(SyntaxKind.AsyncKeyword))); - } + statements.Add( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + requestVar.Member($"arg{parameterIndex}"), + valueExpression))); - var explicitInterfaceSpecifier = ExplicitInterfaceSpecifier(methodDescription.Method.ContainingType.ToNameSyntax()); - declaration = declaration.WithExplicitInterfaceSpecifier(explicitInterfaceSpecifier); + parameterIndex++; + } - if (methodDescription.MethodTypeParameters.Count > 0) + string? invokeMethodName = default; + foreach (var attr in methodDescription.Method.GetAttributes()) + { + if (attr.AttributeClass is { } attributeClass && attributeClass.GetAttributes(LibraryTypes.InvokeMethodNameAttribute, out var attrs)) + { + foreach (var methodAttr in attrs) { - declaration = declaration.WithTypeParameterList( - TypeParameterList(SeparatedList(methodDescription.MethodTypeParameters.Select(tp => TypeParameter(tp.Name))))); + invokeMethodName = (string?)methodAttr.ConstructorArguments.First().Value; } - - return declaration; } } - private (bool IsAsync, BlockSyntax body) CreateAsyncProxyMethodBody( - List fieldDescriptions, - ProxyMethodDescription methodDescription) + var methodReturnType = methodDescription.Method.ReturnType; + if (methodReturnType is not INamedTypeSymbol namedMethodReturnType) { - var statements = new List(); - var requestVar = IdentifierName("request"); - var methodSymbol = methodDescription.Method; - var invokable = methodDescription.GeneratedInvokable; - ExpressionSyntax createRequestExpr = (!invokable.IsEmptyConstructable || invokable.UseActivator) switch - { - true => InvocationExpression(ThisExpression().Member("GetInvokable", invokable.TypeSyntax)) - .WithArgumentList(ArgumentList(SeparatedList())), - _ => ObjectCreationExpression(invokable.TypeSyntax).WithArgumentList(ArgumentList()) - }; + var diagnostic = InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(methodDescription.InvokableMethod); + throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); + } - statements.Add( - LocalDeclarationStatement( - VariableDeclaration( - ParseTypeName("var"), - SingletonSeparatedList( - VariableDeclarator( - Identifier("request")) - .WithInitializer( - EqualsValueClause(createRequestExpr)))))); - - var codecs = fieldDescriptions.OfType() - .Concat(_codeGenerator.LibraryTypes.StaticCopiers) - .ToList(); - - // Set request object fields from method parameters. - var parameterIndex = 0; - var parameters = invokable.Members.OfType().Select(member => new SerializableMethodMember(member)); - ExpressionSyntax copyContextPool = BaseExpression().Member(CopyContextPoolMemberName); - ExpressionSyntax copyContextVariable = IdentifierName("copyContext"); - var hasCopyContext = false; - foreach (var parameter in parameters) - { - // Only create a copy context as needed. - if (!hasCopyContext && !parameter.IsShallowCopyable) - { - // C#: using var copyContext = base.CopyContext.GetContext(); - statements.Add( - LocalDeclarationStatement( - VariableDeclaration( - ParseTypeName("var"), - SingletonSeparatedList( - VariableDeclarator(Identifier("copyContext")).WithInitializer( - EqualsValueClause(InvocationExpression( - copyContextPool.Member("GetContext"), - ArgumentList())))))).WithUsingKeyword(Token(SyntaxKind.UsingKeyword))); - hasCopyContext = true; - } + ExpressionSyntax baseInvokeExpression; + var isVoid = methodReturnType.SpecialType is SpecialType.System_Void; + if (namedMethodReturnType.TypeArguments.Length == 1) + { + // Task / ValueTask + var resultType = namedMethodReturnType.TypeArguments[0]; + baseInvokeExpression = BaseExpression().Member( + invokeMethodName ?? "InvokeAsync", + resultType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions)); + } + else if (isVoid) + { + // void + baseInvokeExpression = BaseExpression().Member(invokeMethodName ?? "Invoke"); + } + else + { + // Task / ValueTask + baseInvokeExpression = BaseExpression().Member(invokeMethodName ?? "InvokeAsync"); + } - var valueExpression = _codeGenerator.CopierGenerator.GenerateMemberCopy( - fieldDescriptions, - IdentifierName($"arg{parameterIndex}"), - copyContextVariable, - codecs, - parameter); + // C#: base.InvokeAsync(request); + var invocationExpression = + InvocationExpression( + baseInvokeExpression, + ArgumentList(SeparatedList([Argument(requestVar)]))); - statements.Add( - ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - requestVar.Member($"arg{parameterIndex}"), - valueExpression))); + var rt = namedMethodReturnType.ConstructedFrom; + bool isAsync; + if (SymbolEqualityComparer.Default.Equals(rt, LibraryTypes.Task_1) || SymbolEqualityComparer.Default.Equals(methodReturnType, LibraryTypes.Task)) + { + // C#: return .AsTask() + statements.Add(ReturnStatement(InvocationExpression(invocationExpression.Member("AsTask"), ArgumentList()))); + isAsync = false; + } + else if (SymbolEqualityComparer.Default.Equals(rt, LibraryTypes.ValueTask_1) || SymbolEqualityComparer.Default.Equals(methodReturnType, LibraryTypes.ValueTask)) + { + // ValueTask / ValueTask + // C#: return + statements.Add(ReturnStatement(invocationExpression)); + isAsync = false; + } + else if (invokable.ReturnValueInitializerMethod is { } returnValueInitializerMethod) + { + // C#: return request.(this); + statements.Add(ReturnStatement(InvocationExpression(requestVar.Member(returnValueInitializerMethod), ArgumentList(SingletonSeparatedList(Argument(ThisExpression())))))); + isAsync = false; + } + else if (isVoid) + { + // C#: + statements.Add(ExpressionStatement(invocationExpression)); + isAsync = false; + } + else if (rt.Arity == 0) + { + // C#: await + statements.Add(ExpressionStatement(AwaitExpression(invocationExpression))); + isAsync = true; + } + else + { + // C#: return await + statements.Add(ReturnStatement(AwaitExpression(invocationExpression))); + isAsync = true; + } - parameterIndex++; - } + return (isAsync, Block(statements)); + } - string invokeMethodName = default; - foreach (var attr in methodDescription.Method.GetAttributes()) - { - if (attr.AttributeClass.GetAttributes(LibraryTypes.InvokeMethodNameAttribute, out var attrs)) - { - foreach (var methodAttr in attrs) - { - invokeMethodName = (string)methodAttr.ConstructorArguments.First().Value; - } - } - } + private MemberDeclarationSyntax[] GenerateConstructors( + string simpleClassName, + List fieldDescriptions, + INamedTypeSymbol baseType) + { + if (baseType is null) + { + return []; + } - var methodReturnType = methodDescription.Method.ReturnType; - if (methodReturnType is not INamedTypeSymbol namedMethodReturnType) + var bodyStatements = GetBodyStatements(); + var res = new List(); + foreach (var member in baseType.GetMembers()) + { + if (member is not IMethodSymbol method) { - var diagnostic = InvalidRpcMethodReturnTypeDiagnostic.CreateDiagnostic(methodDescription.InvokableMethod); - throw new OrleansGeneratorDiagnosticAnalysisException(diagnostic); + continue; } - ExpressionSyntax baseInvokeExpression; - var isVoid = methodReturnType.SpecialType is SpecialType.System_Void; - if (namedMethodReturnType.TypeArguments.Length == 1) + if (method.MethodKind != MethodKind.Constructor) { - // Task / ValueTask - var resultType = namedMethodReturnType.TypeArguments[0]; - baseInvokeExpression = BaseExpression().Member( - invokeMethodName ?? "InvokeAsync", - resultType.ToTypeSyntax(methodDescription.TypeParameterSubstitutions)); + continue; } - else if (isVoid) - { - // void - baseInvokeExpression = BaseExpression().Member(invokeMethodName ?? "Invoke"); - } - else - { - // Task / ValueTask - baseInvokeExpression = BaseExpression().Member(invokeMethodName ?? "InvokeAsync"); - } - - // C#: base.InvokeAsync(request); - var invocationExpression = - InvocationExpression( - baseInvokeExpression, - ArgumentList(SeparatedList(new[] { Argument(requestVar) }))); - var rt = namedMethodReturnType.ConstructedFrom; - bool isAsync; - if (SymbolEqualityComparer.Default.Equals(rt, LibraryTypes.Task_1) || SymbolEqualityComparer.Default.Equals(methodReturnType, LibraryTypes.Task)) - { - // C#: return .AsTask() - statements.Add(ReturnStatement(InvocationExpression(invocationExpression.Member("AsTask"), ArgumentList()))); - isAsync = false; - } - else if (SymbolEqualityComparer.Default.Equals(rt, LibraryTypes.ValueTask_1) || SymbolEqualityComparer.Default.Equals(methodReturnType, LibraryTypes.ValueTask)) + if (method.DeclaredAccessibility == Accessibility.Private) { - // ValueTask / ValueTask - // C#: return - statements.Add(ReturnStatement(invocationExpression)); - isAsync = false; - } - else if (invokable.ReturnValueInitializerMethod is { } returnValueInitializerMethod) - { - // C#: return request.(this); - statements.Add(ReturnStatement(InvocationExpression(requestVar.Member(returnValueInitializerMethod), ArgumentList(SingletonSeparatedList(Argument(ThisExpression())))))); - isAsync = false; - } - else if (isVoid) - { - // C#: - statements.Add(ExpressionStatement(invocationExpression)); - isAsync = false; - } - else if (rt.Arity == 0) - { - // C#: await - statements.Add(ExpressionStatement(AwaitExpression(invocationExpression))); - isAsync = true; - } - else - { - // C#: return await - statements.Add(ReturnStatement(AwaitExpression(invocationExpression))); - isAsync = true; + continue; } - return (isAsync, Block(statements)); + res.Add(CreateConstructor(method)); } + return [.. res]; - private MemberDeclarationSyntax[] GenerateConstructors( - string simpleClassName, - List fieldDescriptions, - INamedTypeSymbol baseType) + ConstructorDeclarationSyntax CreateConstructor(IMethodSymbol baseConstructor) { - if (baseType is null) - { - return Array.Empty(); - } + return ConstructorDeclaration(simpleClassName) + .AddParameterListParameters([.. baseConstructor.Parameters.Select((p, i) => GetParameterSyntax(i, p, typeParameterSubstitutions: null))]) + .WithModifiers(TokenList(GetModifiers(baseConstructor))) + .WithInitializer( + ConstructorInitializer( + SyntaxKind.BaseConstructorInitializer, + ArgumentList( + SeparatedList(baseConstructor.Parameters.Select(GetBaseInitializerArgument))))) + .WithBody(Block(bodyStatements)); + } - var bodyStatements = GetBodyStatements(); - var res = new List(); - foreach (var member in baseType.GetMembers()) + static SyntaxToken[] GetModifiers(IMethodSymbol method) + { + switch (method.DeclaredAccessibility) { - if (member is not IMethodSymbol method) - { - continue; - } - - if (method.MethodKind != MethodKind.Constructor) - { - continue; - } - - if (method.DeclaredAccessibility == Accessibility.Private) - { - continue; - } - - res.Add(CreateConstructor(method)); + case Accessibility.Public: + case Accessibility.Protected: + return [Token(SyntaxKind.PublicKeyword)]; + case Accessibility.Internal: + case Accessibility.ProtectedOrInternal: + case Accessibility.ProtectedAndInternal: + return [Token(SyntaxKind.InternalKeyword)]; + default: + return []; } - return res.ToArray(); + } - ConstructorDeclarationSyntax CreateConstructor(IMethodSymbol baseConstructor) + static ArgumentSyntax GetBaseInitializerArgument(IParameterSymbol parameter, int index) + { + var name = $"arg{index}"; + var result = Argument(IdentifierName(name)); + switch (parameter.RefKind) { - return ConstructorDeclaration(simpleClassName) - .AddParameterListParameters(baseConstructor.Parameters.Select((p, i) => GetParameterSyntax(i, p, typeParameterSubstitutions: null)).ToArray()) - .WithModifiers(TokenList(GetModifiers(baseConstructor))) - .WithInitializer( - ConstructorInitializer( - SyntaxKind.BaseConstructorInitializer, - ArgumentList( - SeparatedList(baseConstructor.Parameters.Select(GetBaseInitializerArgument))))) - .WithBody(Block(bodyStatements)); + case RefKind.None: + break; + case RefKind.Ref: + result = result.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); + break; + case RefKind.Out: + result = result.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)); + break; + default: + break; } - static SyntaxToken[] GetModifiers(IMethodSymbol method) - { - switch (method.DeclaredAccessibility) - { - case Accessibility.Public: - case Accessibility.Protected: - return new[] { Token(SyntaxKind.PublicKeyword) }; - case Accessibility.Internal: - case Accessibility.ProtectedOrInternal: - case Accessibility.ProtectedAndInternal: - return new[] { Token(SyntaxKind.InternalKeyword) }; - default: - return Array.Empty(); - } - } + return result; + } - static ArgumentSyntax GetBaseInitializerArgument(IParameterSymbol parameter, int index) + List GetBodyStatements() + { + var res = new List(); + foreach (var field in fieldDescriptions) { - var name = $"arg{index}"; - var result = Argument(IdentifierName(name)); - switch (parameter.RefKind) + switch (field) { - case RefKind.None: - break; - case RefKind.Ref: - result = result.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - break; - case RefKind.Out: - result = result.WithRefOrOutKeyword(Token(SyntaxKind.OutKeyword)); - break; - default: + case GeneratedFieldDescription _ when field.IsInjected: + res.Add(ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + ThisExpression().Member(field.FieldName.ToIdentifierName()), + Unwrapped(field.FieldName.ToIdentifierName())))); break; - } - - return result; - } - - List GetBodyStatements() - { - var res = new List(); - foreach (var field in fieldDescriptions) - { - switch (field) - { - case GeneratedFieldDescription _ when field.IsInjected: + case CopierFieldDescription codec: + { res.Add(ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, - ThisExpression().Member(field.FieldName.ToIdentifierName()), - Unwrapped(field.FieldName.ToIdentifierName())))); - break; - case CopierFieldDescription codec: - { - res.Add(ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - field.FieldName.ToIdentifierName(), - GetService(field.FieldType)))); - } - break; - } - } - return res; - - static ExpressionSyntax Unwrapped(ExpressionSyntax expr) - { - return InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), IdentifierName("UnwrapService")), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(expr) }))); + field.FieldName.ToIdentifierName(), + GetService(field.FieldType)))); + } + break; } + } + return res; - static ExpressionSyntax GetService(TypeSyntax type) - { - return InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(type)))), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(IdentifierName(CodecProviderMemberName)) }))); - } + static ExpressionSyntax Unwrapped(ExpressionSyntax expr) + { + return InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), IdentifierName("UnwrapService")), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(expr)]))); } - } - private ParameterSyntax GetParameterSyntax(int index, IParameterSymbol parameter, Dictionary typeParameterSubstitutions) - { - var result = Parameter(Identifier($"arg{index}")).WithType(parameter.Type.ToTypeSyntax(typeParameterSubstitutions)); - switch (parameter.RefKind) + static ExpressionSyntax GetService(TypeSyntax type) { - case RefKind.None: - break; - case RefKind.Ref: - result = result.WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))); - break; - case RefKind.Out: - result = result.WithModifiers(TokenList(Token(SyntaxKind.OutKeyword))); - break; - case RefKind.In: - result = result.WithModifiers(TokenList(Token(SyntaxKind.InKeyword))); - break; - default: - break; + return InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("OrleansGeneratedCodeHelper"), GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(type)))), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(IdentifierName(CodecProviderMemberName))]))); } + } + } - return result; + private static ParameterSyntax GetParameterSyntax(int index, IParameterSymbol parameter, Dictionary? typeParameterSubstitutions) + { + var result = Parameter(Identifier($"arg{index}")).WithType(parameter.Type.ToTypeSyntax(typeParameterSubstitutions)); + switch (parameter.RefKind) + { + case RefKind.None: + break; + case RefKind.Ref: + result = result.WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))); + break; + case RefKind.Out: + result = result.WithModifiers(TokenList(Token(SyntaxKind.OutKeyword))); + break; + case RefKind.In: + result = result.WithModifiers(TokenList(Token(SyntaxKind.InKeyword))); + break; + default: + break; } + + return result; } } diff --git a/src/Orleans.CodeGenerator/ProxySourceOutputGenerator.cs b/src/Orleans.CodeGenerator/ProxySourceOutputGenerator.cs new file mode 100644 index 00000000000..9517a12e45c --- /dev/null +++ b/src/Orleans.CodeGenerator/ProxySourceOutputGenerator.cs @@ -0,0 +1,430 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal static class ProxySourceOutputGenerator +{ + internal static SourceOutputResult CreateProxySourceOutput( + Compilation compilation, + TypeSymbolResolver resolver, + ProxyOutputModel proxyOutputModel, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + try + { + SourceGeneratorOptionsParser.AttachDebuggerIfRequested(options); + var codeGeneratorOptions = SourceGeneratorOptionsParser.CreateCodeGeneratorOptions(options); + var generatorServices = new GeneratorServices(compilation, codeGeneratorOptions); + var proxyContext = new ProxyGenerationContext(compilation, codeGeneratorOptions); + var model = proxyOutputModel.ProxyInterface; + PopulateProxyInterfaces(proxyContext, resolver, [model], cancellationToken); + + var assemblyName = compilation.AssemblyName ?? "assembly"; + var interfaceDescription = GetProxyInterfaceDescription(proxyContext, resolver, model, cancellationToken); + var proxyGenerator = new ProxyGenerator(generatorServices, new CopierGenerator(generatorServices)); + var (proxyClass, _) = proxyGenerator.Generate(interfaceDescription); + var targetHintName = GeneratedSourceOutput.CreateProxyHintName(assemblyName, interfaceDescription); + var ownedInvokableMetadataNames = new HashSet( + proxyOutputModel.OwnedInvokableMetadataNames, + StringComparer.Ordinal); + var emitDeclaredMethodsFallback = proxyOutputModel.UseDeclaredInvokableFallback; + var generatedInvokables = GetGeneratedInvokables(proxyContext, interfaceDescription).ToImmutableArray(); + var generatedInvokableClassNames = new HashSet( + generatedInvokables.Select(static invokable => invokable.ClassDeclarationSyntax.Identifier.ValueText), + StringComparer.Ordinal); + var additionalInvokableClasses = proxyContext.GetEmittedMembers() + .Where(entry => entry.Member is ClassDeclarationSyntax classDeclaration + && !string.Equals(classDeclaration.Identifier.ValueText, proxyClass.Identifier.ValueText, StringComparison.Ordinal) + && !generatedInvokableClassNames.Contains(classDeclaration.Identifier.ValueText)) + .Select(entry => (entry.Namespace, ClassDeclaration: (ClassDeclarationSyntax)entry.Member)) + .OrderBy(static entry => entry.Namespace, StringComparer.Ordinal) + .ThenBy(static entry => entry.ClassDeclaration.Identifier.ValueText, StringComparer.Ordinal) + .ToImmutableArray(); + + var serializerGenerator = new SerializerGenerator(generatorServices); + var copierGenerator = new CopierGenerator(generatorServices); + var activatorGenerator = new ActivatorGenerator(generatorServices); + var emittedInvokables = generatedInvokables + .Where(invokable => ShouldEmitInvokable( + invokable, + interfaceDescription.InterfaceType, + ownedInvokableMetadataNames, + emitDeclaredMethodsFallback)) + .ToImmutableArray(); + + var namespacedMembers = new Dictionary>(StringComparer.Ordinal); + foreach (var invokable in emittedInvokables) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, invokable.ClassDeclarationSyntax); + } + + GeneratedSourceOutput.AddMember(namespacedMembers, interfaceDescription.GeneratedNamespace, proxyClass); + + foreach (var invokable in emittedInvokables) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, serializerGenerator.Generate(invokable)); + + var copier = invokable.IsShallowCopyable && proxyContext.MetadataModel.DefaultCopiers.ContainsKey(invokable) + ? null + : copierGenerator.GenerateCopier(invokable, proxyContext.MetadataModel.DefaultCopiers); + if (copier is not null) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, copier); + } + + if (ActivatorGenerator.ShouldGenerateActivator(invokable)) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, activatorGenerator.GenerateActivator(invokable)); + } + } + + foreach (var (generatedNamespace, classDeclaration) in additionalInvokableClasses) + { + GeneratedSourceOutput.AddMember(namespacedMembers, generatedNamespace, classDeclaration); + } + + return SourceOutputResult.FromSource( + new GeneratedSourceEntry(targetHintName, GeneratedSourceOutput.CreateSourceString(GeneratedSourceOutput.CreateCompilationUnit(namespacedMembers)))); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return SourceOutputResult.FromDiagnostic(analysisException.Diagnostic); + } + } + + internal static SourceOutputResult CreateProxySourceOutput( + ProxyGenerationContext proxyContext, + IGeneratorServices generatorServices, + TypeSymbolResolver resolver, + string assemblyName, + ProxyOutputModel proxyOutputModel, + CancellationToken cancellationToken) + { + try + { + cancellationToken.ThrowIfCancellationRequested(); + var model = proxyOutputModel.ProxyInterface; + var interfaceDescription = GetProxyInterfaceDescription(proxyContext, resolver, model, cancellationToken); + var proxyGenerator = new ProxyGenerator(generatorServices, new CopierGenerator(generatorServices)); + var (proxyClass, _) = proxyGenerator.Generate(interfaceDescription); + var targetHintName = GeneratedSourceOutput.CreateProxyHintName(assemblyName, interfaceDescription); + var ownedInvokableMetadataNames = new HashSet( + proxyOutputModel.OwnedInvokableMetadataNames, + StringComparer.Ordinal); + var emitDeclaredMethodsFallback = proxyOutputModel.UseDeclaredInvokableFallback; + var serializerGenerator = new SerializerGenerator(generatorServices); + var copierGenerator = new CopierGenerator(generatorServices); + var activatorGenerator = new ActivatorGenerator(generatorServices); + var defaultCopiers = new Dictionary(); + var generatedInvokables = GetGeneratedInvokables(proxyContext, interfaceDescription).ToImmutableArray(); + var emittedInvokables = generatedInvokables + .Where(invokable => ShouldEmitInvokable( + invokable, + interfaceDescription.InterfaceType, + ownedInvokableMetadataNames, + emitDeclaredMethodsFallback)) + .ToImmutableArray(); + + var namespacedMembers = new Dictionary>(StringComparer.Ordinal); + foreach (var invokable in emittedInvokables) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, invokable.ClassDeclarationSyntax); + } + + GeneratedSourceOutput.AddMember(namespacedMembers, interfaceDescription.GeneratedNamespace, proxyClass); + + foreach (var invokable in emittedInvokables) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, serializerGenerator.Generate(invokable)); + + var copier = invokable.IsShallowCopyable && defaultCopiers.ContainsKey(invokable) + ? null + : copierGenerator.GenerateCopier(invokable, defaultCopiers); + if (copier is not null) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, copier); + } + + if (ActivatorGenerator.ShouldGenerateActivator(invokable)) + { + GeneratedSourceOutput.AddMember(namespacedMembers, invokable.GeneratedNamespace, activatorGenerator.GenerateActivator(invokable)); + } + } + + return SourceOutputResult.FromSource( + new GeneratedSourceEntry(targetHintName, GeneratedSourceOutput.CreateSourceString(GeneratedSourceOutput.CreateCompilationUnit(namespacedMembers)))); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return SourceOutputResult.FromDiagnostic(analysisException.Diagnostic); + } + } + + internal static ProxyOutputPreparationResult CreateProxyOutputPreparation( + Compilation compilation, + ImmutableArray models, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + try + { + if (models.IsDefaultOrEmpty) + { + return ProxyOutputPreparationResult.FromModelsAndSources( + [], + []); + } + + var codeGeneratorOptions = SourceGeneratorOptionsParser.CreateCodeGeneratorOptions(options); + var libraryTypes = LibraryTypes.FromCompilation(compilation, codeGeneratorOptions); + var generatorServices = new GeneratorServices(compilation, codeGeneratorOptions, libraryTypes); + var proxyContext = new ProxyGenerationContext(compilation, codeGeneratorOptions, libraryTypes); + var resolver = new TypeSymbolResolver(compilation); + PopulateProxyInterfaces(proxyContext, resolver, models, cancellationToken); + + var proxyOutputModels = CreateProxyOutputModels( + compilation, + proxyContext, + resolver, + models, + cancellationToken); + + return ProxyOutputPreparationResult.FromModelsAndSources( + proxyOutputModels, + CreateProxySourceOutputs( + compilation, + proxyContext, + generatorServices, + resolver, + proxyOutputModels, + options, + cancellationToken)); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return ProxyOutputPreparationResult.FromDiagnostic(analysisException.Diagnostic); + } + } + + internal static ImmutableArray CreateProxySourceOutputs( + Compilation compilation, + ProxyGenerationContext proxyContext, + IGeneratorServices generatorServices, + TypeSymbolResolver resolver, + ImmutableArray proxyOutputModels, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + if (proxyOutputModels.IsDefaultOrEmpty) + { + return []; + } + + var sourceOutputs = ImmutableArray.CreateBuilder(proxyOutputModels.Length); + if (options.GenerateCompatibilityInvokers) + { + foreach (var proxyOutputModel in proxyOutputModels) + { + cancellationToken.ThrowIfCancellationRequested(); + sourceOutputs.Add(CreateProxySourceOutput(compilation, resolver, proxyOutputModel, options, cancellationToken)); + } + } + else + { + SourceGeneratorOptionsParser.AttachDebuggerIfRequested(options); + var assemblyName = compilation.AssemblyName ?? "assembly"; + foreach (var proxyOutputModel in proxyOutputModels) + { + cancellationToken.ThrowIfCancellationRequested(); + sourceOutputs.Add(CreateProxySourceOutput( + proxyContext, + generatorServices, + resolver, + assemblyName, + proxyOutputModel, + cancellationToken)); + } + } + + return GeneratedSourceOutput.DeduplicateSourceOutputs(sourceOutputs); + } + + internal static ImmutableArray CreateProxyOutputModels( + Compilation compilation, + ProxyGenerationContext proxyContext, + TypeSymbolResolver resolver, + ImmutableArray models, + CancellationToken cancellationToken) + { + if (models.IsDefaultOrEmpty) + { + return []; + } + + var assemblyName = compilation.AssemblyName ?? "assembly"; + var proxyEntries = proxyContext.MetadataModel.InvokableInterfaces.Values + .Where(desc => SymbolEqualityComparer.Default.Equals(desc.InterfaceType.ContainingAssembly, compilation.Assembly)) + .OrderBy(static desc => desc.InterfaceType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), StringComparer.Ordinal) + .Select(desc => (HintName: GeneratedSourceOutput.CreateProxyHintName(assemblyName, desc), Description: desc)) + .ToImmutableArray(); + + var invokableOwners = new Dictionary(StringComparer.Ordinal); + foreach (var entry in proxyEntries.OrderBy(static entry => entry.HintName, StringComparer.Ordinal)) + { + foreach (var invokable in GetGeneratedInvokables(proxyContext, entry.Description)) + { + if (!invokableOwners.TryGetValue(invokable.MetadataName, out _)) + { + invokableOwners.Add(invokable.MetadataName, entry.HintName); + } + } + } + + return [.. ModelExtractor.DeduplicateProxyInterfaces(models) + .OrderBy(static model => model.SourceLocation.SourceOrderGroup) + .ThenBy(static model => model.SourceLocation.FilePath, StringComparer.Ordinal) + .ThenBy(static model => model.SourceLocation.Position) + .ThenBy(static model => model.InterfaceType.SyntaxString, StringComparer.Ordinal) + .ThenBy(static model => model.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static model => model.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static model => model.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static model => model.GeneratedNamespace, StringComparer.Ordinal) + .ThenBy(static model => model.Name, StringComparer.Ordinal) + .Select(model => + { + var interfaceDescription = GetProxyInterfaceDescription(proxyContext, resolver, model, cancellationToken); + var targetHintName = GeneratedSourceOutput.CreateProxyHintName(assemblyName, interfaceDescription); + var generatedInvokables = GetGeneratedInvokables(proxyContext, interfaceDescription) + .ToImmutableArray(); + var ownedInvokableMetadataNames = generatedInvokables + .Select(invokable => invokable.MetadataName) + .Where(metadataName => invokableOwners.TryGetValue(metadataName, out var ownerHintName) + && string.Equals(ownerHintName, targetHintName, StringComparison.Ordinal)) + .Distinct(StringComparer.Ordinal) + .OrderBy(static value => value, StringComparer.Ordinal) + .ToImmutableArray(); + var useDeclaredInvokableFallback = + generatedInvokables.Length == 0 + ? model.Methods.Any(method => method.ContainingInterfaceType.Equals(model.InterfaceType)) + : ownedInvokableMetadataNames.Length == 0 + && !generatedInvokables.Any(invokable => invokableOwners.ContainsKey(invokable.MetadataName)); + var ownedInvokableMetadataNameSet = new HashSet(ownedInvokableMetadataNames, StringComparer.Ordinal); + var ownedInvokableActivatorMetadataNames = generatedInvokables + .Where(invokable => ShouldEmitInvokable( + invokable, + interfaceDescription.InterfaceType, + ownedInvokableMetadataNameSet, + useDeclaredInvokableFallback)) + .Where(static invokable => ActivatorGenerator.ShouldGenerateActivator(invokable)) + .Select(static invokable => invokable.MetadataName) + .Distinct(StringComparer.Ordinal) + .OrderBy(static value => value, StringComparer.Ordinal) + .ToImmutableArray(); + + return new ProxyOutputModel( + model, + ownedInvokableMetadataNames, + ownedInvokableActivatorMetadataNames, + useDeclaredInvokableFallback); + })]; + } + + internal static bool ShouldEmitInvokable( + GeneratedInvokableDescription invokable, + INamedTypeSymbol interfaceType, + HashSet ownedInvokableMetadataNames, + bool useDeclaredInvokableFallback) + => ownedInvokableMetadataNames.Contains(invokable.MetadataName) + || useDeclaredInvokableFallback + && SymbolEqualityComparer.Default.Equals(invokable.MethodDescription.ContainingInterface, interfaceType); + + internal static void PopulateProxyInterfaces( + ProxyGenerationContext proxyContext, + TypeSymbolResolver resolver, + ImmutableArray models, + CancellationToken cancellationToken) + { + var processed = new HashSet(StringComparer.Ordinal); + var resolvedInterfaces = new List<(ProxyInterfaceModel Model, INamedTypeSymbol Symbol, int SourceOrderGroup, string FilePath, int Position)>(); + foreach (var model in models) + { + cancellationToken.ThrowIfCancellationRequested(); + + var modelKey = $"{model.MetadataIdentity.AssemblyIdentity}|{model.MetadataIdentity.AssemblyName}|{model.MetadataIdentity.MetadataName}|{model.InterfaceType.SyntaxString}|{model.GeneratedNamespace}|{model.Name}"; + if (!processed.Add(modelKey)) + { + continue; + } + + if (!resolver.TryResolveProxyInterface(model, cancellationToken, out var interfaceType)) + { + throw new InvalidOperationException($"Unable to resolve proxy interface '{model.InterfaceType.SyntaxString}'."); + } + + var sourceLocation = interfaceType.Locations.FirstOrDefault(static location => location.IsInSource); + resolvedInterfaces.Add(( + model, + interfaceType, + SourceOrderGroup: sourceLocation is null ? 1 : 0, + FilePath: sourceLocation?.SourceTree?.FilePath ?? string.Empty, + Position: sourceLocation?.SourceSpan.Start ?? int.MaxValue)); + } + + foreach (var entry in resolvedInterfaces + .OrderBy(static entry => entry.SourceOrderGroup) + .ThenBy(static entry => entry.FilePath, StringComparer.Ordinal) + .ThenBy(static entry => entry.Position) + .ThenBy(static entry => entry.Model.InterfaceType.SyntaxString, StringComparer.Ordinal)) + { + cancellationToken.ThrowIfCancellationRequested(); + proxyContext.VisitInterface(entry.Symbol.OriginalDefinition); + } + } + + internal static ProxyInterfaceDescription GetProxyInterfaceDescription( + ProxyGenerationContext proxyContext, + TypeSymbolResolver resolver, + ProxyInterfaceModel model, + CancellationToken cancellationToken) + { + if (!resolver.TryResolveProxyInterface(model, cancellationToken, out var interfaceType) + || !proxyContext.TryGetInvokableInterfaceDescription(interfaceType.OriginalDefinition, out var description)) + { + throw new InvalidOperationException($"Unable to resolve proxy interface '{model.InterfaceType.SyntaxString}'."); + } + + return description; + } + + internal static IEnumerable GetGeneratedInvokables( + ProxyGenerationContext proxyContext, + ProxyInterfaceDescription interfaceDescription) + { + return interfaceDescription.Methods + .Select(static method => method.InvokableKey) + .Distinct() + .Select(key => proxyContext.MetadataModel.GeneratedInvokables.TryGetValue(key, out var generatedInvokable) ? generatedInvokable : null) + .OfType() + .Where(generatedInvokable => proxyContext.Compilation.GetTypeByMetadataName(generatedInvokable.MetadataName) is null) + .OrderBy(static generatedInvokable => generatedInvokable.MetadataName, StringComparer.Ordinal); + } +} + + + diff --git a/src/Orleans.CodeGenerator/ReferenceAssemblyDataProvider.cs b/src/Orleans.CodeGenerator/ReferenceAssemblyDataProvider.cs new file mode 100644 index 00000000000..578abcdba75 --- /dev/null +++ b/src/Orleans.CodeGenerator/ReferenceAssemblyDataProvider.cs @@ -0,0 +1,47 @@ +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal static class ReferenceAssemblyDataProvider +{ + internal static ReferenceAssemblyDataResult CreateReferenceAssemblyDataResult( + Compilation compilation, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + try + { + var model = ModelExtractor.ExtractReferenceAssemblyData( + compilation, + SourceGeneratorOptionsParser.CreateCodeGeneratorOptions(options), + cancellationToken, + out var diagnostics); + + return ReferenceAssemblyDataResult.FromModelAndDiagnostics(model, diagnostics); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return ReferenceAssemblyDataResult.FromModelAndDiagnostics( + ReferenceAssemblyDataProvider.CreateEmptyReferenceAssemblyModel(compilation.AssemblyName ?? string.Empty), + [analysisException.Diagnostic]); + } + } + + internal static ReferenceAssemblyModel CreateEmptyReferenceAssemblyModel(string assemblyName) + => new( + assemblyName, + EquatableArray.Empty, + EquatableArray.Empty, + EquatableArray.Empty, + EquatableArray.Empty, + EquatableArray.Empty, + EquatableArray.Empty, + EquatableArray.Empty, + EquatableArray.Empty); +} + diff --git a/src/Orleans.CodeGenerator/SerializableSourceOutputGenerator.cs b/src/Orleans.CodeGenerator/SerializableSourceOutputGenerator.cs new file mode 100644 index 00000000000..fcc423ef8bf --- /dev/null +++ b/src/Orleans.CodeGenerator/SerializableSourceOutputGenerator.cs @@ -0,0 +1,473 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Diagnostics; +using Orleans.CodeGenerator.Model; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal static class SerializableSourceOutputGenerator +{ + internal static SerializableTypeResult CreateSerializableTypeResult( + GeneratorAttributeSyntaxContext context, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + if (context.TargetSymbol is not INamedTypeSymbol symbol) + { + return default; + } + + var sourceLocation = ModelExtractor.GetSourceLocation(symbol); + var metadataIdentity = TypeMetadataIdentity.Create(symbol); + var typeSyntax = symbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + try + { + cancellationToken.ThrowIfCancellationRequested(); + SourceGeneratorOptionsParser.AttachDebuggerIfRequested(options); + + var compilation = context.SemanticModel.Compilation; + var codeGeneratorOptions = SourceGeneratorOptionsParser.CreateCodeGeneratorOptions(options); + var libraryTypes = LibraryTypes.FromCompilation(compilation, codeGeneratorOptions); + var typeDescription = CreateSerializableTypeDescription(compilation, libraryTypes, codeGeneratorOptions, symbol); + if (typeDescription is null) + { + return default; + } + + var model = ModelExtractor.ExtractSerializableTypeModel(typeDescription, sourceLocation); + return SerializableTypeResult.FromModel(model); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return SerializableTypeResult.FromDiagnostic( + analysisException.Diagnostic, + metadataIdentity, + sourceLocation, + typeSyntax); + } + } + + internal static ImmutableArray CreateSerializableSourceOutputs( + Compilation compilation, + ImmutableArray models, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + if (models.IsDefaultOrEmpty) + { + return []; + } + + SourceGeneratorOptionsParser.AttachDebuggerIfRequested(options); + var codeGeneratorOptions = SourceGeneratorOptionsParser.CreateCodeGeneratorOptions(options); + var generatorServices = new GeneratorServices(compilation, codeGeneratorOptions); + var resolver = new TypeSymbolResolver(compilation); + var assemblyName = compilation.AssemblyName ?? "assembly"; + var sourceEntries = ImmutableArray.CreateBuilder(); + var defaultCopiers = new Dictionary(); + var serializerGenerator = new SerializerGenerator(generatorServices); + var copierGenerator = new CopierGenerator(generatorServices); + var activatorGenerator = new ActivatorGenerator(generatorServices); + + foreach (var model in ModelExtractor.DeduplicateSerializableTypes(models)) + { + cancellationToken.ThrowIfCancellationRequested(); + + try + { + if (!resolver.TryResolveSerializableType(model, cancellationToken, out var symbol) + || !SymbolEqualityComparer.Default.Equals(symbol.ContainingAssembly, compilation.Assembly)) + { + continue; + } + + var typeDescription = CreateSerializableTypeDescription(generatorServices, symbol); + if (typeDescription is null) + { + continue; + } + + sourceEntries.Add(CreateSerializableSourceOutput( + assemblyName, + typeDescription, + serializerGenerator, + copierGenerator, + activatorGenerator, + defaultCopiers, + model.MetadataIdentity, + model.TypeSyntax.SyntaxString, + model.GeneratedNamespace, + model.TypeParameters.Length)); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + sourceEntries.Add(SourceOutputResult.FromDiagnostic(analysisException.Diagnostic)); + } + } + + return GeneratedSourceOutput.DeduplicateSourceOutputs(sourceEntries); + } + + internal static ImmutableArray CreateReferencedSerializableSourceOutputs( + Compilation compilation, + ReferenceAssemblyModel referenceData, + SourceGeneratorOptions options, + CancellationToken cancellationToken) + { + try + { + if (referenceData is null || referenceData.ReferencedSerializableTypes.IsDefaultOrEmpty) + { + return []; + } + + var processedModelTypes = new HashSet(StringComparer.Ordinal); + var modelsToResolve = ImmutableArray.CreateBuilder(); + foreach (var model in referenceData.ReferencedSerializableTypes + .Distinct() + .OrderBy(static model => model.TypeSyntax.SyntaxString, StringComparer.Ordinal) + .ThenBy(static model => model.MetadataIdentity.MetadataName, StringComparer.Ordinal) + .ThenBy(static model => model.MetadataIdentity.AssemblyIdentity, StringComparer.Ordinal) + .ThenBy(static model => model.MetadataIdentity.AssemblyName, StringComparer.Ordinal) + .ThenBy(static model => model.GeneratedNamespace, StringComparer.Ordinal) + .ThenBy(static model => model.Name, StringComparer.Ordinal)) + { + cancellationToken.ThrowIfCancellationRequested(); + if (IsCurrentCompilationAssembly(model.MetadataIdentity, compilation)) + { + continue; + } + + var modelTypeKey = $"{model.MetadataIdentity.AssemblyIdentity}|{model.MetadataIdentity.AssemblyName}|{model.MetadataIdentity.MetadataName}|{model.GeneratedNamespace}|{model.Name}|{model.TypeSyntax.SyntaxString}"; + if (processedModelTypes.Add(modelTypeKey)) + { + modelsToResolve.Add(model); + } + } + + if (modelsToResolve.Count == 0) + { + return []; + } + + SourceGeneratorOptionsParser.AttachDebuggerIfRequested(options); + var codeGeneratorOptions = SourceGeneratorOptionsParser.CreateCodeGeneratorOptions(options); + var generatorServices = new GeneratorServices(compilation, codeGeneratorOptions); + var resolver = new TypeSymbolResolver(compilation); + var assemblyName = compilation.AssemblyName ?? "assembly"; + var sourceEntries = ImmutableArray.CreateBuilder(); + var defaultCopiers = new Dictionary(); + var serializerGenerator = new SerializerGenerator(generatorServices); + var copierGenerator = new CopierGenerator(generatorServices); + var activatorGenerator = new ActivatorGenerator(generatorServices); + + foreach (var model in modelsToResolve) + { + cancellationToken.ThrowIfCancellationRequested(); + if (!resolver.TryResolveSerializableType(model, cancellationToken, out var symbol) + || SymbolEqualityComparer.Default.Equals(symbol.ContainingAssembly, compilation.Assembly)) + { + continue; + } + + var typeDescription = CreateSerializableTypeDescription(generatorServices, symbol); + if (typeDescription is null) + { + continue; + } + + sourceEntries.Add(CreateSerializableSourceOutput( + assemblyName, + typeDescription, + serializerGenerator, + copierGenerator, + activatorGenerator, + defaultCopiers, + model.MetadataIdentity, + model.TypeSyntax.SyntaxString, + model.GeneratedNamespace, + model.TypeParameters.Length)); + } + + return GeneratedSourceOutput.DeduplicateSourceOutputs(sourceEntries); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + throw; + } + catch (OrleansGeneratorDiagnosticAnalysisException analysisException) + { + return [SourceOutputResult.FromDiagnostic(analysisException.Diagnostic)]; + } + } + + internal static SourceOutputResult CreateSerializableSourceOutput( + string assemblyName, + ISerializableTypeDescription typeDescription, + SerializerGenerator serializerGenerator, + CopierGenerator copierGenerator, + ActivatorGenerator activatorGenerator, + Dictionary defaultCopiers, + TypeMetadataIdentity metadataIdentity, + string typeSyntax, + string hintGeneratedNamespace, + int genericArity) + { + var serializer = serializerGenerator.Generate(typeDescription); + var copier = typeDescription.IsShallowCopyable && defaultCopiers.ContainsKey(typeDescription) + ? null + : copierGenerator.GenerateCopier(typeDescription, defaultCopiers); + var activatorClass = ActivatorGenerator.ShouldGenerateActivator(typeDescription) + ? activatorGenerator.GenerateActivator(typeDescription) + : null; + + return SourceOutputResult.FromSource( + GeneratedSourceOutput.CreateSerializableSourceEntry( + assemblyName, + typeSyntax, + metadataIdentity, + hintGeneratedNamespace, + genericArity, + serializer, + copier, + activatorClass, + typeDescription.GeneratedNamespace)); + } + + internal static ISerializableTypeDescription? CreateSerializableTypeDescription(IGeneratorServices services, INamedTypeSymbol symbol) + => CreateSerializableTypeDescription(services.Compilation, services.LibraryTypes, services.Options, symbol); + + internal static ISerializableTypeDescription? CreateSerializableTypeDescription(Compilation compilation, LibraryTypes libraryTypes, CodeGeneratorOptions options, INamedTypeSymbol symbol) + { + + if (FSharpUtilities.IsUnionCase(libraryTypes, symbol, out var sumType) && sumType.HasAttribute(libraryTypes.GenerateSerializerAttribute)) + { + if (!compilation.IsSymbolAccessibleWithin(sumType, compilation.Assembly)) + { + throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(sumType)); + } + + return new FSharpUtilities.FSharpUnionCaseTypeDescription(compilation, symbol, libraryTypes); + } + + if (!symbol.HasAttribute(libraryTypes.GenerateSerializerAttribute)) + { + return null; + } + + if (HasReferenceAssemblyAttribute(symbol.ContainingAssembly)) + { + throw new OrleansGeneratorDiagnosticAnalysisException(ReferenceAssemblyWithGenerateSerializerDiagnostic.CreateDiagnostic(symbol)); + } + + if (!compilation.IsSymbolAccessibleWithin(symbol, compilation.Assembly)) + { + throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSerializableTypeDiagnostic.CreateDiagnostic(symbol)); + } + + if (FSharpUtilities.IsRecord(libraryTypes, symbol)) + { + return new FSharpUtilities.FSharpRecordTypeDescription(compilation, symbol, libraryTypes); + } + + var includePrimaryConstructorParameters = ShouldIncludePrimaryConstructorParameters(symbol, libraryTypes); + var constructorParameters = ResolveConstructorParameters(symbol, includePrimaryConstructorParameters, libraryTypes); + var implicitMemberSelectionStrategy = (options.GenerateFieldIds, GetGenerateFieldIdsOptionFromType(symbol, libraryTypes)) switch + { + (_, GenerateFieldIds.PublicProperties) => GenerateFieldIds.PublicProperties, + (GenerateFieldIds.PublicProperties, _) => GenerateFieldIds.PublicProperties, + _ => GenerateFieldIds.None, + }; + var fieldIdAssignmentHelper = new FieldIdAssignmentHelper(symbol, constructorParameters, implicitMemberSelectionStrategy, libraryTypes); + if (!fieldIdAssignmentHelper.IsValidForSerialization) + { + throw new OrleansGeneratorDiagnosticAnalysisException(CanNotGenerateImplicitFieldIdsDiagnostic.CreateDiagnostic(symbol, fieldIdAssignmentHelper.FailureReason!)); + } + + return new SerializableTypeDescription(compilation, symbol, includePrimaryConstructorParameters, GetDataMembers(fieldIdAssignmentHelper), libraryTypes); + } + + internal static bool HasReferenceAssemblyAttribute(IAssemblySymbol assembly) + { + return assembly.GetAttributes().Any(attributeData => attributeData.AttributeClass is + { + Name: "ReferenceAssemblyAttribute", + ContainingNamespace: + { + Name: "CompilerServices", + ContainingNamespace: + { + Name: "Runtime", + ContainingNamespace: + { + Name: "System", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + }); + } + + internal static GenerateFieldIds GetGenerateFieldIdsOptionFromType(INamedTypeSymbol type, LibraryTypes libraryTypes) + { + var attribute = type.GetAttribute(libraryTypes.GenerateSerializerAttribute); + if (attribute is null) + { + return GenerateFieldIds.None; + } + + foreach (var namedArgument in attribute.NamedArguments) + { + if (namedArgument.Key == "GenerateFieldIds") + { + var value = namedArgument.Value.Value; + return value is null ? GenerateFieldIds.None : (GenerateFieldIds)(int)value; + } + } + + return GenerateFieldIds.None; + } + + internal static bool ShouldIncludePrimaryConstructorParameters(INamedTypeSymbol type, LibraryTypes libraryTypes) + { + static bool? GetNamedOption(INamedTypeSymbol type, INamedTypeSymbol attributeType) + { + var attribute = type.GetAttribute(attributeType); + if (attribute is null) + { + return null; + } + + foreach (var namedArgument in attribute.NamedArguments) + { + if (namedArgument.Key == "IncludePrimaryConstructorParameters" + && namedArgument.Value.Kind == TypedConstantKind.Primitive + && namedArgument.Value.Value is bool value) + { + return value; + } + } + + return null; + } + + if (GetNamedOption(type, libraryTypes.GenerateSerializerAttribute) is bool includePrimaryCtorParameters) + { + return includePrimaryCtorParameters; + } + + if (type.IsRecord) + { + return true; + } + + var properties = type.GetMembers().OfType().ToImmutableArray(); + return type.GetMembers() + .OfType() + .Where(static method => method.MethodKind == MethodKind.Constructor && method.Parameters.Length > 0) + .Any(ctor => ctor.Parameters.All(parameter => + properties.Any(property => property.Name.Equals(parameter.Name, StringComparison.Ordinal) && property.IsCompilerGenerated()))); + } + + internal static ImmutableArray ResolveConstructorParameters( + INamedTypeSymbol type, + bool includePrimaryConstructorParameters, + LibraryTypes libraryTypes) + { + if (!includePrimaryConstructorParameters) + { + return []; + } + + if (type.IsRecord) + { + var potentialPrimaryConstructor = type.Constructors[0]; + if (!potentialPrimaryConstructor.IsImplicitlyDeclared && !potentialPrimaryConstructor.IsCompilerGenerated()) + { + return potentialPrimaryConstructor.Parameters; + } + } + else + { + var annotatedConstructors = type.Constructors.Where(ctor => ctor.HasAnyAttribute(libraryTypes.ConstructorAttributeTypes)).ToList(); + if (annotatedConstructors.Count == 1) + { + return annotatedConstructors[0].Parameters; + } + + var properties = type.GetMembers().OfType().ToImmutableArray(); + var primaryConstructor = type.GetMembers() + .OfType() + .Where(static method => method.MethodKind == MethodKind.Constructor && method.Parameters.Length > 0) + .FirstOrDefault(ctor => ctor.Parameters.All(parameter => + properties.Any(property => property.Name.Equals(parameter.Name, StringComparison.Ordinal) && property.IsCompilerGenerated()))); + if (primaryConstructor is not null) + { + return primaryConstructor.Parameters; + } + } + + return []; + } + + internal static IEnumerable GetDataMembers(FieldIdAssignmentHelper fieldIdAssignmentHelper) + { + var members = new Dictionary<(uint Id, bool IsConstructorParameter), IMemberDescription>(); + foreach (var member in fieldIdAssignmentHelper.Members) + { + if (!fieldIdAssignmentHelper.TryGetSymbolKey(member, out var key)) + { + continue; + } + + var (id, isConstructorParameter) = key; + if (member is IPropertySymbol property + && !members.TryGetValue((id, isConstructorParameter), out _)) + { + members[(id, isConstructorParameter)] = new PropertyDescription(id, isConstructorParameter, property); + } + + if (member is IFieldSymbol field) + { + if (!members.TryGetValue((id, isConstructorParameter), out var existing) + || existing is PropertyDescription) + { + members[(id, isConstructorParameter)] = new FieldDescription(id, isConstructorParameter, field); + } + } + } + + return members.Values; + } + + internal static bool IsCurrentCompilationAssembly(TypeMetadataIdentity metadataIdentity, Compilation compilation) + { + if (metadataIdentity.IsEmpty) + { + return false; + } + + var assemblyIdentity = compilation.Assembly.Identity; + if (!string.IsNullOrEmpty(metadataIdentity.AssemblyIdentity)) + { + return string.Equals(metadataIdentity.AssemblyIdentity, assemblyIdentity.GetDisplayName(), StringComparison.Ordinal); + } + + return !string.IsNullOrEmpty(metadataIdentity.AssemblyName) + && string.Equals(metadataIdentity.AssemblyName, assemblyIdentity.Name, StringComparison.Ordinal); + } +} + + + diff --git a/src/Orleans.CodeGenerator/SerializerGenerator.cs b/src/Orleans.CodeGenerator/SerializerGenerator.cs index c63c2c05ee7..600008dc163 100644 --- a/src/Orleans.CodeGenerator/SerializerGenerator.cs +++ b/src/Orleans.CodeGenerator/SerializerGenerator.cs @@ -1,89 +1,82 @@ -using System.Collections.Generic; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Diagnostics; using Orleans.CodeGenerator.Diagnostics; using Orleans.CodeGenerator.SyntaxGeneration; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using static Orleans.CodeGenerator.InvokableGenerator; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal class SerializerGenerator(IGeneratorServices generatorServices) { - internal class SerializerGenerator + private const string BaseTypeSerializerFieldName = "_baseTypeSerializer"; + private const string ActivatorFieldName = "_activator"; + private const string SerializeMethodName = "Serialize"; + private const string DeserializeMethodName = "Deserialize"; + private const string WriteFieldMethodName = "WriteField"; + private const string ReadValueMethodName = "ReadValue"; + private const string CodecFieldTypeFieldName = "_codecFieldType"; + private readonly IGeneratorServices _generatorServices = generatorServices; + + private LibraryTypes LibraryTypes => _generatorServices.LibraryTypes; + + public ClassDeclarationSyntax Generate(ISerializableTypeDescription type) { - private const string BaseTypeSerializerFieldName = "_baseTypeSerializer"; - private const string ActivatorFieldName = "_activator"; - private const string SerializeMethodName = "Serialize"; - private const string DeserializeMethodName = "Deserialize"; - private const string WriteFieldMethodName = "WriteField"; - private const string ReadValueMethodName = "ReadValue"; - private const string CodecFieldTypeFieldName = "_codecFieldType"; - private readonly CodeGenerator _codeGenerator; - - public SerializerGenerator(CodeGenerator codeGenerator) - { - _codeGenerator = codeGenerator; - } - - private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; + var simpleClassName = GetSimpleClassName(type); - public ClassDeclarationSyntax Generate(ISerializableTypeDescription type) + var members = new List(); + foreach (var member in type.Members) { - var simpleClassName = GetSimpleClassName(type); - - var members = new List(); - foreach (var member in type.Members) + if (!member.IsSerializable) { - if (!member.IsSerializable) - { - continue; - } + continue; + } - if (member is ISerializableMember serializable) - { - members.Add(serializable); - } - else if (member is IFieldDescription or IPropertyDescription) - { - members.Add(new SerializableMember(_codeGenerator, member, members.Count)); - } - else if (member is MethodParameterFieldDescription methodParameter) - { - members.Add(new SerializableMethodMember(methodParameter)); - } + if (member is ISerializableMember serializable) + { + members.Add(serializable); + } + else if (member is IFieldDescription or IPropertyDescription) + { + members.Add(new SerializableMember(_generatorServices, member, members.Count)); + } + else if (member is MethodParameterFieldDescription methodParameter) + { + members.Add(new SerializableMethodMember(methodParameter)); } + } - var fieldDescriptions = GetFieldDescriptions(type, members); - var fieldDeclarations = GetFieldDeclarations(fieldDescriptions); - var ctor = GenerateConstructor(simpleClassName, fieldDescriptions); + var fieldDescriptions = GetFieldDescriptions(type, members); + var fieldDeclarations = GetFieldDeclarations(fieldDescriptions); + var ctor = GenerateConstructor(simpleClassName, fieldDescriptions); - var accessibility = type.Accessibility switch - { - Accessibility.Public => SyntaxKind.PublicKeyword, - _ => SyntaxKind.InternalKeyword, - }; + var accessibility = type.Accessibility switch + { + Accessibility.Public => SyntaxKind.PublicKeyword, + _ => SyntaxKind.InternalKeyword, + }; - var baseType = (type.IsAbstractType ? LibraryTypes.AbstractTypeSerializer : LibraryTypes.FieldCodec_1).ToTypeSyntax(type.TypeSyntax); + var baseType = (type.IsAbstractType ? LibraryTypes.AbstractTypeSerializer : LibraryTypes.FieldCodec_1).ToTypeSyntax(type.TypeSyntax); - var classDeclaration = ClassDeclaration(simpleClassName) - .AddBaseListTypes(SimpleBaseType(baseType)) - .AddModifiers(Token(accessibility), Token(SyntaxKind.SealedKeyword)) - .AddAttributeLists(CodeGenerator.GetGeneratedCodeAttributes()) - .AddMembers(fieldDeclarations); + var classDeclaration = ClassDeclaration(simpleClassName) + .AddBaseListTypes(SimpleBaseType(baseType)) + .AddModifiers(Token(accessibility), Token(SyntaxKind.SealedKeyword)) + .AddAttributeLists(GeneratedCodeUtilities.GetGeneratedCodeAttributes()) + .AddMembers(fieldDeclarations); - if (ctor != null) - classDeclaration = classDeclaration.AddMembers(ctor); + if (ctor != null) + classDeclaration = classDeclaration.AddMembers(ctor); - if (type.IsEnumType) - { - var writeMethod = GenerateEnumWriteMethod(type); - var readMethod = GenerateEnumReadMethod(type); - classDeclaration = classDeclaration.AddMembers(writeMethod, readMethod); - } - else - { + if (type.IsEnumType) + { + var writeMethod = GenerateEnumWriteMethod(type); + var readMethod = GenerateEnumReadMethod(type); + classDeclaration = classDeclaration.AddMembers(writeMethod, readMethod); + } + else + { var serializeMethod = GenerateSerializeMethod(type, fieldDescriptions, members); var deserializeMethod = GenerateDeserializeMethod(type, fieldDescriptions, members); if (type.IsAbstractType) @@ -93,1346 +86,1290 @@ public ClassDeclarationSyntax Generate(ISerializableTypeDescription type) } else { + Debug.Assert(serializeMethod is not null); + Debug.Assert(deserializeMethod is not null); var writeFieldMethod = GenerateCompoundTypeWriteFieldMethod(type); var readValueMethod = GenerateCompoundTypeReadValueMethod(type, fieldDescriptions); - classDeclaration = classDeclaration.AddMembers(serializeMethod, deserializeMethod, writeFieldMethod, readValueMethod); - - var serializerInterface = type.IsValueType ? LibraryTypes.ValueSerializer : type.IsSealedType ? null : LibraryTypes.BaseCodec_1; - if (serializerInterface != null) - classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(serializerInterface.ToTypeSyntax(type.TypeSyntax))); - } - } + classDeclaration = classDeclaration.AddMembers(serializeMethod!, deserializeMethod!, writeFieldMethod, readValueMethod); - if (type.IsGenericType) - { - classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters); + var serializerInterface = type.IsValueType ? LibraryTypes.ValueSerializer : type.IsSealedType ? null : LibraryTypes.BaseCodec_1; + if (serializerInterface != null) + classDeclaration = classDeclaration.AddBaseListTypes(SimpleBaseType(serializerInterface.ToTypeSyntax(type.TypeSyntax))); } + } - return classDeclaration; + if (type.IsGenericType) + { + classDeclaration = SyntaxFactoryUtility.AddGenericTypeParameters(classDeclaration, type.TypeParameters); } - public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name); + return classDeclaration; + } - public static string GetSimpleClassName(string name) => $"Codec_{name}"; + public static string GetSimpleClassName(ISerializableTypeDescription serializableType) => GetSimpleClassName(serializableType.Name); - public static string GetGeneratedNamespaceName(ITypeSymbol type) => type.GetNamespaceAndNesting() switch - { - { Length: > 0 } ns => $"{CodeGenerator.CodeGeneratorName}.{ns}", - _ => CodeGenerator.CodeGeneratorName - }; + public static string GetSimpleClassName(string name) => $"Codec_{name}"; - private MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions) - { - return fieldDescriptions.Select(GetFieldDeclaration).ToArray(); + public static string GetGeneratedNamespaceName(ITypeSymbol type) => type.GetNamespaceAndNesting() switch + { + { Length: > 0 } ns => $"{GeneratedCodeUtilities.CodeGeneratorName}.{ns}", + _ => GeneratedCodeUtilities.CodeGeneratorName + }; + + private static MemberDeclarationSyntax[] GetFieldDeclarations(List fieldDescriptions) + { + return [.. fieldDescriptions.Select(GetFieldDeclaration)]; - static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description) + static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription description) + { + switch (description) { - switch (description) - { - case TypeFieldDescription type: - return FieldDeclaration( - VariableDeclaration( - type.FieldType, - SingletonSeparatedList(VariableDeclarator(type.FieldName) - .WithInitializer(EqualsValueClause(TypeOfExpression(type.UnderlyingTypeSyntax)))))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - case CodecFieldTypeFieldDescription type: - return FieldDeclaration( - VariableDeclaration( - type.FieldType, - SingletonSeparatedList(VariableDeclarator(type.FieldName) - .WithInitializer(EqualsValueClause(TypeOfExpression(type.CodecFieldType)))))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - case FieldAccessorDescription accessor when accessor.InitializationSyntax != null: - return - FieldDeclaration(VariableDeclaration(accessor.FieldType, - SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax))))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - case FieldAccessorDescription accessor when accessor.InitializationSyntax == null: - //[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")] - //extern static void SetAmount(External instance, int value); - return - MethodDeclaration( - PredefinedType(Token(SyntaxKind.VoidKeyword)), - accessor.AccessorName) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword)) - .AddAttributeLists(AttributeList(SingletonSeparatedList( - Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor")) - .AddArgumentListArguments( - AttributeArgument( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"), - IdentifierName("Method"))), - AttributeArgument( - LiteralExpression( - SyntaxKind.StringLiteralExpression, - Literal($"set_{accessor.FieldName}"))) - .WithNameEquals(NameEquals("Name")))))) - .WithParameterList( - ParameterList(SeparatedList(new[] - { - Parameter(Identifier("instance")).WithType(accessor.ContainingType), - Parameter(Identifier("value")).WithType(description.FieldType) - }))) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); - default: - return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName)))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); - } + case TypeFieldDescription type: + return FieldDeclaration( + VariableDeclaration( + type.FieldType, + SingletonSeparatedList(VariableDeclarator(type.FieldName) + .WithInitializer(EqualsValueClause(TypeOfExpression(type.UnderlyingTypeSyntax)))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + case CodecFieldTypeFieldDescription type: + return FieldDeclaration( + VariableDeclaration( + type.FieldType, + SingletonSeparatedList(VariableDeclarator(type.FieldName) + .WithInitializer(EqualsValueClause(TypeOfExpression(type.CodecFieldType)))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + case FieldAccessorDescription accessor when accessor.InitializationSyntax != null: + return + FieldDeclaration(VariableDeclaration(accessor.FieldType, + SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax))))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)); + case FieldAccessorDescription accessor when accessor.InitializationSyntax == null: + //[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")] + //extern static void SetAmount(External instance, int value); + return + MethodDeclaration( + PredefinedType(Token(SyntaxKind.VoidKeyword)), + accessor.AccessorName) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword)) + .AddAttributeLists(AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor")) + .AddArgumentListArguments( + AttributeArgument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"), + IdentifierName("Method"))), + AttributeArgument( + LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal($"set_{accessor.FieldName}"))) + .WithNameEquals(NameEquals("Name")))))) + .WithParameterList( + ParameterList(SeparatedList( + [ + Parameter(Identifier("instance")).WithType(accessor.ContainingType), + Parameter(Identifier("value")).WithType(description.FieldType) + ]))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + default: + return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName)))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword)); } } + } - private ConstructorDeclarationSyntax GenerateConstructor(string simpleClassName, List fieldDescriptions) + private ConstructorDeclarationSyntax? GenerateConstructor(string simpleClassName, List fieldDescriptions) + { + var codecProviderAdded = false; + var parameters = new List(); + var statements = new List(); + foreach (var field in fieldDescriptions) { - var codecProviderAdded = false; - var parameters = new List(); - var statements = new List(); - foreach (var field in fieldDescriptions) + switch (field) { - switch (field) - { - case GeneratedFieldDescription _ when field.IsInjected: - parameters.Add(Parameter(field.FieldName.ToIdentifier()).WithType(field.FieldType)); - statements.Add(ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - ThisExpression().Member(field.FieldName.ToIdentifierName()), - Unwrapped(field.FieldName.ToIdentifierName())))); - break; - case CodecFieldDescription or BaseCodecFieldDescription when !field.IsInjected: - if (!codecProviderAdded) - { - parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax())); - codecProviderAdded = true; - } - - var codec = InvocationExpression( - IdentifierName("OrleansGeneratedCodeHelper").Member(GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(field.FieldType)))), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(IdentifierName("codecProvider")) }))); - - statements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, field.FieldName.ToIdentifierName(), codec))); - break; - } - } + case GeneratedFieldDescription _ when field.IsInjected: + parameters.Add(Parameter(field.FieldName.ToIdentifier()).WithType(field.FieldType)); + statements.Add(ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + ThisExpression().Member(field.FieldName.ToIdentifierName()), + Unwrapped(field.FieldName.ToIdentifierName())))); + break; + case CodecFieldDescription or BaseCodecFieldDescription when !field.IsInjected: + if (!codecProviderAdded) + { + parameters.Add(Parameter(Identifier("codecProvider")).WithType(LibraryTypes.ICodecProvider.ToTypeSyntax())); + codecProviderAdded = true; + } - return statements.Count == 0 ? null : ConstructorDeclaration(simpleClassName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters.ToArray()) - .AddBodyStatements(statements.ToArray()); + var codec = InvocationExpression( + IdentifierName("OrleansGeneratedCodeHelper").Member(GenericName(Identifier("GetService"), TypeArgumentList(SingletonSeparatedList(field.FieldType)))), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(IdentifierName("codecProvider"))]))); - static ExpressionSyntax Unwrapped(ExpressionSyntax expr) - { - return InvocationExpression( - IdentifierName("OrleansGeneratedCodeHelper").Member("UnwrapService"), - ArgumentList(SeparatedList(new[] { Argument(ThisExpression()), Argument(expr) }))); + statements.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, field.FieldName.ToIdentifierName(), codec))); + break; } } - private List GetFieldDescriptions( - ISerializableTypeDescription serializableTypeDescription, - List members) + return statements.Count == 0 ? null : ConstructorDeclaration(simpleClassName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters([.. parameters]) + .AddBodyStatements([.. statements]); + + static ExpressionSyntax Unwrapped(ExpressionSyntax expr) { - var fields = new List(); + return InvocationExpression( + IdentifierName("OrleansGeneratedCodeHelper").Member("UnwrapService"), + ArgumentList(SeparatedList([Argument(ThisExpression()), Argument(expr)]))); + } + } - if (!serializableTypeDescription.IsAbstractType) - { - fields.Add(new CodecFieldTypeFieldDescription(LibraryTypes.Type.ToTypeSyntax(), CodecFieldTypeFieldName, serializableTypeDescription.TypeSyntax)); - } + private List GetFieldDescriptions( + ISerializableTypeDescription serializableTypeDescription, + List members) + { + var fields = new List(); + + if (!serializableTypeDescription.IsAbstractType) + { + fields.Add(new CodecFieldTypeFieldDescription(LibraryTypes.Type.ToTypeSyntax(), CodecFieldTypeFieldName, serializableTypeDescription.TypeSyntax)); + } - if (serializableTypeDescription.HasComplexBaseType) + if (serializableTypeDescription.HasComplexBaseType) + { + fields.Add(GetBaseTypeField(serializableTypeDescription)); + } + + if (serializableTypeDescription.UseActivator && !serializableTypeDescription.IsAbstractType) + { + fields.Add(new ActivatorFieldDescription(LibraryTypes.IActivator_1.ToTypeSyntax(serializableTypeDescription.TypeSyntax), ActivatorFieldName)); + } + + int typeIndex = 0; + foreach (var member in serializableTypeDescription.Members.Distinct(MemberDescriptionTypeComparer.Default)) + { + if (!member.IsSerializable) { - fields.Add(GetBaseTypeField(serializableTypeDescription)); + continue; } - if (serializableTypeDescription.UseActivator && !serializableTypeDescription.IsAbstractType) + // Add a codec field for any field in the target which does not have a static codec. + if (LibraryTypes.StaticCodecs.FindByUnderlyingType(member.Type) is not null) + continue; + + fields.Add(new TypeFieldDescription(LibraryTypes.Type.ToTypeSyntax(), $"_type{typeIndex}", member.TypeSyntax, member.Type)); + fields.Add(GetCodecDescription(member, typeIndex)); + typeIndex++; + } + + foreach (var member in members) + { + if (member.GetGetterFieldDescription() is { } getterFieldDescription) { - fields.Add(new ActivatorFieldDescription(LibraryTypes.IActivator_1.ToTypeSyntax(serializableTypeDescription.TypeSyntax), ActivatorFieldName)); + fields.Add(getterFieldDescription); } - int typeIndex = 0; - foreach (var member in serializableTypeDescription.Members.Distinct(MemberDescriptionTypeComparer.Default)) + if (member.GetSetterFieldDescription() is { } setterFieldDescription) { - if (!member.IsSerializable) - { - continue; - } + fields.Add(setterFieldDescription); + } + } - // Add a codec field for any field in the target which does not have a static codec. - if (LibraryTypes.StaticCodecs.FindByUnderlyingType(member.Type) is not null) - continue; + for (var hookIndex = 0; hookIndex < serializableTypeDescription.SerializationHooks.Count; ++hookIndex) + { + var hookType = serializableTypeDescription.SerializationHooks[hookIndex]; + fields.Add(new SerializationHookFieldDescription(hookType.ToTypeSyntax(), $"_hook{hookIndex}")); + } - fields.Add(new TypeFieldDescription(LibraryTypes.Type.ToTypeSyntax(), $"_type{typeIndex}", member.TypeSyntax, member.Type)); - fields.Add(GetCodecDescription(member, typeIndex)); - typeIndex++; - } + return fields; - foreach (var member in members) + CodecFieldDescription GetCodecDescription(IMemberDescription member, int index) + { + var t = member.Type; + TypeSyntax? codecType = null; + if (t.HasAttribute(LibraryTypes.GenerateSerializerAttribute) + && (SymbolEqualityComparer.Default.Equals(t.ContainingAssembly, LibraryTypes.Compilation.Assembly) || t.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) + && t is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 0 }) { - if (member.GetGetterFieldDescription() is { } getterFieldDescription) + // Use the concrete generated type and avoid expensive interface dispatch (except for complex nested cases that will fall back to IFieldCodec) + SimpleNameSyntax name; + if (t is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) { - fields.Add(getterFieldDescription); + // Construct the full generic type name + name = GenericName(Identifier(GetSimpleClassName(t.Name)), TypeArgumentList(SeparatedList(namedTypeSymbol.TypeArguments.Select(arg => member.GetTypeSyntax(arg))))); } - - if (member.GetSetterFieldDescription() is { } setterFieldDescription) + else { - fields.Add(setterFieldDescription); + name = IdentifierName(GetSimpleClassName(t.Name)); } + codecType = QualifiedName(ParseName(GetGeneratedNamespaceName(t)), name); } - - for (var hookIndex = 0; hookIndex < serializableTypeDescription.SerializationHooks.Count; ++hookIndex) + else if (t is IArrayTypeSymbol { IsSZArray: true } array) + { + codecType = LibraryTypes.ArrayCodec.Construct(array.ElementType).ToTypeSyntax(); + } + else if (LibraryTypes.WellKnownCodecs.FindByUnderlyingType(t) is { } codec) + { + // The codec is not a static codec and is also not a generic codec. + codecType = codec.CodecType.ToTypeSyntax(); + } + else if (t is INamedTypeSymbol { ConstructedFrom: { } unboundFieldType } named && LibraryTypes.WellKnownCodecs.FindByUnderlyingType(unboundFieldType) is { } genericCodec) + { + // Construct the generic codec type using the field's type arguments. + codecType = genericCodec.CodecType.Construct([.. named.TypeArguments]).ToTypeSyntax(); + } + else { - var hookType = serializableTypeDescription.SerializationHooks[hookIndex]; - fields.Add(new SerializationHookFieldDescription(hookType.ToTypeSyntax(), $"_hook{hookIndex}")); + // Use the IFieldCodec interface + codecType = LibraryTypes.FieldCodec_1.ToTypeSyntax(member.TypeSyntax); } - return fields; + return new CodecFieldDescription(codecType, $"_codec{index}", t); + } + } + + private BaseCodecFieldDescription GetBaseTypeField(ISerializableTypeDescription serializableTypeDescription) + { + var baseType = serializableTypeDescription.BaseType; + if (baseType.HasAttribute(LibraryTypes.GenerateSerializerAttribute) + && (SymbolEqualityComparer.Default.Equals(baseType.ContainingAssembly, LibraryTypes.Compilation.Assembly) || baseType.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) + && baseType is not INamedTypeSymbol { IsGenericType: true }) + { + // Use the concrete generated type and avoid expensive interface dispatch (except for generic types that will fall back to IBaseCodec) + return new(QualifiedName(ParseName(GetGeneratedNamespaceName(baseType)), IdentifierName(GetSimpleClassName(baseType.Name))), true); + } - CodecFieldDescription GetCodecDescription(IMemberDescription member, int index) - { - var t = member.Type; - TypeSyntax codecType = null; - if (t.HasAttribute(LibraryTypes.GenerateSerializerAttribute) - && (SymbolEqualityComparer.Default.Equals(t.ContainingAssembly, LibraryTypes.Compilation.Assembly) || t.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) - && t is not INamedTypeSymbol { IsGenericType: true, TypeArguments.Length: 0 }) - { - // Use the concrete generated type and avoid expensive interface dispatch (except for complex nested cases that will fall back to IFieldCodec) - SimpleNameSyntax name; - if (t is INamedTypeSymbol namedTypeSymbol && namedTypeSymbol.IsGenericType) - { - // Construct the full generic type name - name = GenericName(Identifier(GetSimpleClassName(t.Name)), TypeArgumentList(SeparatedList(namedTypeSymbol.TypeArguments.Select(arg => member.GetTypeSyntax(arg))))); - } - else - { - name = IdentifierName(GetSimpleClassName(t.Name)); - } - codecType = QualifiedName(ParseName(GetGeneratedNamespaceName(t)), name); - } - else if (t is IArrayTypeSymbol { IsSZArray: true } array) - { - codecType = LibraryTypes.ArrayCodec.Construct(array.ElementType).ToTypeSyntax(); - } - else if (LibraryTypes.WellKnownCodecs.FindByUnderlyingType(t) is { } codec) - { - // The codec is not a static codec and is also not a generic codec. - codecType = codec.CodecType.ToTypeSyntax(); - } - else if (t is INamedTypeSymbol { ConstructedFrom: { } unboundFieldType } named && LibraryTypes.WellKnownCodecs.FindByUnderlyingType(unboundFieldType) is { } genericCodec) - { - // Construct the generic codec type using the field's type arguments. - codecType = genericCodec.CodecType.Construct(named.TypeArguments.ToArray()).ToTypeSyntax(); - } - else - { - // Use the IFieldCodec interface - codecType = LibraryTypes.FieldCodec_1.ToTypeSyntax(member.TypeSyntax); - } + return new(LibraryTypes.BaseCodec_1.ToTypeSyntax(serializableTypeDescription.BaseTypeSyntax)); + } - return new CodecFieldDescription(codecType, $"_codec{index}", t); - } + private MemberDeclarationSyntax? GenerateSerializeMethod( + ISerializableTypeDescription type, + List serializerFields, + List members) + { + var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + + var writerParam = "writer".ToIdentifierName(); + var instanceParam = "instance".ToIdentifierName(); + + var body = new List(); + if (type.HasComplexBaseType) + { + body.Add( + ExpressionStatement( + InvocationExpression( + BaseTypeSerializerFieldName.ToIdentifierName().Member(SerializeMethodName), + ArgumentList(SeparatedList([Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), Argument(instanceParam)]))))); + body.Add(ExpressionStatement(InvocationExpression(writerParam.Member("WriteEndBase"), ArgumentList()))); } - private BaseCodecFieldDescription GetBaseTypeField(ISerializableTypeDescription serializableTypeDescription) + AddSerializationCallbacks(type, instanceParam, "OnSerializing", body); + + // Order members according to their FieldId, since fields must be serialized in order and FieldIds are serialized as deltas. + var previousFieldIdVar = "previousFieldId".ToIdentifierName(); + if (type.OmitDefaultMemberValues && members.Count > 0) { - var baseType = serializableTypeDescription.BaseType; - if (baseType.HasAttribute(LibraryTypes.GenerateSerializerAttribute) - && (SymbolEqualityComparer.Default.Equals(baseType.ContainingAssembly, LibraryTypes.Compilation.Assembly) || baseType.ContainingAssembly.HasAttribute(LibraryTypes.TypeManifestProviderAttribute)) - && baseType is not INamedTypeSymbol { IsGenericType: true }) - { - // Use the concrete generated type and avoid expensive interface dispatch (except for generic types that will fall back to IBaseCodec) - return new(QualifiedName(ParseName(GetGeneratedNamespaceName(baseType)), IdentifierName(GetSimpleClassName(baseType.Name))), true); - } + // C#: uint previousFieldId = 0; + body.Add(LocalDeclarationStatement( + VariableDeclaration( + PredefinedType(Token(SyntaxKind.UIntKeyword)), + SingletonSeparatedList(VariableDeclarator(previousFieldIdVar.Identifier) + .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0U)))))))); + } - return new(LibraryTypes.BaseCodec_1.ToTypeSyntax(serializableTypeDescription.BaseTypeSyntax)); + if (type.IncludePrimaryConstructorParameters) + { + AddSerializationMembers(type, serializerFields, members.Where(m => m.IsPrimaryConstructorParameter), writerParam, instanceParam, previousFieldIdVar, body); + body.Add(ExpressionStatement(InvocationExpression(writerParam.Member("WriteEndBase"), ArgumentList()))); } - private MemberDeclarationSyntax GenerateSerializeMethod( - ISerializableTypeDescription type, - List serializerFields, - List members) + AddSerializationMembers(type, serializerFields, members.Where(m => !m.IsPrimaryConstructorParameter), writerParam, instanceParam, previousFieldIdVar, body); + + AddSerializationCallbacks(type, instanceParam, "OnSerialized", body); + + if (body.Count == 0 && type.IsAbstractType) + return null; + + var parameters = new[] + { + Parameter("writer".ToIdentifier()).WithType(LibraryTypes.Writer.ToTypeSyntax()).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), + Parameter("instance".ToIdentifier()).WithType(type.TypeSyntax) + }; + + if (type.IsValueType) { - var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + parameters[1] = parameters[1].WithModifiers(LibraryTypes.HasScopedKeyword() ? TokenList(Token(SyntaxKind.ScopedKeyword), Token(SyntaxKind.RefKeyword)) : TokenList(Token(SyntaxKind.RefKeyword))); + } + + var res = MethodDeclaration(returnType, SerializeMethodName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddTypeParameterListParameters(TypeParameter("TBufferWriter")) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + + res = type.IsAbstractType + ? res.AddModifiers(Token(SyntaxKind.OverrideKeyword)) + : res.AddConstraintClauses(TypeParameterConstraintClause("TBufferWriter").AddConstraints(TypeConstraint(LibraryTypes.IBufferWriter.ToTypeSyntax(PredefinedType(Token(SyntaxKind.ByteKeyword)))))); - var writerParam = "writer".ToIdentifierName(); - var instanceParam = "instance".ToIdentifierName(); + return res; + } - var body = new List(); - if (type.HasComplexBaseType) + private void AddSerializationMembers(ISerializableTypeDescription type, List serializerFields, IEnumerable members, IdentifierNameSyntax writerParam, IdentifierNameSyntax instanceParam, IdentifierNameSyntax previousFieldIdVar, List body) + { + uint previousFieldId = 0; + foreach (var member in members.OrderBy(m => m.Member.FieldId)) + { + var description = member.Member; + ExpressionSyntax fieldIdDeltaExpr; + if (type.OmitDefaultMemberValues) { - body.Add( - ExpressionStatement( - InvocationExpression( - BaseTypeSerializerFieldName.ToIdentifierName().Member(SerializeMethodName), - ArgumentList(SeparatedList(new[] { Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), Argument(instanceParam) }))))); - body.Add(ExpressionStatement(InvocationExpression(writerParam.Member("WriteEndBase"), ArgumentList()))); + // C#: - previousFieldId + fieldIdDeltaExpr = BinaryExpression(SyntaxKind.SubtractExpression, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(description.FieldId)), previousFieldIdVar); + } + else + { + var fieldIdDelta = description.FieldId - previousFieldId; + previousFieldId = description.FieldId; + fieldIdDeltaExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(fieldIdDelta)); + } + + // Codecs can either be static classes or injected into the constructor. + // Either way, the member signatures are the same. + var memberType = description.Type; + var staticCodec = LibraryTypes.StaticCodecs.FindByUnderlyingType(memberType); + ExpressionSyntax codecExpression; + if (staticCodec != null) + { + codecExpression = staticCodec.CodecType.ToNameSyntax(); + } + else + { + var instanceCodec = serializerFields.First(f => f is CodecFieldDescription cf && SymbolEqualityComparer.Default.Equals(cf.UnderlyingType, memberType)); + codecExpression = IdentifierName(instanceCodec.FieldName); + } + + // When a static codec is available, we can call it directly and can skip passing the expected type, + // since it is known to be the static codec's field type: + // C#: .WriteField(ref writer, ) + // When no static codec is available: + // C#: .WriteField(ref writer, , , ) + var writeFieldArgs = new List { + Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + Argument(fieldIdDeltaExpr) + }; + + if (staticCodec is null) + writeFieldArgs.Add(Argument(serializerFields.First(f => f is TypeFieldDescription tf && SymbolEqualityComparer.Default.Equals(tf.UnderlyingType, memberType)).FieldName.ToIdentifierName())); + + writeFieldArgs.Add(Argument(member.GetGetter(instanceParam))); + + var writeFieldExpr = ExpressionStatement(InvocationExpression(codecExpression.Member("WriteField"), ArgumentList(SeparatedList(writeFieldArgs)))); + + if (!type.OmitDefaultMemberValues) + { + body.Add(writeFieldExpr); + } + else + { + ExpressionSyntax condition = member.IsValueType switch + { + true => BinaryExpression(SyntaxKind.NotEqualsExpression, member.GetGetter(instanceParam), LiteralExpression(SyntaxKind.DefaultLiteralExpression)), + false => IsPatternExpression(member.GetGetter(instanceParam), TypePattern(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) + }; + + body.Add(IfStatement( + condition, + Block( + writeFieldExpr, + ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, previousFieldIdVar, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(description.FieldId))))))); } + } + } + + private MemberDeclarationSyntax? GenerateDeserializeMethod( + ISerializableTypeDescription type, + List serializerFields, + List members) + { + var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + + var readerParam = "reader".ToIdentifierName(); + var instanceParam = "instance".ToIdentifierName(); + var idVar = "id".ToIdentifierName(); + var headerVar = "header".ToIdentifierName(); + + var body = new List(); + + if (type.HasComplexBaseType) + { + // C#: _baseTypeSerializer.Deserialize(ref reader, instance); + body.Add( + ExpressionStatement( + InvocationExpression( + BaseTypeSerializerFieldName.ToIdentifierName().Member(DeserializeMethodName), + ArgumentList(SeparatedList( + [ + Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + Argument(instanceParam) + ]))))); + } - AddSerializationCallbacks(type, instanceParam, "OnSerializing", body); + AddSerializationCallbacks(type, instanceParam, "OnDeserializing", body); - // Order members according to their FieldId, since fields must be serialized in order and FieldIds are serialized as deltas. - var previousFieldIdVar = "previousFieldId".ToIdentifierName(); - if (type.OmitDefaultMemberValues && members.Count > 0) + int emptyBodyCount; + var nonCtorMembers = type.IncludePrimaryConstructorParameters ? members.FindAll(static m => !m.IsPrimaryConstructorParameter) : members; + if ((members.Count == 0 || nonCtorMembers.Count == 0) && !type.IncludePrimaryConstructorParameters) + { + // C#: reader.ConsumeEndBaseOrEndObject(); + body.Add(ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeEndBaseOrEndObject")))); + emptyBodyCount = 1; + } + else + { + // C#: uint id = 0; + if (members.Count > 0) { - // C#: uint previousFieldId = 0; body.Add(LocalDeclarationStatement( VariableDeclaration( PredefinedType(Token(SyntaxKind.UIntKeyword)), - SingletonSeparatedList(VariableDeclarator(previousFieldIdVar.Identifier) - .WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0U)))))))); + SingletonSeparatedList(VariableDeclarator(idVar.Identifier, null, EqualsValueClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0U)))))))); } + // C#: Field header = default; + body.Add(LocalDeclarationStatement( + VariableDeclaration( + LibraryTypes.Field.ToTypeSyntax(), + SingletonSeparatedList(VariableDeclarator(headerVar.Identifier, null, EqualsValueClause(LiteralExpression(SyntaxKind.DefaultLiteralExpression))))))); + + emptyBodyCount = 2; + if (type.IncludePrimaryConstructorParameters) { - AddSerializationMembers(type, serializerFields, members.Where(m => m.IsPrimaryConstructorParameter), writerParam, instanceParam, previousFieldIdVar, body); - body.Add(ExpressionStatement(InvocationExpression(writerParam.Member("WriteEndBase"), ArgumentList()))); + var constructorParameterMembers = members.FindAll(m => m.IsPrimaryConstructorParameter); + body.Add(GetDeserializerLoop(constructorParameterMembers)); + if (members.Count > 0) + { + body.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, idVar, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0U))))); + } + + body.Add(IfStatement(headerVar.Member("IsEndBaseFields"), GetDeserializerLoop(nonCtorMembers))); + } + else + { + body.Add(GetDeserializerLoop(nonCtorMembers)); } + } - AddSerializationMembers(type, serializerFields, members.Where(m => !m.IsPrimaryConstructorParameter), writerParam, instanceParam, previousFieldIdVar, body); + AddSerializationCallbacks(type, instanceParam, "OnDeserialized", body); - AddSerializationCallbacks(type, instanceParam, "OnSerialized", body); + if (body.Count == emptyBodyCount && type.IsAbstractType) + return null; - if (body.Count == 0 && type.IsAbstractType) - return null; + var genericParam = ParseTypeName("TReaderInput"); + var parameters = new[] + { + Parameter(readerParam.Identifier).WithType(LibraryTypes.Reader.ToTypeSyntax(genericParam)).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), + Parameter(instanceParam.Identifier).WithType(type.TypeSyntax) + }; - var parameters = new[] - { - Parameter("writer".ToIdentifier()).WithType(LibraryTypes.Writer.ToTypeSyntax()).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), - Parameter("instance".ToIdentifier()).WithType(type.TypeSyntax) - }; + if (type.IsValueType) + { + parameters[1] = parameters[1].WithModifiers(LibraryTypes.HasScopedKeyword() ? TokenList(Token(SyntaxKind.ScopedKeyword), Token(SyntaxKind.RefKeyword)) : TokenList(Token(SyntaxKind.RefKeyword))); + } + + var res = MethodDeclaration(returnType, DeserializeMethodName) + .AddTypeParameterListParameters(TypeParameter("TReaderInput")) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + + if (type.IsAbstractType) + res = res.AddModifiers(Token(SyntaxKind.OverrideKeyword)); + + return res; - if (type.IsValueType) + // Create the loop body. + StatementSyntax GetDeserializerLoop(List members) + { + var refHeaderVar = ArgumentList(SingletonSeparatedList(Argument(null, Token(SyntaxKind.RefKeyword), headerVar))); + if (members.Count == 0) { - parameters[1] = parameters[1].WithModifiers(LibraryTypes.HasScopedKeyword() ? TokenList(Token(SyntaxKind.ScopedKeyword), Token(SyntaxKind.RefKeyword)) : TokenList(Token(SyntaxKind.RefKeyword))); + // C#: reader.ReadFieldHeader(ref header); + // C#: reader.ConsumeEndBaseOrEndObject(ref header); + return Block( + ExpressionStatement(InvocationExpression(readerParam.Member("ReadFieldHeader"), refHeaderVar)), + ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeEndBaseOrEndObject"), refHeaderVar))); } - var res = MethodDeclaration(returnType, SerializeMethodName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddTypeParameterListParameters(TypeParameter("TBufferWriter")) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); + var loopBody = new List(); - res = type.IsAbstractType - ? res.AddModifiers(Token(SyntaxKind.OverrideKeyword)) - : res.AddConstraintClauses(TypeParameterConstraintClause("TBufferWriter").AddConstraints(TypeConstraint(LibraryTypes.IBufferWriter.ToTypeSyntax(PredefinedType(Token(SyntaxKind.ByteKeyword)))))); - - return res; - } + // C#: reader.ReadFieldHeader(ref header); + // C#: if (header.IsEndBaseOrEndObject) break; + // C#: id += header.FieldIdDelta; + var readFieldHeader = ExpressionStatement(InvocationExpression(readerParam.Member("ReadFieldHeader"), refHeaderVar)); + var endObjectCheck = IfStatement(headerVar.Member("IsEndBaseOrEndObject"), BreakStatement()); + var idUpdate = ExpressionStatement(AssignmentExpression(SyntaxKind.AddAssignmentExpression, idVar, headerVar.Member("FieldIdDelta"))); + loopBody.Add(readFieldHeader); + loopBody.Add(endObjectCheck); + loopBody.Add(idUpdate); - private void AddSerializationMembers(ISerializableTypeDescription type, List serializerFields, IEnumerable members, IdentifierNameSyntax writerParam, IdentifierNameSyntax instanceParam, IdentifierNameSyntax previousFieldIdVar, List body) - { - uint previousFieldId = 0; - foreach (var member in members.OrderBy(m => m.Member.FieldId)) + members.Sort((x, y) => x.Member.FieldId.CompareTo(y.Member.FieldId)); + var contiguousIds = members[members.Count - 1].Member.FieldId == members.Count - 1; + foreach (var member in members) { var description = member.Member; - ExpressionSyntax fieldIdDeltaExpr; - if (type.OmitDefaultMemberValues) - { - // C#: - previousFieldId - fieldIdDeltaExpr = BinaryExpression(SyntaxKind.SubtractExpression, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(description.FieldId)), previousFieldIdVar); - } - else - { - var fieldIdDelta = description.FieldId - previousFieldId; - previousFieldId = description.FieldId; - fieldIdDeltaExpr = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(fieldIdDelta)); - } + // C#: instance. = .ReadValue(ref reader, header); // Codecs can either be static classes or injected into the constructor. // Either way, the member signatures are the same. - var memberType = description.Type; - var staticCodec = LibraryTypes.StaticCodecs.FindByUnderlyingType(memberType); ExpressionSyntax codecExpression; - if (staticCodec != null) + if (LibraryTypes.StaticCodecs.FindByUnderlyingType(description.Type) is { } staticCodec) { codecExpression = staticCodec.CodecType.ToNameSyntax(); } else { - var instanceCodec = serializerFields.First(f => f is CodecFieldDescription cf && SymbolEqualityComparer.Default.Equals(cf.UnderlyingType, memberType)); + var instanceCodec = serializerFields.OfType().First(c => SymbolEqualityComparer.Default.Equals(c.UnderlyingType, description.Type)); codecExpression = IdentifierName(instanceCodec.FieldName); } - // When a static codec is available, we can call it directly and can skip passing the expected type, - // since it is known to be the static codec's field type: - // C#: .WriteField(ref writer, ) - // When no static codec is available: - // C#: .WriteField(ref writer, , , ) - var writeFieldArgs = new List { - Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - Argument(fieldIdDeltaExpr) - }; - - if (staticCodec is null) - writeFieldArgs.Add(Argument(serializerFields.First(f => f is TypeFieldDescription tf && SymbolEqualityComparer.Default.Equals(tf.UnderlyingType, memberType)).FieldName.ToIdentifierName())); - - writeFieldArgs.Add(Argument(member.GetGetter(instanceParam))); + ExpressionSyntax readValueExpression = InvocationExpression( + codecExpression.Member("ReadValue"), + ArgumentList(SeparatedList([Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), Argument(headerVar)]))); - var writeFieldExpr = ExpressionStatement(InvocationExpression(codecExpression.Member("WriteField"), ArgumentList(SeparatedList(writeFieldArgs)))); + var memberAssignment = ExpressionStatement(member.GetSetter(instanceParam, readValueExpression)); - if (!type.OmitDefaultMemberValues) + BlockSyntax ifBody; + if (member != members[members.Count - 1]) + { + ifBody = Block(memberAssignment, readFieldHeader, endObjectCheck, idUpdate); + } + else if (contiguousIds) { - body.Add(writeFieldExpr); + ifBody = Block(memberAssignment, readFieldHeader); } else { - ExpressionSyntax condition = member.IsValueType switch - { - true => BinaryExpression(SyntaxKind.NotEqualsExpression, member.GetGetter(instanceParam), LiteralExpression(SyntaxKind.DefaultLiteralExpression)), - false => IsPatternExpression(member.GetGetter(instanceParam), TypePattern(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) - }; - - body.Add(IfStatement( - condition, - Block( - writeFieldExpr, - ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, previousFieldIdVar, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(description.FieldId))))))); + idUpdate = ExpressionStatement(PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, idVar)); + ifBody = Block(memberAssignment, readFieldHeader, endObjectCheck, idUpdate); } - } - } - - private MemberDeclarationSyntax GenerateDeserializeMethod( - ISerializableTypeDescription type, - List serializerFields, - List members) - { - var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); - - var readerParam = "reader".ToIdentifierName(); - var instanceParam = "instance".ToIdentifierName(); - var idVar = "id".ToIdentifierName(); - var headerVar = "header".ToIdentifierName(); - var body = new List(); + // C#: if (id == ) { ... } + var ifStatement = IfStatement(BinaryExpression(SyntaxKind.EqualsExpression, idVar, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(description.FieldId))), + ifBody); - if (type.HasComplexBaseType) - { - // C#: _baseTypeSerializer.Deserialize(ref reader, instance); - body.Add( - ExpressionStatement( - InvocationExpression( - BaseTypeSerializerFieldName.ToIdentifierName().Member(DeserializeMethodName), - ArgumentList(SeparatedList(new[] - { - Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - Argument(instanceParam) - }))))); + loopBody.Add(ifStatement); } - AddSerializationCallbacks(type, instanceParam, "OnDeserializing", body); - - int emptyBodyCount; - var nonCtorMembers = type.IncludePrimaryConstructorParameters ? members.FindAll(static m => !m.IsPrimaryConstructorParameter) : members; - if ((members.Count == 0 || nonCtorMembers.Count == 0) && !type.IncludePrimaryConstructorParameters) + // Consume any unknown fields + if (contiguousIds) { - // C#: reader.ConsumeEndBaseOrEndObject(); - body.Add(ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeEndBaseOrEndObject")))); - emptyBodyCount = 1; + // C#: reader.ConsumeEndBaseOrEndObject(ref header); break; + loopBody.Add(ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeEndBaseOrEndObject"), refHeaderVar))); + loopBody.Add(BreakStatement()); } else { - // C#: uint id = 0; - if (members.Count > 0) - { - body.Add(LocalDeclarationStatement( - VariableDeclaration( - PredefinedType(Token(SyntaxKind.UIntKeyword)), - SingletonSeparatedList(VariableDeclarator(idVar.Identifier, null, EqualsValueClause(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0U)))))))); - } - - // C#: Field header = default; - body.Add(LocalDeclarationStatement( - VariableDeclaration( - LibraryTypes.Field.ToTypeSyntax(), - SingletonSeparatedList(VariableDeclarator(headerVar.Identifier, null, EqualsValueClause(LiteralExpression(SyntaxKind.DefaultLiteralExpression))))))); - - emptyBodyCount = 2; - - if (type.IncludePrimaryConstructorParameters) - { - var constructorParameterMembers = members.FindAll(m => m.IsPrimaryConstructorParameter); - body.Add(GetDeserializerLoop(constructorParameterMembers)); - if (members.Count > 0) - { - body.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, idVar, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0U))))); - } - - body.Add(IfStatement(headerVar.Member("IsEndBaseFields"), GetDeserializerLoop(nonCtorMembers))); - } - else - { - body.Add(GetDeserializerLoop(nonCtorMembers)); - } + // C#: reader.ConsumeUnknownField(ref header); + loopBody.Add(ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeUnknownField"), refHeaderVar))); } - AddSerializationCallbacks(type, instanceParam, "OnDeserialized", body); - - if (body.Count == emptyBodyCount && type.IsAbstractType) - return null; + return WhileStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression), Block(loopBody)); + } + } - var genericParam = ParseTypeName("TReaderInput"); - var parameters = new[] + private static void AddSerializationCallbacks(ISerializableTypeDescription type, IdentifierNameSyntax instanceParam, string callbackMethodName, List body) + { + for (var hookIndex = 0; hookIndex < type.SerializationHooks.Count; ++hookIndex) + { + var hookType = type.SerializationHooks[hookIndex]; + var member = hookType.GetAllMembers(callbackMethodName, Accessibility.Public).FirstOrDefault(); + if (member is null || member.Parameters.Length != 1) { - Parameter(readerParam.Identifier).WithType(LibraryTypes.Reader.ToTypeSyntax(genericParam)).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), - Parameter(instanceParam.Identifier).WithType(type.TypeSyntax) - }; + continue; + } - if (type.IsValueType) + var argument = Argument(instanceParam); + if (member.Parameters[0].RefKind == RefKind.Ref) { - parameters[1] = parameters[1].WithModifiers(LibraryTypes.HasScopedKeyword() ? TokenList(Token(SyntaxKind.ScopedKeyword), Token(SyntaxKind.RefKeyword)) : TokenList(Token(SyntaxKind.RefKeyword))); + argument = argument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } - var res = MethodDeclaration(returnType, DeserializeMethodName) - .AddTypeParameterListParameters(TypeParameter("TReaderInput")) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); - - if (type.IsAbstractType) - res = res.AddModifiers(Token(SyntaxKind.OverrideKeyword)); + body.Add(ExpressionStatement(InvocationExpression( + IdentifierName($"_hook{hookIndex}").Member(callbackMethodName), + ArgumentList(SeparatedList([argument]))))); + } + } - return res; + private MemberDeclarationSyntax GenerateCompoundTypeWriteFieldMethod( + ISerializableTypeDescription type) + { + var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); - // Create the loop body. - StatementSyntax GetDeserializerLoop(List members) - { - var refHeaderVar = ArgumentList(SingletonSeparatedList(Argument(null, Token(SyntaxKind.RefKeyword), headerVar))); - if (members.Count == 0) - { - // C#: reader.ReadFieldHeader(ref header); - // C#: reader.ConsumeEndBaseOrEndObject(ref header); - return Block( - ExpressionStatement(InvocationExpression(readerParam.Member("ReadFieldHeader"), refHeaderVar)), - ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeEndBaseOrEndObject"), refHeaderVar))); - } + var writerParam = "writer".ToIdentifierName(); + var fieldIdDeltaParam = "fieldIdDelta".ToIdentifierName(); + var expectedTypeParam = "expectedType".ToIdentifierName(); + var valueParam = "value".ToIdentifierName(); - var loopBody = new List(); + var innerBody = new List(); - // C#: reader.ReadFieldHeader(ref header); - // C#: if (header.IsEndBaseOrEndObject) break; - // C#: id += header.FieldIdDelta; - var readFieldHeader = ExpressionStatement(InvocationExpression(readerParam.Member("ReadFieldHeader"), refHeaderVar)); - var endObjectCheck = IfStatement(headerVar.Member("IsEndBaseOrEndObject"), BreakStatement()); - var idUpdate = ExpressionStatement(AssignmentExpression(SyntaxKind.AddAssignmentExpression, idVar, headerVar.Member("FieldIdDelta"))); - loopBody.Add(readFieldHeader); - loopBody.Add(endObjectCheck); - loopBody.Add(idUpdate); - - members.Sort((x, y) => x.Member.FieldId.CompareTo(y.Member.FieldId)); - var contiguousIds = members[members.Count - 1].Member.FieldId == members.Count - 1; - foreach (var member in members) - { - var description = member.Member; + if (type.IsValueType) + { + // C#: ReferenceCodec.MarkValueField(reader.Session); + innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("MarkValueField"), ArgumentList(SingletonSeparatedList(Argument(writerParam.Member("Session"))))))); + } + else + { + if (type.TrackReferences) + { + // C#: if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) return; + innerBody.Add( + IfStatement( + InvocationExpression( + IdentifierName("ReferenceCodec").Member("TryWriteReferenceField"), + ArgumentList(SeparatedList( + [ + Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + Argument(fieldIdDeltaParam), + Argument(expectedTypeParam), + Argument(valueParam) + ]))), + ReturnStatement()) + ); + } + else + { + // C#: if (value is null) { ReferenceCodec.WriteNullReference(ref writer, fieldIdDelta); return; } + innerBody.Add( + IfStatement( + IsPatternExpression(valueParam, ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), + Block( + ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("WriteNullReference"), + ArgumentList(SeparatedList( + [ + Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + Argument(fieldIdDeltaParam) + ])))), + ReturnStatement())) + ); - // C#: instance. = .ReadValue(ref reader, header); - // Codecs can either be static classes or injected into the constructor. - // Either way, the member signatures are the same. - ExpressionSyntax codecExpression; - if (LibraryTypes.StaticCodecs.FindByUnderlyingType(description.Type) is { } staticCodec) - { - codecExpression = staticCodec.CodecType.ToNameSyntax(); - } - else - { - var instanceCodec = serializerFields.Find(c => c is CodecFieldDescription f && SymbolEqualityComparer.Default.Equals(f.UnderlyingType, description.Type)); - codecExpression = IdentifierName(instanceCodec.FieldName); - } + // C#: ReferenceCodec.MarkValueField(reader.Session); + innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("MarkValueField"), ArgumentList(SingletonSeparatedList(Argument(writerParam.Member("Session"))))))); + } + } - ExpressionSyntax readValueExpression = InvocationExpression( - codecExpression.Member("ReadValue"), - ArgumentList(SeparatedList(new[] { Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), Argument(headerVar) }))); + // C#: writer.WriteStartObject(fieldIdDelta, expectedType, _codecFieldType); + innerBody.Add( + ExpressionStatement(InvocationExpression(writerParam.Member("WriteStartObject"), + ArgumentList(SeparatedList([ + Argument(fieldIdDeltaParam), + Argument(expectedTypeParam), + Argument(IdentifierName(CodecFieldTypeFieldName)) + ]))) + )); + + // C#: this.Serialize(ref writer, [ref] value); + var valueParamArgument = type.IsValueType switch + { + true => Argument(valueParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + false => Argument(valueParam) + }; - var memberAssignment = ExpressionStatement(member.GetSetter(instanceParam, readValueExpression)); + innerBody.Add( + ExpressionStatement( + InvocationExpression( + IdentifierName(SerializeMethodName), + ArgumentList( + SeparatedList( + [ + Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + valueParamArgument + ]))))); + + // C#: writer.WriteEndObject(); + innerBody.Add(ExpressionStatement(InvocationExpression(writerParam.Member("WriteEndObject")))); + + List body; + if (type.IsSealedType) + { + body = innerBody; + } + else + { + // For types which are not sealed/value types, add some extra logic to support sub-types: + body = new() + { + // C#: if (value is null || value.GetType() == typeof(TField)) { } + // C#: else writer.SerializeUnexpectedType(fieldIdDelta, expectedType, value); + IfStatement( + BinaryExpression(SyntaxKind.LogicalOrExpression, + IsPatternExpression(valueParam, ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), + BinaryExpression(SyntaxKind.EqualsExpression, InvocationExpression(valueParam.Member("GetType")), TypeOfExpression(type.TypeSyntax))), + Block(innerBody), + ElseClause(ExpressionStatement( + InvocationExpression( + writerParam.Member("SerializeUnexpectedType"), + ArgumentList( + SeparatedList([ + Argument(fieldIdDeltaParam), + Argument(expectedTypeParam), + Argument(valueParam) + ]))) + ))) + }; + } - BlockSyntax ifBody; - if (member != members[members.Count - 1]) - { - ifBody = Block(memberAssignment, readFieldHeader, endObjectCheck, idUpdate); - } - else if (contiguousIds) - { - ifBody = Block(memberAssignment, readFieldHeader); - } - else - { - idUpdate = ExpressionStatement(PostfixUnaryExpression(SyntaxKind.PostIncrementExpression, idVar)); - ifBody = Block(memberAssignment, readFieldHeader, endObjectCheck, idUpdate); - } + var parameters = new[] + { + Parameter("writer".ToIdentifier()).WithType(LibraryTypes.Writer.ToTypeSyntax()).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), + Parameter("fieldIdDelta".ToIdentifier()).WithType(PredefinedType(Token(SyntaxKind.UIntKeyword))), + Parameter("expectedType".ToIdentifier()).WithType(LibraryTypes.Type.ToTypeSyntax()), + Parameter("value".ToIdentifier()).WithType(type.TypeSyntax) + }; - // C#: if (id == ) { ... } - var ifStatement = IfStatement(BinaryExpression(SyntaxKind.EqualsExpression, idVar, LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(description.FieldId))), - ifBody); + return MethodDeclaration(returnType, WriteFieldMethodName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddTypeParameterListParameters(TypeParameter("TBufferWriter")) + .AddConstraintClauses(TypeParameterConstraintClause("TBufferWriter").AddConstraints(TypeConstraint(LibraryTypes.IBufferWriter.ToTypeSyntax(PredefinedType(Token(SyntaxKind.ByteKeyword)))))) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + } - loopBody.Add(ifStatement); - } + private MemberDeclarationSyntax GenerateCompoundTypeReadValueMethod( + ISerializableTypeDescription type, + List serializerFields) + { + var readerParam = "reader".ToIdentifierName(); + var fieldParam = "field".ToIdentifierName(); + var resultVar = "result".ToIdentifierName(); + var readerInputTypeParam = ParseTypeName("TReaderInput"); - // Consume any unknown fields - if (contiguousIds) - { - // C#: reader.ConsumeEndBaseOrEndObject(ref header); break; - loopBody.Add(ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeEndBaseOrEndObject"), refHeaderVar))); - loopBody.Add(BreakStatement()); - } - else - { - // C#: reader.ConsumeUnknownField(ref header); - loopBody.Add(ExpressionStatement(InvocationExpression(readerParam.Member("ConsumeUnknownField"), refHeaderVar))); - } + var body = new List(); + var innerBody = type.IsSealedType ? body : new List(); - return WhileStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression), Block(loopBody)); - } + if (!type.IsValueType) + { + // C#: if (field.IsReference) return ReferenceCodec.ReadReference(ref reader, field); + body.Add( + IfStatement( + fieldParam.Member("IsReference"), + ReturnStatement(InvocationExpression( + IdentifierName("ReferenceCodec").Member("ReadReference", [type.TypeSyntax, readerInputTypeParam]), + ArgumentList(SeparatedList( + [ + Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + Argument(fieldParam), + ]))))) + ); } - private void AddSerializationCallbacks(ISerializableTypeDescription type, IdentifierNameSyntax instanceParam, string callbackMethodName, List body) + // C#: field.EnsureWireTypeTagDelimited(); + body.Add(ExpressionStatement(InvocationExpression(fieldParam.Member("EnsureWireTypeTagDelimited")))); + + ExpressionSyntax createValueExpression = type.UseActivator switch { - for (var hookIndex = 0; hookIndex < type.SerializationHooks.Count; ++hookIndex) - { - var hookType = type.SerializationHooks[hookIndex]; - var member = hookType.GetAllMembers(callbackMethodName, Accessibility.Public).FirstOrDefault(); - if (member is null || member.Parameters.Length != 1) - { - continue; - } + true => InvocationExpression(serializerFields.OfType().Single().FieldName.ToIdentifierName().Member("Create")), + false => type.GetObjectCreationExpression() + }; - var argument = Argument(instanceParam); - if (member.Parameters[0].RefKind == RefKind.Ref) - { - argument = argument.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } + // C#: var result = _activator.Create(); + // or C#: var result = new TField(); + // or C#: var result = default(TField); + innerBody.Add(LocalDeclarationStatement( + VariableDeclaration( + IdentifierName("var"), + SingletonSeparatedList(VariableDeclarator(resultVar.Identifier) + .WithInitializer(EqualsValueClause(createValueExpression)))))); - body.Add(ExpressionStatement(InvocationExpression( - IdentifierName($"_hook{hookIndex}").Member(callbackMethodName), - ArgumentList(SeparatedList(new[] { argument }))))); - } + if (type.TrackReferences) + { + // C#: ReferenceCodec.RecordObject(reader.Session, result); + innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("RecordObject"), ArgumentList(SeparatedList([Argument(readerParam.Member("Session")), Argument(resultVar)]))))); } - - private MemberDeclarationSyntax GenerateCompoundTypeWriteFieldMethod( - ISerializableTypeDescription type) + else { - var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); - - var writerParam = "writer".ToIdentifierName(); - var fieldIdDeltaParam = "fieldIdDelta".ToIdentifierName(); - var expectedTypeParam = "expectedType".ToIdentifierName(); - var valueParam = "value".ToIdentifierName(); - - var innerBody = new List(); - - if (type.IsValueType) - { - // C#: ReferenceCodec.MarkValueField(reader.Session); - innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("MarkValueField"), ArgumentList(SingletonSeparatedList(Argument(writerParam.Member("Session"))))))); - } - else - { - if (type.TrackReferences) - { - // C#: if (ReferenceCodec.TryWriteReferenceField(ref writer, fieldIdDelta, expectedType, value)) return; - innerBody.Add( - IfStatement( - InvocationExpression( - IdentifierName("ReferenceCodec").Member("TryWriteReferenceField"), - ArgumentList(SeparatedList(new[] - { - Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - Argument(fieldIdDeltaParam), - Argument(expectedTypeParam), - Argument(valueParam) - }))), - ReturnStatement()) - ); - } - else - { - // C#: if (value is null) { ReferenceCodec.WriteNullReference(ref writer, fieldIdDelta); return; } - innerBody.Add( - IfStatement( - IsPatternExpression(valueParam, ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), - Block( - ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("WriteNullReference"), - ArgumentList(SeparatedList(new[] - { - Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - Argument(fieldIdDeltaParam) - })))), - ReturnStatement())) - ); - - // C#: ReferenceCodec.MarkValueField(reader.Session); - innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("MarkValueField"), ArgumentList(SingletonSeparatedList(Argument(writerParam.Member("Session"))))))); - } - } + // C#: ReferenceCodec.MarkValueField(reader.Session); + innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("MarkValueField"), ArgumentList(SingletonSeparatedList(Argument(readerParam.Member("Session"))))))); + } - // C#: writer.WriteStartObject(fieldIdDelta, expectedType, _codecFieldType); - innerBody.Add( - ExpressionStatement(InvocationExpression(writerParam.Member("WriteStartObject"), - ArgumentList(SeparatedList(new[]{ - Argument(fieldIdDeltaParam), - Argument(expectedTypeParam), - Argument(IdentifierName(CodecFieldTypeFieldName)) - }))) - )); - - // C#: this.Serialize(ref writer, [ref] value); - var valueParamArgument = type.IsValueType switch - { - true => Argument(valueParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - false => Argument(valueParam) - }; + // C#: this.Deserializer(ref reader, [ref] result); + var resultArgument = type.IsValueType switch + { + true => Argument(resultVar).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + false => Argument(resultVar) + }; + innerBody.Add( + ExpressionStatement( + InvocationExpression( + IdentifierName(DeserializeMethodName), + ArgumentList( + SeparatedList( + [ + Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + resultArgument + ]))))); - innerBody.Add( - ExpressionStatement( - InvocationExpression( - IdentifierName(SerializeMethodName), - ArgumentList( - SeparatedList( - new[] - { - Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - valueParamArgument - }))))); + innerBody.Add(ReturnStatement(resultVar)); - // C#: writer.WriteEndObject(); - innerBody.Add(ExpressionStatement(InvocationExpression(writerParam.Member("WriteEndObject")))); + if (!type.IsSealedType) + { + // C#: var fieldType = field.FieldType; + var valueTypeField = "valueType".ToIdentifierName(); + body.Add( + LocalDeclarationStatement( + VariableDeclaration( + LibraryTypes.Type.ToTypeSyntax(), + SingletonSeparatedList(VariableDeclarator(valueTypeField.Identifier) + .WithInitializer(EqualsValueClause(fieldParam.Member("FieldType"))))))); + body.Add( + IfStatement( + BinaryExpression(SyntaxKind.LogicalOrExpression, + IsPatternExpression(valueTypeField, ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), + BinaryExpression(SyntaxKind.EqualsExpression, valueTypeField, IdentifierName(CodecFieldTypeFieldName))), + Block(innerBody))); - List body; - if (type.IsSealedType) - { - body = innerBody; - } - else - { - // For types which are not sealed/value types, add some extra logic to support sub-types: - body = new() - { - // C#: if (value is null || value.GetType() == typeof(TField)) { } - // C#: else writer.SerializeUnexpectedType(fieldIdDelta, expectedType, value); - IfStatement( - BinaryExpression(SyntaxKind.LogicalOrExpression, - IsPatternExpression(valueParam, ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), - BinaryExpression(SyntaxKind.EqualsExpression, InvocationExpression(valueParam.Member("GetType")), TypeOfExpression(type.TypeSyntax))), - Block(innerBody), - ElseClause(ExpressionStatement( + body.Add(ReturnStatement( InvocationExpression( - writerParam.Member("SerializeUnexpectedType"), + readerParam.Member("DeserializeUnexpectedType", [readerInputTypeParam, type.TypeSyntax]), ArgumentList( - SeparatedList(new [] { - Argument(fieldIdDeltaParam), - Argument(expectedTypeParam), - Argument(valueParam) - }))) - ))) - }; - } - - var parameters = new[] - { - Parameter("writer".ToIdentifier()).WithType(LibraryTypes.Writer.ToTypeSyntax()).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), - Parameter("fieldIdDelta".ToIdentifier()).WithType(PredefinedType(Token(SyntaxKind.UIntKeyword))), - Parameter("expectedType".ToIdentifier()).WithType(LibraryTypes.Type.ToTypeSyntax()), - Parameter("value".ToIdentifier()).WithType(type.TypeSyntax) - }; - - return MethodDeclaration(returnType, WriteFieldMethodName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddTypeParameterListParameters(TypeParameter("TBufferWriter")) - .AddConstraintClauses(TypeParameterConstraintClause("TBufferWriter").AddConstraints(TypeConstraint(LibraryTypes.IBufferWriter.ToTypeSyntax(PredefinedType(Token(SyntaxKind.ByteKeyword)))))) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); + SingletonSeparatedList(Argument(null, Token(SyntaxKind.RefKeyword), fieldParam)))))); } - private MemberDeclarationSyntax GenerateCompoundTypeReadValueMethod( - ISerializableTypeDescription type, - List serializerFields) + var parameters = new[] { - var readerParam = "reader".ToIdentifierName(); - var fieldParam = "field".ToIdentifierName(); - var resultVar = "result".ToIdentifierName(); - var readerInputTypeParam = ParseTypeName("TReaderInput"); + Parameter(readerParam.Identifier).WithType(LibraryTypes.Reader.ToTypeSyntax(readerInputTypeParam)).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), + Parameter(fieldParam.Identifier).WithType(LibraryTypes.Field.ToTypeSyntax()) + }; - var body = new List(); - var innerBody = type.IsSealedType ? body : new List(); + return MethodDeclaration(type.TypeSyntax, ReadValueMethodName) + .AddTypeParameterListParameters(TypeParameter("TReaderInput")) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + } - if (!type.IsValueType) - { - // C#: if (field.IsReference) return ReferenceCodec.ReadReference(ref reader, field); - body.Add( - IfStatement( - fieldParam.Member("IsReference"), - ReturnStatement(InvocationExpression( - IdentifierName("ReferenceCodec").Member("ReadReference", new[] { type.TypeSyntax, readerInputTypeParam }), - ArgumentList(SeparatedList(new[] - { - Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - Argument(fieldParam), - }))))) - ); - } + private MemberDeclarationSyntax GenerateEnumWriteMethod( + ISerializableTypeDescription type) + { + var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + + var writerParam = "writer".ToIdentifierName(); + var fieldIdDeltaParam = "fieldIdDelta".ToIdentifierName(); + var expectedTypeParam = "expectedType".ToIdentifierName(); + var valueParam = "value".ToIdentifierName(); + + var body = new List(); + + // Codecs can either be static classes or injected into the constructor. + // Either way, the member signatures are the same. + var staticCodec = LibraryTypes.StaticCodecs.FindByUnderlyingType(type.BaseType); + Debug.Assert(staticCodec is not null); + var codecExpression = staticCodec!.CodecType.ToNameSyntax(); + + body.Add( + ExpressionStatement( + InvocationExpression( + codecExpression.Member("WriteField"), + ArgumentList( + SeparatedList( + [ + Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), + Argument(fieldIdDeltaParam), + Argument(expectedTypeParam), + Argument(CastExpression(type.BaseTypeSyntax, valueParam)), + Argument(IdentifierName(CodecFieldTypeFieldName)) + ]))))); + + var parameters = new[] + { + Parameter("writer".ToIdentifier()).WithType(LibraryTypes.Writer.ToTypeSyntax()).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), + Parameter("fieldIdDelta".ToIdentifier()).WithType(PredefinedType(Token(SyntaxKind.UIntKeyword))), + Parameter("expectedType".ToIdentifier()).WithType(LibraryTypes.Type.ToTypeSyntax()), + Parameter("value".ToIdentifier()).WithType(type.TypeSyntax) + }; - // C#: field.EnsureWireTypeTagDelimited(); - body.Add(ExpressionStatement(InvocationExpression(fieldParam.Member("EnsureWireTypeTagDelimited")))); + return MethodDeclaration(returnType, WriteFieldMethodName) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddTypeParameterListParameters(TypeParameter("TBufferWriter")) + .AddConstraintClauses(TypeParameterConstraintClause("TBufferWriter").AddConstraints(TypeConstraint(LibraryTypes.IBufferWriter.ToTypeSyntax(PredefinedType(Token(SyntaxKind.ByteKeyword)))))) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + } - ExpressionSyntax createValueExpression = type.UseActivator switch - { - true => InvocationExpression(serializerFields.OfType().Single().FieldName.ToIdentifierName().Member("Create")), - false => type.GetObjectCreationExpression() - }; + private MemberDeclarationSyntax GenerateEnumReadMethod( + ISerializableTypeDescription type) + { + var readerParam = "reader".ToIdentifierName(); + var fieldParam = "field".ToIdentifierName(); + + var staticCodec = LibraryTypes.StaticCodecs.FindByUnderlyingType(type.BaseType); + Debug.Assert(staticCodec is not null); + ExpressionSyntax codecExpression = staticCodec!.CodecType.ToNameSyntax(); + ExpressionSyntax readValueExpression = InvocationExpression( + codecExpression.Member("ReadValue"), + ArgumentList(SeparatedList([Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), Argument(fieldParam)]))); + + readValueExpression = CastExpression(type.TypeSyntax, readValueExpression); + var body = new List + { + ReturnStatement(readValueExpression) + }; - // C#: var result = _activator.Create(); - // or C#: var result = new TField(); - // or C#: var result = default(TField); - innerBody.Add(LocalDeclarationStatement( - VariableDeclaration( - IdentifierName("var"), - SingletonSeparatedList(VariableDeclarator(resultVar.Identifier) - .WithInitializer(EqualsValueClause(createValueExpression)))))); + var genericParam = ParseTypeName("TReaderInput"); + var parameters = new[] + { + Parameter(readerParam.Identifier).WithType(LibraryTypes.Reader.ToTypeSyntax(genericParam)).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), + Parameter(fieldParam.Identifier).WithType(LibraryTypes.Field.ToTypeSyntax()) + }; - if (type.TrackReferences) - { - // C#: ReferenceCodec.RecordObject(reader.Session, result); - innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("RecordObject"), ArgumentList(SeparatedList(new[] { Argument(readerParam.Member("Session")), Argument(resultVar) }))))); - } - else - { - // C#: ReferenceCodec.MarkValueField(reader.Session); - innerBody.Add(ExpressionStatement(InvocationExpression(IdentifierName("ReferenceCodec").Member("MarkValueField"), ArgumentList(SingletonSeparatedList(Argument(readerParam.Member("Session"))))))); - } + return MethodDeclaration(type.TypeSyntax, ReadValueMethodName) + .AddTypeParameterListParameters(TypeParameter("TReaderInput")) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddParameterListParameters(parameters) + .AddAttributeLists(AttributeList(SingletonSeparatedList(GeneratedCodeUtilities.GetMethodImplAttributeSyntax()))) + .AddBodyStatements([.. body]); + } - // C#: this.Deserializer(ref reader, [ref] result); - var resultArgument = type.IsValueType switch - { - true => Argument(resultVar).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - false => Argument(resultVar) - }; - innerBody.Add( - ExpressionStatement( - InvocationExpression( - IdentifierName(DeserializeMethodName), - ArgumentList( - SeparatedList( - new[] - { - Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - resultArgument - }))))); - - innerBody.Add(ReturnStatement(resultVar)); - - if (!type.IsSealedType) - { - // C#: var fieldType = field.FieldType; - var valueTypeField = "valueType".ToIdentifierName(); - body.Add( - LocalDeclarationStatement( - VariableDeclaration( - LibraryTypes.Type.ToTypeSyntax(), - SingletonSeparatedList(VariableDeclarator(valueTypeField.Identifier) - .WithInitializer(EqualsValueClause(fieldParam.Member("FieldType"))))))); - body.Add( - IfStatement( - BinaryExpression(SyntaxKind.LogicalOrExpression, - IsPatternExpression(valueTypeField, ConstantPattern(LiteralExpression(SyntaxKind.NullLiteralExpression))), - BinaryExpression(SyntaxKind.EqualsExpression, valueTypeField, IdentifierName(CodecFieldTypeFieldName))), - Block(innerBody))); - - body.Add(ReturnStatement( - InvocationExpression( - readerParam.Member("DeserializeUnexpectedType", new[] { readerInputTypeParam, type.TypeSyntax }), - ArgumentList( - SingletonSeparatedList(Argument(null, Token(SyntaxKind.RefKeyword), fieldParam)))))); - } + internal abstract class GeneratedFieldDescription(TypeSyntax fieldType, string fieldName) + { + public readonly TypeSyntax FieldType = fieldType; + public readonly string FieldName = fieldName; + public abstract bool IsInjected { get; } + } - var parameters = new[] - { - Parameter(readerParam.Identifier).WithType(LibraryTypes.Reader.ToTypeSyntax(readerInputTypeParam)).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), - Parameter(fieldParam.Identifier).WithType(LibraryTypes.Field.ToTypeSyntax()) - }; + internal sealed class BaseCodecFieldDescription(TypeSyntax fieldType, bool concreteType = false) : GeneratedFieldDescription(fieldType, BaseTypeSerializerFieldName) + { + public override bool IsInjected { get; } = !concreteType; + } - return MethodDeclaration(type.TypeSyntax, ReadValueMethodName) - .AddTypeParameterListParameters(TypeParameter("TReaderInput")) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); - } + internal sealed class ActivatorFieldDescription(TypeSyntax fieldType, string fieldName) : GeneratedFieldDescription(fieldType, fieldName) + { + public override bool IsInjected => true; + } - private MemberDeclarationSyntax GenerateEnumWriteMethod( - ISerializableTypeDescription type) - { - var returnType = PredefinedType(Token(SyntaxKind.VoidKeyword)); + internal sealed class CodecFieldDescription(TypeSyntax fieldType, string fieldName, ITypeSymbol underlyingType) : GeneratedFieldDescription(fieldType, fieldName) + { + public ITypeSymbol UnderlyingType { get; } = underlyingType; + public override bool IsInjected => false; + } - var writerParam = "writer".ToIdentifierName(); - var fieldIdDeltaParam = "fieldIdDelta".ToIdentifierName(); - var expectedTypeParam = "expectedType".ToIdentifierName(); - var valueParam = "value".ToIdentifierName(); + internal sealed class TypeFieldDescription(TypeSyntax fieldType, string fieldName, TypeSyntax underlyingTypeSyntax, ITypeSymbol underlyingType) : GeneratedFieldDescription(fieldType, fieldName) + { + public TypeSyntax UnderlyingTypeSyntax { get; } = underlyingTypeSyntax; + public ITypeSymbol UnderlyingType { get; } = underlyingType; + public override bool IsInjected => false; + } - var body = new List(); + internal sealed class CodecFieldTypeFieldDescription(TypeSyntax fieldType, string fieldName, TypeSyntax codecFieldType) : GeneratedFieldDescription(fieldType, fieldName) + { + public TypeSyntax CodecFieldType { get; } = codecFieldType; + public override bool IsInjected => false; + } - // Codecs can either be static classes or injected into the constructor. - // Either way, the member signatures are the same. - var staticCodec = LibraryTypes.StaticCodecs.FindByUnderlyingType(type.BaseType); - var codecExpression = staticCodec.CodecType.ToNameSyntax(); + internal sealed class FieldAccessorDescription(TypeSyntax containingType, TypeSyntax fieldType, string fieldName, string accessorName, ExpressionSyntax? initializationSyntax = null) : GeneratedFieldDescription(fieldType, fieldName) + { + public override bool IsInjected => false; + public readonly string AccessorName = accessorName; + public readonly TypeSyntax ContainingType = containingType; + public readonly ExpressionSyntax? InitializationSyntax = initializationSyntax; + } - body.Add( - ExpressionStatement( - InvocationExpression( - codecExpression.Member("WriteField"), - ArgumentList( - SeparatedList( - new[] - { - Argument(writerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), - Argument(fieldIdDeltaParam), - Argument(expectedTypeParam), - Argument(CastExpression(type.BaseTypeSyntax, valueParam)), - Argument(IdentifierName(CodecFieldTypeFieldName)) - }))))); + internal sealed class SerializationHookFieldDescription(TypeSyntax fieldType, string fieldName) : GeneratedFieldDescription(fieldType, fieldName) + { + public override bool IsInjected => true; + } - var parameters = new[] - { - Parameter("writer".ToIdentifier()).WithType(LibraryTypes.Writer.ToTypeSyntax()).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), - Parameter("fieldIdDelta".ToIdentifier()).WithType(PredefinedType(Token(SyntaxKind.UIntKeyword))), - Parameter("expectedType".ToIdentifier()).WithType(LibraryTypes.Type.ToTypeSyntax()), - Parameter("value".ToIdentifier()).WithType(type.TypeSyntax) - }; + internal interface ISerializableMember + { + bool IsShallowCopyable { get; } + bool IsValueType { get; } + bool IsPrimaryConstructorParameter { get; } - return MethodDeclaration(returnType, WriteFieldMethodName) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddTypeParameterListParameters(TypeParameter("TBufferWriter")) - .AddConstraintClauses(TypeParameterConstraintClause("TBufferWriter").AddConstraints(TypeConstraint(LibraryTypes.IBufferWriter.ToTypeSyntax(PredefinedType(Token(SyntaxKind.ByteKeyword)))))) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); - } + IMemberDescription Member { get; } - private MemberDeclarationSyntax GenerateEnumReadMethod( - ISerializableTypeDescription type) - { - var readerParam = "reader".ToIdentifierName(); - var fieldParam = "field".ToIdentifierName(); + /// + /// Gets syntax representing the type of this field. + /// + TypeSyntax TypeSyntax { get; } - var staticCodec = LibraryTypes.StaticCodecs.FindByUnderlyingType(type.BaseType); - ExpressionSyntax codecExpression = staticCodec.CodecType.ToNameSyntax(); - ExpressionSyntax readValueExpression = InvocationExpression( - codecExpression.Member("ReadValue"), - ArgumentList(SeparatedList(new[] { Argument(readerParam).WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)), Argument(fieldParam) }))); + /// + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. + /// + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + ExpressionSyntax GetGetter(ExpressionSyntax instance); - readValueExpression = CastExpression(type.TypeSyntax, readValueExpression); - var body = new List - { - ReturnStatement(readValueExpression) - }; + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value); - var genericParam = ParseTypeName("TReaderInput"); - var parameters = new[] - { - Parameter(readerParam.Identifier).WithType(LibraryTypes.Reader.ToTypeSyntax(genericParam)).WithModifiers(TokenList(Token(SyntaxKind.RefKeyword))), - Parameter(fieldParam.Identifier).WithType(LibraryTypes.Field.ToTypeSyntax()) - }; + FieldAccessorDescription? GetGetterFieldDescription(); + FieldAccessorDescription? GetSetterFieldDescription(); + } - return MethodDeclaration(type.TypeSyntax, ReadValueMethodName) - .AddTypeParameterListParameters(TypeParameter("TReaderInput")) - .AddModifiers(Token(SyntaxKind.PublicKeyword)) - .AddParameterListParameters(parameters) - .AddAttributeLists(AttributeList(SingletonSeparatedList(CodeGenerator.GetMethodImplAttributeSyntax()))) - .AddBodyStatements(body.ToArray()); - } + /// + /// Represents a serializable member (field/property) of a type. + /// + internal class SerializableMethodMember(InvokableGenerator.MethodParameterFieldDescription member) : ISerializableMember + { + IMemberDescription ISerializableMember.Member => Member; + public MethodParameterFieldDescription Member { get; } = member; - internal abstract class GeneratedFieldDescription - { - protected GeneratedFieldDescription(TypeSyntax fieldType, string fieldName) - { - FieldType = fieldType; - FieldName = fieldName; - } + private LibraryTypes LibraryTypes => Member.LibraryTypes; - public readonly TypeSyntax FieldType; - public readonly string FieldName; - public abstract bool IsInjected { get; } - } + public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(Member.Parameter.Type) || Member.Parameter.HasAttribute(LibraryTypes.ImmutableAttribute); - internal sealed class BaseCodecFieldDescription : GeneratedFieldDescription - { - public BaseCodecFieldDescription(TypeSyntax fieldType, bool concreteType = false) : base(fieldType, BaseTypeSerializerFieldName) - => IsInjected = !concreteType; + /// + /// Gets syntax representing the type of this field. + /// + public TypeSyntax TypeSyntax => Member.TypeSyntax; - public override bool IsInjected { get; } - } + public bool IsValueType => Member.Type.IsValueType; - internal sealed class ActivatorFieldDescription : GeneratedFieldDescription - { - public ActivatorFieldDescription(TypeSyntax fieldType, string fieldName) : base(fieldType, fieldName) - { - } + public bool IsPrimaryConstructorParameter => Member.IsPrimaryConstructorParameter; - public override bool IsInjected => true; - } + /// + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. + /// + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(Member.FieldName); - internal sealed class CodecFieldDescription : GeneratedFieldDescription - { - public CodecFieldDescription(TypeSyntax fieldType, string fieldName, ITypeSymbol underlyingType) : base(fieldType, fieldName) - { - UnderlyingType = underlyingType; - } + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) => AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + instance.Member(Member.FieldName), + value); + + public FieldAccessorDescription? GetGetterFieldDescription() => null; + public FieldAccessorDescription? GetSetterFieldDescription() => null; + } - public ITypeSymbol UnderlyingType { get; } - public override bool IsInjected => false; - } + /// + /// Represents a serializable member (field/property) of a type. + /// + internal class SerializableMember(IGeneratorServices generatorServices, IMemberDescription member, int ordinal) : ISerializableMember + { + private readonly IGeneratorServices _generatorServices = generatorServices; - internal sealed class TypeFieldDescription : GeneratedFieldDescription - { - public TypeFieldDescription(TypeSyntax fieldType, string fieldName, TypeSyntax underlyingTypeSyntax, ITypeSymbol underlyingType) : base(fieldType, fieldName) - { - UnderlyingType = underlyingType; - UnderlyingTypeSyntax = underlyingTypeSyntax; - } + /// + /// The ordinal assigned to this field. + /// + private readonly int _ordinal = ordinal; - public TypeSyntax UnderlyingTypeSyntax { get; } - public ITypeSymbol UnderlyingType { get; } - public override bool IsInjected => false; - } + private Compilation Compilation => _generatorServices.Compilation; + private LibraryTypes LibraryTypes => _generatorServices.LibraryTypes; - internal sealed class CodecFieldTypeFieldDescription : GeneratedFieldDescription - { - public CodecFieldTypeFieldDescription(TypeSyntax fieldType, string fieldName, TypeSyntax codecFieldType) : base(fieldType, fieldName) - { - CodecFieldType = codecFieldType; - } + public bool IsShallowCopyable => + LibraryTypes.IsShallowCopyable(Member.Type) + || Property is { } prop && prop.HasAttribute(LibraryTypes.ImmutableAttribute) + || Member.Symbol.HasAttribute(LibraryTypes.ImmutableAttribute); - public TypeSyntax CodecFieldType { get; } - public override bool IsInjected => false; - } + public bool IsValueType => Type.IsValueType; - internal sealed class FieldAccessorDescription : GeneratedFieldDescription - { - public FieldAccessorDescription(TypeSyntax containingType, TypeSyntax fieldType, string fieldName, string accessorName, ExpressionSyntax initializationSyntax = null) : base(fieldType, fieldName) - { - ContainingType = containingType; - AccessorName = accessorName; - InitializationSyntax = initializationSyntax; - } + public IMemberDescription Member { get; } = member; - public override bool IsInjected => false; - public readonly string AccessorName; - public readonly TypeSyntax ContainingType; - public readonly ExpressionSyntax InitializationSyntax; - } + /// + /// Gets the underlying instance. + /// + private IFieldSymbol? Field => (Member as IFieldDescription)?.Field; - internal sealed class SerializationHookFieldDescription : GeneratedFieldDescription - { - public SerializationHookFieldDescription(TypeSyntax fieldType, string fieldName) : base(fieldType, fieldName) - { - } + public ITypeSymbol Type => Member.Type; - public override bool IsInjected => true; - } + public INamedTypeSymbol ContainingType => Member.ContainingType; - internal interface ISerializableMember - { - bool IsShallowCopyable { get; } - bool IsValueType { get; } - bool IsPrimaryConstructorParameter { get; } - - IMemberDescription Member { get; } - - /// - /// Gets syntax representing the type of this field. - /// - TypeSyntax TypeSyntax { get; } - - /// - /// Returns syntax for retrieving the value of this field, deep copying it if necessary. - /// - /// The instance of the containing type. - /// Syntax for retrieving the value of this field. - ExpressionSyntax GetGetter(ExpressionSyntax instance); - - /// - /// Returns syntax for setting the value of this field. - /// - /// The instance of the containing type. - /// Syntax for the new value. - /// Syntax for setting the value of this field. - ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value); - - FieldAccessorDescription GetGetterFieldDescription(); - FieldAccessorDescription GetSetterFieldDescription(); - } + public string MemberName => Field?.Name ?? Property?.Name ?? Member.Symbol.Name; /// - /// Represents a serializable member (field/property) of a type. + /// Gets the name of the getter field. /// - internal class SerializableMethodMember : ISerializableMember - { - private readonly MethodParameterFieldDescription _member; + private string GetterFieldName => $"getField{_ordinal}"; - public SerializableMethodMember(MethodParameterFieldDescription member) - { - _member = member; - } + /// + /// Gets the name of the setter field. + /// + private string SetterFieldName => $"setField{_ordinal}"; - IMemberDescription ISerializableMember.Member => _member; - public MethodParameterFieldDescription Member => _member; + /// + /// Gets a value indicating if the member is a property. + /// + private bool IsProperty => Member.Symbol is IPropertySymbol; - private LibraryTypes LibraryTypes => _member.CodeGenerator.LibraryTypes; + /// + /// Gets a value indicating whether or not this member represents an accessible field. + /// + private bool IsGettableField => Field is { } fieldInfo && _generatorServices.Compilation.IsSymbolAccessibleWithin(fieldInfo, Compilation.Assembly) && !IsObsolete; - public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Parameter.Type) || _member.Parameter.HasAttribute(LibraryTypes.ImmutableAttribute); + /// + /// Gets a value indicating whether or not this member represents an accessible, mutable field. + /// + private bool IsSettableField => Field is { } fieldInfo && IsGettableField && !fieldInfo.IsReadOnly; - /// - /// Gets syntax representing the type of this field. - /// - public TypeSyntax TypeSyntax => _member.TypeSyntax; + /// + /// Gets a value indicating whether or not this member represents a property with an accessible, non-obsolete getter. + /// + private bool IsGettableProperty => Property?.GetMethod is { } getMethod && Compilation.IsSymbolAccessibleWithin(getMethod, Compilation.Assembly) && !IsObsolete; - public bool IsValueType => _member.Type.IsValueType; + /// + /// Gets a value indicating whether or not this member represents a property with an accessible, non-obsolete setter. + /// + private bool IsSettableProperty => Property?.SetMethod is { } setMethod && Compilation.IsSymbolAccessibleWithin(setMethod, Compilation.Assembly) && !setMethod.IsInitOnly && !IsObsolete; - public bool IsPrimaryConstructorParameter => _member.IsPrimaryConstructorParameter; + /// + /// Gets syntax representing the type of this field. + /// + public TypeSyntax TypeSyntax => Member.Type.TypeKind == TypeKind.Dynamic + ? PredefinedType(Token(SyntaxKind.ObjectKeyword)) + : Member.GetTypeSyntax(Member.Type); - /// - /// Returns syntax for retrieving the value of this field, deep copying it if necessary. - /// - /// The instance of the containing type. - /// Syntax for retrieving the value of this field. - public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_member.FieldName); + /// + /// Gets the which this field is the backing property for, or + /// if this is not the backing field of an auto-property. + /// + private IPropertySymbol? Property => field ??= field = Member.Symbol as IPropertySymbol ?? (Field is { } fieldSymbol ? PropertyUtility.GetMatchingProperty(fieldSymbol) : null); - /// - /// Returns syntax for setting the value of this field. - /// - /// The instance of the containing type. - /// Syntax for the new value. - /// Syntax for setting the value of this field. - public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) => AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - instance.Member(_member.FieldName), - value); + /// + /// Gets a value indicating whether or not this field is obsolete. + /// + private bool IsObsolete => Member.Symbol.HasAttribute(LibraryTypes.ObsoleteAttribute) || + Property != null && Property.HasAttribute(LibraryTypes.ObsoleteAttribute); - public FieldAccessorDescription GetGetterFieldDescription() => null; - public FieldAccessorDescription GetSetterFieldDescription() => null; - } + public bool IsPrimaryConstructorParameter => Member.IsPrimaryConstructorParameter; /// - /// Represents a serializable member (field/property) of a type. + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. /// - internal class SerializableMember : ISerializableMember + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + public ExpressionSyntax GetGetter(ExpressionSyntax instance) { - private readonly IMemberDescription _member; - private readonly CodeGenerator _codeGenerator; - private IPropertySymbol _property; - - /// - /// The ordinal assigned to this field. - /// - private readonly int _ordinal; - - public SerializableMember(CodeGenerator codeGenerator, IMemberDescription member, int ordinal) + // If the field is the backing field for an accessible auto-property use the property directly. + ExpressionSyntax result; + if (Property is { } property && IsGettableProperty) { - _codeGenerator = codeGenerator; - _ordinal = ordinal; - _member = member; + result = instance.Member(property.Name); } - - private Compilation Compilation => _codeGenerator.Compilation; - private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes; - - public bool IsShallowCopyable => - LibraryTypes.IsShallowCopyable(_member.Type) - || Property is { } prop && prop.HasAttribute(LibraryTypes.ImmutableAttribute) - || _member.Symbol.HasAttribute(LibraryTypes.ImmutableAttribute); - - public bool IsValueType => Type.IsValueType; - - public IMemberDescription Member => _member; - - /// - /// Gets the underlying instance. - /// - private IFieldSymbol Field => (_member as IFieldDescription)?.Field; - - public ITypeSymbol Type => _member.Type; - - public INamedTypeSymbol ContainingType => _member.ContainingType; - - public string MemberName => Field?.Name ?? Property?.Name; - - /// - /// Gets the name of the getter field. - /// - private string GetterFieldName => $"getField{_ordinal}"; - - /// - /// Gets the name of the setter field. - /// - private string SetterFieldName => $"setField{_ordinal}"; - - /// - /// Gets a value indicating if the member is a property. - /// - private bool IsProperty => Member.Symbol is IPropertySymbol; - - /// - /// Gets a value indicating whether or not this member represents an accessible field. - /// - private bool IsGettableField => Field is { } fieldInfo && _codeGenerator.Compilation.IsSymbolAccessibleWithin(fieldInfo, Compilation.Assembly) && !IsObsolete; - - /// - /// Gets a value indicating whether or not this member represents an accessible, mutable field. - /// - private bool IsSettableField => Field is { } fieldInfo && IsGettableField && !fieldInfo.IsReadOnly; - - /// - /// Gets a value indicating whether or not this member represents a property with an accessible, non-obsolete getter. - /// - private bool IsGettableProperty => Property?.GetMethod is { } getMethod && Compilation.IsSymbolAccessibleWithin(getMethod, Compilation.Assembly) && !IsObsolete; - - /// - /// Gets a value indicating whether or not this member represents a property with an accessible, non-obsolete setter. - /// - private bool IsSettableProperty => Property?.SetMethod is { } setMethod && Compilation.IsSymbolAccessibleWithin(setMethod, Compilation.Assembly) && !setMethod.IsInitOnly && !IsObsolete; - - /// - /// Gets syntax representing the type of this field. - /// - public TypeSyntax TypeSyntax => Member.Type.TypeKind == TypeKind.Dynamic - ? PredefinedType(Token(SyntaxKind.ObjectKeyword)) - : _member.GetTypeSyntax(Member.Type); - - /// - /// Gets the which this field is the backing property for, or - /// if this is not the backing field of an auto-property. - /// - private IPropertySymbol Property => _property ??= _property = Member.Symbol as IPropertySymbol ?? PropertyUtility.GetMatchingProperty(Field); - - /// - /// Gets a value indicating whether or not this field is obsolete. - /// - private bool IsObsolete => Member.Symbol.HasAttribute(LibraryTypes.ObsoleteAttribute) || - Property != null && Property.HasAttribute(LibraryTypes.ObsoleteAttribute); - - public bool IsPrimaryConstructorParameter => _member.IsPrimaryConstructorParameter; - - /// - /// Returns syntax for retrieving the value of this field, deep copying it if necessary. - /// - /// The instance of the containing type. - /// Syntax for retrieving the value of this field. - public ExpressionSyntax GetGetter(ExpressionSyntax instance) + else if (Field is { } field && IsGettableField) { - // If the field is the backing field for an accessible auto-property use the property directly. - ExpressionSyntax result; - if (IsGettableProperty) - { - result = instance.Member(Property.Name); - } - else if (IsGettableField) + result = instance.Member(field.Name); + } + else + { + + var instanceArg = Argument(instance); + if (ContainingType?.IsValueType == true) { - result = instance.Member(Field.Name); + instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } - else - { - var instanceArg = Argument(instance); - if (ContainingType?.IsValueType == true) - { - instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } + // Retrieve the field using the generated getter. + result = + InvocationExpression(IdentifierName(GetterFieldName)) + .AddArgumentListArguments(instanceArg); + } - // Retrieve the field using the generated getter. - result = - InvocationExpression(IdentifierName(GetterFieldName)) - .AddArgumentListArguments(instanceArg); - } + return result; + } - return result; + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) + { + // If the field is the backing field for an accessible auto-property use the property directly. + if (Property is { } property && IsSettableProperty) + { + return AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + instance.Member(property.Name), + value); } - /// - /// Returns syntax for setting the value of this field. - /// - /// The instance of the containing type. - /// Syntax for the new value. - /// Syntax for setting the value of this field. - public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) + if (Field is { } field && IsSettableField) { - // If the field is the backing field for an accessible auto-property use the property directly. - if (IsSettableProperty) - { - return AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - instance.Member(Property.Name), - value); - } - - if (IsSettableField) - { - return AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - instance.Member(Field.Name), - value); - } + return AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + instance.Member(field.Name), + value); + } - // If the symbol itself is a property but is not settable, then error out, since we do not know how to set it value - if (IsProperty && !IsPrimaryConstructorParameter) + // If the symbol itself is a property but is not settable, then error out, since we do not know how to set it value + if (IsProperty && !IsPrimaryConstructorParameter) + { + Location? location = default; + if (Member.Symbol is IPropertySymbol prop && prop.SetMethod is { } setMethod) { - Location location = default; - if (Member.Symbol is IPropertySymbol prop && prop.SetMethod is { } setMethod) - { - location = setMethod.Locations.FirstOrDefault(); - } - - location ??= Member.Symbol.Locations.FirstOrDefault(); - - throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSetterDiagnostic.CreateDiagnostic(location, Member.Symbol?.ToDisplayString() ?? $"{ContainingType.ToDisplayString()}.{MemberName}")); + location = setMethod.Locations.FirstOrDefault(); } - var instanceArg = Argument(instance); - if (ContainingType?.IsValueType == true) - { - instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } + location ??= Member.Symbol.Locations.FirstOrDefault(); - return - InvocationExpression(IdentifierName(SetterFieldName)) - .AddArgumentListArguments(instanceArg, Argument(value)); + throw new OrleansGeneratorDiagnosticAnalysisException(InaccessibleSetterDiagnostic.CreateDiagnostic(location, Member.Symbol?.ToDisplayString() ?? $"{ContainingType.ToDisplayString()}.{MemberName}")); } - public FieldAccessorDescription GetGetterFieldDescription() + var instanceArg = Argument(instance); + if (ContainingType?.IsValueType == true) { - if (IsGettableField || IsGettableProperty) return null; - return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, GetterFieldName, LibraryTypes, false, - IsPrimaryConstructorParameter && IsProperty); + instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } - public FieldAccessorDescription GetSetterFieldDescription() - { - if (IsSettableField || IsSettableProperty) return null; - return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, SetterFieldName, LibraryTypes, true, - IsPrimaryConstructorParameter && IsProperty); - } + return + InvocationExpression(IdentifierName(SetterFieldName)) + .AddArgumentListArguments(instanceArg, Argument(value)); + } - public static FieldAccessorDescription GetFieldAccessor(INamedTypeSymbol containingType, TypeSyntax fieldType, string fieldName, string accessorName, LibraryTypes library, bool setter, bool useUnsafeAccessor = false) - { - var containingTypeSyntax = containingType.ToTypeSyntax(); + public FieldAccessorDescription? GetGetterFieldDescription() + { + if (IsGettableField || IsGettableProperty) return null; + return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, GetterFieldName, LibraryTypes, false, + IsPrimaryConstructorParameter && IsProperty); + } - if (useUnsafeAccessor) - return new(containingTypeSyntax, fieldType, fieldName, accessorName); + public FieldAccessorDescription? GetSetterFieldDescription() + { + if (IsSettableField || IsSettableProperty) return null; + return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, SetterFieldName, LibraryTypes, true, + IsPrimaryConstructorParameter && IsProperty); + } - var valueType = containingType.IsValueType; + public static FieldAccessorDescription GetFieldAccessor(INamedTypeSymbol containingType, TypeSyntax fieldType, string fieldName, string accessorName, LibraryTypes library, bool setter, bool useUnsafeAccessor = false) + { + var containingTypeSyntax = containingType.ToTypeSyntax(); - var delegateType = (setter ? (valueType ? library.ValueTypeSetter_2 : library.Action_2) : (valueType ? library.ValueTypeGetter_2 : library.Func_2)) - .ToTypeSyntax(containingTypeSyntax, fieldType); + if (useUnsafeAccessor) + return new(containingTypeSyntax, fieldType, fieldName, accessorName); - // Generate syntax to initialize the field in the constructor - var fieldAccessorUtility = AliasQualifiedName("global", IdentifierName("Orleans.Serialization")).Member("Utilities").Member("FieldAccessor"); - var accessorMethod = setter ? (valueType ? "GetValueSetter" : "GetReferenceSetter") : (valueType ? "GetValueGetter" : "GetGetter"); - var accessorInvoke = CastExpression(delegateType, - InvocationExpression(fieldAccessorUtility.Member(accessorMethod)) - .AddArgumentListArguments(Argument(TypeOfExpression(containingTypeSyntax)), Argument(fieldName.GetLiteralExpression()))); + var valueType = containingType.IsValueType; - // Existing case, accessor is the field in both cases - return new(containingTypeSyntax, delegateType, accessorName, accessorName, accessorInvoke); - } + var delegateType = (setter ? (valueType ? library.ValueTypeSetter_2 : library.Action_2) : (valueType ? library.ValueTypeGetter_2 : library.Func_2)) + .ToTypeSyntax(containingTypeSyntax, fieldType); + + // Generate syntax to initialize the field in the constructor + var fieldAccessorUtility = AliasQualifiedName("global", IdentifierName("Orleans.Serialization")).Member("Utilities").Member("FieldAccessor"); + var accessorMethod = setter ? (valueType ? "GetValueSetter" : "GetReferenceSetter") : (valueType ? "GetValueGetter" : "GetGetter"); + var accessorInvoke = CastExpression(delegateType, + InvocationExpression(fieldAccessorUtility.Member(accessorMethod)) + .AddArgumentListArguments(Argument(TypeOfExpression(containingTypeSyntax)), Argument(fieldName.GetLiteralExpression()))); + + // Existing case, accessor is the field in both cases + return new(containingTypeSyntax, delegateType, accessorName, accessorName, accessorInvoke); } } } diff --git a/src/Orleans.CodeGenerator/SourceGeneratorOptionsParser.cs b/src/Orleans.CodeGenerator/SourceGeneratorOptionsParser.cs new file mode 100644 index 00000000000..7348db1d0aa --- /dev/null +++ b/src/Orleans.CodeGenerator/SourceGeneratorOptionsParser.cs @@ -0,0 +1,83 @@ +using System.Diagnostics; +using Microsoft.CodeAnalysis.Diagnostics; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal static class SourceGeneratorOptionsParser +{ + private static int _debuggerLaunchState; + + internal static CodeGeneratorOptions CreateCodeGeneratorOptions(SourceGeneratorOptions options) + { + return new CodeGeneratorOptions + { + GenerateFieldIds = options.GenerateFieldIds, + GenerateCompatibilityInvokers = options.GenerateCompatibilityInvokers, + }; + } + + internal static void AttachDebuggerIfRequested(SourceGeneratorOptions options) + { + if (!options.AttachDebugger || Debugger.IsAttached) + { + return; + } + + if (Interlocked.Exchange(ref _debuggerLaunchState, 1) == 0) + { + Debugger.Launch(); + } + } + + internal static SourceGeneratorOptions ParseOptions(AnalyzerConfigOptions globalOptions) + { + var result = new SourceGeneratorOptions(); + + if (globalOptions.TryGetValue("build_property.orleans_attachdebugger", out var attachDebuggerOption) + && string.Equals("true", attachDebuggerOption, StringComparison.OrdinalIgnoreCase)) + { + result.AttachDebugger = true; + } + + if (globalOptions.TryGetValue("build_property.orleans_generatefieldids", out var generateFieldIds) && generateFieldIds is { Length: > 0 } + && Enum.TryParse(generateFieldIds, out GenerateFieldIds fieldIdOption)) + { + result.GenerateFieldIds = fieldIdOption; + } + + if (globalOptions.TryGetValue("build_property.orleansgeneratecompatibilityinvokers", out var generateCompatInvokersValue) + && bool.TryParse(generateCompatInvokersValue, out var genCompatInvokers)) + { + result.GenerateCompatibilityInvokers = genCompatInvokers; + } + + return result; + } + +} + +internal struct SourceGeneratorOptions : IEquatable +{ + public GenerateFieldIds GenerateFieldIds { get; set; } + public bool GenerateCompatibilityInvokers { get; set; } + public bool AttachDebugger { get; set; } + + public readonly bool Equals(SourceGeneratorOptions other) + => GenerateFieldIds == other.GenerateFieldIds + && GenerateCompatibilityInvokers == other.GenerateCompatibilityInvokers + && AttachDebugger == other.AttachDebugger; + + public override readonly bool Equals(object obj) => obj is SourceGeneratorOptions other && Equals(other); + + public override readonly int GetHashCode() + { + unchecked + { + var hash = (int)GenerateFieldIds; + hash = hash * 31 + (GenerateCompatibilityInvokers ? 1 : 0); + hash = hash * 31 + (AttachDebugger ? 1 : 0); + return hash; + } + } +} diff --git a/src/Orleans.CodeGenerator/SourceGeneratorResults.cs b/src/Orleans.CodeGenerator/SourceGeneratorResults.cs new file mode 100644 index 00000000000..2b12f3e0060 --- /dev/null +++ b/src/Orleans.CodeGenerator/SourceGeneratorResults.cs @@ -0,0 +1,247 @@ +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using Orleans.CodeGenerator.Model; + +namespace Orleans.CodeGenerator; + +internal readonly struct GeneratedSourceEntry(string hintName, string source) : IEquatable +{ + public string HintName { get; } = hintName; + public string Source { get; } = source; + public SourceText SourceText => SourceText.From(Source ?? string.Empty, Encoding.UTF8); + + public bool Equals(GeneratedSourceEntry other) + => string.Equals(HintName, other.HintName, StringComparison.Ordinal) + && string.Equals(Source, other.Source, StringComparison.Ordinal); + + public override bool Equals(object obj) => obj is GeneratedSourceEntry other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = StringComparer.Ordinal.GetHashCode(HintName ?? string.Empty); + hash = hash * 31 + StringComparer.Ordinal.GetHashCode(Source ?? string.Empty); + return hash; + } + } +} + +internal readonly struct SourceOutputResult(GeneratedSourceEntry? sourceEntry, Diagnostic? diagnostic) : IEquatable +{ + public GeneratedSourceEntry? SourceEntry { get; } = sourceEntry; + public Diagnostic? Diagnostic { get; } = diagnostic; + + public static SourceOutputResult FromSource(GeneratedSourceEntry sourceEntry) => new(sourceEntry, null); + public static SourceOutputResult FromDiagnostic(Diagnostic diagnostic) => new(null, diagnostic); + + public bool Equals(SourceOutputResult other) + => Nullable.Equals(SourceEntry, other.SourceEntry) + && SourceGeneratorDiagnosticComparer.AreEqual(Diagnostic, other.Diagnostic); + + public override bool Equals(object? obj) => obj is SourceOutputResult other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = SourceEntry.GetHashCode(); + hash = hash * 31 + SourceGeneratorDiagnosticComparer.GetHashCode(Diagnostic); + return hash; + } + } +} + +internal readonly struct ReferenceAssemblyDataResult(ReferenceAssemblyModel model, ImmutableArray diagnostics) : IEquatable +{ + public ReferenceAssemblyModel Model { get; } = model; + public ImmutableArray Diagnostics { get; } = diagnostics.IsDefault ? [] : diagnostics; + + public static ReferenceAssemblyDataResult FromModelAndDiagnostics(ReferenceAssemblyModel model, ImmutableArray diagnostics) + => new(model, diagnostics); + + public bool Equals(ReferenceAssemblyDataResult other) + => EqualityComparer.Default.Equals(Model, other.Model) + && SourceGeneratorDiagnosticComparer.AreSequencesEqual(Diagnostics, other.Diagnostics); + + public override bool Equals(object? obj) => obj is ReferenceAssemblyDataResult other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = Model?.GetHashCode() ?? 0; + hash = hash * 31 + SourceGeneratorDiagnosticComparer.GetSequenceHashCode(Diagnostics); + return hash; + } + } +} + +internal readonly struct SerializableTypeResult( + SerializableTypeModel? model, + Diagnostic? diagnostic, + TypeMetadataIdentity metadataIdentity, + SourceLocationModel sourceLocation, + string typeSyntax) : IEquatable +{ + public SerializableTypeModel? Model { get; } = model; + public Diagnostic? Diagnostic { get; } = diagnostic; + public TypeMetadataIdentity MetadataIdentity { get; } = metadataIdentity; + public SourceLocationModel SourceLocation { get; } = sourceLocation; + public string TypeSyntax { get; } = typeSyntax; + + public static SerializableTypeResult FromModel(SerializableTypeModel model) + => new( + model, + diagnostic: null, + model?.MetadataIdentity ?? TypeMetadataIdentity.Empty, + model?.SourceLocation ?? default, + model?.TypeSyntax.SyntaxString ?? string.Empty); + + public static SerializableTypeResult FromDiagnostic( + Diagnostic diagnostic, + TypeMetadataIdentity metadataIdentity, + SourceLocationModel sourceLocation, + string typeSyntax) + => new(model: null, diagnostic, metadataIdentity, sourceLocation, typeSyntax ?? string.Empty); + + public bool Equals(SerializableTypeResult other) + => Nullable.Equals(Model, other.Model) + && SourceGeneratorDiagnosticComparer.AreEqual(Diagnostic, other.Diagnostic) + && MetadataIdentity.Equals(other.MetadataIdentity) + && SourceLocation.Equals(other.SourceLocation) + && string.Equals(TypeSyntax, other.TypeSyntax, StringComparison.Ordinal); + + public override bool Equals(object? obj) => obj is SerializableTypeResult other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = Model?.GetHashCode() ?? 0; + hash = hash * 31 + SourceGeneratorDiagnosticComparer.GetHashCode(Diagnostic); + hash = hash * 31 + MetadataIdentity.GetHashCode(); + hash = hash * 31 + SourceLocation.GetHashCode(); + hash = hash * 31 + StringComparer.Ordinal.GetHashCode(TypeSyntax ?? string.Empty); + return hash; + } + } +} + +internal readonly struct ProxyOutputPreparationResult( + ImmutableArray proxyOutputModels, + ImmutableArray sourceOutputs, + Diagnostic? diagnostic) : IEquatable +{ + public ImmutableArray ProxyOutputModels { get; } = proxyOutputModels; + public ImmutableArray SourceOutputs { get; } = sourceOutputs; + public Diagnostic? Diagnostic { get; } = diagnostic; + + public static ProxyOutputPreparationResult FromModelsAndSources( + ImmutableArray proxyOutputModels, + ImmutableArray sourceOutputs) + => new(proxyOutputModels, sourceOutputs, diagnostic: null); + + public static ProxyOutputPreparationResult FromDiagnostic(Diagnostic diagnostic) + => new([], [], diagnostic); + + public bool Equals(ProxyOutputPreparationResult other) + => StructuralEquality.SequenceEqual(ProxyOutputModels, other.ProxyOutputModels) + && StructuralEquality.SequenceEqual(SourceOutputs, other.SourceOutputs) + && SourceGeneratorDiagnosticComparer.AreEqual(Diagnostic, other.Diagnostic); + + public override bool Equals(object? obj) => obj is ProxyOutputPreparationResult other && Equals(other); + + public override int GetHashCode() + { + unchecked + { + var hash = StructuralEquality.GetSequenceHashCode(ProxyOutputModels); + hash = hash * 31 + StructuralEquality.GetSequenceHashCode(SourceOutputs); + hash = hash * 31 + SourceGeneratorDiagnosticComparer.GetHashCode(Diagnostic); + return hash; + } + } +} + +internal static class SourceGeneratorDiagnosticComparer +{ + internal static bool AreSequencesEqual(ImmutableArray left, ImmutableArray right) + { + if (left.IsDefaultOrEmpty) + { + return right.IsDefaultOrEmpty; + } + + if (right.IsDefaultOrEmpty || left.Length != right.Length) + { + return false; + } + + for (var i = 0; i < left.Length; i++) + { + if (!AreEqual(left[i], right[i])) + { + return false; + } + } + + return true; + } + + internal static int GetSequenceHashCode(ImmutableArray diagnostics) + { + if (diagnostics.IsDefaultOrEmpty) + { + return 0; + } + + unchecked + { + var hash = 0; + foreach (var diagnostic in diagnostics) + { + hash = hash * 31 + GetHashCode(diagnostic); + } + + return hash; + } + } + + internal static bool AreEqual(Diagnostic? left, Diagnostic? right) + { + if (ReferenceEquals(left, right)) + { + return true; + } + + if (left is null || right is null) + { + return false; + } + + return string.Equals(left.Id, right.Id, StringComparison.Ordinal) + && left.Severity == right.Severity + && left.WarningLevel == right.WarningLevel + && string.Equals(left.ToString(), right.ToString(), StringComparison.Ordinal); + } + + internal static int GetHashCode(Diagnostic? diagnostic) + { + if (diagnostic is null) + { + return 0; + } + + unchecked + { + var hash = StringComparer.Ordinal.GetHashCode(diagnostic.Id ?? string.Empty); + hash = hash * 31 + (int)diagnostic.Severity; + hash = hash * 31 + diagnostic.WarningLevel; + hash = hash * 31 + StringComparer.Ordinal.GetHashCode(diagnostic.ToString() ?? string.Empty); + return hash; + } + } +} diff --git a/src/Orleans.CodeGenerator/SyntaxGeneration/FSharpUtils.cs b/src/Orleans.CodeGenerator/SyntaxGeneration/FSharpUtils.cs index d5b80984f4b..a4d667f1b38 100644 --- a/src/Orleans.CodeGenerator/SyntaxGeneration/FSharpUtils.cs +++ b/src/Orleans.CodeGenerator/SyntaxGeneration/FSharpUtils.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -6,349 +6,324 @@ using static Orleans.CodeGenerator.SerializerGenerator; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -#nullable disable -namespace Orleans.CodeGenerator +namespace Orleans.CodeGenerator; + +internal static class FSharpUtilities { - internal static class FSharpUtilities + private const int SourceConstructFlagsSumTypeValue = 1; + private const int SourceConstructFlagsKindMaskValue = 31; + private const int SourceConstructFlagsRecordTypeValue = 2; + + public static bool IsUnionCase(LibraryTypes libraryTypes, INamedTypeSymbol symbol, [NotNullWhen(true)] out INamedTypeSymbol? sumType) { - private const int SourceConstructFlagsSumTypeValue = 1; - private const int SourceConstructFlagsKindMaskValue = 31; - private const int SourceConstructFlagsRecordTypeValue = 2; + sumType = null; + var compilationAttributeType = libraryTypes.FSharpCompilationMappingAttributeOrDefault; + var sourceConstructFlagsType = libraryTypes.FSharpSourceConstructFlagsOrDefault; + var baseType = symbol.BaseType; + if (compilationAttributeType is null || sourceConstructFlagsType is null || baseType is null) + { + return false; + } - public static bool IsUnionCase(LibraryTypes libraryTypes, INamedTypeSymbol symbol, out INamedTypeSymbol sumType) + INamedTypeSymbol sumTypeCandidate; + if (symbol.GetAttributes(compilationAttributeType, out var compilationAttributes) && compilationAttributes.Length > 0) { - sumType = default; - var compilationAttributeType = libraryTypes.FSharpCompilationMappingAttributeOrDefault; - var sourceConstructFlagsType = libraryTypes.FSharpSourceConstructFlagsOrDefault; - var baseType = symbol.BaseType; - if (compilationAttributeType is null || sourceConstructFlagsType is null || baseType is null) - { - return false; - } + sumTypeCandidate = symbol; + } + else if (baseType.GetAttributes(compilationAttributeType, out compilationAttributes) && compilationAttributes.Length > 0) + { + sumTypeCandidate = baseType; + } + else + { + return false; + } - INamedTypeSymbol sumTypeCandidate; - if (symbol.GetAttributes(compilationAttributeType, out var compilationAttributes) && compilationAttributes.Length > 0) - { - sumTypeCandidate = symbol; - } - else if (baseType.GetAttributes(compilationAttributeType, out compilationAttributes) && compilationAttributes.Length > 0) - { - sumTypeCandidate = baseType; - } - else + var compilationAttribute = compilationAttributes[0]; + var foundArg = false; + TypedConstant sourceConstructFlagsArgument = default; + foreach (var arg in compilationAttribute.ConstructorArguments) + { + if (SymbolEqualityComparer.Default.Equals(arg.Type, sourceConstructFlagsType)) { - return false; + sourceConstructFlagsArgument = arg; + foundArg = true; + break; } + } - var compilationAttribute = compilationAttributes[0]; - var foundArg = false; - TypedConstant sourceConstructFlagsArgument = default; - foreach (var arg in compilationAttribute.ConstructorArguments) - { - if (SymbolEqualityComparer.Default.Equals(arg.Type, sourceConstructFlagsType)) - { - sourceConstructFlagsArgument = arg; - foundArg = true; - break; - } - } + if (!foundArg) + { + return false; + } - if (!foundArg) - { - return false; - } + if (sourceConstructFlagsArgument.Value != null && ((int)sourceConstructFlagsArgument.Value & SourceConstructFlagsKindMaskValue) != SourceConstructFlagsSumTypeValue) + { + return false; + } - if (sourceConstructFlagsArgument.Value != null && ((int)sourceConstructFlagsArgument.Value & SourceConstructFlagsKindMaskValue) != SourceConstructFlagsSumTypeValue) - { - return false; - } + sumType = sumTypeCandidate; + return true; + } - sumType = sumTypeCandidate; - return true; + public static bool IsRecord(LibraryTypes libraryTypes, INamedTypeSymbol symbol) + { + var compilationAttributeType = libraryTypes.FSharpCompilationMappingAttributeOrDefault; + var sourceConstructFlagsType = libraryTypes.FSharpSourceConstructFlagsOrDefault; + if (compilationAttributeType is null || sourceConstructFlagsType is null) + { + return false; } - public static bool IsRecord(LibraryTypes libraryTypes, INamedTypeSymbol symbol) + if (!symbol.GetAttributes(compilationAttributeType, out var compilationAttributes) || compilationAttributes.Length == 0) { - var compilationAttributeType = libraryTypes.FSharpCompilationMappingAttributeOrDefault; - var sourceConstructFlagsType = libraryTypes.FSharpSourceConstructFlagsOrDefault; - if (compilationAttributeType is null || sourceConstructFlagsType is null) - { - return false; - } + return false; + } - if (!symbol.GetAttributes(compilationAttributeType, out var compilationAttributes) || compilationAttributes.Length == 0) + var compilationAttribute = compilationAttributes[0]; + var foundArg = false; + TypedConstant sourceConstructFlagsArgument = default; + foreach (var arg in compilationAttribute.ConstructorArguments) + { + if (SymbolEqualityComparer.Default.Equals(arg.Type, sourceConstructFlagsType)) { - return false; + sourceConstructFlagsArgument = arg; + foundArg = true; + break; } + } - var compilationAttribute = compilationAttributes[0]; - var foundArg = false; - TypedConstant sourceConstructFlagsArgument = default; - foreach (var arg in compilationAttribute.ConstructorArguments) - { - if (SymbolEqualityComparer.Default.Equals(arg.Type, sourceConstructFlagsType)) - { - sourceConstructFlagsArgument = arg; - foundArg = true; - break; - } - } + if (!foundArg) + { + return false; + } - if (!foundArg) + if ((int)sourceConstructFlagsArgument.Value! != SourceConstructFlagsRecordTypeValue) + { + return false; + } + + return true; + } + + public class FSharpUnionCaseTypeDescription(Compilation compilation, INamedTypeSymbol type, LibraryTypes libraryTypes) : SerializableTypeDescription(compilation, type, false, GetUnionCaseDataMembers(libraryTypes, type), libraryTypes) + { + private static IEnumerable GetUnionCaseDataMembers(LibraryTypes libraryTypes, INamedTypeSymbol symbol) + { + List dataMembers = new(); + foreach (var field in symbol.GetDeclaredInstanceMembers()) { - return false; + dataMembers.Add(field); } - if ((int)sourceConstructFlagsArgument.Value != SourceConstructFlagsRecordTypeValue) + dataMembers.Sort(FSharpUnionCasePropertyNameComparer.Default); + + uint id = 0; + foreach (var field in dataMembers) { - return false; + yield return new FSharpUnionCaseFieldDescription(libraryTypes, field, id); + id++; } - - return true; } - public class FSharpUnionCaseTypeDescription : SerializableTypeDescription + private class FSharpUnionCasePropertyNameComparer : IComparer { - public FSharpUnionCaseTypeDescription(Compilation compilation, INamedTypeSymbol type, LibraryTypes libraryTypes) : base(compilation, type, false, GetUnionCaseDataMembers(libraryTypes, type), libraryTypes) - { - } + public static FSharpUnionCasePropertyNameComparer Default { get; } = new FSharpUnionCasePropertyNameComparer(); - private static IEnumerable GetUnionCaseDataMembers(LibraryTypes libraryTypes, INamedTypeSymbol symbol) + public int Compare(IFieldSymbol x, IFieldSymbol y) { - List dataMembers = new(); - foreach (var field in symbol.GetDeclaredInstanceMembers()) + var xName = x.Name; + var yName = y.Name; + if (xName.Length > yName.Length) { - dataMembers.Add(field); + return 1; } - dataMembers.Sort(FSharpUnionCasePropertyNameComparer.Default); - - uint id = 0; - foreach (var field in dataMembers) + if (xName.Length < yName.Length) { - yield return new FSharpUnionCaseFieldDescription(libraryTypes, field, id); - id++; + return -1; } - } - - private class FSharpUnionCasePropertyNameComparer : IComparer - { - public static FSharpUnionCasePropertyNameComparer Default { get; } = new FSharpUnionCasePropertyNameComparer(); - public int Compare(IFieldSymbol x, IFieldSymbol y) - { - var xName = x.Name; - var yName = y.Name; - if (xName.Length > yName.Length) - { - return 1; - } - - if (xName.Length < yName.Length) - { - return -1; - } - - return string.CompareOrdinal(xName, yName); - } + return string.CompareOrdinal(xName, yName); } + } - private class FSharpUnionCaseFieldDescription : IMemberDescription, ISerializableMember - { - private readonly LibraryTypes _libraryTypes; - private readonly IFieldSymbol _field; - - public FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IFieldSymbol field, uint ordinal) - { - _libraryTypes = libraryTypes; - FieldId = ordinal; - _field = field; - } + private class FSharpUnionCaseFieldDescription(LibraryTypes libraryTypes, IFieldSymbol field, uint ordinal) : IMemberDescription, ISerializableMember + { + private readonly LibraryTypes _libraryTypes = libraryTypes; + private readonly IFieldSymbol _field = field; - public uint FieldId { get; } + public uint FieldId { get; } = ordinal; - public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _field.HasAttribute(_libraryTypes.ImmutableAttribute); + public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _field.HasAttribute(_libraryTypes.ImmutableAttribute); - public bool IsValueType => Type.IsValueType; + public bool IsValueType => Type.IsValueType; - public IMemberDescription Member => this; + public IMemberDescription Member => this; - public ITypeSymbol Type => _field.Type; + public ITypeSymbol Type => _field.Type; - public INamedTypeSymbol ContainingType => _field.ContainingType; + public INamedTypeSymbol ContainingType => _field.ContainingType; - public ISymbol Symbol => _field; + public ISymbol Symbol => _field; - /// - /// Gets the name of the setter field. - /// - private string SetterFieldName => "setField" + FieldId; + /// + /// Gets the name of the setter field. + /// + private string SetterFieldName => "setField" + FieldId; - /// - /// Gets syntax representing the type of this field. - /// - public TypeSyntax TypeSyntax => Type.TypeKind == TypeKind.Dynamic - ? PredefinedType(Token(SyntaxKind.ObjectKeyword)) - : GetTypeSyntax(Type); + /// + /// Gets syntax representing the type of this field. + /// + public TypeSyntax TypeSyntax => Type.TypeKind == TypeKind.Dynamic + ? PredefinedType(Token(SyntaxKind.ObjectKeyword)) + : GetTypeSyntax(Type); - public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); - public string TypeName => Type.ToDisplayName(); - public string TypeNameIdentifier => Type.GetValidIdentifier(); + public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); + public string TypeName => Type.ToDisplayName(); + public string TypeNameIdentifier => Type.GetValidIdentifier(); - public bool IsPrimaryConstructorParameter => false; + public bool IsPrimaryConstructorParameter => false; - public bool IsSerializable => true; - public bool IsCopyable => true; + public bool IsSerializable => true; + public bool IsCopyable => true; - public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); - /// - /// Returns syntax for retrieving the value of this field, deep copying it if necessary. - /// - /// The instance of the containing type. - /// Syntax for retrieving the value of this field. - public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_field.Name); + /// + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. + /// + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_field.Name); - /// - /// Returns syntax for setting the value of this field. - /// - /// The instance of the containing type. - /// Syntax for the new value. - /// Syntax for setting the value of this field. - public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) + { + var instanceArg = Argument(instance); + if (ContainingType != null && ContainingType.IsValueType) { - var instanceArg = Argument(instance); - if (ContainingType != null && ContainingType.IsValueType) - { - instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } - - return - InvocationExpression(IdentifierName(SetterFieldName)) - .AddArgumentListArguments(instanceArg, Argument(value)); + instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } - public FieldAccessorDescription GetGetterFieldDescription() => null; - - public FieldAccessorDescription GetSetterFieldDescription() - => SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, _field.Name, SetterFieldName, _libraryTypes, true); + return + InvocationExpression(IdentifierName(SetterFieldName)) + .AddArgumentListArguments(instanceArg, Argument(value)); } + + public FieldAccessorDescription? GetGetterFieldDescription() => null; + + public FieldAccessorDescription? GetSetterFieldDescription() + => SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, _field.Name, SetterFieldName, _libraryTypes, true); } + } - public class FSharpRecordTypeDescription : SerializableTypeDescription + public class FSharpRecordTypeDescription(Compilation compilation, INamedTypeSymbol type, LibraryTypes libraryTypes) : SerializableTypeDescription(compilation, type, false, GetRecordDataMembers(libraryTypes, type), libraryTypes) + { + private static IEnumerable GetRecordDataMembers(LibraryTypes libraryTypes, INamedTypeSymbol symbol) { - public FSharpRecordTypeDescription(Compilation compilation, INamedTypeSymbol type, LibraryTypes libraryTypes) : base(compilation, type, false, GetRecordDataMembers(libraryTypes, type), libraryTypes) + List<(IPropertySymbol, uint)> dataMembers = new(); + foreach (var property in symbol.GetDeclaredInstanceMembers()) { - } - - private static IEnumerable GetRecordDataMembers(LibraryTypes libraryTypes, INamedTypeSymbol symbol) - { - List<(IPropertySymbol, uint)> dataMembers = new(); - foreach (var property in symbol.GetDeclaredInstanceMembers()) + var id = GeneratedCodeUtilities.GetId(libraryTypes, property); + if (!id.HasValue) { - var id = CodeGenerator.GetId(libraryTypes, property); - if (!id.HasValue) - { - continue; - } - - dataMembers.Add((property, id.Value)); + continue; } - foreach (var (property, id) in dataMembers) - { - yield return new FSharpRecordPropertyDescription(libraryTypes, property, id); - } + dataMembers.Add((property, id.Value)); } - private class FSharpRecordPropertyDescription : IMemberDescription, ISerializableMember + foreach (var (property, id) in dataMembers) { - private readonly LibraryTypes _libraryTypes; - private readonly IPropertySymbol _property; + yield return new FSharpRecordPropertyDescription(libraryTypes, property, id); + } + } - public FSharpRecordPropertyDescription(LibraryTypes libraryTypes, IPropertySymbol property, uint ordinal) - { - _libraryTypes = libraryTypes; - FieldId = ordinal; - _property = property; - } + private class FSharpRecordPropertyDescription(LibraryTypes libraryTypes, IPropertySymbol property, uint ordinal) : IMemberDescription, ISerializableMember + { + private readonly LibraryTypes _libraryTypes = libraryTypes; - public uint FieldId { get; } + public uint FieldId { get; } = ordinal; - public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || _property.HasAttribute(_libraryTypes.ImmutableAttribute); + public bool IsShallowCopyable => _libraryTypes.IsShallowCopyable(Type) || Property.HasAttribute(_libraryTypes.ImmutableAttribute); - public bool IsValueType => Type.IsValueType; + public bool IsValueType => Type.IsValueType; - public IMemberDescription Member => this; + public IMemberDescription Member => this; - public ITypeSymbol Type => _property.Type; + public ITypeSymbol Type => Property.Type; - public ISymbol Symbol => _property; + public ISymbol Symbol => Property; - public INamedTypeSymbol ContainingType => _property.ContainingType; + public INamedTypeSymbol ContainingType => Property.ContainingType; - public string FieldName => _property.Name + "@"; + public string FieldName => Property.Name + "@"; - /// - /// Gets the name of the setter field. - /// - private string SetterFieldName => "setField" + FieldId; + /// + /// Gets the name of the setter field. + /// + private string SetterFieldName => "setField" + FieldId; - /// - /// Gets syntax representing the type of this field. - /// - public TypeSyntax TypeSyntax => Type.TypeKind == TypeKind.Dynamic - ? PredefinedType(Token(SyntaxKind.ObjectKeyword)) - : GetTypeSyntax(Type); + /// + /// Gets syntax representing the type of this field. + /// + public TypeSyntax TypeSyntax => Type.TypeKind == TypeKind.Dynamic + ? PredefinedType(Token(SyntaxKind.ObjectKeyword)) + : GetTypeSyntax(Type); - /// - /// Gets the which this field is the backing property for, or - /// if this is not the backing field of an auto-property. - /// - private IPropertySymbol Property => _property; + /// + /// Gets the which this field is the backing property for, or + /// if this is not the backing field of an auto-property. + /// + private IPropertySymbol Property { get; } = property; - public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); - public string TypeName => Type.ToDisplayName(); - public string TypeNameIdentifier => Type.GetValidIdentifier(); + public string AssemblyName => Type.ContainingAssembly.ToDisplayName(); + public string TypeName => Type.ToDisplayName(); + public string TypeNameIdentifier => Type.GetValidIdentifier(); - public bool IsPrimaryConstructorParameter => false; + public bool IsPrimaryConstructorParameter => false; - public bool IsSerializable => true; - public bool IsCopyable => true; + public bool IsSerializable => true; + public bool IsCopyable => true; - public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); + public TypeSyntax GetTypeSyntax(ITypeSymbol typeSymbol) => typeSymbol.ToTypeSyntax(); - /// - /// Returns syntax for retrieving the value of this field, deep copying it if necessary. - /// - /// The instance of the containing type. - /// Syntax for retrieving the value of this field. - public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(Property.Name); + /// + /// Returns syntax for retrieving the value of this field, deep copying it if necessary. + /// + /// The instance of the containing type. + /// Syntax for retrieving the value of this field. + public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(Property.Name); - /// - /// Returns syntax for setting the value of this field. - /// - /// The instance of the containing type. - /// Syntax for the new value. - /// Syntax for setting the value of this field. - public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) + /// + /// Returns syntax for setting the value of this field. + /// + /// The instance of the containing type. + /// Syntax for the new value. + /// Syntax for setting the value of this field. + public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) + { + var instanceArg = Argument(instance); + if (ContainingType != null && ContainingType.IsValueType) { - var instanceArg = Argument(instance); - if (ContainingType != null && ContainingType.IsValueType) - { - instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); - } - - return - InvocationExpression(IdentifierName(SetterFieldName)) - .AddArgumentListArguments(instanceArg, Argument(value)); + instanceArg = instanceArg.WithRefOrOutKeyword(Token(SyntaxKind.RefKeyword)); } - public FieldAccessorDescription GetGetterFieldDescription() => null; - - public FieldAccessorDescription GetSetterFieldDescription() - => SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, FieldName, SetterFieldName, _libraryTypes, true); + return + InvocationExpression(IdentifierName(SetterFieldName)) + .AddArgumentListArguments(instanceArg, Argument(value)); } + + public FieldAccessorDescription? GetGetterFieldDescription() => null; + + public FieldAccessorDescription? GetSetterFieldDescription() + => SerializableMember.GetFieldAccessor(ContainingType, TypeSyntax, FieldName, SetterFieldName, _libraryTypes, true); } } } diff --git a/src/Orleans.CodeGenerator/SyntaxGeneration/Identifier.cs b/src/Orleans.CodeGenerator/SyntaxGeneration/Identifier.cs index fdb35b86138..a09b3447ecd 100644 --- a/src/Orleans.CodeGenerator/SyntaxGeneration/Identifier.cs +++ b/src/Orleans.CodeGenerator/SyntaxGeneration/Identifier.cs @@ -1,132 +1,131 @@ using System.Text.RegularExpressions; -namespace Orleans.CodeGenerator.SyntaxGeneration +namespace Orleans.CodeGenerator.SyntaxGeneration; + +internal static class Identifier { - internal static class Identifier + internal static bool IsCSharpKeyword(string identifier) { - internal static bool IsCSharpKeyword(string identifier) + switch (identifier) { - switch (identifier) - { - case "abstract": - case "add": - case "alias": - case "as": - case "ascending": - case "async": - case "await": - case "base": - case "bool": - case "break": - case "byte": - case "case": - case "catch": - case "char": - case "checked": - case "class": - case "const": - case "continue": - case "decimal": - case "default": - case "delegate": - case "descending": - case "do": - case "double": - case "dynamic": - case "else": - case "enum": - case "event": - case "explicit": - case "extern": - case "false": - case "finally": - case "fixed": - case "float": - case "for": - case "foreach": - case "from": - case "get": - case "global": - case "goto": - case "group": - case "if": - case "implicit": - case "in": - case "int": - case "interface": - case "internal": - case "into": - case "is": - case "join": - case "let": - case "lock": - case "long": - case "nameof": - case "namespace": - case "new": - case "null": - case "object": - case "operator": - case "orderby": - case "out": - case "override": - case "params": - case "partial": - case "private": - case "protected": - case "public": - case "readonly": - case "ref": - case "remove": - case "return": - case "sbyte": - case "sealed": - case "select": - case "set": - case "short": - case "sizeof": - case "stackalloc": - case "static": - case "string": - case "struct": - case "switch": - case "this": - case "throw": - case "true": - case "try": - case "typeof": - case "uint": - case "ulong": - case "unchecked": - case "unsafe": - case "ushort": - case "using": - case "value": - case "var": - case "virtual": - case "void": - case "volatile": - case "when": - case "where": - case "while": - case "yield": - return true; - default: - return false; - } + case "abstract": + case "add": + case "alias": + case "as": + case "ascending": + case "async": + case "await": + case "base": + case "bool": + case "break": + case "byte": + case "case": + case "catch": + case "char": + case "checked": + case "class": + case "const": + case "continue": + case "decimal": + case "default": + case "delegate": + case "descending": + case "do": + case "double": + case "dynamic": + case "else": + case "enum": + case "event": + case "explicit": + case "extern": + case "false": + case "finally": + case "fixed": + case "float": + case "for": + case "foreach": + case "from": + case "get": + case "global": + case "goto": + case "group": + case "if": + case "implicit": + case "in": + case "int": + case "interface": + case "internal": + case "into": + case "is": + case "join": + case "let": + case "lock": + case "long": + case "nameof": + case "namespace": + case "new": + case "null": + case "object": + case "operator": + case "orderby": + case "out": + case "override": + case "params": + case "partial": + case "private": + case "protected": + case "public": + case "readonly": + case "ref": + case "remove": + case "return": + case "sbyte": + case "sealed": + case "select": + case "set": + case "short": + case "sizeof": + case "stackalloc": + case "static": + case "string": + case "struct": + case "switch": + case "this": + case "throw": + case "true": + case "try": + case "typeof": + case "uint": + case "ulong": + case "unchecked": + case "unsafe": + case "ushort": + case "using": + case "value": + case "var": + case "virtual": + case "void": + case "volatile": + case "when": + case "where": + case "while": + case "yield": + return true; + default: + return false; } + } - private static readonly Regex SanitizeIdentifierRegex = new("^([0-9]+)|([^0-9a-zA-Z_]+)", RegexOptions.Compiled); + private static readonly Regex SanitizeIdentifierRegex = new("^([0-9]+)|([^0-9a-zA-Z_]+)", RegexOptions.Compiled); - public static string SanitizeIdentifierName(string input) => SanitizeIdentifierRegex.Replace( - input, - static match => match.Value switch - { - // Prefix leading digits with an '_' to make them a valid identifier. - { Length: > 0 } value when char.IsDigit(value[0]) => $"_{value}", + public static string SanitizeIdentifierName(string input) => SanitizeIdentifierRegex.Replace( + input, + static match => match.Value switch + { + // Prefix leading digits with an '_' to make them a valid identifier. + { Length: > 0 } value when char.IsDigit(value[0]) => $"_{value}", - // Eliminate all other matches by replacing them with an empty string. - _ => "" - }); - } + // Eliminate all other matches by replacing them with an empty string. + _ => "" + }); } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/SyntaxGeneration/StringExtensions.cs b/src/Orleans.CodeGenerator/SyntaxGeneration/StringExtensions.cs index 0d949739b38..9356c002e06 100644 --- a/src/Orleans.CodeGenerator/SyntaxGeneration/StringExtensions.cs +++ b/src/Orleans.CodeGenerator/SyntaxGeneration/StringExtensions.cs @@ -2,57 +2,56 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -namespace Orleans.CodeGenerator.SyntaxGeneration +namespace Orleans.CodeGenerator.SyntaxGeneration; + +/// +/// Extensions to the class to support code generation. +/// +internal static class StringExtensions { /// - /// Extensions to the class to support code generation. + /// Returns the provided string as a literal expression. /// - internal static class StringExtensions + /// + /// The string. + /// + /// + /// The literal expression. + /// + public static LiteralExpressionSyntax GetLiteralExpression(this string str) { - /// - /// Returns the provided string as a literal expression. - /// - /// - /// The string. - /// - /// - /// The literal expression. - /// - public static LiteralExpressionSyntax GetLiteralExpression(this string str) - { - var syntaxToken = SyntaxFactory.Literal( - SyntaxFactory.TriviaList(), - @"""" + str + @"""", - str, - SyntaxFactory.TriviaList()); - return SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, syntaxToken); - } + var syntaxToken = SyntaxFactory.Literal( + SyntaxFactory.TriviaList(), + @"""" + str + @"""", + str, + SyntaxFactory.TriviaList()); + return SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, syntaxToken); + } - public static SyntaxToken ToIdentifier(this string identifier) + public static SyntaxToken ToIdentifier(this string identifier) + { + identifier = identifier.TrimStart('@'); + if (Identifier.IsCSharpKeyword(identifier)) { - identifier = identifier.TrimStart('@'); - if (Identifier.IsCSharpKeyword(identifier)) - { - return SyntaxFactory.VerbatimIdentifier( - SyntaxTriviaList.Empty, - identifier, - identifier, - SyntaxTriviaList.Empty); - } - - return SyntaxFactory.Identifier(SyntaxTriviaList.Empty, identifier, SyntaxTriviaList.Empty); + return SyntaxFactory.VerbatimIdentifier( + SyntaxTriviaList.Empty, + identifier, + identifier, + SyntaxTriviaList.Empty); } - public static string EscapeIdentifier(this string str) - { - if (Identifier.IsCSharpKeyword(str)) - { - return "@" + str; - } + return SyntaxFactory.Identifier(SyntaxTriviaList.Empty, identifier, SyntaxTriviaList.Empty); + } - return str; + public static string EscapeIdentifier(this string str) + { + if (Identifier.IsCSharpKeyword(str)) + { + return "@" + str; } - public static IdentifierNameSyntax ToIdentifierName(this string identifier) => SyntaxFactory.IdentifierName(identifier.ToIdentifier()); + return str; } + + public static IdentifierNameSyntax ToIdentifierName(this string identifier) => SyntaxFactory.IdentifierName(identifier.ToIdentifier()); } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolExtensions.cs b/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolExtensions.cs index 5df79c2a084..1a63030baef 100644 --- a/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolExtensions.cs +++ b/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolExtensions.cs @@ -1,353 +1,397 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System; using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; +using System.Diagnostics.CodeAnalysis; using System.Text; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator.SyntaxGeneration +namespace Orleans.CodeGenerator.SyntaxGeneration; + +internal static class SymbolExtensions { - internal static class SymbolExtensions - { - private static readonly ConcurrentDictionary TypeCache = new(SymbolEqualityComparer.Default); - private static readonly ConcurrentDictionary NameCache = new(SymbolEqualityComparer.Default); + private static readonly ConcurrentDictionary TypeCache = new(SymbolEqualityComparer.Default); + private static readonly ConcurrentDictionary NameCache = new(SymbolEqualityComparer.Default); - public struct DisplayNameOptions + public struct DisplayNameOptions + { + public DisplayNameOptions() { - public DisplayNameOptions() - { - Substitutions = null; - } - - public Dictionary? Substitutions { get; set; } - public bool IncludeGlobalSpecifier { get; set; } = true; - public bool IncludeNamespace { get; set; } = true; + Substitutions = null; } - public static bool HasAttribute(this INamedTypeSymbol symbol, INamedTypeSymbol attributeType, bool inherited) => GetAttribute(symbol, attributeType, inherited) is not null; + public Dictionary? Substitutions { get; set; } + public bool IncludeGlobalSpecifier { get; set; } = true; + public bool IncludeNamespace { get; set; } = true; + } + + public static bool HasAttribute(this INamedTypeSymbol symbol, INamedTypeSymbol attributeType, bool inherited) => GetAttribute(symbol, attributeType, inherited) is not null; - public static AttributeData? GetAttribute(this INamedTypeSymbol symbol, INamedTypeSymbol attributeType, bool inherited) + public static AttributeData? GetAttribute(this INamedTypeSymbol symbol, INamedTypeSymbol attributeType, bool inherited) + { + var s = symbol; + if (s.GetAttribute(attributeType) is { } attribute) { - var s = symbol; - if (s.GetAttribute(attributeType) is { } attribute) - { - return attribute; - } + return attribute; + } - if (inherited) + if (inherited) + { + foreach (var iface in symbol.AllInterfaces) { - foreach (var iface in symbol.AllInterfaces) + if (iface.GetAttribute(attributeType) is { } iattr) { - if (iface.GetAttribute(attributeType) is { } iattr) - { - return iattr; - } + return iattr; } + } - while ((s = s.BaseType) != null) + while ((s = s.BaseType) != null) + { + if (s.GetAttribute(attributeType) is { } attr) { - if (s.GetAttribute(attributeType) is { } attr) - { - return attr; - } + return attr; } } + } + + return null; + } - return null; + public static TypeSyntax ToTypeSyntax(this ITypeSymbol typeSymbol) + { + if (typeSymbol.SpecialType == SpecialType.System_Void) + { + return PredefinedType(Token(SyntaxKind.VoidKeyword)); } - public static TypeSyntax ToTypeSyntax(this ITypeSymbol typeSymbol) + if (!TypeCache.TryGetValue(typeSymbol, out var result)) { - if (typeSymbol.SpecialType == SpecialType.System_Void) - { - return PredefinedType(Token(SyntaxKind.VoidKeyword)); - } + result = TypeCache[typeSymbol] = ParseTypeName(typeSymbol.ToDisplayName()); + } - if (!TypeCache.TryGetValue(typeSymbol, out var result)) - { - result = TypeCache[typeSymbol] = ParseTypeName(typeSymbol.ToDisplayName()); - } + return result; + } - return result; + public static TypeSyntax ToTypeSyntax(this ITypeSymbol typeSymbol, Dictionary? substitutions) + { + if (substitutions is null or { Count: 0 }) + { + return typeSymbol.ToTypeSyntax(); } - public static TypeSyntax ToTypeSyntax(this ITypeSymbol typeSymbol, Dictionary substitutions) + if (typeSymbol.SpecialType == SpecialType.System_Void) { - if (substitutions is null or { Count: 0 }) - { - return typeSymbol.ToTypeSyntax(); - } + return PredefinedType(Token(SyntaxKind.VoidKeyword)); + } - if (typeSymbol.SpecialType == SpecialType.System_Void) - { - return PredefinedType(Token(SyntaxKind.VoidKeyword)); - } + var res = new StringBuilder(); + var options = new DisplayNameOptions + { + Substitutions = substitutions, + }; + ToTypeSyntaxInner(typeSymbol, res, options); + var result = ParseTypeName(res.ToString()); + return result; + } - var res = new StringBuilder(); - var options = new DisplayNameOptions - { - Substitutions = substitutions, - }; - ToTypeSyntaxInner(typeSymbol, res, options); - var result = ParseTypeName(res.ToString()); - return result; - } + public static string ToDisplayName(this ITypeSymbol typeSymbol, Dictionary? substitutions, bool includeGlobalSpecifier = true, bool includeNamespace = true) + { + return ToDisplayName(typeSymbol, new DisplayNameOptions { Substitutions = substitutions, IncludeGlobalSpecifier = includeGlobalSpecifier, IncludeNamespace = includeNamespace }); + } - public static string ToDisplayName(this ITypeSymbol typeSymbol, Dictionary? substitutions, bool includeGlobalSpecifier = true, bool includeNamespace = true) + public static string ToDisplayName(this ITypeSymbol typeSymbol, DisplayNameOptions options) + { + if (typeSymbol.SpecialType == SpecialType.System_Void) { - return ToDisplayName(typeSymbol, new DisplayNameOptions { Substitutions = substitutions, IncludeGlobalSpecifier = includeGlobalSpecifier, IncludeNamespace = includeNamespace }); + return "void"; } - public static string ToDisplayName(this ITypeSymbol typeSymbol, DisplayNameOptions options) - { - if (typeSymbol.SpecialType == SpecialType.System_Void) - { - return "void"; - } + var result = new StringBuilder(); + ToTypeSyntaxInner(typeSymbol, result, options); + return result.ToString(); + } - var result = new StringBuilder(); - ToTypeSyntaxInner(typeSymbol, result, options); - return result.ToString(); + public static string ToDisplayName(this ITypeSymbol typeSymbol) + { + if (typeSymbol.SpecialType == SpecialType.System_Void) + { + return "void"; } - public static string ToDisplayName(this ITypeSymbol typeSymbol) + if (!NameCache.TryGetValue(typeSymbol, out var result)) { - if (typeSymbol.SpecialType == SpecialType.System_Void) - { - return "void"; - } + result = NameCache[typeSymbol] = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + } - if (!NameCache.TryGetValue(typeSymbol, out var result)) - { - result = NameCache[typeSymbol] = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - } + return result; + } - return result; + public static string ToDisplayName(this IAssemblySymbol? assemblySymbol) + { + if (assemblySymbol is null) + { + return string.Empty; } - public static string ToDisplayName(this IAssemblySymbol assemblySymbol) + if (!NameCache.TryGetValue(assemblySymbol, out var result)) { - if (assemblySymbol is null) - { - return string.Empty; - } + result = NameCache[assemblySymbol] = assemblySymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + } - if (!NameCache.TryGetValue(assemblySymbol, out var result)) - { - result = NameCache[assemblySymbol] = assemblySymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); - } + return result; + } - return result; + private static void ToTypeSyntaxInner(ITypeSymbol typeSymbol, StringBuilder res, DisplayNameOptions options) + { + switch (typeSymbol) + { + case IDynamicTypeSymbol: + res.Append("dynamic"); + break; + case IArrayTypeSymbol a: + ToTypeSyntaxInner(a.ElementType, res, options); + res.Append('['); + if (a.Rank > 1) + { + res.Append(new string(',', a.Rank - 1)); + } + + res.Append(']'); + break; + case ITypeParameterSymbol tp: + if (options.Substitutions is { } substitutions && substitutions.TryGetValue(tp, out var sub)) + { + res.Append(sub); + } + else + { + res.Append(tp.Name.EscapeIdentifier()); + } + break; + case INamedTypeSymbol n: + OnNamedTypeSymbol(n, res, options); + break; + default: + throw new NotSupportedException($"Symbols of type {typeSymbol?.GetType().ToString() ?? "null"} are not supported"); } - private static void ToTypeSyntaxInner(ITypeSymbol typeSymbol, StringBuilder res, DisplayNameOptions options) + static void OnNamedTypeSymbol(INamedTypeSymbol symbol, StringBuilder res, DisplayNameOptions options) { - switch (typeSymbol) + switch (symbol.ContainingSymbol) { - case IDynamicTypeSymbol: - res.Append("dynamic"); - break; - case IArrayTypeSymbol a: - ToTypeSyntaxInner(a.ElementType, res, options); - res.Append('['); - if (a.Rank > 1) - { - res.Append(new string(',', a.Rank - 1)); - } - - res.Append(']'); + case INamespaceSymbol ns when options.IncludeNamespace: + AddFullNamespace(ns, res, options.IncludeGlobalSpecifier); break; - case ITypeParameterSymbol tp: - if (options.Substitutions is { } substitutions && substitutions.TryGetValue(tp, out var sub)) - { - res.Append(sub); - } - else - { - res.Append(tp.Name.EscapeIdentifier()); - } - break; - case INamedTypeSymbol n: - OnNamedTypeSymbol(n, res, options); + case INamedTypeSymbol containingType: + OnNamedTypeSymbol(containingType, res, options); + res.Append('.'); break; - default: - throw new NotSupportedException($"Symbols of type {typeSymbol?.GetType().ToString() ?? "null"} are not supported"); } - static void OnNamedTypeSymbol(INamedTypeSymbol symbol, StringBuilder res, DisplayNameOptions options) + res.Append(symbol.Name.EscapeIdentifier()); + if (symbol.TypeArguments.Length > 0) { - switch (symbol.ContainingSymbol) + res.Append('<'); + bool first = true; + foreach (var typeParameter in symbol.TypeArguments) { - case INamespaceSymbol ns when options.IncludeNamespace: - AddFullNamespace(ns, res, options.IncludeGlobalSpecifier); - break; - case INamedTypeSymbol containingType: - OnNamedTypeSymbol(containingType, res, options); - res.Append('.'); - break; - } - - res.Append(symbol.Name.EscapeIdentifier()); - if (symbol.TypeArguments.Length > 0) - { - res.Append('<'); - bool first = true; - foreach (var typeParameter in symbol.TypeArguments) + if (!first) { - if (!first) - { - res.Append(','); - } - - ToTypeSyntaxInner(typeParameter, res, options); - first = false; + res.Append(','); } - res.Append('>'); + + ToTypeSyntaxInner(typeParameter, res, options); + first = false; } + res.Append('>'); } + } - static void AddFullNamespace(INamespaceSymbol symbol, StringBuilder res, bool includeGlobalSpecifier) + static void AddFullNamespace(INamespaceSymbol symbol, StringBuilder res, bool includeGlobalSpecifier) + { + if (symbol.ContainingNamespace is { } parent) { - if (symbol.ContainingNamespace is { } parent) - { - AddFullNamespace(parent, res, includeGlobalSpecifier); - } + AddFullNamespace(parent, res, includeGlobalSpecifier); + } - if (symbol.IsGlobalNamespace) - { - if (includeGlobalSpecifier) - { - res.Append("global::"); - } - } - else + if (symbol.IsGlobalNamespace) + { + if (includeGlobalSpecifier) { - res.Append(symbol.Name.EscapeIdentifier()); - res.Append('.'); + res.Append("global::"); } } + else + { + res.Append(symbol.Name.EscapeIdentifier()); + res.Append('.'); + } + } + } + + public static TypeSyntax ToTypeSyntax(this ITypeSymbol typeSymbol, params TypeSyntax[] genericParameters) + { + var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var nameSyntax = ParseName(displayString); + + switch (nameSyntax) + { + case AliasQualifiedNameSyntax aliased: + return aliased.WithName(WithGenericParameters(aliased.Name)); + case QualifiedNameSyntax qualified: + return qualified.WithRight(WithGenericParameters(qualified.Right)); + case GenericNameSyntax g: + return WithGenericParameters(g); + default: + throw new InvalidOperationException( + $"Attempted to add generic parameters to non-generic type {displayString} ({nameSyntax.GetType()}, adding parameters {string.Join(", ", genericParameters.Select(n => n.ToFullString()))}"); } - public static TypeSyntax ToTypeSyntax(this ITypeSymbol typeSymbol, params TypeSyntax[] genericParameters) + SimpleNameSyntax WithGenericParameters(SimpleNameSyntax simpleNameSyntax) { - var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var nameSyntax = ParseName(displayString); + if (simpleNameSyntax is GenericNameSyntax generic) + { + return generic.WithTypeArgumentList(TypeArgumentList(SeparatedList(genericParameters))); + } + throw new InvalidOperationException( + $"Attempted to add generic parameters to non-generic type {displayString} ({nameSyntax.GetType()}, adding parameters {string.Join(", ", genericParameters.Select(n => n.ToFullString()))}"); + } + } + + public static TypeSyntax ToOpenTypeSyntax(this ITypeSymbol typeSymbol) + { + var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var nameSyntax = ParseName(displayString); + return Visit(nameSyntax); + + static NameSyntax Visit(NameSyntax nameSyntax) + { switch (nameSyntax) { + case GenericNameSyntax generic: + { + var argCount = generic.TypeArgumentList.Arguments.Count; + return generic.WithTypeArgumentList(TypeArgumentList(SeparatedList(Enumerable.Range(0, argCount).Select(_ => OmittedTypeArgument())))); + } case AliasQualifiedNameSyntax aliased: - return aliased.WithName(WithGenericParameters(aliased.Name)); + return aliased.WithName((SimpleNameSyntax)Visit(aliased.Name)); case QualifiedNameSyntax qualified: - return qualified.WithRight(WithGenericParameters(qualified.Right)); - case GenericNameSyntax g: - return WithGenericParameters(g); + return qualified.WithRight((SimpleNameSyntax)Visit(qualified.Right)).WithLeft(Visit(qualified.Left)); default: - throw new InvalidOperationException( - $"Attempted to add generic parameters to non-generic type {displayString} ({nameSyntax.GetType()}, adding parameters {string.Join(", ", genericParameters.Select(n => n.ToFullString()))}"); + return nameSyntax; } + } + } - SimpleNameSyntax WithGenericParameters(SimpleNameSyntax simpleNameSyntax) - { - if (simpleNameSyntax is GenericNameSyntax generic) - { - return generic.WithTypeArgumentList(TypeArgumentList(SeparatedList(genericParameters))); - } + public static NameSyntax ToNameSyntax(this ITypeSymbol typeSymbol) => ParseName(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); - throw new InvalidOperationException( - $"Attempted to add generic parameters to non-generic type {displayString} ({nameSyntax.GetType()}, adding parameters {string.Join(", ", genericParameters.Select(n => n.ToFullString()))}"); + public static string GetValidIdentifier(this ITypeSymbol type) => type switch + { + INamedTypeSymbol named when !named.IsGenericType => $"{named.Name}", + INamedTypeSymbol named => $"{named.Name}_{string.Join("_", named.TypeArguments.Select(GetValidIdentifier))}", + IArrayTypeSymbol array => $"{GetValidIdentifier(array.ElementType)}_{array.Rank}", + ITypeParameterSymbol parameter => $"{parameter.Name}", + _ => throw new NotSupportedException($"Unable to format type of kind {type.GetType()} with name \"{type.Name}\""), + }; + + public static bool HasBaseType(this ITypeSymbol? typeSymbol, INamedTypeSymbol baseType) + { + var current = typeSymbol; + for (; current != null; current = current.BaseType) + { + if (SymbolEqualityComparer.Default.Equals(baseType, current)) + { + return true; } } + return false; + } - public static TypeSyntax ToOpenTypeSyntax(this ITypeSymbol typeSymbol) - { - var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var nameSyntax = ParseName(displayString); - return Visit(nameSyntax); + public static bool HasAnyAttribute(this ISymbol symbol, INamedTypeSymbol[] attributeTypes) => GetAnyAttribute(symbol, attributeTypes) != null; - static NameSyntax Visit(NameSyntax nameSyntax) + public static AttributeData? GetAnyAttribute(this ISymbol symbol, INamedTypeSymbol[] attributeTypes) + { + foreach (var attr in symbol.GetAttributes()) + { + foreach (var t in attributeTypes) { - switch (nameSyntax) + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, t)) { - case GenericNameSyntax generic: - { - var argCount = generic.TypeArgumentList.Arguments.Count; - return generic.WithTypeArgumentList(TypeArgumentList(SeparatedList(Enumerable.Range(0, argCount).Select(_ => OmittedTypeArgument())))); - } - case AliasQualifiedNameSyntax aliased: - return aliased.WithName((SimpleNameSyntax)Visit(aliased.Name)); - case QualifiedNameSyntax qualified: - return qualified.WithRight((SimpleNameSyntax)Visit(qualified.Right)).WithLeft(Visit(qualified.Left)); - default: - return nameSyntax; + return attr; } } } + return null; + } - public static NameSyntax ToNameSyntax(this ITypeSymbol typeSymbol) => ParseName(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); - - public static string GetValidIdentifier(this ITypeSymbol type) => type switch - { - INamedTypeSymbol named when !named.IsGenericType => $"{named.Name}", - INamedTypeSymbol named => $"{named.Name}_{string.Join("_", named.TypeArguments.Select(GetValidIdentifier))}", - IArrayTypeSymbol array => $"{GetValidIdentifier(array.ElementType)}_{array.Rank}", - ITypeParameterSymbol parameter => $"{parameter.Name}", - _ => throw new NotSupportedException($"Unable to format type of kind {type.GetType()} with name \"{type.Name}\""), - }; + public static bool HasAttribute(this ISymbol symbol, INamedTypeSymbol attributeType) => GetAttribute(symbol, attributeType) != null; - public static bool HasBaseType(this ITypeSymbol typeSymbol, INamedTypeSymbol baseType) + public static AttributeData? GetAttribute(this ISymbol symbol, INamedTypeSymbol attributeType) + { + foreach (var attr in symbol.GetAttributes()) { - var current = typeSymbol; - for (; current != null; current = current.BaseType) + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, attributeType)) { - if (SymbolEqualityComparer.Default.Equals(baseType, current)) - { - return true; - } + return attr; } - return false; } - public static bool HasAnyAttribute(this ISymbol symbol, INamedTypeSymbol[] attributeTypes) => GetAnyAttribute(symbol, attributeTypes) != null; + return null; + } - public static AttributeData? GetAnyAttribute(this ISymbol symbol, INamedTypeSymbol[] attributeTypes) + /// + /// Gets all attributes which are assignable to the specified attribute type. + /// + public static bool GetAttributes(this ISymbol symbol, INamedTypeSymbol attributeType, [NotNullWhen(true)] out AttributeData[]? attributes) + { + var result = default(List); + foreach (var attr in symbol.GetAttributes()) { - foreach (var attr in symbol.GetAttributes()) + if (attr.AttributeClass is { } attrClass && !attrClass.HasBaseType(attributeType)) { - foreach (var t in attributeTypes) - { - if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, t)) - { - return attr; - } - } + continue; } - return null; + + if (result is null) + { + result = new List(); + } + + result.Add(attr); } - public static bool HasAttribute(this ISymbol symbol, INamedTypeSymbol attributeType) => GetAttribute(symbol, attributeType) != null; + attributes = result?.ToArray(); + return attributes != null && attributes.Length > 0; + } + + /// + /// Gets all attributes which are assignable to the specified attribute type. + /// + public static bool GetAttributes(this INamedTypeSymbol symbol, INamedTypeSymbol attributeType, [NotNullWhen(true)] out AttributeData[]? attributes, bool inherited = false) + { + var result = default(List); + AddSymbolAttributes(symbol, attributeType, ref result); - public static AttributeData? GetAttribute(this ISymbol symbol, INamedTypeSymbol attributeType) + if (inherited) { - foreach (var attr in symbol.GetAttributes()) + foreach (var iface in symbol.AllInterfaces) { - if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass, attributeType)) - { - return attr; - } + AddSymbolAttributes(iface, attributeType, ref result); } - return null; + var s = symbol; + while ((s = s.BaseType) != null) + { + AddSymbolAttributes(s, attributeType, ref result); + } } - /// - /// Gets all attributes which are assignable to the specified attribute type. - /// - public static bool GetAttributes(this ISymbol symbol, INamedTypeSymbol attributeType, out AttributeData[]? attributes) + attributes = result?.ToArray(); + return attributes != null && attributes.Length > 0; + + static void AddSymbolAttributes(ISymbol symbol, INamedTypeSymbol attributeType, ref List? result) { - var result = default(List); foreach (var attr in symbol.GetAttributes()) { if (attr.AttributeClass is { } attrClass && !attrClass.HasBaseType(attributeType)) @@ -362,252 +406,205 @@ public static bool GetAttributes(this ISymbol symbol, INamedTypeSymbol attribute result.Add(attr); } - - attributes = result?.ToArray(); - return attributes != null && attributes.Length > 0; } + } - /// - /// Gets all attributes which are assignable to the specified attribute type. - /// - public static bool GetAttributes(this INamedTypeSymbol symbol, INamedTypeSymbol attributeType, out AttributeData[]? attributes, bool inherited = false) + public static IEnumerable GetAllMembers(this ITypeSymbol type, string name) where TSymbol : ISymbol + { + foreach (var member in type.GetAllMembers()) { - var result = default(List); - AddSymbolAttributes(symbol, attributeType, ref result); - - if (inherited) + if (!string.Equals(member.Name, name, StringComparison.Ordinal)) { - foreach (var iface in symbol.AllInterfaces) - { - AddSymbolAttributes(iface, attributeType, ref result); - } - - var s = symbol; - while ((s = s.BaseType) != null) - { - AddSymbolAttributes(s, attributeType, ref result); - } + continue; } - attributes = result?.ToArray(); - return attributes != null && attributes.Length > 0; + yield return member; + } + } - static void AddSymbolAttributes(ISymbol symbol, INamedTypeSymbol attributeType, ref List? result) + public static IEnumerable GetAllMembers(this ITypeSymbol type, string name, Accessibility accessibility) where TSymbol : ISymbol + { + foreach (var member in type.GetAllMembers(name)) + { + if (member.DeclaredAccessibility != accessibility) { - foreach (var attr in symbol.GetAttributes()) - { - if (attr.AttributeClass is { } attrClass && !attrClass.HasBaseType(attributeType)) - { - continue; - } + continue; + } - if (result is null) - { - result = new List(); - } + yield return member; + } + } - result.Add(attr); - } - } + public static IEnumerable GetAllMembers(this ITypeSymbol type) where TSymbol : ISymbol + { + var bases = new Stack(); + var b = type.BaseType; + while (b is { }) + { + bases.Push(b); + b = b.BaseType; } - public static IEnumerable GetAllMembers(this ITypeSymbol type, string name) where TSymbol : ISymbol + foreach (var @base in bases) { - foreach (var member in type.GetAllMembers()) + foreach (var member in @base.GetDeclaredInstanceMembers()) { - if (!string.Equals(member.Name, name, StringComparison.Ordinal)) - { - continue; - } - yield return member; } } - public static IEnumerable GetAllMembers(this ITypeSymbol type, string name, Accessibility accessibility) where TSymbol : ISymbol + foreach (var iface in type.AllInterfaces) { - foreach (var member in type.GetAllMembers(name)) + foreach (var member in iface.GetDeclaredInstanceMembers()) { - if (member.DeclaredAccessibility != accessibility) - { - continue; - } - yield return member; } } - public static IEnumerable GetAllMembers(this ITypeSymbol type) where TSymbol : ISymbol + foreach (var member in type.GetDeclaredInstanceMembers()) { - var bases = new Stack(); - var b = type.BaseType; - while (b is { }) - { - bases.Push(b); - b = b.BaseType; - } - - foreach (var @base in bases) - { - foreach (var member in @base.GetDeclaredInstanceMembers()) - { - yield return member; - } - } - - foreach (var iface in type.AllInterfaces) + yield return member; + } + } + + public static IEnumerable GetDeclaredInstanceMembers(this ITypeSymbol type) where TSymbol : ISymbol + { + foreach (var candidate in type.GetMembers()) + { + if (candidate.IsStatic) { - foreach (var member in iface.GetDeclaredInstanceMembers()) - { - yield return member; - } + continue; } - foreach (var member in type.GetDeclaredInstanceMembers()) + if (candidate is TSymbol symbol) { - yield return member; + yield return symbol; } } - - public static IEnumerable GetDeclaredInstanceMembers(this ITypeSymbol type) where TSymbol : ISymbol - { - foreach (var candidate in type.GetMembers()) - { - if (candidate.IsStatic) - { - continue; - } + } - if (candidate is TSymbol symbol) - { - yield return symbol; - } - } - } + public static string GetNamespaceAndNesting(this ISymbol symbol) + { + var result = new StringBuilder(); + Visit(symbol, result); + return result.ToString(); - public static string GetNamespaceAndNesting(this ISymbol symbol) + static void Visit(ISymbol symbol, StringBuilder res) { - var result = new StringBuilder(); - Visit(symbol, result); - return result.ToString(); - - static void Visit(ISymbol symbol, StringBuilder res) + switch (symbol.ContainingSymbol) { - switch (symbol.ContainingSymbol) - { - case INamespaceOrTypeSymbol parent: - Visit(parent, res); + case INamespaceOrTypeSymbol parent: + Visit(parent, res); - if (res is { Length: > 0 }) - { - res.Append('.'); - } + if (res is { Length: > 0 }) + { + res.Append('.'); + } - res.Append(parent.Name); - break; - } + res.Append(parent.Name); + break; } } + } - public static IEnumerable GetAllTypeParameters(this INamedTypeSymbol symbol) + public static IEnumerable GetAllTypeParameters(this INamedTypeSymbol symbol) + { + // Note that this will not work if multiple points in the inheritance hierarchy are containing within a single generic type. + // To solve that, we could retain some context throughout the recursive calls. + if (symbol.ContainingType is { } containingType && containingType.IsGenericType) { - // Note that this will not work if multiple points in the inheritance hierarchy are containing within a single generic type. - // To solve that, we could retain some context throughout the recursive calls. - if (symbol.ContainingType is { } containingType && containingType.IsGenericType) - { - foreach (var containingTypeParameter in containingType.GetAllTypeParameters()) - { - yield return containingTypeParameter; - } - } - - foreach (var tp in symbol.TypeParameters) + foreach (var containingTypeParameter in containingType.GetAllTypeParameters()) { - yield return tp; + yield return containingTypeParameter; } } - public static IEnumerable GetAllTypeArguments(this INamedTypeSymbol symbol) + foreach (var tp in symbol.TypeParameters) { - if (symbol.ContainingType is { } containingType && containingType.IsGenericType) - { - foreach (var containingTypeParameter in containingType.GetAllTypeArguments()) - { - yield return containingTypeParameter; - } - } + yield return tp; + } + } - foreach (var tp in symbol.TypeArguments) + public static IEnumerable GetAllTypeArguments(this INamedTypeSymbol symbol) + { + if (symbol.ContainingType is { } containingType && containingType.IsGenericType) + { + foreach (var containingTypeParameter in containingType.GetAllTypeArguments()) { - yield return tp; + yield return containingTypeParameter; } } - public static bool IsAssignableFrom(this INamedTypeSymbol symbol, INamedTypeSymbol type) + foreach (var tp in symbol.TypeArguments) { - if (symbol.TypeKind == TypeKind.Interface) - { - return IsInterfaceAssignableFromInternal(symbol, type); - } + yield return tp; + } + } - return IsBaseAssignableFromInternal(symbol, type); + public static bool IsAssignableFrom(this INamedTypeSymbol symbol, INamedTypeSymbol type) + { + if (symbol.TypeKind == TypeKind.Interface) + { + return IsInterfaceAssignableFromInternal(symbol, type); } - private static bool IsBaseAssignableFromInternal(this INamedTypeSymbol symbol, INamedTypeSymbol? type) + return IsBaseAssignableFromInternal(symbol, type); + } + + private static bool IsBaseAssignableFromInternal(this INamedTypeSymbol symbol, INamedTypeSymbol? type) + { + if (type is null) return false; + if (SymbolEqualityComparer.Default.Equals(symbol, type)) { - if (type is null) return false; - if (SymbolEqualityComparer.Default.Equals(symbol, type)) - { - return true; - } + return true; + } + + return IsBaseAssignableFromInternal(symbol, type.BaseType); + } - return IsBaseAssignableFromInternal(symbol, type.BaseType); + private static bool IsInterfaceAssignableFromInternal(this INamedTypeSymbol iface, INamedTypeSymbol type) + { + if (SymbolEqualityComparer.Default.Equals(iface, type)) + { + return true; } - private static bool IsInterfaceAssignableFromInternal(this INamedTypeSymbol iface, INamedTypeSymbol type) + foreach (var typeInterface in type.AllInterfaces) { - if (SymbolEqualityComparer.Default.Equals(iface, type)) + if (SymbolEqualityComparer.Default.Equals(iface, typeInterface)) { return true; } - - foreach (var typeInterface in type.AllInterfaces) - { - if (SymbolEqualityComparer.Default.Equals(iface, typeInterface)) - { - return true; - } - } - - return false; } - public static IEnumerable GetDeclaredTypes(this IAssemblySymbol reference) + return false; + } + + public static IEnumerable GetDeclaredTypes(this IAssemblySymbol reference) + { + foreach (var module in reference.Modules) { - foreach (var module in reference.Modules) + foreach (var type in GetDeclaredTypes(module.GlobalNamespace)) { - foreach (var type in GetDeclaredTypes(module.GlobalNamespace)) - { - yield return type; - } + yield return type; } + } - IEnumerable GetDeclaredTypes(INamespaceOrTypeSymbol ns) + IEnumerable GetDeclaredTypes(INamespaceOrTypeSymbol ns) + { + foreach (var member in ns.GetMembers()) { - foreach (var member in ns.GetMembers()) + switch (member) { - switch (member) - { - case INamespaceSymbol nestedNamespace: - foreach (var nested in GetDeclaredTypes(nestedNamespace)) yield return nested; - break; - case ITypeSymbol type: - if (type is INamedTypeSymbol namedType) yield return namedType; - foreach (var nested in GetDeclaredTypes(type)) yield return nested; - break; - } + case INamespaceSymbol nestedNamespace: + foreach (var nested in GetDeclaredTypes(nestedNamespace)) yield return nested; + break; + case ITypeSymbol type: + if (type is INamedTypeSymbol namedType) yield return namedType; + foreach (var nested in GetDeclaredTypes(type)) yield return nested; + break; } } } } -} \ No newline at end of file +} diff --git a/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolSyntaxExtensions.cs b/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolSyntaxExtensions.cs index 4f711ddc5ab..3362e29668d 100644 --- a/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolSyntaxExtensions.cs +++ b/src/Orleans.CodeGenerator/SyntaxGeneration/SymbolSyntaxExtensions.cs @@ -1,133 +1,131 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System; using System.Reflection; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using System.Diagnostics; using Microsoft.CodeAnalysis; -#nullable disable -namespace Orleans.CodeGenerator.SyntaxGeneration +namespace Orleans.CodeGenerator.SyntaxGeneration; + +internal static class SymbolSyntaxExtensions { - internal static class SymbolSyntaxExtensions + public static ParenthesizedExpressionSyntax GetBindingFlagsParenthesizedExpressionSyntax( + SyntaxKind operationKind, + params BindingFlags[] bindingFlags) { - public static ParenthesizedExpressionSyntax GetBindingFlagsParenthesizedExpressionSyntax( - SyntaxKind operationKind, - params BindingFlags[] bindingFlags) + if (bindingFlags.Length < 2) { - if (bindingFlags.Length < 2) - { - throw new ArgumentOutOfRangeException( - nameof(bindingFlags), - $"Can't create parenthesized binary expression with {bindingFlags.Length} arguments"); - } + throw new ArgumentOutOfRangeException( + nameof(bindingFlags), + $"Can't create parenthesized binary expression with {bindingFlags.Length} arguments"); + } - var flags = AliasQualifiedName("global", IdentifierName("System")).Member("Reflection").Member("BindingFlags"); - var bindingFlagsBinaryExpression = BinaryExpression( + var flags = AliasQualifiedName("global", IdentifierName("System")).Member("Reflection").Member("BindingFlags"); + var bindingFlagsBinaryExpression = BinaryExpression( + operationKind, + flags.Member(bindingFlags[0].ToString()), + flags.Member(bindingFlags[1].ToString())); + for (var i = 2; i < bindingFlags.Length; i++) + { + bindingFlagsBinaryExpression = BinaryExpression( operationKind, - flags.Member(bindingFlags[0].ToString()), - flags.Member(bindingFlags[1].ToString())); - for (var i = 2; i < bindingFlags.Length; i++) - { - bindingFlagsBinaryExpression = BinaryExpression( - operationKind, - bindingFlagsBinaryExpression, - flags.Member(bindingFlags[i].ToString())); - } - - return ParenthesizedExpression(bindingFlagsBinaryExpression); + bindingFlagsBinaryExpression, + flags.Member(bindingFlags[i].ToString())); } - /// - /// Returns the System.String that represents the current TypedConstant. - /// - /// A System.String that represents the current TypedConstant. - public static ExpressionSyntax ToExpression(this TypedConstant constant) - { - if (constant.IsNull) - { - return LiteralExpression(SyntaxKind.NullLiteralExpression); - } + return ParenthesizedExpression(bindingFlagsBinaryExpression); + } - if (constant.Kind == TypedConstantKind.Array) - { - throw new NotSupportedException($"Unsupported TypedConstant: {constant.ToCSharpString()}"); - } + /// + /// Returns the System.String that represents the current TypedConstant. + /// + /// A System.String that represents the current TypedConstant. + public static ExpressionSyntax ToExpression(this TypedConstant constant) + { + if (constant.IsNull) + { + return LiteralExpression(SyntaxKind.NullLiteralExpression); + } - if (constant.Kind == TypedConstantKind.Type) - { - Debug.Assert(constant.Value is not null); - return TypeOfExpression(((ITypeSymbol)constant.Value).ToTypeSyntax()); - } + if (constant.Kind == TypedConstantKind.Array) + { + throw new NotSupportedException($"Unsupported TypedConstant: {constant.ToCSharpString()}"); + } - if (constant.Kind == TypedConstantKind.Enum) - { - return DisplayEnumConstant(constant); - } + if (constant.Kind == TypedConstantKind.Type) + { + Debug.Assert(constant.Value is ITypeSymbol); + return TypeOfExpression(((ITypeSymbol)constant.Value!).ToTypeSyntax()); + } - return ParseExpression(constant.ToCSharpString()); + if (constant.Kind == TypedConstantKind.Enum) + { + return DisplayEnumConstant(constant); } - // Decode the value of enum constant - private static ExpressionSyntax DisplayEnumConstant(TypedConstant constant) + return ParseExpression(constant.ToCSharpString()); + } + + // Decode the value of enum constant + private static ExpressionSyntax DisplayEnumConstant(TypedConstant constant) + { + //string typeName = constant.Type.ToDisplayName(); + var constantToDecode = ConvertToUInt64(constant.Value); + ulong curValue = 0; + + // Iterate through all the constant members in the enum type + var members = constant.Type!.GetMembers(); + var type = constant.Type.ToTypeSyntax(); + ExpressionSyntax? result = null; + foreach (var member in members) { - //string typeName = constant.Type.ToDisplayName(); - var constantToDecode = ConvertToUInt64(constant.Value); - ulong curValue = 0; - - // Iterate through all the constant members in the enum type - var members = constant.Type!.GetMembers(); - var type = constant.Type.ToTypeSyntax(); - ExpressionSyntax result = null; - foreach (var member in members) + var field = member as IFieldSymbol; + + if (field is object && field.HasConstantValue) { - var field = member as IFieldSymbol; + ulong memberValue = ConvertToUInt64(field.ConstantValue); - if (field is object && field.HasConstantValue) + if (memberValue == constantToDecode) { - ulong memberValue = ConvertToUInt64(field.ConstantValue); + return constant.Type.ToTypeSyntax().Member(field.Name); + } + + if ((memberValue & constantToDecode) == memberValue) + { + // update the current value + curValue = curValue | memberValue; - if (memberValue == constantToDecode) + var valueExpression = type.Member(field.Name); + if (result is null) { - return constant.Type.ToTypeSyntax().Member(field.Name); + result = valueExpression; } - - if ((memberValue & constantToDecode) == memberValue) + else { - // update the current value - curValue = curValue | memberValue; - - var valueExpression = type.Member(field.Name); - if (result is null) - { - result = valueExpression; - } - else - { - result = BinaryExpression(SyntaxKind.BitwiseOrExpression, result, valueExpression); - } + result = BinaryExpression(SyntaxKind.BitwiseOrExpression, result, valueExpression); } } } - - return result; } - private static ulong ConvertToUInt64(object value) + Debug.Assert(result is not null); + return result!; + } + + private static ulong ConvertToUInt64(object? value) + { + return value switch { - return value switch - { - byte b => b, - sbyte sb => (ulong)sb, - short s => (ulong)s, - ushort us => us, - int i => (ulong)i, - uint ui => ui, - long l => (ulong)l, - ulong ul => ul, - _ => throw new NotSupportedException($"Type {value?.GetType()} not supported") - }; - } + byte b => b, + sbyte sb => (ulong)sb, + short s => (ulong)s, + ushort us => us, + int i => (ulong)i, + uint ui => ui, + long l => (ulong)l, + ulong ul => ul, + _ => throw new NotSupportedException($"Type {value?.GetType()} not supported") + }; } } diff --git a/src/Orleans.CodeGenerator/SyntaxGeneration/SyntaxFactoryUtility.cs b/src/Orleans.CodeGenerator/SyntaxGeneration/SyntaxFactoryUtility.cs index 81ed413015c..0ef64a1e8c4 100644 --- a/src/Orleans.CodeGenerator/SyntaxGeneration/SyntaxFactoryUtility.cs +++ b/src/Orleans.CodeGenerator/SyntaxGeneration/SyntaxFactoryUtility.cs @@ -1,114 +1,111 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using System.Collections.Generic; -using System.Linq; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -namespace Orleans.CodeGenerator.SyntaxGeneration +namespace Orleans.CodeGenerator.SyntaxGeneration; + +internal static class SyntaxFactoryUtility { - internal static class SyntaxFactoryUtility - { - /// - /// Returns member access syntax. - /// - /// - /// The instance. - /// - /// - /// The member. - /// - /// - /// The resulting . - /// - public static MemberAccessExpressionSyntax Member(this ExpressionSyntax instance, string member) => instance.Member(member.ToIdentifierName()); + /// + /// Returns member access syntax. + /// + /// + /// The instance. + /// + /// + /// The member. + /// + /// + /// The resulting . + /// + public static MemberAccessExpressionSyntax Member(this ExpressionSyntax instance, string member) => instance.Member(member.ToIdentifierName()); - /// - /// Returns member access syntax. - /// - /// - /// The instance. - /// - /// - /// The member. - /// - /// - /// The resulting . - /// - public static MemberAccessExpressionSyntax Member(this ExpressionSyntax instance, IdentifierNameSyntax member) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, instance, member); + /// + /// Returns member access syntax. + /// + /// + /// The instance. + /// + /// + /// The member. + /// + /// + /// The resulting . + /// + public static MemberAccessExpressionSyntax Member(this ExpressionSyntax instance, IdentifierNameSyntax member) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, instance, member); - public static MemberAccessExpressionSyntax Member(this ExpressionSyntax instance, GenericNameSyntax member) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, instance, member); + public static MemberAccessExpressionSyntax Member(this ExpressionSyntax instance, GenericNameSyntax member) => MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, instance, member); - public static MemberAccessExpressionSyntax Member( - this ExpressionSyntax instance, - string member, - params TypeSyntax[] genericTypes) => instance.Member( - member.ToGenericName() - .AddTypeArgumentListArguments(genericTypes)); + public static MemberAccessExpressionSyntax Member( + this ExpressionSyntax instance, + string member, + params TypeSyntax[] genericTypes) => instance.Member( + member.ToGenericName() + .AddTypeArgumentListArguments(genericTypes)); - public static GenericNameSyntax ToGenericName(this string identifier) => GenericName(identifier.ToIdentifier()); + public static GenericNameSyntax ToGenericName(this string identifier) => GenericName(identifier.ToIdentifier()); - public static ClassDeclarationSyntax AddGenericTypeParameters( - ClassDeclarationSyntax classDeclaration, - List<(string Name, ITypeParameterSymbol Parameter)> typeParameters) + public static ClassDeclarationSyntax AddGenericTypeParameters( + ClassDeclarationSyntax classDeclaration, + List<(string Name, ITypeParameterSymbol Parameter)> typeParameters) + { + var typeParametersWithConstraints = GetTypeParameterConstraints(typeParameters); + foreach (var (name, constraints) in typeParametersWithConstraints) { - var typeParametersWithConstraints = GetTypeParameterConstraints(typeParameters); - foreach (var (name, constraints) in typeParametersWithConstraints) + if (constraints.Count > 0) { - if (constraints.Count > 0) - { - classDeclaration = classDeclaration.AddConstraintClauses( - TypeParameterConstraintClause(name).AddConstraints(constraints.ToArray())); - } + classDeclaration = classDeclaration.AddConstraintClauses( + TypeParameterConstraintClause(name).AddConstraints([.. constraints])); } - - if (typeParametersWithConstraints.Count > 0) - { - classDeclaration = classDeclaration.WithTypeParameterList( - TypeParameterList(SeparatedList(typeParametersWithConstraints.Select(tp => TypeParameter(tp.Name))))); - } - - return classDeclaration; } - public static List<(string Name, List Constraints)> GetTypeParameterConstraints(List<(string Name, ITypeParameterSymbol Parameter)> typeParameter) + if (typeParametersWithConstraints.Count > 0) { - var allConstraints = new List<(string, List)>(); - foreach (var (name, tp) in typeParameter) - { - var constraints = new List(); + classDeclaration = classDeclaration.WithTypeParameterList( + TypeParameterList(SeparatedList(typeParametersWithConstraints.Select(tp => TypeParameter(tp.Name))))); + } - if (tp.HasUnmanagedTypeConstraint) - { - constraints.Add(TypeConstraint(IdentifierName("unmanaged"))); - } - else if (tp.HasValueTypeConstraint) - { - constraints.Add(ClassOrStructConstraint(SyntaxKind.StructConstraint)); - } - else if (tp.HasNotNullConstraint) - { - constraints.Add(TypeConstraint(IdentifierName("notnull"))); - } - else if (tp.HasReferenceTypeConstraint) - { - constraints.Add(ClassOrStructConstraint(SyntaxKind.ClassConstraint)); - } + return classDeclaration; + } - foreach (var c in tp.ConstraintTypes) - { - constraints.Add(TypeConstraint(c.ToTypeSyntax())); - } + public static List<(string Name, List Constraints)> GetTypeParameterConstraints(List<(string Name, ITypeParameterSymbol Parameter)> typeParameter) + { + var allConstraints = new List<(string, List)>(); + foreach (var (name, tp) in typeParameter) + { + var constraints = new List(); - if (tp.HasConstructorConstraint) - { - constraints.Add(ConstructorConstraint()); - } + if (tp.HasUnmanagedTypeConstraint) + { + constraints.Add(TypeConstraint(IdentifierName("unmanaged"))); + } + else if (tp.HasValueTypeConstraint) + { + constraints.Add(ClassOrStructConstraint(SyntaxKind.StructConstraint)); + } + else if (tp.HasNotNullConstraint) + { + constraints.Add(TypeConstraint(IdentifierName("notnull"))); + } + else if (tp.HasReferenceTypeConstraint) + { + constraints.Add(ClassOrStructConstraint(SyntaxKind.ClassConstraint)); + } + + foreach (var c in tp.ConstraintTypes) + { + constraints.Add(TypeConstraint(c.ToTypeSyntax())); + } - allConstraints.Add((name, constraints)); + if (tp.HasConstructorConstraint) + { + constraints.Add(ConstructorConstraint()); } - return allConstraints; + allConstraints.Add((name, constraints)); } + + return allConstraints; } } \ No newline at end of file diff --git a/src/Orleans.CodeGenerator/TypeSymbolResolver.cs b/src/Orleans.CodeGenerator/TypeSymbolResolver.cs new file mode 100644 index 00000000000..4efe8d93208 --- /dev/null +++ b/src/Orleans.CodeGenerator/TypeSymbolResolver.cs @@ -0,0 +1,350 @@ +using System.Diagnostics.CodeAnalysis; +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator.Model; +using Orleans.CodeGenerator.SyntaxGeneration; + +namespace Orleans.CodeGenerator; + +internal sealed class TypeSymbolResolver(Compilation compilation) +{ + private readonly Compilation _compilation = compilation; + private FallbackIndex? _fallbackIndex; + + public bool TryResolveSerializableType( + SerializableTypeModel model, + CancellationToken cancellationToken, + [NotNullWhen(true)] out INamedTypeSymbol? symbol) + { + if (model is null) + { + symbol = null; + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + if (TryResolveMetadataIdentity(model.MetadataIdentity, cancellationToken, out symbol) + || TryResolveTypeSyntax(model.TypeSyntax.SyntaxString, cancellationToken, out symbol)) + { + return true; + } + + foreach (var candidate in GetFallbackIndex(cancellationToken).AllTypes) + { + cancellationToken.ThrowIfCancellationRequested(); + if (string.Equals(candidate.Name, model.Name, StringComparison.Ordinal) + && string.Equals(candidate.GetNamespaceAndNesting(), model.Namespace, StringComparison.Ordinal) + && candidate.GetAllTypeParameters().Count() == model.TypeParameters.Length) + { + symbol = candidate; + return true; + } + } + + symbol = null; + return false; + } + + public bool TryResolveProxyInterface( + ProxyInterfaceModel model, + CancellationToken cancellationToken, + [NotNullWhen(true)] out INamedTypeSymbol? symbol) + { + if (model is null) + { + symbol = null; + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + if (TryResolveMetadataIdentity(model.MetadataIdentity, cancellationToken, out symbol) + || TryResolveTypeSyntax(model.InterfaceType.SyntaxString, cancellationToken, out symbol)) + { + return symbol.TypeKind == TypeKind.Interface; + } + + foreach (var candidate in GetFallbackIndex(cancellationToken).AllTypes) + { + cancellationToken.ThrowIfCancellationRequested(); + if (candidate.TypeKind == TypeKind.Interface + && string.Equals(candidate.Name, model.Name, StringComparison.Ordinal) + && string.Equals(candidate.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), model.InterfaceType.SyntaxString, StringComparison.Ordinal)) + { + symbol = candidate; + return true; + } + } + + symbol = null; + return false; + } + + private bool TryResolveMetadataIdentity( + TypeMetadataIdentity metadataIdentity, + CancellationToken cancellationToken, + [NotNullWhen(true)] out INamedTypeSymbol? symbol) + { + if (metadataIdentity.IsEmpty) + { + symbol = null; + return false; + } + + if (!string.IsNullOrEmpty(metadataIdentity.AssemblyIdentity) + || !string.IsNullOrEmpty(metadataIdentity.AssemblyName)) + { + if (TryGetAssembly(metadataIdentity, cancellationToken, out var assembly)) + { + symbol = assembly.GetTypeByMetadataName(metadataIdentity.MetadataName); + return symbol is not null; + } + + symbol = null; + return false; + } + + return TryResolveMetadataName(metadataIdentity.MetadataName, out symbol); + } + + private bool TryGetAssembly( + TypeMetadataIdentity metadataIdentity, + CancellationToken cancellationToken, + [NotNullWhen(true)] out IAssemblySymbol? assembly) + { + if (IsMatchingAssembly(_compilation.Assembly, metadataIdentity)) + { + assembly = _compilation.Assembly; + return true; + } + + IAssemblySymbol? assemblyByName = null; + foreach (var reference in _compilation.References) + { + cancellationToken.ThrowIfCancellationRequested(); + if (_compilation.GetAssemblyOrModuleSymbol(reference) is not IAssemblySymbol candidate) + { + continue; + } + + if (IsMatchingAssembly(candidate, metadataIdentity)) + { + assembly = candidate; + return true; + } + + if (string.IsNullOrEmpty(metadataIdentity.AssemblyIdentity) + && !string.IsNullOrEmpty(metadataIdentity.AssemblyName) + && string.Equals(candidate.Identity.Name, metadataIdentity.AssemblyName, StringComparison.Ordinal)) + { + if (assemblyByName is not null) + { + assembly = null; + return false; + } + + assemblyByName = candidate; + } + } + + if (assemblyByName is not null) + { + assembly = assemblyByName; + return true; + } + + assembly = null; + return false; + } + + private static bool IsMatchingAssembly(IAssemblySymbol assembly, TypeMetadataIdentity metadataIdentity) + { + if (!string.IsNullOrEmpty(metadataIdentity.AssemblyIdentity)) + { + return string.Equals(assembly.Identity.GetDisplayName(), metadataIdentity.AssemblyIdentity, StringComparison.Ordinal); + } + + return !string.IsNullOrEmpty(metadataIdentity.AssemblyName) + && string.Equals(assembly.Identity.Name, metadataIdentity.AssemblyName, StringComparison.Ordinal); + } + + private bool TryResolveTypeSyntax( + string typeSyntax, + CancellationToken cancellationToken, + [NotNullWhen(true)] out INamedTypeSymbol? symbol) + { + if (string.IsNullOrWhiteSpace(typeSyntax)) + { + symbol = null; + return false; + } + + if (TryGetMetadataName(typeSyntax, allowGenericSyntax: false, out var metadataName) + && TryResolveMetadataName(metadataName, out symbol)) + { + return true; + } + + var fallbackIndex = GetFallbackIndex(cancellationToken); + if (fallbackIndex.TypesByKey.TryGetValue(NormalizeTypeKey(typeSyntax), out symbol)) + { + return true; + } + + return TryGetMetadataName(typeSyntax, allowGenericSyntax: true, out metadataName) + && TryResolveMetadataName(metadataName, out symbol); + } + + private bool TryResolveMetadataName(string metadataName, [NotNullWhen(true)] out INamedTypeSymbol? symbol) + { + symbol = _compilation.GetTypeByMetadataName(metadataName); + if (symbol is null && TryGetSpecialType(metadataName, out var specialType)) + { + symbol = _compilation.GetSpecialType(specialType); + } + + return symbol is not null; + } + + private static bool TryGetMetadataName(string typeSyntax, bool allowGenericSyntax, [NotNullWhen(true)] out string? metadataName) + { + metadataName = typeSyntax.Trim(); + if (metadataName.StartsWith("global::", StringComparison.Ordinal)) + { + metadataName = metadataName.Substring("global::".Length); + } + + var genericStart = metadataName.IndexOf('<'); + if (genericStart >= 0) + { + if (!allowGenericSyntax) + { + metadataName = null; + return false; + } + + metadataName = metadataName.Substring(0, genericStart); + } + + metadataName = metadataName.Trim(); + if (metadataName.StartsWith("global::", StringComparison.Ordinal)) + { + metadataName = metadataName.Substring("global::".Length); + } + + metadataName = metadataName switch + { + "bool" => "System.Boolean", + "byte" => "System.Byte", + "sbyte" => "System.SByte", + "short" => "System.Int16", + "ushort" => "System.UInt16", + "int" => "System.Int32", + "uint" => "System.UInt32", + "long" => "System.Int64", + "ulong" => "System.UInt64", + "float" => "System.Single", + "double" => "System.Double", + "decimal" => "System.Decimal", + "char" => "System.Char", + "string" => "System.String", + "object" => "System.Object", + _ => metadataName, + }; + + return !string.IsNullOrWhiteSpace(metadataName); + } + + private static bool TryGetSpecialType(string metadataName, out SpecialType specialType) + { + specialType = metadataName switch + { + "System.Boolean" => SpecialType.System_Boolean, + "System.Byte" => SpecialType.System_Byte, + "System.SByte" => SpecialType.System_SByte, + "System.Int16" => SpecialType.System_Int16, + "System.UInt16" => SpecialType.System_UInt16, + "System.Int32" => SpecialType.System_Int32, + "System.UInt32" => SpecialType.System_UInt32, + "System.Int64" => SpecialType.System_Int64, + "System.UInt64" => SpecialType.System_UInt64, + "System.Single" => SpecialType.System_Single, + "System.Double" => SpecialType.System_Double, + "System.Decimal" => SpecialType.System_Decimal, + "System.Char" => SpecialType.System_Char, + "System.String" => SpecialType.System_String, + "System.Object" => SpecialType.System_Object, + _ => SpecialType.None, + }; + + return specialType != SpecialType.None; + } + + private FallbackIndex GetFallbackIndex(CancellationToken cancellationToken) + { + if (_fallbackIndex is { } fallbackIndex) + { + return fallbackIndex; + } + + fallbackIndex = BuildFallbackIndex(cancellationToken); + _fallbackIndex = fallbackIndex; + return fallbackIndex; + } + + private FallbackIndex BuildFallbackIndex(CancellationToken cancellationToken) + { + var typesByKey = new Dictionary(StringComparer.Ordinal); + var allTypes = new List(); + AddAssembly(_compilation.Assembly); + + foreach (var reference in _compilation.References) + { + cancellationToken.ThrowIfCancellationRequested(); + if (_compilation.GetAssemblyOrModuleSymbol(reference) is IAssemblySymbol assembly) + { + AddAssembly(assembly); + } + } + + return new FallbackIndex(typesByKey, allTypes); + + void AddAssembly(IAssemblySymbol assembly) + { + foreach (var type in assembly.GetDeclaredTypes()) + { + cancellationToken.ThrowIfCancellationRequested(); + AddType(type); + } + } + + void AddType(INamedTypeSymbol type) + { + allTypes.Add(type); + AddKey(type.ToOpenTypeSyntax().ToString(), type); + AddKey(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), type); + AddKey(type.ToDisplayString(), type); + } + + void AddKey(string key, INamedTypeSymbol type) + { + if (string.IsNullOrWhiteSpace(key)) + { + return; + } + + var normalizedKey = NormalizeTypeKey(key); + if (!typesByKey.TryGetValue(normalizedKey, out _)) + { + typesByKey.Add(normalizedKey, type); + } + } + } + + private sealed class FallbackIndex(Dictionary typesByKey, List allTypes) + { + public Dictionary TypesByKey { get; } = typesByKey; + public List AllTypes { get; } = allTypes; + } + + private static string NormalizeTypeKey(string value) + => string.Concat(value.Where(static character => !char.IsWhiteSpace(character))); +} diff --git a/src/api/Orleans.Transactions/Orleans.Transactions.cs b/src/api/Orleans.Transactions/Orleans.Transactions.cs index dd303e2c383..12b373cd2ea 100644 --- a/src/api/Orleans.Transactions/Orleans.Transactions.cs +++ b/src/api/Orleans.Transactions/Orleans.Transactions.cs @@ -2106,4 +2106,4 @@ public Copier_OperationState(global::Orleans.Serialization.Serializers.ICodecPro public global::Orleans.Transactions.TransactionCommitter.OperationState DeepCopy(global::Orleans.Transactions.TransactionCommitter.OperationState original, global::Orleans.Serialization.Cloning.CopyContext context) { throw null; } } -} \ No newline at end of file +} diff --git a/test/Grains/TestGrainInterfaces/IErrorGrain.cs b/test/Grains/TestGrainInterfaces/IErrorGrain.cs index e72c67c0e82..9042c896662 100644 --- a/test/Grains/TestGrainInterfaces/IErrorGrain.cs +++ b/test/Grains/TestGrainInterfaces/IErrorGrain.cs @@ -1,4 +1,4 @@ -namespace UnitTests.GrainInterfaces +namespace UnitTests.GrainInterfaces { public interface IErrorGrain : ISimpleGrain { diff --git a/test/Orleans.CodeGenerator.Tests/DiagnosticTests.cs b/test/Orleans.CodeGenerator.Tests/DiagnosticTests.cs new file mode 100644 index 00000000000..ddfef128c40 --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/DiagnosticTests.cs @@ -0,0 +1,475 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Orleans.CodeGenerator.Diagnostics; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Tests that verify the Orleans source generator emits correct diagnostics +/// for invalid or unsupported code patterns. +/// +public class DiagnosticTests +{ + [Fact] + public async Task GenerateSerializer_OnAccessibleType_ProducesOutput() + { + // Verifying that accessible types produce expected output without diagnostics. + // The inaccessibility diagnostic (ORLEANS0107) is tested implicitly via cross-assembly + // scenarios in the BVT suite; it requires complex assembly setup not easily done in a unit test. + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class PublicDto + { + [Id(0)] + public string Name { get; set; } + } + + [GenerateSerializer] + internal class InternalDto + { + [Id(0)] + public int Value { get; set; } + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + + var generatedSource = ConcatenateGeneratedSources(result); + Assert.Contains("PublicDto", generatedSource); + Assert.Contains("InternalDto", generatedSource); + } + + [Fact] + public async Task ReferenceAssembly_WithGenerateSerializer_EmitsWarning() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class RefAsmType + { + [Id(0)] + public string Value { get; set; } = string.Empty; + } + """; + + var compilation = await CreateCompilation(code); + + // Add ReferenceAssemblyAttribute to the assembly + var referenceAssemblyAttribute = SyntaxFactory.Attribute( + SyntaxFactory.ParseName("System.Runtime.CompilerServices.ReferenceAssemblyAttribute")); + var assemblyAttr = SyntaxFactory.AttributeList( + SyntaxFactory.SingletonSeparatedList(referenceAssemblyAttribute)) + .WithTarget(SyntaxFactory.AttributeTargetSpecifier(SyntaxFactory.Token(SyntaxKind.AssemblyKeyword))); + var root = (CompilationUnitSyntax)compilation.SyntaxTrees[0].GetRoot(); + var newRoot = root.AddAttributeLists(assemblyAttr); + var newTree = compilation.SyntaxTrees[0].WithRootAndOptions(newRoot, compilation.SyntaxTrees[0].Options); + compilation = compilation.RemoveSyntaxTrees(compilation.SyntaxTrees[0]).AddSyntaxTrees(newTree); + + var result = RunGeneratorOnCompilation(compilation); + + Assert.Contains(result.Diagnostics, + d => d.Id == DiagnosticRuleId.ReferenceAssemblyWithGenerateSerializer); + } + + [Fact] + public async Task RpcInterfaceWithProperty_EmitsDiagnostic() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + string Name { get; set; } + Task SayHello(string name); + } + """; + + var result = await RunGenerator(code); + + Assert.Contains(result.Diagnostics, + d => d.Id == DiagnosticRuleId.RpcInterfaceProperty); + } + + [Fact] + public async Task InvalidRpcMethodReturnType_EmitsDiagnostic() + { + const string code = """ + using Orleans; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + string SayHello(string name); + } + """; + + var result = await RunGenerator(code); + + Assert.Contains(result.Diagnostics, + d => d.Id == DiagnosticRuleId.InvalidRpcMethodReturnType); + } + + [Fact] + public async Task GenerateSerializer_WithInaccessibleSetter_EmitsDiagnostic() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class MyDto + { + [Id(0)] + public string Name => "computed"; + } + """; + + var result = await RunGenerator(code); + + Assert.Contains(result.Diagnostics, + d => d.Id == DiagnosticRuleId.InaccessibleSetter); + } + + [Fact] + public async Task IncorrectProxyBaseClassSpecification_EmitsDiagnostic() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public class InvalidProxyBase + { + } + + [GenerateMethodSerializers(typeof(InvalidProxyBase))] + public interface IMyGrain : IGrainWithIntegerKey + { + Task Ping(); + } + """; + + var result = await RunGenerator(code); + + Assert.Contains(result.Diagnostics, + d => d.Id == DiagnosticRuleId.IncorrectProxyBaseClassSpecification); + } + + [Fact] + public async Task ValidSerializableType_EmitsNoDiagnostics() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class ValidDto + { + [Id(0)] + public string Name { get; set; } + + [Id(1)] + public int Value { get; set; } + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + } + + [Fact] + public async Task ValidGrainInterface_EmitsNoDiagnostics() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + Task SayHello(string name); + Task DoWork(); + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + } + + [Fact] + public async Task SerializableRecord_EmitsNoDiagnostics() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public record RecordDto + { + [Id(0)] + public string Name { get; init; } + + [Id(1)] + public int Value { get; init; } + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + } + + [Fact] + public async Task SerializableStruct_EmitsNoDiagnostics() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public struct StructDto + { + [Id(0)] + public string Name { get; set; } + + [Id(1)] + public int Value { get; set; } + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + } + + [Fact] + public async Task GenerateSerializerOnEnum_EmitsNoDiagnostics() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public enum MyEnum + { + None = 0, + Value1 = 1, + Value2 = 2, + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + } + + [Fact] + public async Task GenerateSerializerOnAbstractClass_EmitsNoDiagnostics() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public abstract class AbstractDto + { + [Id(0)] + public string Name { get; set; } + } + + [GenerateSerializer] + public class ConcreteDto : AbstractDto + { + [Id(1)] + public int Value { get; set; } + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + } + + [Fact] + public async Task CompilationErrors_DoNotPreventSourceGeneration() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class DtoWithMissingType + { + [Id(0)] + public MissingNamespace.MissingType Value { get; set; } + } + """; + + var compilation = await CreateCompilation(code); + Assert.Contains(compilation.GetDiagnostics(), diagnostic => diagnostic.Severity == DiagnosticSeverity.Error); + + var result = RunGeneratorOnCompilation(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Contains("Codec_DtoWithMissingType", ConcatenateGeneratedSources(result)); + } + + [Fact] + public async Task EmptyCompilation_ProducesNoOutput() + { + const string code = """ + namespace TestProject; + + public class PlainClass + { + public string Name { get; set; } + } + """; + + var result = await RunGenerator(code); + + // No Orleans attributes → no generated output and no diagnostics + Assert.Empty(result.Diagnostics); + + // May still produce metadata output; verify no serializer-specific outputs + var serializerSources = result.GeneratedSources + .Where(s => s.HintName.Contains("serializer", StringComparison.OrdinalIgnoreCase) + || s.HintName.Contains("copier", StringComparison.OrdinalIgnoreCase) + || s.HintName.Contains("activator", StringComparison.OrdinalIgnoreCase)) + .ToList(); + Assert.Empty(serializerSources); + } + + [Fact] + public async Task MultipleGrainInterfaces_AllProduceProxies() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IGrainA : IGrainWithIntegerKey + { + Task MethodA(); + } + + public interface IGrainB : IGrainWithGuidKey + { + Task MethodB(string input); + } + + public interface IGrainC : IGrainWithStringKey + { + Task MethodC(int a, int b); + } + """; + + var result = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + + var generatedSource = ConcatenateGeneratedSources(result); + Assert.Contains("IGrainA", generatedSource); + Assert.Contains("IGrainB", generatedSource); + Assert.Contains("IGrainC", generatedSource); + } + + #region Helpers + + private static async Task RunGenerator(string code) + { + var compilation = await CreateCompilation(code); + return RunGeneratorOnCompilation(compilation); + } + + private static GeneratorRunResult RunGeneratorOnCompilation( + CSharpCompilation compilation, + IReadOnlyDictionary? globalOptions = null) + { + AnalyzerConfigOptionsProvider? optionsProvider = globalOptions is null + ? null + : new TestAnalyzerConfigOptionsProvider(globalOptions); + + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + optionsProvider: optionsProvider, + driverOptions: new GeneratorDriverOptions(default)); + driver = driver.RunGenerators(compilation); + return driver.GetRunResult().Results.Single(); + } + + private static string ConcatenateGeneratedSources(GeneratorRunResult result) + => string.Join( + $"{Environment.NewLine}{Environment.NewLine}", + result.GeneratedSources + .OrderBy(source => source.HintName, StringComparer.Ordinal) + .Select(source => source.SourceText.ToString().TrimStart('\uFEFF').TrimEnd())); + + private static Task CreateCompilation(string sourceCode, string assemblyName = "TestProject") + => TestCompilationHelper.CreateCompilation(sourceCode, assemblyName); + + private sealed class TestAnalyzerConfigOptionsProvider : AnalyzerConfigOptionsProvider + { + private static readonly AnalyzerConfigOptions EmptyOptions = new TestAnalyzerConfigOptions(new Dictionary()); + private readonly AnalyzerConfigOptions _globalOptions; + + public TestAnalyzerConfigOptionsProvider(IReadOnlyDictionary globalOptions) + { + _globalOptions = new TestAnalyzerConfigOptions(globalOptions); + } + + public override AnalyzerConfigOptions GlobalOptions => _globalOptions; + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) => EmptyOptions; + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) => EmptyOptions; + } + + private sealed class TestAnalyzerConfigOptions : AnalyzerConfigOptions + { + private readonly IReadOnlyDictionary _options; + + public TestAnalyzerConfigOptions(IReadOnlyDictionary options) + { + _options = options; + } + + public override bool TryGetValue(string key, out string value) => _options.TryGetValue(key, out value!); + } + + #endregion +} diff --git a/test/Orleans.CodeGenerator.Tests/GeneratedWarningSuppressionTests.cs b/test/Orleans.CodeGenerator.Tests/GeneratedWarningSuppressionTests.cs new file mode 100644 index 00000000000..a224ffb4add --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/GeneratedWarningSuppressionTests.cs @@ -0,0 +1,200 @@ +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace Orleans.CodeGenerator.Tests; + +public sealed class GeneratedWarningSuppressionTests +{ + private const string MissingXmlCommentDiagnosticId = "CS1591"; + private const string PublicApiAnalyzerDiagnosticId = "RS0016"; + private const string CompilerApiAnalyzerDiagnosticId = "RS0041"; + private static readonly CSharpParseOptions DocumentationParseOptions = + CSharpParseOptions.Default.WithDocumentationMode(DocumentationMode.Diagnose); + + [Fact] + public async Task GeneratedSerializerProxyAndMetadataSourcesDoNotProduceMissingXmlCommentErrors() + { + const string source = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + /// + /// A DTO which forces serializer generation. + /// + [GenerateSerializer] + public sealed class WarningDto + { + /// + /// Gets or sets the name. + /// + [Id(0)] + public string Name { get; set; } = string.Empty; + } + + /// + /// A grain interface which forces proxy and invokable generation. + /// + public interface IWarningGrain : IGrainWithIntegerKey + { + /// + /// Gets the generated DTO. + /// + Task GetDto(); + } + """; + + var compilation = await CreateCompilationWithDocumentationDiagnostics(source); + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + Assert.Contains(result.GeneratedSources, static source => source.HintName.Contains(".orleans.ser.", StringComparison.Ordinal)); + Assert.Contains(result.GeneratedSources, static source => source.HintName.Contains(".orleans.proxy.", StringComparison.Ordinal)); + Assert.Contains(result.GeneratedSources, static source => source.HintName.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)); + + var generatedCompilation = compilation.AddSyntaxTrees(CreateGeneratedSyntaxTrees(result)); + var generatedTreePaths = result.GeneratedSources + .Select(static source => source.HintName) + .ToHashSet(StringComparer.Ordinal); + var missingXmlCommentErrors = generatedCompilation.GetDiagnostics() + .Where(diagnostic => diagnostic.Id == MissingXmlCommentDiagnosticId + && diagnostic.Severity == DiagnosticSeverity.Error + && diagnostic.Location.SourceTree is { } tree + && generatedTreePaths.Contains(tree.FilePath)) + .Select(static diagnostic => $"{diagnostic.Location.GetLineSpan().Path}: {diagnostic.Id} {diagnostic.Severity}: {diagnostic.GetMessage()}") + .ToArray(); + + Assert.True( + missingXmlCommentErrors.Length == 0, + string.Join(Environment.NewLine, missingXmlCommentErrors)); + } + + [Fact] + public async Task GeneratedSources_CompileCleanlyUnderStrictDiagnosticsForMixedSourceAndReferences() + { + const string librarySource = """ + using Orleans; + + namespace LibraryProject; + + public sealed class Marker + { + } + + [GenerateSerializer] + public sealed class ReferencedWarningDto + { + [Id(0)] + public string Name { get; set; } = string.Empty; + } + """; + + const string consumerSource = """ + using Orleans; + using System.Threading.Tasks; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Marker))] + + namespace TestProject; + + /// + /// Local DTO for strict generated-source compilation. + /// + [GenerateSerializer] + public sealed class LocalWarningDto + { + /// + /// Gets or sets the name. + /// + [Id(0)] + public string Name { get; set; } = string.Empty; + } + + /// + /// Local grain interface for strict generated-source compilation. + /// + public interface ILocalWarningGrain : IGrainWithIntegerKey + { + /// + /// Gets a local DTO. + /// + Task GetLocal(); + + /// + /// Gets a referenced DTO. + /// + Task GetReferenced(); + } + """; + + var libraryCompilation = await TestCompilationHelper.CreateCompilation(librarySource, "LibraryProject"); + Assert.Empty(libraryCompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var compilation = await CreateCompilationWithDocumentationDiagnostics( + consumerSource, + libraryCompilation.ToMetadataReference()); + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + Assert.Contains(result.GeneratedSources, static source => source.HintName.Contains(".orleans.ser.", StringComparison.Ordinal)); + Assert.Contains(result.GeneratedSources, static source => source.HintName.Contains(".orleans.proxy.", StringComparison.Ordinal)); + Assert.Contains(result.GeneratedSources, static source => source.HintName.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)); + + var generatedCompilation = compilation.AddSyntaxTrees(CreateGeneratedSyntaxTrees(result)); + var generatedTreePaths = result.GeneratedSources + .Select(static source => source.HintName) + .ToHashSet(StringComparer.Ordinal); + var generatedErrors = generatedCompilation.GetDiagnostics() + .Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error + && diagnostic.Location.SourceTree is { } tree + && generatedTreePaths.Contains(tree.FilePath)) + .Select(static diagnostic => $"{diagnostic.Location.GetLineSpan().Path}: {diagnostic.Id} {diagnostic.Severity}: {diagnostic.GetMessage()}") + .ToArray(); + + Assert.True( + generatedErrors.Length == 0, + string.Join(Environment.NewLine, generatedErrors)); + } + + private static async Task CreateCompilationWithDocumentationDiagnostics( + string source, + params MetadataReference[] additionalReferences) + { + var compilation = await TestCompilationHelper.CreateCompilation(source, additionalReferences: additionalReferences); + var syntaxTree = CSharpSyntaxTree.ParseText( + source, + DocumentationParseOptions, + path: "WarningSuppressionInput.cs", + encoding: Encoding.UTF8); + var options = compilation.Options.WithSpecificDiagnosticOptions( + compilation.Options.SpecificDiagnosticOptions + .SetItem(MissingXmlCommentDiagnosticId, ReportDiagnostic.Error) + .SetItem(PublicApiAnalyzerDiagnosticId, ReportDiagnostic.Error) + .SetItem(CompilerApiAnalyzerDiagnosticId, ReportDiagnostic.Error)); + + return compilation + .ReplaceSyntaxTree(compilation.SyntaxTrees.Single(), syntaxTree) + .WithOptions(options); + } + + private static GeneratorRunResult RunGenerator(CSharpCompilation compilation) + { + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create(generator); + driver = driver.RunGenerators(compilation); + return driver.GetRunResult().Results.Single(); + } + + private static IEnumerable CreateGeneratedSyntaxTrees(GeneratorRunResult result) + { + foreach (var source in result.GeneratedSources) + { + yield return CSharpSyntaxTree.ParseText( + source.SourceText, + DocumentationParseOptions, + path: source.HintName); + } + } +} diff --git a/test/Orleans.CodeGenerator.Tests/HintNameCollisionTests.cs b/test/Orleans.CodeGenerator.Tests/HintNameCollisionTests.cs new file mode 100644 index 00000000000..900ea579cff --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/HintNameCollisionTests.cs @@ -0,0 +1,161 @@ +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Characterization tests for generated source hint-name collisions. +/// +public class HintNameCollisionTests +{ + [Fact] + public async Task GenerateSerializerTypesWithCollidingSanitizedHintNamesAreBothEmitted() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public sealed class Colliding + { + [Id(0)] + public T Value { get; set; } = default!; + } + + [GenerateSerializer] + public sealed class Colliding_T + { + [Id(0)] + public int Value { get; set; } + } + """; + + Assert.Equal( + SanitizeHintComponent("global::TestProject.Colliding"), + SanitizeHintComponent("global::TestProject.Colliding_T")); + + var (result, outputCompilation) = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + + var serializerSources = GetGeneratedSources(result, ".orleans.ser."); + Assert.Equal(2, serializerSources.Length); + + var serializerClassNames = GetGeneratedClassNames(serializerSources, "Codec_"); + Assert.Contains("Codec_Colliding", serializerClassNames); + Assert.Contains("Codec_Colliding_T", serializerClassNames); + + AssertNoCompilationErrors(outputCompilation); + } + + [Fact] + public async Task GrainInterfacesWithCollidingSanitizedProxyHintNamesAreBothEmitted() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface ICollidingGrain : IGrainWithGuidKey + { + Task Ping(); + } + + public interface ICollidingGrain_T : IGrainWithGuidKey + { + Task Ping(); + } + """; + + Assert.Equal( + SanitizeHintComponent("global::TestProject.ICollidingGrain"), + SanitizeHintComponent("global::TestProject.ICollidingGrain_T")); + + var (result, outputCompilation) = await RunGenerator(code); + + Assert.Empty(result.Diagnostics); + + var proxySources = GetGeneratedSources(result, ".orleans.proxy."); + Assert.Equal(2, proxySources.Length); + + var proxyClassNames = GetGeneratedClassNames(proxySources, "Proxy_"); + Assert.Contains("Proxy_ICollidingGrain", proxyClassNames); + Assert.Contains("Proxy_ICollidingGrain_T", proxyClassNames); + + AssertNoCompilationErrors(outputCompilation); + } + + private static async Task<(GeneratorRunResult Result, Compilation OutputCompilation)> RunGenerator( + string code, + string assemblyName = "TestProject") + { + var compilation = await TestCompilationHelper.CreateCompilation(code, assemblyName); + AssertNoCompilationErrors(compilation); + + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions(default)); + + driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var generatorDiagnostics); + AssertNoErrors(generatorDiagnostics); + + return (driver.GetRunResult().Results.Single(), outputCompilation); + } + + private static GeneratedSourceResult[] GetGeneratedSources(GeneratorRunResult result, string hintNameFragment) + => result.GeneratedSources + .Where(source => source.HintName.Contains(hintNameFragment, StringComparison.Ordinal)) + .OrderBy(source => source.HintName, StringComparer.Ordinal) + .ToArray(); + + private static string[] GetGeneratedClassNames(IEnumerable generatedSources, string classNamePrefix) + => generatedSources + .SelectMany(static source => CSharpSyntaxTree.ParseText(source.SourceText.ToString().TrimStart('\uFEFF')) + .GetCompilationUnitRoot() + .DescendantNodes() + .OfType()) + .Select(static declaration => declaration.Identifier.ValueText) + .Where(name => name.StartsWith(classNamePrefix, StringComparison.Ordinal)) + .Distinct(StringComparer.Ordinal) + .OrderBy(static name => name, StringComparer.Ordinal) + .ToArray(); + + private static void AssertNoCompilationErrors(Compilation compilation) + => AssertNoErrors(compilation.GetDiagnostics()); + + private static void AssertNoErrors(IEnumerable diagnostics) + { + var errors = diagnostics + .Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .ToArray(); + + Assert.True(errors.Length == 0, string.Join(Environment.NewLine, errors.Select(static error => error.ToString()))); + } + + private static string SanitizeHintComponent(string value) + { + var builder = new StringBuilder(value.Length); + var previousCharacterWasUnderscore = false; + foreach (var character in value) + { + if (char.IsLetterOrDigit(character) || character is '_' or '.') + { + builder.Append(character); + previousCharacterWasUnderscore = false; + } + else if (!previousCharacterWasUnderscore) + { + builder.Append('_'); + previousCharacterWasUnderscore = true; + } + } + + var result = builder.ToString().Trim('_', '.'); + return result.Length > 0 ? result : "generated"; + } +} diff --git a/test/Orleans.CodeGenerator.Tests/IncrementalCachingTests.cs b/test/Orleans.CodeGenerator.Tests/IncrementalCachingTests.cs new file mode 100644 index 00000000000..1d186caf839 --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/IncrementalCachingTests.cs @@ -0,0 +1,981 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Diagnostics; +using Orleans.CodeGenerator.Diagnostics; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Tests that verify the Orleans incremental source generator correctly caches +/// pipeline outputs when inputs have not changed, avoiding unnecessary regeneration. +/// +public class IncrementalCachingTests +{ + [Fact] + public async Task UnchangedSource_ProducesCachedOutput() + { + const string code = """ + using Orleans; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + """; + + var compilation = await CreateCompilation(code); + var (_, result2) = await RunTwice(compilation, compilation); + + AssertAllOutputsCachedOrUnchanged(result2); + } + + [Fact] + public async Task ChangedSerializableType_TriggersRegeneration() + { + const string originalCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + """; + + const string modifiedCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + + [Id(1)] + public int Age { get; set; } + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (_, result2) = await RunTwice(compilation, newCompilation); + + AssertAnyOutputModifiedOrNew(result2); + } + + [Fact] + public async Task UnrelatedChange_DoesNotTriggerRegeneration() + { + const string originalCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + """; + + const string modifiedCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + + public class UnrelatedClass + { + public int Value { get; set; } + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertTrackedStepsCachedOrUnchanged( + result2, + OrleansSerializationSourceGenerator.SerializableTypeResultsTrackingName, + OrleansSerializationSourceGenerator.CollectedSerializableTypesTrackingName, + OrleansSerializationSourceGenerator.SerializerOutputsTrackingName, + OrleansSerializationSourceGenerator.ReferenceAssemblyDataTrackingName, + OrleansSerializationSourceGenerator.MetadataAggregateTrackingName, + OrleansSerializationSourceGenerator.MetadataOutputsTrackingName); + AssertGeneratedSourcesIdentical(result1, result2); + } + + [Fact] + public async Task AddingNewSerializableType_TriggersRegeneration() + { + const string originalCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class TypeA + { + [Id(0)] + public string Name { get; set; } + } + """; + + const string modifiedCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class TypeA + { + [Id(0)] + public string Name { get; set; } + } + + [GenerateSerializer] + public sealed class TypeB + { + [Id(0)] + public int Value { get; set; } + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertAnyOutputModifiedOrNew(result2); + + // Second run should produce more generated sources than the first + Assert.True( + result2.GeneratedSources.Length >= result1.GeneratedSources.Length, + "Adding a new serializable type should produce at least as many generated sources."); + } + + [Fact] + public async Task RemovingSerializableType_TriggersRegeneration() + { + const string originalCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class TypeA + { + [Id(0)] + public string Name { get; set; } + } + + [GenerateSerializer] + public sealed class TypeB + { + [Id(0)] + public int Value { get; set; } + } + """; + + const string modifiedCode = """ + using Orleans; + + [GenerateSerializer] + public sealed class TypeA + { + [Id(0)] + public string Name { get; set; } + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertAnyOutputModifiedOrNew(result2); + Assert.True( + result2.GeneratedSources.Length <= result1.GeneratedSources.Length, + "Removing a serializable type should produce fewer or equal generated sources."); + } + + [Fact] + public async Task ChangedProxyInterface_TriggersRegeneration() + { + const string originalCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + Task SayHello(string name); + } + """; + + const string modifiedCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + Task SayHello(string name); + Task GetCount(); + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (_, result2) = await RunTwice(compilation, newCompilation); + + AssertAnyOutputModifiedOrNew(result2); + } + + [Fact] + public async Task UnchangedProxyInterface_ProducesCachedOutput() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + Task SayHello(string name); + } + """; + + var compilation = await CreateCompilation(code); + var (_, result2) = await RunTwice(compilation, compilation); + + AssertAllOutputsCachedOrUnchanged(result2); + } + + [Fact] + public async Task AddedSyntaxTreeWithoutProxyInterfaces_ProducesIdenticalOutput() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IMyGrain : IGrainWithIntegerKey + { + Task SayHello(string name); + } + """; + + const string additionalFile = """ + namespace TestProject; + + public static class Helpers + { + public static string Format(string input) => input.ToUpperInvariant(); + } + """; + + var compilation = await CreateCompilation(code); + var newCompilation = compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(additionalFile)); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertTrackedStepsCachedOrUnchanged( + result2, + OrleansSerializationSourceGenerator.InheritedProxyInterfacesTrackingName, + OrleansSerializationSourceGenerator.CollectedProxyInterfacesTrackingName, + OrleansSerializationSourceGenerator.PreparedProxyOutputsTrackingName, + OrleansSerializationSourceGenerator.ProxyOutputsTrackingName, + OrleansSerializationSourceGenerator.MetadataAggregateTrackingName, + OrleansSerializationSourceGenerator.MetadataOutputsTrackingName); + AssertGeneratedSourcesIdentical(result1, result2); + } + + [Fact] + public async Task AddedSyntaxTreeWithoutSerializableTypes_ProducesIdenticalOutput() + { + const string code = """ + using Orleans; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + """; + + const string additionalFile = """ + namespace TestProject; + + public static class Helpers + { + public static string Format(string input) => input.ToUpperInvariant(); + } + """; + + var compilation = await CreateCompilation(code); + var newCompilation = compilation.AddSyntaxTrees(CSharpSyntaxTree.ParseText(additionalFile)); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertTrackedStepsCachedOrUnchanged( + result2, + OrleansSerializationSourceGenerator.SerializableTypeResultsTrackingName, + OrleansSerializationSourceGenerator.CollectedSerializableTypesTrackingName, + OrleansSerializationSourceGenerator.SerializerOutputsTrackingName, + OrleansSerializationSourceGenerator.ReferenceAssemblyDataTrackingName, + OrleansSerializationSourceGenerator.MetadataAggregateTrackingName, + OrleansSerializationSourceGenerator.MetadataOutputsTrackingName); + AssertGeneratedSourcesIdentical(result1, result2); + } + + [Fact] + public async Task MixedSerializableAndProxy_BothCachedWhenUnchanged() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + } + """; + + var compilation = await CreateCompilation(code); + var (_, result2) = await RunTwice(compilation, compilation); + + AssertAllOutputsCachedOrUnchanged(result2); + } + + [Fact] + public async Task MixedSerializableAndProxy_OnlySerializableChanged_ProducesModifiedOutput() + { + const string originalCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + } + """; + + const string modifiedCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + + [Id(1)] + public int Age { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (_, result2) = await RunTwice(compilation, newCompilation); + + AssertAnyOutputModifiedOrNew(result2); + } + + [Fact] + public async Task MixedSerializableAndProxy_OnlySerializableChanged_LeavesProxyOutputsUnchanged() + { + const string originalCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + } + """; + + const string modifiedCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + + [Id(1)] + public int Age { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertSourcesChanged(result1, result2, static hint => hint.Contains(".orleans.ser.", StringComparison.Ordinal)); + AssertSourcesUnchanged(result1, result2, static hint => hint.Contains(".orleans.proxy.", StringComparison.Ordinal)); + } + + [Fact] + public async Task MixedSerializableAndProxy_OnlyProxyChanged_LeavesSerializerOutputsUnchanged() + { + const string originalCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + } + """; + + const string modifiedCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class MyDto + { + [Id(0)] + public string Name { get; set; } + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task GetDto(); + Task Ping(); + } + """; + + var compilation = await CreateCompilation(originalCode); + var newCompilation = ReplaceSource(compilation, modifiedCode); + var (result1, result2) = await RunTwice(compilation, newCompilation); + + AssertSourcesChanged(result1, result2, static hint => hint.Contains(".orleans.proxy.", StringComparison.Ordinal)); + AssertSourcesUnchanged(result1, result2, static hint => hint.Contains(".orleans.ser.", StringComparison.Ordinal)); + } + + [Fact] + public async Task ChangedReferenceAssembly_InvalidatesReferenceAssemblyPipelineAndDropsStaleOutputs() + { + const string libraryV1Code = """ + using Orleans; + + namespace LibraryProject; + + public sealed class Marker + { + } + + [GenerateSerializer] + public sealed class ReferencedDto + { + [Id(0)] + public string LegacyValue { get; set; } = string.Empty; + } + """; + + const string libraryV2Code = """ + using Orleans; + + namespace LibraryProject; + + public sealed class Marker + { + } + + [GenerateSerializer] + public sealed class ReferencedDto + { + [Id(0)] + public string CurrentValue { get; set; } = string.Empty; + + [Id(1)] + public int Version { get; set; } + } + """; + + const string consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Marker))] + """; + + var consumerV1 = await CreateConsumerCompilationWithLibrary(libraryV1Code, consumerCode); + var consumerV2 = await CreateConsumerCompilationWithLibrary(libraryV2Code, consumerCode); + var (result1, result2) = await RunTwice(consumerV1, consumerV2); + + Assert.Empty(result1.Diagnostics); + Assert.Empty(result2.Diagnostics); + AssertTrackedStepModifiedOrNew(result2, OrleansSerializationSourceGenerator.ReferenceAssemblyDataTrackingName); + AssertTrackedStepModifiedOrNew(result2, OrleansSerializationSourceGenerator.ReferencedSerializerOutputsTrackingName); + + var firstGeneratedSource = ConcatenateGeneratedSources(result1); + var secondGeneratedSource = ConcatenateGeneratedSources(result2); + Assert.Contains("LegacyValue", firstGeneratedSource, StringComparison.Ordinal); + Assert.DoesNotContain("CurrentValue", firstGeneratedSource, StringComparison.Ordinal); + Assert.DoesNotContain("LegacyValue", secondGeneratedSource, StringComparison.Ordinal); + Assert.Contains("CurrentValue", secondGeneratedSource, StringComparison.Ordinal); + Assert.Contains("Version", secondGeneratedSource, StringComparison.Ordinal); + } + + [Fact] + public async Task SameDriverSameCompilation_ProducesIdenticalDiagnosticsAndSources() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class StableDto + { + [Id(0)] + public string Name { get; set; } = string.Empty; + } + + public interface IStableGrain : IGrainWithIntegerKey + { + Task Get(); + } + """; + + var compilation = await CreateCompilation(code); + var (result1, result2) = await RunTwice(compilation, compilation); + + AssertNoDuplicateHintNames(result1); + AssertNoDuplicateHintNames(result2); + AssertDiagnosticsIdentical(result1.Diagnostics, result2.Diagnostics); + AssertGeneratedSourcesIdentical(result1, result2); + AssertAllOutputsCachedOrUnchanged(result2); + } + + [Fact] + public async Task UnrelatedAnalyzerConfigOption_DoesNotInvalidateGeneratedModels() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class StableDto + { + [Id(0)] + public string Name { get; set; } = string.Empty; + } + + public interface IStableGrain : IGrainWithIntegerKey + { + Task Get(); + } + """; + + var compilation = await CreateCompilation(code); + var (result1, result2) = await RunTwice( + compilation, + compilation, + new Dictionary + { + ["build_property.unrelated_option"] = "before", + }, + new Dictionary + { + ["build_property.unrelated_option"] = "after", + }); + + Assert.Empty(result1.Diagnostics); + Assert.Empty(result2.Diagnostics); + AssertGeneratedSourcesIdentical(result1, result2); + AssertTrackedStepsCachedOrUnchanged( + result2, + OrleansSerializationSourceGenerator.SerializableTypeResultsTrackingName, + OrleansSerializationSourceGenerator.CollectedSerializableTypesTrackingName, + OrleansSerializationSourceGenerator.SerializerOutputsTrackingName, + OrleansSerializationSourceGenerator.InheritedProxyInterfacesTrackingName, + OrleansSerializationSourceGenerator.CollectedProxyInterfacesTrackingName, + OrleansSerializationSourceGenerator.PreparedProxyOutputsTrackingName, + OrleansSerializationSourceGenerator.ProxyOutputsTrackingName, + OrleansSerializationSourceGenerator.ReferenceAssemblyDataTrackingName, + OrleansSerializationSourceGenerator.MetadataAggregateTrackingName, + OrleansSerializationSourceGenerator.MetadataOutputsTrackingName); + } + + [Fact] + public async Task GenerateFieldIdsOption_ChangesSerializerDiagnosticsWithoutChangingProxyOutputs() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class OptionDto + { + public string Name { get; set; } = string.Empty; + public int Age { get; set; } + } + + public interface IOptionGrain : IGrainWithIntegerKey + { + Task Ping(); + } + """; + + var compilation = await CreateCompilation(code); + var (baselineResult, configuredResult) = await RunTwice( + compilation, + compilation, + firstGlobalOptions: null, + secondGlobalOptions: new Dictionary + { + ["build_property.orleans_generatefieldids"] = "PublicProperties", + }); + + Assert.Contains(baselineResult.Diagnostics, diagnostic => diagnostic.Id == DiagnosticRuleId.CanNotGenerateImplicitFieldIds); + Assert.Empty(configuredResult.Diagnostics); + Assert.Contains(configuredResult.GeneratedSources, static source => source.HintName.Contains(".orleans.ser.", StringComparison.Ordinal)); + AssertSourcesUnchanged(baselineResult, configuredResult, static hint => hint.Contains(".orleans.proxy.", StringComparison.Ordinal)); + AssertTrackedStepModifiedOrNew(configuredResult, OrleansSerializationSourceGenerator.SerializableTypeResultsTrackingName); + AssertTrackedStepModifiedOrNew(configuredResult, OrleansSerializationSourceGenerator.SerializerOutputsTrackingName); + } + + [Fact] + public async Task CompatibilityInvokersOption_ChangesProxyOutputsWithoutChangingSerializerOutputs() + { + const string code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public sealed class StableDto + { + [Id(0)] + public string Name { get; set; } = string.Empty; + } + + public interface IBaseOptionGrain : IGrainWithIntegerKey + { + Task Get(); + } + + public interface IDerivedOptionGrain : IBaseOptionGrain + { + } + """; + + var compilation = await CreateCompilation(code); + var (baselineResult, configuredResult) = await RunTwice( + compilation, + compilation, + firstGlobalOptions: null, + secondGlobalOptions: new Dictionary + { + ["build_property.orleansgeneratecompatibilityinvokers"] = "true", + }); + + Assert.Empty(baselineResult.Diagnostics); + Assert.Empty(configuredResult.Diagnostics); + AssertSourcesUnchanged(baselineResult, configuredResult, static hint => hint.Contains(".orleans.ser.", StringComparison.Ordinal)); + AssertSourcesChanged(baselineResult, configuredResult, static hint => hint.Contains(".orleans.proxy.", StringComparison.Ordinal)); + AssertTrackedStepModifiedOrNew(configuredResult, OrleansSerializationSourceGenerator.PreparedProxyOutputsTrackingName); + AssertTrackedStepModifiedOrNew(configuredResult, OrleansSerializationSourceGenerator.ProxyOutputsTrackingName); + } + + #region Helpers + + private static CSharpCompilation ReplaceSource(CSharpCompilation compilation, string newSource) + { + var newTree = CSharpSyntaxTree.ParseText(newSource); + return compilation.ReplaceSyntaxTree(compilation.SyntaxTrees.First(), newTree); + } + + private static async Task<(GeneratorRunResult First, GeneratorRunResult Second)> RunTwice( + CSharpCompilation firstCompilation, + CSharpCompilation secondCompilation, + IReadOnlyDictionary? firstGlobalOptions = null, + IReadOnlyDictionary? secondGlobalOptions = null) + { + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + var hasOptions = firstGlobalOptions is not null || secondGlobalOptions is not null; + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + optionsProvider: hasOptions ? new TestAnalyzerConfigOptionsProvider(firstGlobalOptions) : null, + driverOptions: new GeneratorDriverOptions( + disabledOutputs: default, + trackIncrementalGeneratorSteps: true)); + + driver = driver.RunGenerators(firstCompilation); + var result1 = driver.GetRunResult(); + Assert.NotEmpty(result1.Results[0].GeneratedSources); + + if (hasOptions) + { + driver = driver.WithUpdatedAnalyzerConfigOptions(new TestAnalyzerConfigOptionsProvider(secondGlobalOptions)); + } + + driver = driver.RunGenerators(secondCompilation); + var result2 = driver.GetRunResult(); + + await Task.CompletedTask; + return (result1.Results[0], result2.Results[0]); + } + + private static void AssertAllOutputsCachedOrUnchanged(GeneratorRunResult result) + { + var outputSteps = result.TrackedOutputSteps; + Assert.NotEmpty(outputSteps); + + foreach (var (stepName, steps) in outputSteps) + { + foreach (var step in steps) + { + foreach (var (_, reason) in step.Outputs) + { + Assert.True( + reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged, + $"Step '{stepName}' had reason '{reason}' — expected Cached or Unchanged."); + } + } + } + } + + private static void AssertAnyOutputModifiedOrNew(GeneratorRunResult result) + { + var outputSteps = result.TrackedOutputSteps; + Assert.NotEmpty(outputSteps); + + var allReasons = outputSteps + .SelectMany(kvp => kvp.Value) + .SelectMany(step => step.Outputs) + .Select(o => o.Reason) + .ToList(); + + Assert.Contains(allReasons, reason => + reason is IncrementalStepRunReason.Modified or IncrementalStepRunReason.New); + } + + private static void AssertTrackedStepsCachedOrUnchanged(GeneratorRunResult result, params string[] stepNames) + { + var trackedSteps = result.TrackedSteps; + Assert.NotEmpty(trackedSteps); + + foreach (var stepName in stepNames) + { + Assert.True(trackedSteps.TryGetValue(stepName, out var steps), $"Missing tracked step '{stepName}'."); + Assert.NotEmpty(steps); + + foreach (var step in steps) + { + foreach (var (_, reason) in step.Outputs) + { + Assert.True( + reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged, + $"Step '{stepName}' had reason '{reason}' — expected Cached or Unchanged."); + } + } + } + } + + private static void AssertTrackedStepModifiedOrNew(GeneratorRunResult result, string stepName) + { + var trackedSteps = result.TrackedSteps; + Assert.NotEmpty(trackedSteps); + Assert.True(trackedSteps.TryGetValue(stepName, out var steps), $"Missing tracked step '{stepName}'."); + + var reasons = steps + .SelectMany(static step => step.Outputs) + .Select(static output => output.Reason) + .ToArray(); + Assert.Contains(reasons, static reason => reason is IncrementalStepRunReason.Modified or IncrementalStepRunReason.New); + } + + private static void AssertGeneratedSourcesIdentical(GeneratorRunResult result1, GeneratorRunResult result2) + { + Assert.Equal(result1.GeneratedSources.Length, result2.GeneratedSources.Length); + + var sources1 = result1.GeneratedSources.OrderBy(s => s.HintName).ToList(); + var sources2 = result2.GeneratedSources.OrderBy(s => s.HintName).ToList(); + + for (int i = 0; i < sources1.Count; i++) + { + Assert.Equal(sources1[i].HintName, sources2[i].HintName); + Assert.Equal(sources1[i].SourceText.ToString(), sources2[i].SourceText.ToString()); + } + } + + private static void AssertDiagnosticsIdentical(IEnumerable diagnostics, IEnumerable otherDiagnostics) + => Assert.Equal( + diagnostics.Select(GetDiagnosticShape).OrderBy(static value => value, StringComparer.Ordinal), + otherDiagnostics.Select(GetDiagnosticShape).OrderBy(static value => value, StringComparer.Ordinal)); + + private static string GetDiagnosticShape(Diagnostic diagnostic) + { + var lineSpan = diagnostic.Location.GetLineSpan(); + return string.Join( + "|", + diagnostic.Id, + diagnostic.Severity.ToString(), + diagnostic.GetMessage(), + lineSpan.Path ?? string.Empty, + lineSpan.StartLinePosition.Line.ToString(), + lineSpan.StartLinePosition.Character.ToString()); + } + + private static void AssertNoDuplicateHintNames(GeneratorRunResult result) + { + var duplicateHintNames = result.GeneratedSources + .GroupBy(static source => source.HintName, StringComparer.Ordinal) + .Where(static group => group.Count() > 1) + .Select(static group => group.Key) + .ToArray(); + + Assert.True(duplicateHintNames.Length == 0, $"Duplicate generated source hint names: {string.Join(", ", duplicateHintNames)}"); + } + + private static void AssertSourcesChanged( + GeneratorRunResult result1, + GeneratorRunResult result2, + Func predicate) + { + var sourceMap1 = GetGeneratedSourceMap(result1); + var sourceMap2 = GetGeneratedSourceMap(result2); + var matchingHints = sourceMap1.Keys.Intersect(sourceMap2.Keys).Where(predicate).ToList(); + + Assert.NotEmpty(matchingHints); + Assert.Contains(matchingHints, hint => !string.Equals(sourceMap1[hint], sourceMap2[hint], StringComparison.Ordinal)); + } + + private static void AssertSourcesUnchanged( + GeneratorRunResult result1, + GeneratorRunResult result2, + Func predicate) + { + var sourceMap1 = GetGeneratedSourceMap(result1); + var sourceMap2 = GetGeneratedSourceMap(result2); + var matchingHints = sourceMap1.Keys.Intersect(sourceMap2.Keys).Where(predicate).ToList(); + + Assert.NotEmpty(matchingHints); + Assert.All(matchingHints, hint => Assert.Equal(sourceMap1[hint], sourceMap2[hint])); + } + + private static Dictionary GetGeneratedSourceMap(GeneratorRunResult result) + => result.GeneratedSources.ToDictionary(source => source.HintName, source => source.SourceText.ToString(), StringComparer.Ordinal); + + private static string ConcatenateGeneratedSources(GeneratorRunResult result) + => string.Join( + Environment.NewLine, + result.GeneratedSources + .OrderBy(static source => source.HintName, StringComparer.Ordinal) + .Select(static source => source.SourceText.ToString())); + + private static async Task CreateConsumerCompilationWithLibrary( + string libraryCode, + string consumerCode) + { + var libraryCompilation = await CreateCompilation(libraryCode, "LibraryProject"); + Assert.Empty(libraryCompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var consumerCompilation = await TestCompilationHelper.CreateCompilation( + consumerCode, + "ConsumerProject", + libraryCompilation.ToMetadataReference()); + Assert.Empty(consumerCompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + return consumerCompilation; + } + + private static Task CreateCompilation(string sourceCode, string assemblyName = "TestProject") + => TestCompilationHelper.CreateCompilation(sourceCode, assemblyName); + + private sealed class TestAnalyzerConfigOptionsProvider : AnalyzerConfigOptionsProvider + { + private static readonly AnalyzerConfigOptions EmptyOptions = new TestAnalyzerConfigOptions(new Dictionary()); + private readonly AnalyzerConfigOptions _globalOptions; + + public TestAnalyzerConfigOptionsProvider(IReadOnlyDictionary? globalOptions) + { + _globalOptions = new TestAnalyzerConfigOptions(globalOptions ?? new Dictionary()); + } + + public override AnalyzerConfigOptions GlobalOptions => _globalOptions; + + public override AnalyzerConfigOptions GetOptions(SyntaxTree tree) => EmptyOptions; + + public override AnalyzerConfigOptions GetOptions(AdditionalText textFile) => EmptyOptions; + } + + private sealed class TestAnalyzerConfigOptions : AnalyzerConfigOptions + { + private readonly IReadOnlyDictionary _options; + + public TestAnalyzerConfigOptions(IReadOnlyDictionary options) + { + _options = options; + } + + public override bool TryGetValue(string key, out string value) => _options.TryGetValue(key, out value!); + } + + #endregion +} diff --git a/test/Orleans.CodeGenerator.Tests/IncrementalModelEqualityTests.cs b/test/Orleans.CodeGenerator.Tests/IncrementalModelEqualityTests.cs new file mode 100644 index 00000000000..80271f5cf7c --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/IncrementalModelEqualityTests.cs @@ -0,0 +1,259 @@ +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Orleans.CodeGenerator; +using Orleans.CodeGenerator.Model; +using Xunit; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Tests structural equality behavior that Roslyn incremental caching depends on. +/// +public class IncrementalModelEqualityTests +{ + [Fact] + public void StructuralEquality_DefaultArray_EqualsEmptyArray() + { + var defaultArray = default(ImmutableArray); + var emptyArray = ImmutableArray.Empty; + + Assert.True(StructuralEquality.SequenceEqual(defaultArray, emptyArray)); + Assert.Equal( + StructuralEquality.GetSequenceHashCode(defaultArray), + StructuralEquality.GetSequenceHashCode(emptyArray)); + } + + [Fact] + public void StructuralEquality_UsesElementValues() + { + var left = ImmutableArray.Create("alpha", "beta"); + var right = ImmutableArray.Create("alpha", "beta"); + var different = ImmutableArray.Create("alpha", "gamma"); + + Assert.True(StructuralEquality.SequenceEqual(left, right)); + Assert.Equal( + StructuralEquality.GetSequenceHashCode(left), + StructuralEquality.GetSequenceHashCode(right)); + Assert.False(StructuralEquality.SequenceEqual(left, different)); + } + + [Fact] + public void TypeRef_ToTypeSyntax_ProducesValidSyntax() + { + var typeRef = new TypeRef("global::System.Collections.Generic.List"); + + Assert.Equal("global::System.Collections.Generic.List", typeRef.ToTypeSyntax().ToString()); + Assert.True(TypeRef.Empty.IsEmpty); + } + + [Fact] + public void CompoundAliasComponentModel_DistinguishesStringAndTypeComponents() + { + var stringComponent = new CompoundAliasComponentModel("part"); + var matchingStringComponent = new CompoundAliasComponentModel("part"); + var typeComponent = new CompoundAliasComponentModel(new TypeRef("global::Example.Part")); + + Assert.Equal(stringComponent, matchingStringComponent); + Assert.NotEqual(stringComponent, typeComponent); + Assert.True(stringComponent.IsString); + Assert.True(typeComponent.IsType); + } + + [Fact] + public void SerializableTypeModel_DefaultArrays_AreEqualToEmptyArrays() + { + var defaultArrays = CreateSerializableTypeModel("MyType", "MyNamespace", members: default); + var emptyArrays = CreateSerializableTypeModel("MyType", "MyNamespace", members: ImmutableArray.Empty); + + Assert.Equal(defaultArrays, emptyArrays); + Assert.Equal(defaultArrays.GetHashCode(), emptyArrays.GetHashCode()); + } + + [Fact] + public void SerializableTypeModel_DifferentStructuralArrayValues_AreNotEqual() + { + var oneMember = CreateSerializableTypeModel( + "MyType", + "MyNamespace", + members: ImmutableArray.Create(CreateMemberModel(0, "Value", "int"))); + var twoMembers = CreateSerializableTypeModel( + "MyType", + "MyNamespace", + members: ImmutableArray.Create( + CreateMemberModel(0, "Value", "int"), + CreateMemberModel(1, "Other", "int"))); + + Assert.NotEqual(oneMember, twoMembers); + } + + [Fact] + public void MetadataAggregateModel_CreateMetadataAggregate_MergesAndSortsDeterministically() + { + var aggregate = ModelExtractor.CreateMetadataAggregate( + "TestAssembly", + ImmutableArray.Create( + CreateSerializableTypeModel("ZuluType", "MyNamespace"), + CreateSerializableTypeModel("AlphaType", "MyNamespace")), + ImmutableArray.Create( + CreateProxyInterfaceModel("IZulu"), + CreateProxyInterfaceModel("IAlpha")), + CreateReferenceAssemblyModel( + applicationParts: ImmutableArray.Create("PartZ", "PartA"), + referencedSerializableTypes: ImmutableArray.Create( + CreateSerializableTypeModel("MiddleType", "MyNamespace"), + CreateSerializableTypeModel("AlphaType", "MyNamespace")), + referencedProxyInterfaces: ImmutableArray.Create( + CreateProxyInterfaceModel("IMiddle"), + CreateProxyInterfaceModel("IAlpha")), + registeredCodecs: ImmutableArray.Create( + new RegisteredCodecModel(new TypeRef("global::Codecs.ZuluCodec"), RegisteredCodecKind.Serializer), + new RegisteredCodecModel(new TypeRef("global::Codecs.AlphaCodec"), RegisteredCodecKind.Serializer)), + interfaceImplementations: ImmutableArray.Create( + new InterfaceImplementationModel(new TypeRef("global::Impl.Zulu")), + new InterfaceImplementationModel(new TypeRef("global::Impl.Alpha"))))); + + Assert.Equal( + ["global::MyNamespace.AlphaType", "global::MyNamespace.MiddleType", "global::MyNamespace.ZuluType"], + aggregate.SerializableTypes.Select(static type => type.TypeSyntax.SyntaxString).ToArray()); + Assert.Equal( + ["global::MyNamespace.IAlpha", "global::MyNamespace.IMiddle", "global::MyNamespace.IZulu"], + aggregate.ProxyInterfaces.Select(static proxy => proxy.InterfaceType.SyntaxString).ToArray()); + Assert.Equal( + ["global::Codecs.AlphaCodec", "global::Codecs.ZuluCodec"], + aggregate.RegisteredCodecs.Select(static codec => codec.Type.SyntaxString).ToArray()); + Assert.Equal( + ["global::Impl.Alpha", "global::Impl.Zulu"], + aggregate.InterfaceImplementations.Select(static implementation => implementation.ImplementationType.SyntaxString).ToArray()); + Assert.Equal( + ["PartZ", "PartA"], + aggregate.ReferenceAssemblyData.ApplicationParts.ToArray()); + } + + [Fact] + public void MetadataAggregateModel_CreateMetadataAggregate_OrderIndependentInputs_AreEqual() + { + var serializableAlpha = CreateSerializableTypeModel("AlphaType", "MyNamespace"); + var serializableBeta = CreateSerializableTypeModel("BetaType", "MyNamespace"); + var serializableGamma = CreateSerializableTypeModel("GammaType", "MyNamespace"); + var proxyAlpha = CreateProxyInterfaceModel("IAlpha"); + var proxyBeta = CreateProxyInterfaceModel("IBeta"); + var proxyGamma = CreateProxyInterfaceModel("IGamma"); + + var aggregateA = ModelExtractor.CreateMetadataAggregate( + "TestAssembly", + ImmutableArray.Create(serializableBeta, serializableAlpha), + ImmutableArray.Create(proxyBeta, proxyAlpha), + CreateReferenceAssemblyModel( + applicationParts: ImmutableArray.Create("PartZ", "PartA"), + referencedSerializableTypes: ImmutableArray.Create(serializableGamma, serializableAlpha), + referencedProxyInterfaces: ImmutableArray.Create(proxyGamma, proxyAlpha), + registeredCodecs: ImmutableArray.Create( + new RegisteredCodecModel(new TypeRef("global::Codecs.ZuluCodec"), RegisteredCodecKind.Serializer), + new RegisteredCodecModel(new TypeRef("global::Codecs.AlphaCodec"), RegisteredCodecKind.Serializer)), + interfaceImplementations: ImmutableArray.Create( + new InterfaceImplementationModel(new TypeRef("global::Impl.Zulu")), + new InterfaceImplementationModel(new TypeRef("global::Impl.Alpha"))))); + + var aggregateB = ModelExtractor.CreateMetadataAggregate( + "TestAssembly", + ImmutableArray.Create(serializableAlpha, serializableBeta), + ImmutableArray.Create(proxyAlpha, proxyBeta), + CreateReferenceAssemblyModel( + applicationParts: ImmutableArray.Create("PartZ", "PartA"), + referencedSerializableTypes: ImmutableArray.Create(serializableAlpha, serializableGamma), + referencedProxyInterfaces: ImmutableArray.Create(proxyAlpha, proxyGamma), + registeredCodecs: ImmutableArray.Create( + new RegisteredCodecModel(new TypeRef("global::Codecs.AlphaCodec"), RegisteredCodecKind.Serializer), + new RegisteredCodecModel(new TypeRef("global::Codecs.ZuluCodec"), RegisteredCodecKind.Serializer)), + interfaceImplementations: ImmutableArray.Create( + new InterfaceImplementationModel(new TypeRef("global::Impl.Alpha")), + new InterfaceImplementationModel(new TypeRef("global::Impl.Zulu"))))); + + Assert.Equal(aggregateA, aggregateB); + Assert.Equal(aggregateA.GetHashCode(), aggregateB.GetHashCode()); + } + + private static MemberModel CreateMemberModel(uint fieldId, string name, string type) => new( + fieldId, + name, + new TypeRef(type), + new TypeRef("global::MyNamespace.MyType"), + "TestAssembly", + type.Replace("global::", string.Empty).Replace(".", "_"), + isPrimaryConstructorParameter: false, + isSerializable: true, + isCopyable: true, + MemberKind.Property, + AccessStrategy.Direct, + AccessStrategy.Direct, + isObsolete: false, + hasImmutableAttribute: false, + isShallowCopyable: false, + isValueType: type is "int" or "bool" or "double", + containingTypeIsValueType: false, + backingPropertyName: name); + + private static SerializableTypeModel CreateSerializableTypeModel( + string name, + string ns, + ImmutableArray members = default, + TypeMetadataIdentity metadataIdentity = default) => new( + Accessibility.Public, + new TypeRef($"global::{ns}.{name}"), + HasComplexBaseType: false, + IncludePrimaryConstructorParameters: false, + TypeRef.Empty, + ns, + $"OrleansCodeGen.{ns}", + name, + IsValueType: false, + IsSealedType: true, + IsAbstractType: false, + IsEnumType: false, + IsGenericType: false, + ImmutableArray.Empty, + members, + UseActivator: false, + IsEmptyConstructable: true, + HasActivatorConstructor: false, + TrackReferences: true, + OmitDefaultMemberValues: false, + ImmutableArray.Empty, + IsShallowCopyable: false, + IsUnsealedImmutable: false, + IsImmutable: false, + IsExceptionType: false, + ImmutableArray.Empty, + ObjectCreationStrategy.NewExpression, + MetadataIdentity: metadataIdentity); + + private static ProxyInterfaceModel CreateProxyInterfaceModel(string name) => new( + new TypeRef($"global::MyNamespace.{name}"), + name, + "OrleansCodeGen.MyNamespace", + ImmutableArray.Empty, + new ProxyBaseModel( + new TypeRef("global::Orleans.Runtime.GrainReference"), + IsExtension: false, + GeneratedClassNameComponent: "GrainReference", + ImmutableArray.Empty), + ImmutableArray.Empty, + MetadataIdentity: new TypeMetadataIdentity($"MyNamespace.{name}", "TestAssembly", "TestAssembly")); + + private static ReferenceAssemblyModel CreateReferenceAssemblyModel( + ImmutableArray applicationParts = default, + ImmutableArray referencedSerializableTypes = default, + ImmutableArray referencedProxyInterfaces = default, + ImmutableArray registeredCodecs = default, + ImmutableArray interfaceImplementations = default) => new( + "TestAssembly", + applicationParts, + ImmutableArray.Empty, + ImmutableArray.Empty, + ImmutableArray.Empty, + referencedSerializableTypes, + referencedProxyInterfaces, + registeredCodecs, + interfaceImplementations); +} diff --git a/test/Orleans.CodeGenerator.Tests/IncrementalOrderingStabilityTests.cs b/test/Orleans.CodeGenerator.Tests/IncrementalOrderingStabilityTests.cs new file mode 100644 index 00000000000..f11399babe4 --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/IncrementalOrderingStabilityTests.cs @@ -0,0 +1,411 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Characterization tests for order-sensitive incremental source generator stability. +/// +public class IncrementalOrderingStabilityTests +{ + private const string SerializableTypeA = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public sealed class OrderingDtoA + { + [Id(0)] + public string Name { get; set; } + } + """; + + private const string SerializableTypeB = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public sealed class OrderingDtoB + { + [Id(0)] + public int Value { get; set; } + } + """; + + private const string ProxyInterface = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IOrderingGrain : IGrainWithIntegerKey + { + Task GetA(); + Task GetB(); + } + """; + + private const string UnrelatedType = """ + namespace TestProject; + + public sealed class UnrelatedOrderingType + { + public string Format(string value) => value.ToUpperInvariant(); + } + """; + + [Fact] + public async Task ReorderedSyntaxTreesWithSameGeneratorTargets_ProducesIdenticalGeneratedSources() + { + var compilation = await CreateCompilation( + "OrderingProject", + SerializableTypeA, + ProxyInterface, + SerializableTypeB); + var reorderedCompilation = await CreateCompilation( + "OrderingProject", + SerializableTypeB, + SerializableTypeA, + ProxyInterface); + + var result = RunGenerator(compilation); + var reorderedResult = RunGenerator(reorderedCompilation); + + AssertGeneratedSourcesIdentical(result, reorderedResult); + } + + [Fact] + public async Task AddingUnrelatedSyntaxTree_PreservesGeneratedSourcesAndCachesStableSteps() + { + var compilation = await CreateCompilation( + "OrderingProject", + SerializableTypeA, + ProxyInterface, + SerializableTypeB); + var updatedCompilation = compilation.AddSyntaxTrees(ParseSource(UnrelatedType)); + + var (result, updatedResult) = RunTwice(compilation, updatedCompilation); + + AssertTrackedStepsCachedOrUnchanged( + updatedResult, + OrleansSerializationSourceGenerator.SerializableTypeResultsTrackingName, + OrleansSerializationSourceGenerator.CollectedSerializableTypesTrackingName, + OrleansSerializationSourceGenerator.SerializerOutputsTrackingName, + OrleansSerializationSourceGenerator.InheritedProxyInterfacesTrackingName, + OrleansSerializationSourceGenerator.CollectedProxyInterfacesTrackingName, + OrleansSerializationSourceGenerator.PreparedProxyOutputsTrackingName, + OrleansSerializationSourceGenerator.ProxyOutputsTrackingName, + OrleansSerializationSourceGenerator.ReferenceAssemblyDataTrackingName, + OrleansSerializationSourceGenerator.MetadataAggregateTrackingName, + OrleansSerializationSourceGenerator.MetadataOutputsTrackingName); + AssertGeneratedSourcesIdentical(result, updatedResult); + } + + [Fact] + public async Task ReorderedInterfaceInheritanceGraph_ProducesIdenticalProxyAndMetadata() + { + const string featureInterfaces = """ + using System.Threading.Tasks; + + namespace TestProject; + + public interface IFirstOrderingFeature + { + Task First(); + } + + public interface ISecondOrderingFeature + { + Task Second(); + } + """; + + const string derivedInterface = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface ICompositeOrderingGrain : IGrainWithIntegerKey, IFirstOrderingFeature, ISecondOrderingFeature + { + Task Own(); + } + """; + + const string reorderedDerivedInterface = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface ICompositeOrderingGrain : IGrainWithIntegerKey, IFirstOrderingFeature, ISecondOrderingFeature + { + Task Own(); + } + """; + + var compilation = await CreateCompilation( + "OrderingProject", + featureInterfaces, + derivedInterface); + var reorderedCompilation = await CreateCompilation( + "OrderingProject", + reorderedDerivedInterface, + featureInterfaces); + + var result = RunGenerator(compilation); + var reorderedResult = RunGenerator(reorderedCompilation); + + AssertGeneratedSourcesIdentical(result, reorderedResult); + Assert.Contains(GetGeneratedSourceMap(result).Keys, static hint => hint.Contains(".orleans.proxy.", StringComparison.Ordinal)); + Assert.Contains(GetGeneratedSourceMap(result).Keys, static hint => hint.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)); + } + + [Fact] + public async Task ReorderedMetadataInputsWithAliasesCodecsAndApplicationParts_ProducesIdenticalMetadata() + { + var compilation = await CreateMetadataStabilityCompilation(reverseReferenceOrder: false); + var reorderedCompilation = await CreateMetadataStabilityCompilation(reverseReferenceOrder: true); + + var result = RunGenerator(compilation); + var reorderedResult = RunGenerator(reorderedCompilation); + + var metadataSource = GetMetadataSource(result); + var reorderedMetadataSource = GetMetadataSource(reorderedResult); + Assert.Equal(metadataSource, reorderedMetadataSource); + Assert.Equal(1, CountOccurrences(metadataSource, "WellKnownTypeAliases.Add(\"A.Alias\"")); + Assert.Equal(1, CountOccurrences(metadataSource, "WellKnownTypeAliases.Add(\"B.Alias\"")); + Assert.Equal(1, CountOccurrences(metadataSource, "Serializers.Add(typeof(global::LibraryB.SerializerType))")); + Assert.Equal(1, CountOccurrences(metadataSource, "Copiers.Add(typeof(global::LibraryB.CopierType))")); + Assert.Equal(1, CountOccurrences(metadataSource, "Activators.Add(typeof(global::LibraryB.ActivatorType))")); + Assert.Equal(1, CountOccurrences(metadataSource, "Converters.Add(typeof(global::LibraryB.ConverterType))")); + Assert.Equal(1, CountOccurrences(metadataSource, "InterfaceImplementations.Add(typeof(global::LibraryB.GeneratedInterfaceImplementation))")); + } + + private static async Task CreateCompilation(string assemblyName, params string[] sources) + { + Assert.NotEmpty(sources); + + var compilation = await TestCompilationHelper.CreateCompilation(sources[0], assemblyName); + if (sources.Length == 1) + { + return compilation; + } + + return compilation.AddSyntaxTrees(sources.Skip(1).Select(ParseSource)); + } + + private static async Task CreateMetadataStabilityCompilation(bool reverseReferenceOrder) + { + const string libraryBCode = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace LibraryB; + + [Id(200)] + [Alias("B.Alias")] + [CompoundTypeAlias("B", typeof(LibraryB.BetaType))] + public sealed class BetaType + { + } + + [RegisterSerializer] + public sealed class SerializerType + { + } + + [RegisterCopier] + public sealed class CopierType + { + } + + [RegisterActivator] + public sealed class ActivatorType + { + } + + [RegisterConverter] + public sealed class ConverterType + { + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IGeneratedInterface + { + Task Ping(); + } + + public sealed class GeneratedInterfaceImplementation : IGeneratedInterface + { + public Task Ping() => Task.CompletedTask; + } + """; + + const string libraryACode = """ + using Orleans; + using LibraryB; + + [assembly: ApplicationPart("Zeta.Part")] + [assembly: ApplicationPart("Alpha.Part")] + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryB.BetaType))] + + namespace LibraryA; + + [Id(100)] + [Alias("A.Alias")] + public sealed class AlphaType + { + } + """; + + const string consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryA.AlphaType))] + + namespace ConsumerProject; + + public sealed class ConsumerMarker + { + } + """; + + var libraryBCompilation = await TestCompilationHelper.CreateCompilation(libraryBCode, "LibraryB"); + Assert.Empty(libraryBCompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var libraryACompilation = await TestCompilationHelper.CreateCompilation( + libraryACode, + "LibraryA", + libraryBCompilation.ToMetadataReference()); + Assert.Empty(libraryACompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var libraryAReference = libraryACompilation.ToMetadataReference(); + var libraryBReference = libraryBCompilation.ToMetadataReference(); + var consumerCompilation = reverseReferenceOrder + ? await TestCompilationHelper.CreateCompilation(consumerCode, "ConsumerProject", libraryBReference, libraryAReference) + : await TestCompilationHelper.CreateCompilation(consumerCode, "ConsumerProject", libraryAReference, libraryBReference); + + Assert.Empty(consumerCompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + return consumerCompilation; + } + + private static SyntaxTree ParseSource(string source) + => CSharpSyntaxTree.ParseText(source); + + private static GeneratorRunResult RunGenerator(Compilation compilation) + { + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions( + disabledOutputs: default, + trackIncrementalGeneratorSteps: true)); + + driver = driver.RunGenerators(compilation); + var runResult = driver.GetRunResult(); + Assert.Empty(runResult.Diagnostics); + + var result = Assert.Single(runResult.Results); + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + + return result; + } + + private static (GeneratorRunResult First, GeneratorRunResult Second) RunTwice( + Compilation firstCompilation, + Compilation secondCompilation) + { + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions( + disabledOutputs: default, + trackIncrementalGeneratorSteps: true)); + + driver = driver.RunGenerators(firstCompilation); + var result1 = GetSingleGeneratorResult(driver); + + driver = driver.RunGenerators(secondCompilation); + var result2 = GetSingleGeneratorResult(driver); + + return (result1, result2); + } + + private static GeneratorRunResult GetSingleGeneratorResult(GeneratorDriver driver) + { + var runResult = driver.GetRunResult(); + Assert.Empty(runResult.Diagnostics); + + var result = Assert.Single(runResult.Results); + Assert.Empty(result.Diagnostics); + Assert.NotEmpty(result.GeneratedSources); + + return result; + } + + private static void AssertTrackedStepsCachedOrUnchanged(GeneratorRunResult result, params string[] stepNames) + { + var trackedSteps = result.TrackedSteps; + Assert.NotEmpty(trackedSteps); + + foreach (var stepName in stepNames) + { + Assert.True(trackedSteps.TryGetValue(stepName, out var steps), $"Missing tracked step '{stepName}'."); + Assert.NotEmpty(steps); + + foreach (var step in steps) + { + foreach (var (_, reason) in step.Outputs) + { + Assert.True( + reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged, + $"Step '{stepName}' had reason '{reason}' — expected Cached or Unchanged."); + } + } + } + } + + private static void AssertGeneratedSourcesIdentical(GeneratorRunResult result, GeneratorRunResult other) + { + var sources = GetGeneratedSourceMap(result); + var otherSources = GetGeneratedSourceMap(other); + + Assert.Equal(sources.Count, otherSources.Count); + + foreach (var (hintName, sourceText) in sources) + { + Assert.True(otherSources.TryGetValue(hintName, out var otherSourceText), $"Missing generated source '{hintName}'."); + Assert.Equal(sourceText, otherSourceText); + } + } + + private static SortedDictionary GetGeneratedSourceMap(GeneratorRunResult result) + => new( + result.GeneratedSources.ToDictionary(source => source.HintName, source => source.SourceText.ToString(), StringComparer.Ordinal), + StringComparer.Ordinal); + + private static string GetMetadataSource(GeneratorRunResult result) + { + var source = Assert.Single(result.GeneratedSources, static source => source.HintName.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)); + return CSharpSyntaxTree.ParseText(source.SourceText.ToString().TrimStart('\uFEFF')).GetCompilationUnitRoot().NormalizeWhitespace().ToFullString(); + } + + private static int CountOccurrences(string value, string substring) + { + var count = 0; + var index = 0; + while ((index = value.IndexOf(substring, index, StringComparison.Ordinal)) >= 0) + { + count++; + index += substring.Length; + } + + return count; + } +} diff --git a/test/Orleans.CodeGenerator.Tests/ModelExtractorTests.cs b/test/Orleans.CodeGenerator.Tests/ModelExtractorTests.cs new file mode 100644 index 00000000000..8c0cbf13891 --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/ModelExtractorTests.cs @@ -0,0 +1,795 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; +using Microsoft.CodeAnalysis.Testing; +using Microsoft.Extensions.DependencyInjection; +using Orleans.CodeGenerator.Model; +using Orleans.Serialization; +using Xunit; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Tests that correctly extracts value models from source declarations. +/// Validates that extracted models capture all necessary data for incremental pipeline caching +/// and that identical inputs produce equal models. +/// +public class ModelExtractorTests +{ + [Fact] + public async Task ExtractSerializableTypeModel_BasicClass_CapturesCorrectData() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public class DemoData +{ + [Id(0)] + public string Value { get; set; } = string.Empty; + + [Id(1)] + public int Count { get; set; } +}"; + var (model, _) = await ExtractFirstSerializableType(code); + + Assert.Equal("DemoData", model.Name); + Assert.Equal("TestProject", model.Namespace); + Assert.Equal(Accessibility.Public, model.Accessibility); + Assert.False(model.IsValueType); + Assert.False(model.IsEnumType); + Assert.False(model.IsAbstractType); + Assert.False(model.IsGenericType); + Assert.Equal(2, model.Members.Length); + + // Auto-properties are stored as backing fields; match via BackingPropertyName + var valueMember = Assert.Single(model.Members, m => m.BackingPropertyName == "Value"); + var countMember = Assert.Single(model.Members, m => m.BackingPropertyName == "Count"); + Assert.Equal((uint)0, valueMember.FieldId); + Assert.Equal((uint)1, countMember.FieldId); + } + + [Fact] + public async Task ExtractSerializableTypeModel_ValueType_CapturesCorrectFlags() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public struct DemoStruct +{ + [Id(0)] + public int X { get; set; } +}"; + var (model, _) = await ExtractFirstSerializableType(code); + + Assert.True(model.IsValueType); + Assert.Equal("DemoStruct", model.Name); + Assert.Equal(ObjectCreationStrategy.Default, model.CreationStrategy); + } + + [Fact] + public async Task ExtractSerializableTypeModel_GenericType_CapturesTypeParameters() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public class GenericData +{ + [Id(0)] + public T Value { get; set; } +}"; + var (model, _) = await ExtractFirstSerializableType(code); + + Assert.True(model.IsGenericType); + Assert.Single(model.TypeParameters); + Assert.Equal("T", model.TypeParameters[0].Name); + } + + [Fact] + public async Task ExtractSerializableTypeModel_MetadataIdentity_CapturesTopLevelGenericAndNestedTypes() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public sealed class TopLevelDto + { + [Id(0)] + public int Value { get; set; } + } + + [GenerateSerializer] + public sealed class GenericDto + { + [Id(0)] + public T Value { get; set; } = default!; + } + + public sealed class Container + { + [GenerateSerializer] + public sealed class NestedDto + { + [Id(0)] + public int Value { get; set; } + } + + [GenerateSerializer] + public sealed class NestedGenericDto + { + [Id(0)] + public T Value { get; set; } = default!; + } + } + """; + var compilation = await CreateCompilation(code); + + AssertMetadataIdentity( + ExtractSerializableTypeModel(compilation, "TestProject.TopLevelDto").MetadataIdentity, + compilation, + "TestProject.TopLevelDto"); + AssertMetadataIdentity( + ExtractSerializableTypeModel(compilation, "TestProject.GenericDto`1").MetadataIdentity, + compilation, + "TestProject.GenericDto`1"); + AssertMetadataIdentity( + ExtractSerializableTypeModel(compilation, "TestProject.Container+NestedDto").MetadataIdentity, + compilation, + "TestProject.Container+NestedDto"); + AssertMetadataIdentity( + ExtractSerializableTypeModel(compilation, "TestProject.Container+NestedGenericDto`1").MetadataIdentity, + compilation, + "TestProject.Container+NestedGenericDto`1"); + } + + [Fact] + public async Task ExtractSerializableTypeModel_SameInput_ProducesEqualModels() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public class DemoData +{ + [Id(0)] + public string Value { get; set; } = string.Empty; + + [Id(1)] + public int Count { get; set; } +}"; + var (model1, compilation) = await ExtractFirstSerializableType(code); + + // Extract again from the same compilation — should produce identical model + var model2 = ExtractFromCompilation(compilation); + + Assert.Equal(model1, model2); + Assert.Equal(model1.GetHashCode(), model2.GetHashCode()); + } + + [Fact] + public async Task ExtractSerializableTypeModel_DifferentInput_ProducesUnequalModels() + { + var code1 = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public class DemoData +{ + [Id(0)] + public string Value { get; set; } = string.Empty; +}"; + var code2 = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public class DemoData +{ + [Id(0)] + public string Value { get; set; } = string.Empty; + + [Id(1)] + public int NewField { get; set; } +}"; + var (model1, _) = await ExtractFirstSerializableType(code1); + var (model2, _) = await ExtractFirstSerializableType(code2); + + Assert.NotEqual(model1, model2); + } + + [Fact] + public async Task ExtractSerializableTypeModel_Enum_CapturesEnumFlag() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public enum DemoEnum +{ + A, + B, + C +}"; + var (model, _) = await ExtractFirstSerializableType(code); + + Assert.True(model.IsEnumType); + Assert.True(model.IsValueType); + Assert.Equal("DemoEnum", model.Name); + } + + [Fact] + public async Task ExtractSerializableTypeModel_SealedClass_CapturesSealedFlag() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public sealed class DemoData +{ + [Id(0)] + public string Value { get; set; } = string.Empty; +}"; + var (model, _) = await ExtractFirstSerializableType(code); + Assert.True(model.IsSealedType); + } + + [Fact] + public async Task ExtractSerializableTypeModel_RecordWithPropertyTargetedIds_PreservesExistingMemberSection() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public record DemoRecord([property: Id(0)] string Value, [property: Id(1)] int Count); +"; + var (model, _) = await ExtractFirstSerializableType(code); + + Assert.Equal("DemoRecord", model.Name); + Assert.True(model.IncludePrimaryConstructorParameters); + + var members = model.Members.OrderBy(member => member.FieldId).ToArray(); + Assert.Collection( + members, + member => + { + Assert.Equal((uint)0, member.FieldId); + Assert.Equal("Value", member.BackingPropertyName); + Assert.False(member.IsPrimaryConstructorParameter); + }, + member => + { + Assert.Equal((uint)1, member.FieldId); + Assert.Equal("Count", member.BackingPropertyName); + Assert.False(member.IsPrimaryConstructorParameter); + }); + } + + [Fact] + public async Task ExtractSerializableTypeModel_RecordWithFieldTargetedIds_PreservesExistingMemberSection() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public record DemoRecord([field: Id(0)] string Value); +"; + var (model, _) = await ExtractFirstSerializableType(code); + var member = Assert.Single(model.Members); + + Assert.Equal((uint)0, member.FieldId); + Assert.Equal("Value", member.BackingPropertyName); + Assert.False(member.IsPrimaryConstructorParameter); + } + + [Fact] + public async Task ExtractSerializableTypeModel_RecordWithoutIds_AssignsPrimaryConstructorParameterIds() + { + var code = @" +using Orleans; +namespace TestProject; + +[GenerateSerializer] +public readonly record struct DemoRecord(int Value, string Name); +"; + var (model, _) = await ExtractFirstSerializableType(code); + var members = model.Members.OrderBy(member => member.FieldId).ToArray(); + + Assert.Collection( + members, + member => + { + Assert.Equal((uint)0, member.FieldId); + Assert.Equal("Value", member.BackingPropertyName); + Assert.True(member.IsPrimaryConstructorParameter); + }, + member => + { + Assert.Equal((uint)1, member.FieldId); + Assert.Equal("Name", member.BackingPropertyName); + Assert.True(member.IsPrimaryConstructorParameter); + }); + } + + [Fact] + public async Task FieldIdAssignmentHelper_TypeWithComputedPropertyAndNoIds_RemainsValidWithoutSerializableCandidates() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class ComputedDto + { + public string Value => string.Empty; + } + """; + var compilation = await CreateCompilation(code); + var helper = CreateFieldIdAssignmentHelper(compilation, "TestProject.ComputedDto"); + + Assert.True(helper.IsValidForSerialization); + Assert.Null(helper.FailureReason); + Assert.DoesNotContain(helper.Members, member => helper.TryGetSymbolKey(member, out _)); + } + + [Fact] + public async Task ExtractReferenceAssemblyData_CollectsCrossAssemblyMetadataAndDeterministicOrdering() + { + var consumerCompilation = await CreateReferenceExtractionCompilation(); + var model = ModelExtractor.ExtractReferenceAssemblyData(consumerCompilation, new CodeGeneratorOptions(), default); + + Assert.Equal("ConsumerProject", model.ApplicationParts[0]); + Assert.Equal(model.ApplicationParts.Length, model.ApplicationParts.Distinct(StringComparer.Ordinal).Count()); + Assert.Contains("ConsumerProject", model.ApplicationParts); + Assert.Contains("LibraryA", model.ApplicationParts); + Assert.Contains("Alpha.Part", model.ApplicationParts); + Assert.Contains("Zeta.Part", model.ApplicationParts); + + Assert.Contains(model.WellKnownTypeIds, entry => entry.Type.SyntaxString == "global::LibraryA.AlphaType" && entry.Id == 100u); + Assert.Contains(model.WellKnownTypeIds, entry => entry.Type.SyntaxString == "global::LibraryB.BetaType" && entry.Id == 200u); + + Assert.Contains(model.TypeAliases, entry => entry.Type.SyntaxString == "global::LibraryA.AlphaType" && entry.Alias == "A.Alias"); + Assert.Contains(model.TypeAliases, entry => entry.Type.SyntaxString == "global::LibraryB.BetaType" && entry.Alias == "B.Alias"); + + var compoundAlias = Assert.Single(model.CompoundTypeAliases, entry => entry.TargetType.SyntaxString == "global::LibraryB.BetaType"); + Assert.Equal(2, compoundAlias.Components.Length); + Assert.True(compoundAlias.Components[0].IsString); + Assert.Equal("B", compoundAlias.Components[0].StringValue); + Assert.True(compoundAlias.Components[1].IsType); + Assert.Equal("global::LibraryB.BetaType", compoundAlias.Components[1].TypeValue.SyntaxString); + + var registeredCodecTypes = model.RegisteredCodecs.Select(static codec => codec.Type.SyntaxString).ToArray(); + Assert.Equal( + registeredCodecTypes.OrderBy(static name => name, StringComparer.Ordinal), + registeredCodecTypes); + Assert.Contains("global::LibraryB.ActivatorType", registeredCodecTypes); + Assert.Contains("global::LibraryB.CopierType", registeredCodecTypes); + Assert.Contains("global::LibraryB.ConverterType", registeredCodecTypes); + Assert.Contains("global::LibraryB.SerializerType", registeredCodecTypes); + Assert.Contains(model.InterfaceImplementations, implementation => implementation.ImplementationType.SyntaxString == "global::LibraryB.GeneratedInterfaceImplementation"); + } + + [Fact] + public async Task ExtractReferenceAssemblyData_IsStableWhenReferenceOrderChanges() + { + var compilationA = await CreateReferenceExtractionCompilation(); + var compilationB = await CreateReferenceExtractionCompilation(reverseReferenceOrder: true); + + var modelA = ModelExtractor.ExtractReferenceAssemblyData(compilationA, new CodeGeneratorOptions(), default); + var modelB = ModelExtractor.ExtractReferenceAssemblyData(compilationB, new CodeGeneratorOptions(), default); + + Assert.Equal(modelA, modelB); + Assert.Equal(modelA.GetHashCode(), modelB.GetHashCode()); + } + + [Fact] + public async Task ExtractProxyInterfaceModel_InheritedGenerateMethodSerializers_FallsBackToInheritedAttribute() + { + const string code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IBaseGrain + { + ValueTask Ping(); + } + + public interface IDerivedGrain : IBaseGrain + { + } + """; + var compilation = await CreateCompilation(code); + var model = ExtractProxyInterfaceModel(compilation, "TestProject.IDerivedGrain"); + + Assert.Equal("IDerivedGrain", model.Name); + Assert.Single(model.Methods, method => method.Name == "Ping"); + } + + [Fact] + public async Task ExtractProxyInterfaceModel_ProxyBase_IncludesInvokableBaseMappings() + { + const string code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface ITestGrain + { + ValueTask Ping(); + } + """; + var compilation = await CreateCompilation(code); + var model = ExtractProxyInterfaceModel(compilation, "TestProject.ITestGrain"); + + Assert.NotEmpty(model.ProxyBase.InvokableBaseTypes); + Assert.Contains( + model.ProxyBase.InvokableBaseTypes, + mapping => mapping.ReturnType.SyntaxString.Contains("ValueTask", StringComparison.Ordinal) + && mapping.InvokableBaseType.SyntaxString.Contains("Request", StringComparison.Ordinal)); + } + + [Fact] + public async Task ExtractProxyInterfaceModel_MetadataIdentity_CapturesTopLevelGenericAndNestedInterfaces() + { + const string code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface ITopLevelGrain : IGrainWithIntegerKey + { + Task Ping(); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IGenericGrain : IGrainWithIntegerKey + { + Task Echo(T value); + } + + public sealed class Container + { + [GenerateMethodSerializers(typeof(GrainReference))] + public interface INestedGrain : IGrainWithIntegerKey + { + Task Ping(); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface INestedGenericGrain : IGrainWithIntegerKey + { + Task Echo(T value); + } + } + """; + var compilation = await CreateCompilation(code); + + AssertMetadataIdentity( + ExtractProxyInterfaceModel(compilation, "TestProject.ITopLevelGrain").MetadataIdentity, + compilation, + "TestProject.ITopLevelGrain"); + AssertMetadataIdentity( + ExtractProxyInterfaceModel(compilation, "TestProject.IGenericGrain`1").MetadataIdentity, + compilation, + "TestProject.IGenericGrain`1"); + AssertMetadataIdentity( + ExtractProxyInterfaceModel(compilation, "TestProject.Container+INestedGrain").MetadataIdentity, + compilation, + "TestProject.Container+INestedGrain"); + AssertMetadataIdentity( + ExtractProxyInterfaceModel(compilation, "TestProject.Container+INestedGenericGrain`1").MetadataIdentity, + compilation, + "TestProject.Container+INestedGenericGrain`1"); + } + + [Fact] + public async Task ExtractProxyInterfaceModel_BaseInterfaceOrder_DoesNotAffectMethodOrdering() + { + const string code1 = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IFirst + { + ValueTask First(); + } + + public interface ISecond + { + ValueTask Second(); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface ITestGrain : IFirst, ISecond + { + } + """; + const string code2 = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IFirst + { + ValueTask First(); + } + + public interface ISecond + { + ValueTask Second(); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface ITestGrain : ISecond, IFirst + { + } + """; + + var model1 = ExtractProxyInterfaceModel(await CreateCompilation(code1), "TestProject.ITestGrain"); + var model2 = ExtractProxyInterfaceModel(await CreateCompilation(code2), "TestProject.ITestGrain"); + + Assert.Equal(model1.Methods.Select(method => method.Name), model2.Methods.Select(method => method.Name)); + } + + [Fact] + public async Task ExtractProxyInterfaceModel_UsesOriginalMethodDefinitionForGeneratedMethodId() + { + const string code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IBaseGrain + { + ValueTask Echo(T value); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface ITestGrain : IBaseGrain + { + } + """; + var compilation = await CreateCompilation(code); + var model = ExtractProxyInterfaceModel(compilation, "TestProject.ITestGrain"); + + var method = Assert.Single(model.Methods); + var baseInterface = compilation.GetTypeByMetadataName("TestProject.IBaseGrain`1"); + Assert.NotNull(baseInterface); + var originalMethod = Assert.Single(baseInterface.GetMembers("Echo").OfType()); + var expectedMethodId = GeneratedCodeUtilities.CreateHashedMethodId(originalMethod); + + Assert.Equal(expectedMethodId, method.GeneratedMethodId); + } + + #region Helpers + + private static async Task CreateReferenceExtractionCompilation(bool reverseReferenceOrder = false) + { + const string libraryBCode = """ + using Orleans; + using System.Threading.Tasks; + + namespace LibraryB; + + [Id(200)] + [Alias("B.Alias")] + [CompoundTypeAlias("B", typeof(LibraryB.BetaType))] + public sealed class BetaType + { + } + + [RegisterSerializer] + public sealed class SerializerType + { + } + + [RegisterCopier] + public sealed class CopierType + { + } + + [RegisterActivator] + public sealed class ActivatorType + { + } + + [RegisterConverter] + public sealed class ConverterType + { + } + + [GenerateMethodSerializers(typeof(object))] + public interface IGeneratedInterface + { + Task Ping(); + } + + public sealed class GeneratedInterfaceImplementation : IGeneratedInterface + { + public Task Ping() => Task.CompletedTask; + } + """; + + const string libraryACode = """ + using Orleans; + using LibraryB; + + [assembly: ApplicationPart("Zeta.Part")] + [assembly: ApplicationPart("Alpha.Part")] + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryB.BetaType))] + + namespace LibraryA; + + [Id(100)] + [Alias("A.Alias")] + public sealed class AlphaType + { + } + """; + + const string consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryA.AlphaType))] + + namespace ConsumerProject; + + public sealed class ConsumerMarker + { + } + """; + + var libraryBCompilation = await CreateCompilation(libraryBCode, "LibraryB"); + Assert.Empty(libraryBCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var libraryACompilation = await CreateCompilation( + libraryACode, + "LibraryA", + libraryBCompilation.ToMetadataReference()); + Assert.Empty(libraryACompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var libraryAReference = libraryACompilation.ToMetadataReference(); + var libraryBReference = libraryBCompilation.ToMetadataReference(); + var consumerCompilation = reverseReferenceOrder + ? await CreateCompilation(consumerCode, "ConsumerProject", libraryBReference, libraryAReference) + : await CreateCompilation(consumerCode, "ConsumerProject", libraryAReference, libraryBReference); + + Assert.Empty(consumerCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + return consumerCompilation; + } + + private static async Task<(SerializableTypeModel Model, CSharpCompilation Compilation)> ExtractFirstSerializableType(string code) + { + var compilation = await CreateCompilation(code); + var model = ExtractFromCompilation(compilation); + return (model, compilation); + } + + private static SerializableTypeModel ExtractFromCompilation(CSharpCompilation compilation) + { + var syntaxTree = Assert.Single(compilation.SyntaxTrees); + var semanticModel = compilation.GetSemanticModel(syntaxTree); + var generateSerializerAttribute = compilation.GetTypeByMetadataName("Orleans.GenerateSerializerAttribute"); + Assert.NotNull(generateSerializerAttribute); + + foreach (var declaration in syntaxTree.GetRoot().DescendantNodes()) + { + var symbol = declaration switch + { + TypeDeclarationSyntax typeDeclaration => semanticModel.GetDeclaredSymbol(typeDeclaration), + EnumDeclarationSyntax enumDeclaration => semanticModel.GetDeclaredSymbol(enumDeclaration), + _ => null, + }; + + if (symbol is not INamedTypeSymbol typeSymbol + || !typeSymbol.GetAttributes().Any(attribute => SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, generateSerializerAttribute))) + { + continue; + } + + return ExtractSerializableTypeModel(compilation, typeSymbol); + } + + throw new InvalidOperationException("No [GenerateSerializer] declaration was found."); + } + + private static SerializableTypeModel ExtractSerializableTypeModel(CSharpCompilation compilation, string metadataName) + { + var typeSymbol = compilation.GetTypeByMetadataName(metadataName); + Assert.NotNull(typeSymbol); + + return ExtractSerializableTypeModel(compilation, typeSymbol); + } + + private static SerializableTypeModel ExtractSerializableTypeModel(CSharpCompilation compilation, INamedTypeSymbol typeSymbol) + { + var options = new CodeGeneratorOptions(); + var libraryTypes = LibraryTypes.FromCompilation(compilation, options); + var model = ModelExtractor.TryExtractSerializableTypeModel( + typeSymbol, + compilation, + libraryTypes, + options, + throwOnFailure: true); + + Assert.NotNull(model); + return model; + } + + private static FieldIdAssignmentHelper CreateFieldIdAssignmentHelper( + CSharpCompilation compilation, + string metadataName, + Orleans.CodeGenerator.Model.GenerateFieldIds generateFieldIds = Orleans.CodeGenerator.Model.GenerateFieldIds.None) + { + var typeSymbol = compilation.GetTypeByMetadataName(metadataName); + Assert.NotNull(typeSymbol); + + var options = new CodeGeneratorOptions + { + GenerateFieldIds = generateFieldIds, + }; + var libraryTypes = LibraryTypes.FromCompilation(compilation, options); + return new FieldIdAssignmentHelper(typeSymbol, ImmutableArray.Empty, generateFieldIds, libraryTypes); + } + + private static ProxyInterfaceModel ExtractProxyInterfaceModel(CSharpCompilation compilation, string metadataName) + { + var interfaceType = compilation.GetTypeByMetadataName(metadataName); + Assert.NotNull(interfaceType); + + var model = ModelExtractor.ExtractProxyInterfaceModel(interfaceType, compilation, default); + Assert.NotNull(model); + return model; + } + + private static void AssertMetadataIdentity( + TypeMetadataIdentity metadataIdentity, + CSharpCompilation compilation, + string expectedMetadataName) + { + Assert.False(metadataIdentity.IsEmpty); + Assert.Equal(expectedMetadataName, metadataIdentity.MetadataName); + Assert.Equal(compilation.Assembly.Identity.Name, metadataIdentity.AssemblyName); + Assert.Equal(compilation.Assembly.Identity.GetDisplayName(), metadataIdentity.AssemblyIdentity); + } + + private static Task CreateCompilation( + string sourceCode, + string assemblyName = "TestProject", + params MetadataReference[] additionalReferences) + => TestCompilationHelper.CreateCompilation(sourceCode, assemblyName, additionalReferences); + + #endregion +} diff --git a/test/Orleans.CodeGenerator.Tests/OrleansSourceGeneratorTests.cs b/test/Orleans.CodeGenerator.Tests/OrleansSourceGeneratorTests.cs index 297d742d13b..31ee4cdc233 100644 --- a/test/Orleans.CodeGenerator.Tests/OrleansSourceGeneratorTests.cs +++ b/test/Orleans.CodeGenerator.Tests/OrleansSourceGeneratorTests.cs @@ -977,6 +977,317 @@ public class DemoClass public string Value { get; set; } }"); + /// + /// Tests that invokable deduplication works correctly when multiple grain interfaces + /// share a common base interface. The base method DoWork is inherited by both + /// IGrainA and IGrainB, but the generator should produce only one + /// invokable class for that method per declaring interface, not duplicate it for each + /// derived interface. + /// This behavior is currently handled by the active proxy pipeline via the + /// _invokableMethodDescriptions dictionary on the proxy-generation state. + /// + [Fact] + public async Task SharedBaseInterfaceMethodProducesDeduplicatedInvokable() + { + var code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IBaseGrain : IGrainWithGuidKey + { + Task DoWork(); + } + + public interface IGrainA : IBaseGrain + { + Task DoExtraA(); + } + + public interface IGrainB : IBaseGrain + { + Task DoExtraB(); + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + Assert.Empty(compilation.GetDiagnostics()); + + var result = RunSourceGenerator(compilation); + Assert.Empty(result.Diagnostics); + + var generatedSource = ConcatenateGeneratedSources(result); + + // The base method DoWork is declared on IBaseGrain. Regardless of how many interfaces + // inherit from IBaseGrain, only one invokable should be generated for DoWork on IBaseGrain. + var invokableDoWorkCount = System.Text.RegularExpressions.Regex.Matches( + generatedSource, @"sealed class Invokable_IBaseGrain_GrainReference_\w+").Count; + Assert.Equal(1, invokableDoWorkCount); + + // Each derived interface should get its own invokable for its unique method. + Assert.Contains("Invokable_IGrainA_GrainReference_", generatedSource); + Assert.Contains("Invokable_IGrainB_GrainReference_", generatedSource); + } + + [Fact] + public async Task DerivedInterfaceInheritingGenerateMethodSerializersProducesProxy() + { + var code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IBaseGrain : IGrainWithIntegerKey + { + Task Ping(); + } + + public interface IDerivedGrain : IBaseGrain + { + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + Assert.Empty(compilation.GetDiagnostics()); + + var result = RunSourceGenerator(compilation); + Assert.Empty(result.Diagnostics); + + var generatedSource = ConcatenateGeneratedSources(result); + Assert.Contains("Proxy_IDerivedGrain", generatedSource); + Assert.Contains("Invokable_IBaseGrain_GrainReference_", generatedSource); + } + + /// + /// Tests that the compilation-level reference-assembly extraction path correctly handles + /// types discovered via [GenerateCodeForDeclaringAssembly]. These types cannot be found by + /// ForAttributeWithMetadataName since they live in other assemblies. + /// + [Fact] + public async Task GeneratesSerializersForReferencedAssemblyTypesViaGenerateCodeForDeclaringAssembly() + { + var libraryCode = """ + using Orleans; + + namespace LibraryProject; + + [GenerateSerializer] + public class LibraryDto + { + [Id(0)] + public string Name { get; set; } = string.Empty; + + [Id(1)] + public int Value { get; set; } + } + """; + + var libraryCompilation = await CreateCompilation(libraryCode, "LibraryProject"); + Assert.Empty(libraryCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + // Run the generator on the library so its output assembly includes the generated metadata + var libraryGeneratorResult = RunSourceGenerator(libraryCompilation); + Assert.Empty(libraryGeneratorResult.Diagnostics); + + // Build the consumer that references the library via [GenerateCodeForDeclaringAssembly] + var consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.LibraryDto))] + """; + + var consumerCompilation = await CreateCompilation(consumerCode, "ConsumerProject"); + + // Add the library as a metadata reference so the consumer can see its types + consumerCompilation = consumerCompilation.AddReferences(libraryCompilation.ToMetadataReference()); + Assert.Empty(consumerCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var consumerResult = RunSourceGenerator(consumerCompilation); + Assert.Empty(consumerResult.Diagnostics); + Assert.NotEmpty(consumerResult.GeneratedSources); + + var generatedSource = ConcatenateGeneratedSources(consumerResult); + + // The generator should produce a serializer (codec) for LibraryDto from the referenced assembly + Assert.Contains("LibraryDto", generatedSource); + Assert.Contains("Codec", generatedSource); + } + + [Fact] + public async Task GeneratesSerializersForReferencedNestedAndGenericAssemblyTypesViaGenerateCodeForDeclaringAssembly() + { + var libraryCode = """ + using Orleans; + + namespace LibraryProject; + + [GenerateSerializer] + public sealed class GenericDto + { + [Id(0)] + public T Value { get; set; } = default!; + } + + public sealed class Container + { + [GenerateSerializer] + public sealed class NestedDto + { + [Id(0)] + public int Value { get; set; } + } + + [GenerateSerializer] + public sealed class NestedGenericDto + { + [Id(0)] + public T Value { get; set; } = default!; + } + } + """; + + var libraryCompilation = await CreateCompilation(libraryCode, "LibraryProject"); + Assert.Empty(libraryCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Container.NestedGenericDto<>))] + """; + + var consumerCompilation = (await CreateCompilation(consumerCode, "ConsumerProject")) + .AddReferences(libraryCompilation.ToMetadataReference()); + Assert.Empty(consumerCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var consumerResult = RunSourceGenerator(consumerCompilation); + Assert.Empty(consumerResult.Diagnostics); + + Assert.Contains( + consumerResult.GeneratedSources, + source => source.HintName.Contains("LibraryProject.GenericDto", StringComparison.Ordinal)); + Assert.Contains( + consumerResult.GeneratedSources, + source => source.HintName.Contains("LibraryProject.Container.NestedDto", StringComparison.Ordinal)); + Assert.Contains( + consumerResult.GeneratedSources, + source => source.HintName.Contains("LibraryProject.Container.NestedGenericDto", StringComparison.Ordinal)); + + var generatedSource = ConcatenateGeneratedSources(consumerResult); + Assert.Contains("global::LibraryProject.GenericDto", generatedSource); + Assert.Contains("global::LibraryProject.Container.NestedDto", generatedSource); + Assert.Contains("global::LibraryProject.Container.NestedGenericDto", generatedSource); + } + + [Fact] + public async Task ReferencedSerializerResolutionUsesAssemblyMetadataIdentityWhenConsumerShadowsFullName() + { + var libraryCode = """ + using Orleans; + + namespace LibraryProject + { + public sealed class Marker + { + } + } + + namespace Shadowed + { + [GenerateSerializer] + public sealed class DuplicateDto + { + [Id(0)] + public int LibraryValue { get; set; } + } + } + """; + + var libraryCompilation = await CreateCompilation(libraryCode, "LibraryProject"); + Assert.Empty(libraryCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Marker))] + + namespace Shadowed; + + public sealed class DuplicateDto + { + public string ConsumerValue { get; set; } = string.Empty; + } + """; + + var consumerCompilation = (await CreateCompilation(consumerCode, "ConsumerProject")) + .AddReferences(libraryCompilation.ToMetadataReference()); + Assert.Empty(consumerCompilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var consumerResult = RunSourceGenerator(consumerCompilation); + Assert.Empty(consumerResult.Diagnostics); + + var generatedSource = ConcatenateGeneratedSources(consumerResult); + Assert.Contains("LibraryValue", generatedSource); + Assert.DoesNotContain("ConsumerValue", generatedSource); + } + + [Fact] + public async Task GeneratesProxiesForGenericAndNestedInterfaces() + { + var code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IGenericGrain : IGrainWithIntegerKey + { + Task Echo(T value); + } + + public sealed class Container + { + [GenerateMethodSerializers(typeof(GrainReference))] + public interface INestedGrain : IGrainWithIntegerKey + { + Task Ping(); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface INestedGenericGrain : IGrainWithIntegerKey + { + Task Echo(T value); + } + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + Assert.Empty(compilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var result = RunSourceGenerator(compilation); + Assert.Empty(result.Diagnostics); + + Assert.Contains( + result.GeneratedSources, + source => source.HintName.Contains("IGenericGrain", StringComparison.Ordinal)); + Assert.Contains( + result.GeneratedSources, + source => source.HintName.Contains("Container.INestedGrain", StringComparison.Ordinal)); + Assert.Contains( + result.GeneratedSources, + source => source.HintName.Contains("Container.INestedGenericGrain", StringComparison.Ordinal)); + + var generatedSource = ConcatenateGeneratedSources(result); + Assert.Contains("Proxy_IGenericGrain", generatedSource); + Assert.Contains("Proxy_INestedGrain", generatedSource); + Assert.Contains("Proxy_INestedGenericGrain", generatedSource); + } + /// /// Tests that the generator emits a warning when [GenerateSerializer] is used in a reference assembly. /// Reference assemblies contain only metadata, no implementation, so generating serializers @@ -1016,6 +1327,32 @@ public class RefAsmType Assert.Contains(result.Diagnostics, d => d.Id == DiagnosticRuleId.ReferenceAssemblyWithGenerateSerializer); } + [Fact] + public async Task EmitsInvalidFieldIdDiagnosticForImplicitPublicProperties() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class AutoDto + { + public string Value { get; set; } = string.Empty; + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + Assert.Empty(compilation.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error)); + + var result = RunSourceGenerator(compilation); + var diagnostic = Assert.Single(result.Diagnostics, d => d.Id == CanNotGenerateImplicitFieldIdsDiagnostic.DiagnosticId); + + Assert.Equal(DiagnosticRuleId.CanNotGenerateImplicitFieldIds, diagnostic.Id); + Assert.Equal(DiagnosticSeverity.Error, diagnostic.Severity); + Assert.Contains("AutoDto", diagnostic.GetMessage(), StringComparison.Ordinal); + } + [Fact] public async Task RemovedCustomAttributeBuildPropertiesAreIgnored() { @@ -1071,16 +1408,256 @@ public class DemoData ["build_property.orleans_generateserializerattributes"] = "TestProject.CustomGenerateSerializerAttribute", }); - Assert.Single(baselineResult.GeneratedSources); - Assert.Single(configuredResult.GeneratedSources); + Assert.NotEmpty(baselineResult.GeneratedSources); + Assert.NotEmpty(configuredResult.GeneratedSources); + Assert.Equal( + baselineResult.GeneratedSources.Select(source => source.HintName).OrderBy(name => name), + configuredResult.GeneratedSources.Select(source => source.HintName).OrderBy(name => name)); Assert.Equal( - baselineResult.GeneratedSources[0].SourceText.ToString(), - configuredResult.GeneratedSources[0].SourceText.ToString()); + ConcatenateGeneratedSources(baselineResult), + ConcatenateGeneratedSources(configuredResult)); Assert.Equal( baselineResult.Diagnostics.Select(d => d.Id).OrderBy(id => id), configuredResult.Diagnostics.Select(d => d.Id).OrderBy(id => id)); } + [Fact] + public async Task GlobalGenerateFieldIdsOption_AllowsImplicitPublicProperties() + { + var code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class DemoData + { + public string Value { get; set; } = string.Empty; + public int Count { get; set; } + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + var baselineResult = RunSourceGenerator(compilation); + var configuredResult = RunSourceGenerator( + compilation, + new Dictionary + { + ["build_property.orleans_generatefieldids"] = "PublicProperties", + }); + + Assert.Contains(baselineResult.Diagnostics, d => d.Id == DiagnosticRuleId.CanNotGenerateImplicitFieldIds); + Assert.Empty(configuredResult.Diagnostics); + Assert.Contains(configuredResult.GeneratedSources, source => source.HintName.Contains(".orleans.ser.", StringComparison.Ordinal)); + } + + [Fact] + public async Task CompatibilityInvokersOption_GeneratesAdditionalInheritedInvokables() + { + var code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + public interface IBaseGrain : IGrainWithIntegerKey + { + Task Ping(); + } + + public interface IDerivedGrain : IBaseGrain + { + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + var baselineResult = RunSourceGenerator(compilation); + var configuredResult = RunSourceGenerator( + compilation, + new Dictionary + { + ["build_property.orleansgeneratecompatibilityinvokers"] = "true", + }); + + Assert.Empty(baselineResult.Diagnostics); + Assert.Empty(configuredResult.Diagnostics); + + var baselineSource = ConcatenateGeneratedSources(baselineResult); + var configuredSource = ConcatenateGeneratedSources(configuredResult); + + Assert.True( + CountGeneratedInvokableClasses(configuredSource) > CountGeneratedInvokableClasses(baselineSource), + "Enabling compatibility invokers should generate additional invokable classes for inherited grain methods."); + } + + [Fact] + public async Task AttachDebuggerFalseOption_DoesNotChangeOutput() + { + var code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public class DemoData + { + [Id(0)] + public string Value { get; set; } = string.Empty; + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + var baselineResult = RunSourceGenerator(compilation); + var configuredResult = RunSourceGenerator( + compilation, + new Dictionary + { + ["build_property.orleans_attachdebugger"] = "false", + }); + + Assert.Empty(baselineResult.Diagnostics); + Assert.Empty(configuredResult.Diagnostics); + Assert.Equal( + baselineResult.GeneratedSources.Select(source => source.HintName).OrderBy(name => name), + configuredResult.GeneratedSources.Select(source => source.HintName).OrderBy(name => name)); + Assert.Equal( + ConcatenateGeneratedSources(baselineResult), + ConcatenateGeneratedSources(configuredResult)); + } + + [Fact] + public async Task GeneratedSources_EmitMetadataLast() + { + var code = """ + using Orleans; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateSerializer] + public class DemoData + { + [Id(0)] + public string Value { get; set; } = string.Empty; + } + + public interface IMyGrain : IGrainWithIntegerKey + { + Task Get(); + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + var result = RunSourceGenerator(compilation); + + Assert.Empty(result.Diagnostics); + + var emittedHintNames = result.GeneratedSources + .Where(static source => !string.IsNullOrWhiteSpace(source.SourceText.ToString())) + .Select(static source => source.HintName) + .ToArray(); + + Assert.NotEmpty(emittedHintNames); + Assert.Contains(emittedHintNames, static hintName => hintName.Contains(".orleans.ser.", StringComparison.Ordinal)); + Assert.Contains(emittedHintNames, static hintName => hintName.Contains(".orleans.proxy.", StringComparison.Ordinal)); + Assert.EndsWith(".orleans.metadata.g.cs", emittedHintNames[^1], StringComparison.Ordinal); + } + + [Fact] + public async Task GeneratedInvokableActivators_AreRegisteredInMetadata() + { + var code = """ + using System; + using System.Threading.Tasks; + using Orleans; + using Orleans.Runtime; + + namespace TestProject; + + [InvokableBaseType(typeof(GrainReference), typeof(Task<>), typeof(ActivatingTaskRequest<>))] + [AttributeUsage(AttributeTargets.Method)] + public sealed class ActivatingAttribute : Attribute + { + } + + public abstract class ActivatingTaskRequest : TaskRequest + { + [GeneratedActivatorConstructor] + protected ActivatingTaskRequest(IServiceProvider serviceProvider) + { + } + } + + public interface IActivatingGrain : IGrainWithIntegerKey + { + [Activating] + Task GetValue(); + } + + public interface INormalGrain : IGrainWithIntegerKey + { + Task GetValue(); + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + var result = RunSourceGenerator(compilation); + + Assert.Empty(result.Diagnostics); + + var emittedActivatorNames = GetGeneratedClassNames(result, ".orleans.proxy.", "Activator_Invokable_"); + var registeredActivatorNames = GetRegisteredGeneratedInvokableActivatorNames(result); + + Assert.Single(emittedActivatorNames); + Assert.Equal(emittedActivatorNames, registeredActivatorNames); + Assert.Contains(emittedActivatorNames, static name => name.Contains("IActivatingGrain", StringComparison.Ordinal)); + Assert.DoesNotContain(registeredActivatorNames, static name => name.Contains("INormalGrain", StringComparison.Ordinal)); + } + + [Fact] + public async Task RecordPrimaryConstructorFieldIds_AreNotDuplicatedOrDropped() + { + var code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public sealed record PrimaryCtorRecord( + [property: Id(0)] string Value, + [field: Id(1)] int Count) + { + [Id(2)] + public string Extra { get; init; } = string.Empty; + } + """; + + var compilation = await CreateCompilation(code, "TestProject"); + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions(default)); + driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var outputCompilation, out var generatorDiagnostics); + var result = driver.GetRunResult().Results.Single(); + + Assert.Empty(generatorDiagnostics.Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + Assert.Empty(result.Diagnostics); + Assert.Empty(outputCompilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var serializerSource = Assert.Single( + result.GeneratedSources, + static source => source.HintName.Contains(".orleans.ser.", StringComparison.Ordinal) + && source.SourceText.ToString().Contains("Codec_PrimaryCtorRecord", StringComparison.Ordinal)); + var serializerText = serializerSource.SourceText.ToString(); + + Assert.Equal(1, CountOccurrences(serializerText, "instance.Value")); + Assert.Equal(1, CountOccurrences(serializerText, "instance.Count")); + Assert.Equal(1, CountOccurrences(serializerText, "instance.Extra")); + Assert.Equal(1, CountOccurrences(serializerText, "if (id == 0U)")); + Assert.Equal(1, CountOccurrences(serializerText, "if (id == 1U)")); + Assert.Equal(1, CountOccurrences(serializerText, "if (id == 2U)")); + } + private static GeneratorRunResult RunSourceGenerator( CSharpCompilation compilation, IReadOnlyDictionary? globalOptions = null) @@ -1089,7 +1666,7 @@ private static GeneratorRunResult RunSourceGenerator( ? null : new TestAnalyzerConfigOptionsProvider(globalOptions); - var generator = new OrleansSerializationSourceGenerator(); + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); GeneratorDriver driver = CSharpGeneratorDriver.Create( generators: [generator], optionsProvider: optionsProvider, @@ -1111,13 +1688,271 @@ private static async Task AssertSuccessfulSourceGeneration(string code) var result = RunSourceGenerator(compilation); Assert.Empty(result.Diagnostics); - Assert.Single(result.GeneratedSources); - Assert.Equal($"{projectName}.orleans.g.cs", result.GeneratedSources[0].HintName); - var generatedSource = result.GeneratedSources[0].SourceText.ToString(); + Assert.NotEmpty(result.GeneratedSources); + Assert.All(result.GeneratedSources, generated => + Assert.StartsWith($"{projectName}.orleans.", generated.HintName, StringComparison.Ordinal)); + var generatedSource = ConcatenateGeneratedSources(result); await Verify(generatedSource, extension: "cs").UseDirectory("snapshots"); } + private static string ConcatenateGeneratedSources(GeneratorRunResult result) + { + var assemblyAttributes = new List(); + var topLevelMembers = new List(); + var namespaces = new Dictionary(StringComparer.Ordinal); + + foreach (var item in result.GeneratedSources + .Select(static (source, index) => (Source: source, Index: index)) + .Where(static item => !string.IsNullOrWhiteSpace(item.Source.SourceText.ToString())) + .OrderBy(static item => GetSnapshotSourceOrder(item.Source.HintName)) + .ThenBy(static item => item.Index)) + { + var root = CSharpSyntaxTree.ParseText(item.Source.SourceText.ToString().TrimStart('\uFEFF').TrimEnd()).GetCompilationUnitRoot(); + assemblyAttributes.AddRange(root.AttributeLists); + + foreach (var member in root.Members) + { + if (member is NamespaceDeclarationSyntax namespaceDeclaration) + { + var namespaceName = namespaceDeclaration.Name.ToString(); + if (!namespaces.TryGetValue(namespaceName, out var namespaceMembers)) + { + namespaceMembers = new NamespaceMembers(namespaceDeclaration.Usings); + namespaces.Add(namespaceName, namespaceMembers); + } + + namespaceMembers.Members.AddRange(namespaceDeclaration.Members); + } + else + { + topLevelMembers.Add(member); + } + } + } + + var combinedMembers = new List(); + combinedMembers.AddRange(topLevelMembers); + foreach (var pair in namespaces.OrderBy(static pair => pair.Key, StringComparer.Ordinal)) + { + combinedMembers.Add( + SyntaxFactory.NamespaceDeclaration(SyntaxFactory.ParseName(pair.Key)) + .WithUsings(pair.Value.Usings) + .WithMembers(SyntaxFactory.List(OrderLegacyGeneratedMembers(pair.Value.Members)))); + } + + var unit = SyntaxFactory.CompilationUnit() + .WithAttributeLists(SyntaxFactory.List(assemblyAttributes)) + .WithMembers(SyntaxFactory.List(combinedMembers)); + var resultText = unit.NormalizeWhitespace().ToFullString(); + resultText = resultText.Replace(".Add(typeof(int));", ".Add(typeof( int ));", StringComparison.Ordinal); + if (assemblyAttributes.Count > 0) + { + resultText += $"{Environment.NewLine}#pragma warning restore CS1591, RS0016, RS0041"; + } + + return resultText; + + static int GetSnapshotSourceOrder(string hintName) + { + if (hintName.Contains(".orleans.proxy.", StringComparison.Ordinal)) + { + return 0; + } + + if (hintName.Contains(".orleans.ser.", StringComparison.Ordinal)) + { + return 1; + } + + if (hintName.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)) + { + return 3; + } + + return 2; + } + } + + private static List OrderLegacyGeneratedMembers(List members) + { + var serializerOrder = GetSerializerOrder(members); + return members + .Select(static (member, index) => (Member: member, Index: index)) + .OrderBy(static item => GetLegacyMemberCategory(item.Member)) + .ThenBy(item => GetSerializerOrderIndex(item.Member, serializerOrder)) + .ThenBy(static item => GetSerializerMemberKindOrder(item.Member)) + .ThenBy(static item => item.Index) + .Select(static item => item.Member) + .ToList(); + } + + private static Dictionary GetSerializerOrder(List members) + { + var result = new Dictionary(StringComparer.Ordinal); + foreach (var metadataClass in members + .OfType() + .Where(static declaration => declaration.Identifier.ValueText.StartsWith("Metadata_", StringComparison.Ordinal))) + { + foreach (var invocation in metadataClass.DescendantNodes().OfType()) + { + if (invocation.Expression is not MemberAccessExpressionSyntax { Name.Identifier.ValueText: "Add", Expression: MemberAccessExpressionSyntax { Name.Identifier.ValueText: "Serializers" } } + || invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression is not TypeOfExpressionSyntax typeOfExpression) + { + continue; + } + + var serializerClassName = GetGeneratedClassIdentifier(typeOfExpression.Type.ToString().Split('.').Last()); + if (serializerClassName.StartsWith("Codec_", StringComparison.Ordinal)) + { + result.TryAdd(serializerClassName["Codec_".Length..], result.Count); + } + } + } + + return result; + } + + private static string GetGeneratedClassIdentifier(string typeName) + { + var genericMarkerIndex = typeName.IndexOf('<'); + return genericMarkerIndex >= 0 ? typeName[..genericMarkerIndex] : typeName; + } + + private static int GetLegacyMemberCategory(MemberDeclarationSyntax member) + { + if (member is not ClassDeclarationSyntax classDeclaration) + { + return 1; + } + + var className = classDeclaration.Identifier.ValueText; + if (className.StartsWith("Invokable_", StringComparison.Ordinal) + || className.StartsWith("Proxy_", StringComparison.Ordinal)) + { + return 0; + } + + if (className.StartsWith("Metadata_", StringComparison.Ordinal)) + { + return 3; + } + + return 1; + } + + private static int GetSerializerOrderIndex(MemberDeclarationSyntax member, Dictionary serializerOrder) + { + if (member is not ClassDeclarationSyntax classDeclaration) + { + return int.MaxValue; + } + + var className = classDeclaration.Identifier.ValueText; + var serializableTypeName = GetSerializableTypeName(className); + if (serializableTypeName is not null && serializerOrder.TryGetValue(serializableTypeName, out var order)) + { + return order; + } + + return int.MaxValue; + } + + private static int GetSerializerMemberKindOrder(MemberDeclarationSyntax member) + { + if (member is not ClassDeclarationSyntax classDeclaration) + { + return 0; + } + + var className = classDeclaration.Identifier.ValueText; + if (className.StartsWith("Codec_", StringComparison.Ordinal)) + { + return 0; + } + + if (className.StartsWith("Copier_", StringComparison.Ordinal)) + { + return 1; + } + + if (className.StartsWith("Activator_", StringComparison.Ordinal)) + { + return 2; + } + + return 0; + } + + private static string? GetSerializableTypeName(string generatedClassName) + { + if (generatedClassName.StartsWith("Codec_", StringComparison.Ordinal)) + { + return generatedClassName["Codec_".Length..]; + } + + if (generatedClassName.StartsWith("Copier_", StringComparison.Ordinal)) + { + return generatedClassName["Copier_".Length..]; + } + + if (generatedClassName.StartsWith("Activator_", StringComparison.Ordinal)) + { + return generatedClassName["Activator_".Length..]; + } + + return null; + } + + private static int CountGeneratedInvokableClasses(string source) + => source.Split(Environment.NewLine) + .Count(line => line.Contains("public sealed class Invokable_", StringComparison.Ordinal)); + + private static int CountOccurrences(string value, string substring) + { + var count = 0; + var index = 0; + while ((index = value.IndexOf(substring, index, StringComparison.Ordinal)) >= 0) + { + count++; + index += substring.Length; + } + + return count; + } + + private static string[] GetGeneratedClassNames(GeneratorRunResult result, string hintNameFragment, string classNamePrefix) + => result.GeneratedSources + .Where(source => source.HintName.Contains(hintNameFragment, StringComparison.Ordinal)) + .SelectMany(static source => CSharpSyntaxTree.ParseText(source.SourceText.ToString().TrimStart('\uFEFF')).GetCompilationUnitRoot() + .DescendantNodes() + .OfType()) + .Select(static declaration => declaration.Identifier.ValueText) + .Where(name => name.StartsWith(classNamePrefix, StringComparison.Ordinal)) + .Distinct(StringComparer.Ordinal) + .OrderBy(static name => name, StringComparer.Ordinal) + .ToArray(); + + private static string[] GetRegisteredGeneratedInvokableActivatorNames(GeneratorRunResult result) + => result.GeneratedSources + .Where(static source => source.HintName.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)) + .SelectMany(static source => CSharpSyntaxTree.ParseText(source.SourceText.ToString().TrimStart('\uFEFF')).GetCompilationUnitRoot() + .DescendantNodes() + .OfType()) + .Where(static invocation => + invocation.Expression is MemberAccessExpressionSyntax + { + Name.Identifier.ValueText: "Add", + Expression: MemberAccessExpressionSyntax { Name.Identifier.ValueText: "Activators" } + }) + .Select(static invocation => invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression) + .OfType() + .Select(static typeOfExpression => GetGeneratedClassIdentifier(typeOfExpression.Type.ToString().Split('.').Last())) + .Where(static name => name.StartsWith("Activator_Invokable_", StringComparison.Ordinal)) + .Distinct(StringComparer.Ordinal) + .OrderBy(static name => name, StringComparer.Ordinal) + .ToArray(); + private sealed class TestAnalyzerConfigOptionsProvider : AnalyzerConfigOptionsProvider { private static readonly AnalyzerConfigOptions EmptyOptions = new TestAnalyzerConfigOptions(new Dictionary()); @@ -1147,43 +1982,18 @@ public TestAnalyzerConfigOptions(IReadOnlyDictionary options) public override bool TryGetValue(string key, out string value) => _options.TryGetValue(key, out value!); } + private sealed class NamespaceMembers(SyntaxList usings) + { + public SyntaxList Usings { get; } = usings; + + public List Members { get; } = []; + } + /// /// Creates a Roslyn compilation with the necessary Orleans references. /// This simulates the build environment where the source generator runs, /// including all required Orleans assemblies and .NET framework references. /// - private static async Task CreateCompilation(string sourceCode, string assemblyName = "TestProject") - { -#if NET10_0_OR_GREATER - // Manually construct .NET 10.0 reference assemblies - var net10References = new ReferenceAssemblies( - "net10.0", - new PackageIdentity("Microsoft.NETCore.App.Ref", "10.0.0"), - Path.Combine("ref", "net10.0")); - - var references = await net10References.ResolveAsync(LanguageNames.CSharp, default); -#else - var references = await ReferenceAssemblies.Net.Net80.ResolveAsync(LanguageNames.CSharp, default); -#endif - - // Add the Orleans Orleans.Core.Abstractions assembly - references = references.AddRange( - // Orleans.Core.Abstractions - MetadataReference.CreateFromFile(typeof(GrainId).Assembly.Location), - // Orleans.Core - MetadataReference.CreateFromFile(typeof(IClusterClientLifecycle).Assembly.Location), - // Orleans.Runtime - MetadataReference.CreateFromFile(typeof(IGrainActivator).Assembly.Location), - // Orleans.Serialization - MetadataReference.CreateFromFile(typeof(Serializer).Assembly.Location), - // Orleans.Serialization.Abstractions - MetadataReference.CreateFromFile(typeof(GenerateFieldIds).Assembly.Location), - // Microsoft.Extensions.DependencyInjection.Abstractions - MetadataReference.CreateFromFile(typeof(ActivatorUtilitiesConstructorAttribute).Assembly.Location) - ); - - var syntaxTree = CSharpSyntaxTree.ParseText(sourceCode); - - return CSharpCompilation.Create(assemblyName, [syntaxTree], references, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); - } + private static Task CreateCompilation(string sourceCode, string assemblyName = "TestProject") + => TestCompilationHelper.CreateCompilation(sourceCode, assemblyName); } diff --git a/test/Orleans.CodeGenerator.Tests/PartialDeclarationDeduplicationTests.cs b/test/Orleans.CodeGenerator.Tests/PartialDeclarationDeduplicationTests.cs new file mode 100644 index 00000000000..5b703d746b5 --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/PartialDeclarationDeduplicationTests.cs @@ -0,0 +1,290 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Diagnostics; + +namespace Orleans.CodeGenerator.Tests; + +public class PartialDeclarationDeduplicationTests +{ + [Fact] + public async Task PartialGenerateSerializerAttributesProduceOneSerializerArtifactSetAndMetadataSet() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public partial class PartialDto + { + [Id(0)] + public string First { get; set; } = string.Empty; + } + + [GenerateSerializer] + public partial class PartialDto + { + [Id(1)] + public int Second { get; set; } + } + """; + + var compilation = await CreateCompilation(code); + var result = RunSourceGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Codec_PartialDto")); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Copier_PartialDto")); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Activator_PartialDto")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Serializers", "Codec_PartialDto")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Copiers", "Copier_PartialDto")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Activators", "Activator_PartialDto")); + } + + [Fact] + public async Task PartialGenerateSerializerAttributesWithInvalidMemberProduceOneLogicalDiagnostic() + { + const string code = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public partial class InvalidPartialDto + { + [Id(0)] + public string Name => "computed"; + } + + [GenerateSerializer] + public partial class InvalidPartialDto + { + [Id(1)] + public int Value { get; set; } + } + """; + + var compilation = await CreateCompilation(code); + var result = RunSourceGenerator(compilation); + + var diagnostics = result.Diagnostics + .Where(static diagnostic => diagnostic.Id == DiagnosticRuleId.InaccessibleSetter) + .ToArray(); + + Assert.Single(diagnostics); + } + + [Fact] + public async Task ReorderedPartialDeclarationsWithInvalidMember_ProducesOneStableDiagnostic() + { + const string invalidPart = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public partial class InvalidPartialDto + { + [Id(0)] + public string Name => "computed"; + } + """; + + const string validPart = """ + using Orleans; + + namespace TestProject; + + [GenerateSerializer] + public partial class InvalidPartialDto + { + [Id(1)] + public int Value { get; set; } + } + """; + + var compilation = await CreateCompilation( + "PartialOrderingProject", + ("InvalidPart.cs", invalidPart), + ("ValidPart.cs", validPart)); + var reorderedCompilation = await CreateCompilation( + "PartialOrderingProject", + ("ValidPart.cs", validPart), + ("InvalidPart.cs", invalidPart)); + + var result = RunSourceGenerator(compilation); + var reorderedResult = RunSourceGenerator(reorderedCompilation); + + var diagnostics = result.Diagnostics + .Where(static diagnostic => diagnostic.Id == DiagnosticRuleId.InaccessibleSetter) + .ToArray(); + var reorderedDiagnostics = reorderedResult.Diagnostics + .Where(static diagnostic => diagnostic.Id == DiagnosticRuleId.InaccessibleSetter) + .ToArray(); + + Assert.Single(diagnostics); + Assert.Single(reorderedDiagnostics); + AssertDiagnosticsIdentical(diagnostics, reorderedDiagnostics); + AssertGeneratedSourcesIdentical(result, reorderedResult); + } + + [Fact] + public async Task DirectAndInheritedGenerateMethodSerializersDiscoveryProducesOneProxyAndMetadataSet() + { + const string code = """ + using Orleans; + using Orleans.Runtime; + using System.Threading.Tasks; + + namespace TestProject; + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IBaseGrain : IGrainWithIntegerKey + { + Task Ping(); + } + + [GenerateMethodSerializers(typeof(GrainReference))] + public interface IDerivedGrain : IBaseGrain + { + Task Pong(); + } + """; + + var compilation = await CreateCompilation(code); + Assert.Empty(compilation.GetDiagnostics().Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var result = RunSourceGenerator(compilation); + + Assert.Empty(result.Diagnostics); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Proxy_IDerivedGrain")); + Assert.Equal(1, CountGeneratedClassDeclarationsWithPrefix(result, "Invokable_IDerivedGrain_GrainReference_")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "InterfaceProxies", "Proxy_IDerivedGrain")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Interfaces", "IDerivedGrain")); + } + + private static GeneratorRunResult RunSourceGenerator(CSharpCompilation compilation) + { + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions(default)); + + driver = driver.RunGenerators(compilation); + return driver.GetRunResult().Results.Single(); + } + + private static int CountGeneratedClassDeclarations(GeneratorRunResult result, string className) + => GetGeneratedClassDeclarations(result) + .Count(declaration => string.Equals(declaration.Identifier.ValueText, className, StringComparison.Ordinal)); + + private static int CountGeneratedClassDeclarationsWithPrefix(GeneratorRunResult result, string classNamePrefix) + => GetGeneratedClassDeclarations(result) + .Count(declaration => declaration.Identifier.ValueText.StartsWith(classNamePrefix, StringComparison.Ordinal)); + + private static IEnumerable GetGeneratedClassDeclarations(GeneratorRunResult result) + { + foreach (var root in GetGeneratedCompilationUnits(result)) + { + foreach (var declaration in root.DescendantNodes().OfType()) + { + yield return declaration; + } + } + } + + private static int CountMetadataTypeRegistrations(GeneratorRunResult result, string collectionName, string registeredTypeName) + => GetGeneratedCompilationUnits(result) + .Where(static root => root.SyntaxTree.FilePath.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)) + .SelectMany(static root => root.DescendantNodes().OfType()) + .Count(invocation => IsMetadataRegistration(invocation, collectionName, registeredTypeName)); + + private static bool IsMetadataRegistration(InvocationExpressionSyntax invocation, string collectionName, string registeredTypeName) + { + if (invocation.Expression is not MemberAccessExpressionSyntax addExpression + || !string.Equals(addExpression.Name.Identifier.ValueText, "Add", StringComparison.Ordinal) + || addExpression.Expression is not MemberAccessExpressionSyntax collectionExpression + || !string.Equals(collectionExpression.Name.Identifier.ValueText, collectionName, StringComparison.Ordinal) + || invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression is not TypeOfExpressionSyntax typeOfExpression) + { + return false; + } + + var typeName = typeOfExpression.Type.ToString().Split('.').Last(); + return string.Equals(typeName, registeredTypeName, StringComparison.Ordinal); + } + + private static IEnumerable GetGeneratedCompilationUnits(GeneratorRunResult result) + { + foreach (var source in result.GeneratedSources) + { + var sourceText = source.SourceText.ToString().TrimStart('\uFEFF'); + if (string.IsNullOrWhiteSpace(sourceText)) + { + continue; + } + + var tree = CSharpSyntaxTree.ParseText(sourceText, path: source.HintName); + yield return tree.GetCompilationUnitRoot(); + } + } + + private static Task CreateCompilation(string sourceCode, string assemblyName = "TestProject") + => TestCompilationHelper.CreateCompilation(sourceCode, assemblyName); + + private static async Task CreateCompilation( + string assemblyName, + params (string Path, string Source)[] sources) + { + Assert.NotEmpty(sources); + + var compilation = await TestCompilationHelper.CreateCompilation(sources[0].Source, assemblyName); + var firstTree = CSharpSyntaxTree.ParseText(sources[0].Source, path: sources[0].Path); + compilation = compilation.ReplaceSyntaxTree(compilation.SyntaxTrees.Single(), firstTree); + + if (sources.Length == 1) + { + return compilation; + } + + return compilation.AddSyntaxTrees(sources.Skip(1).Select(static source => CSharpSyntaxTree.ParseText(source.Source, path: source.Path))); + } + + private static void AssertDiagnosticsIdentical(IEnumerable diagnostics, IEnumerable otherDiagnostics) + => Assert.Equal( + diagnostics.Select(GetDiagnosticShape).OrderBy(static value => value, StringComparer.Ordinal), + otherDiagnostics.Select(GetDiagnosticShape).OrderBy(static value => value, StringComparer.Ordinal)); + + private static string GetDiagnosticShape(Diagnostic diagnostic) + { + var lineSpan = diagnostic.Location.GetLineSpan(); + return string.Join( + "|", + diagnostic.Id, + diagnostic.Severity.ToString(), + diagnostic.GetMessage(), + lineSpan.Path ?? string.Empty, + lineSpan.StartLinePosition.Line.ToString(), + lineSpan.StartLinePosition.Character.ToString()); + } + + private static void AssertGeneratedSourcesIdentical(GeneratorRunResult result, GeneratorRunResult other) + { + var sources = GetGeneratedSourceMap(result); + var otherSources = GetGeneratedSourceMap(other); + + Assert.Equal(sources.Count, otherSources.Count); + + foreach (var (hintName, sourceText) in sources) + { + Assert.True(otherSources.TryGetValue(hintName, out var otherSourceText), $"Missing generated source '{hintName}'."); + Assert.Equal(sourceText, otherSourceText); + } + } + + private static SortedDictionary GetGeneratedSourceMap(GeneratorRunResult result) + => new( + result.GeneratedSources.ToDictionary(source => source.HintName, source => source.SourceText.ToString(), StringComparer.Ordinal), + StringComparer.Ordinal); +} diff --git a/test/Orleans.CodeGenerator.Tests/ReferencedAssemblyDiagnosticParityTests.cs b/test/Orleans.CodeGenerator.Tests/ReferencedAssemblyDiagnosticParityTests.cs new file mode 100644 index 00000000000..718581976b8 --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/ReferencedAssemblyDiagnosticParityTests.cs @@ -0,0 +1,225 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Orleans.CodeGenerator.Diagnostics; + +namespace Orleans.CodeGenerator.Tests; + +public class ReferencedAssemblyDiagnosticParityTests +{ + [Fact] + public async Task GenerateCodeForDeclaringAssembly_ReportsInaccessibleSerializableTypesFromReferencedAssembly() + { + const string libraryCode = """ + using Orleans; + + namespace LibraryProject; + + public sealed class Marker + { + } + + [GenerateSerializer] + internal sealed class InternalDto + { + [Id(0)] + public string Value { get; set; } = string.Empty; + } + """; + + const string consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Marker))] + """; + + var result = await RunSourceGeneratorForConsumer(libraryCode, consumerCode); + var diagnostic = Assert.Single(result.Diagnostics, diagnostic => diagnostic.Id == InaccessibleSerializableTypeDiagnostic.RuleId); + + Assert.Equal(DiagnosticSeverity.Error, diagnostic.Severity); + Assert.Equal(Location.None, diagnostic.Location); + Assert.Contains("InternalDto", diagnostic.GetMessage(), StringComparison.Ordinal); + } + + [Fact] + public async Task GenerateCodeForDeclaringAssembly_ReportsImplicitFieldIdFailuresFromReferencedAssembly() + { + const string libraryCode = """ + using Orleans; + + namespace LibraryProject; + + public sealed class Marker + { + } + + [GenerateSerializer] + public sealed class AutoDto + { + public string Value { get; set; } = string.Empty; + public int Count { get; set; } + } + """; + + const string consumerCode = """ + using Orleans; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Marker))] + """; + + var result = await RunSourceGeneratorForConsumer(libraryCode, consumerCode); + var diagnostic = Assert.Single(result.Diagnostics, diagnostic => diagnostic.Id == CanNotGenerateImplicitFieldIdsDiagnostic.DiagnosticId); + + Assert.Equal(DiagnosticSeverity.Error, diagnostic.Severity); + Assert.Equal(Location.None, diagnostic.Location); + Assert.Contains("AutoDto", diagnostic.GetMessage(), StringComparison.Ordinal); + } + + [Fact] + public async Task GenerateCodeForDeclaringAssembly_EmitsReferencedSerializersLocalProxiesAndMetadataOnce() + { + const string libraryCode = """ + using Orleans; + + namespace LibraryProject; + + public sealed class Marker + { + } + + [GenerateSerializer] + public sealed class PublicReferencedDto + { + [Id(0)] + public string Value { get; set; } = string.Empty; + } + """; + + const string consumerCode = """ + using Orleans; + using System.Threading.Tasks; + + [assembly: GenerateCodeForDeclaringAssembly(typeof(LibraryProject.Marker))] + + namespace ConsumerProject; + + public interface IConsumerGrain : IGrainWithIntegerKey + { + Task Get(); + } + """; + + var (result, consumerCompilation) = await RunSourceGeneratorForConsumerWithCompilation(libraryCode, consumerCode); + + Assert.Empty(result.Diagnostics); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Codec_PublicReferencedDto")); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Copier_PublicReferencedDto")); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Activator_PublicReferencedDto")); + Assert.Equal(1, CountGeneratedClassDeclarations(result, "Proxy_IConsumerGrain")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Serializers", "Codec_PublicReferencedDto")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Copiers", "Copier_PublicReferencedDto")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Activators", "Activator_PublicReferencedDto")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "InterfaceProxies", "Proxy_IConsumerGrain")); + Assert.Equal(1, CountMetadataTypeRegistrations(result, "Interfaces", "IConsumerGrain")); + + var outputCompilation = consumerCompilation.AddSyntaxTrees(CreateGeneratedSyntaxTrees(result)); + AssertNoCompilationErrors(outputCompilation); + } + + private static async Task RunSourceGeneratorForConsumer( + string libraryCode, + string consumerCode) + => (await RunSourceGeneratorForConsumerWithCompilation(libraryCode, consumerCode)).Result; + + private static async Task<(GeneratorRunResult Result, CSharpCompilation ConsumerCompilation)> RunSourceGeneratorForConsumerWithCompilation( + string libraryCode, + string consumerCode) + { + var libraryCompilation = await TestCompilationHelper.CreateCompilation(libraryCode, "LibraryProject"); + Assert.Empty(libraryCompilation.GetDiagnostics().Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + var consumerCompilation = await TestCompilationHelper.CreateCompilation( + consumerCode, + "ConsumerProject", + libraryCompilation.ToMetadataReference()); + Assert.Empty(consumerCompilation.GetDiagnostics().Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)); + + return (RunSourceGenerator(consumerCompilation), consumerCompilation); + } + + private static GeneratorRunResult RunSourceGenerator(CSharpCompilation compilation) + { + var generator = new OrleansSerializationSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions(default)); + driver = driver.RunGenerators(compilation); + + return driver.GetRunResult().Results.Single(); + } + + private static int CountGeneratedClassDeclarations(GeneratorRunResult result, string className) + => GetGeneratedCompilationUnits(result) + .SelectMany(static root => root.DescendantNodes().OfType()) + .Count(declaration => string.Equals(declaration.Identifier.ValueText, className, StringComparison.Ordinal)); + + private static int CountMetadataTypeRegistrations(GeneratorRunResult result, string collectionName, string registeredTypeName) + => GetGeneratedCompilationUnits(result) + .Where(static root => root.SyntaxTree.FilePath.EndsWith(".orleans.metadata.g.cs", StringComparison.Ordinal)) + .SelectMany(static root => root.DescendantNodes().OfType()) + .Count(invocation => IsMetadataRegistration(invocation, collectionName, registeredTypeName)); + + private static bool IsMetadataRegistration(InvocationExpressionSyntax invocation, string collectionName, string registeredTypeName) + { + if (invocation.Expression is not MemberAccessExpressionSyntax addExpression + || !string.Equals(addExpression.Name.Identifier.ValueText, "Add", StringComparison.Ordinal) + || addExpression.Expression is not MemberAccessExpressionSyntax collectionExpression + || !string.Equals(collectionExpression.Name.Identifier.ValueText, collectionName, StringComparison.Ordinal) + || invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression is not TypeOfExpressionSyntax typeOfExpression) + { + return false; + } + + var typeName = typeOfExpression.Type.ToString().Split('.').Last(); + return string.Equals(GetGeneratedClassIdentifier(typeName), registeredTypeName, StringComparison.Ordinal); + } + + private static IEnumerable GetGeneratedCompilationUnits(GeneratorRunResult result) + { + foreach (var source in result.GeneratedSources) + { + var sourceText = source.SourceText.ToString().TrimStart('\uFEFF'); + if (string.IsNullOrWhiteSpace(sourceText)) + { + continue; + } + + var tree = CSharpSyntaxTree.ParseText(sourceText, path: source.HintName); + yield return tree.GetCompilationUnitRoot(); + } + } + + private static IEnumerable CreateGeneratedSyntaxTrees(GeneratorRunResult result) + { + foreach (var source in result.GeneratedSources) + { + yield return CSharpSyntaxTree.ParseText(source.SourceText, path: source.HintName); + } + } + + private static string GetGeneratedClassIdentifier(string typeName) + { + var genericMarkerIndex = typeName.IndexOf('<'); + return genericMarkerIndex >= 0 ? typeName[..genericMarkerIndex] : typeName; + } + + private static void AssertNoCompilationErrors(Compilation compilation) + { + var errors = compilation.GetDiagnostics() + .Where(static diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .Select(static diagnostic => diagnostic.ToString()) + .ToArray(); + + Assert.True(errors.Length == 0, string.Join(Environment.NewLine, errors)); + } +} diff --git a/test/Orleans.CodeGenerator.Tests/TestCompilationHelper.cs b/test/Orleans.CodeGenerator.Tests/TestCompilationHelper.cs new file mode 100644 index 00000000000..c3a638e111d --- /dev/null +++ b/test/Orleans.CodeGenerator.Tests/TestCompilationHelper.cs @@ -0,0 +1,53 @@ +using System.Collections.Immutable; +using System.IO; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Testing; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization; + +namespace Orleans.CodeGenerator.Tests; + +/// +/// Shared helper for creating Roslyn compilations with the necessary Orleans references. +/// Used across all code generator test files. +/// +internal static class TestCompilationHelper +{ + /// + /// Creates a with the .NET framework and Orleans assembly references. + /// + public static async Task CreateCompilation( + string sourceCode, + string assemblyName = "TestProject", + params MetadataReference[] additionalReferences) + { +#if NET10_0_OR_GREATER + var net10References = new ReferenceAssemblies( + "net10.0", + new PackageIdentity("Microsoft.NETCore.App.Ref", "10.0.0"), + Path.Combine("ref", "net10.0")); + + var references = await net10References.ResolveAsync(LanguageNames.CSharp, default); +#else + var references = await ReferenceAssemblies.Net.Net80.ResolveAsync(LanguageNames.CSharp, default); +#endif + + references = references.AddRange( + MetadataReference.CreateFromFile(typeof(GrainId).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IClusterClientLifecycle).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IGrainActivator).Assembly.Location), + MetadataReference.CreateFromFile(typeof(Serializer).Assembly.Location), + MetadataReference.CreateFromFile(typeof(GenerateFieldIds).Assembly.Location), + MetadataReference.CreateFromFile(typeof(ActivatorUtilitiesConstructorAttribute).Assembly.Location)); + + if (additionalReferences.Length > 0) + { + references = references.AddRange(additionalReferences); + } + + var syntaxTree = CSharpSyntaxTree.ParseText(sourceCode); + return CSharpCompilation.Create(assemblyName, [syntaxTree], references, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + } +}