From 76f09d0afdb9cfaca2527050c85b290ab0836aaa Mon Sep 17 00:00:00 2001 From: spielman Date: Thu, 8 Jan 2026 09:40:00 -0500 Subject: [PATCH 1/2] Add support for recursing Chains in ReactantCompatibleOptimisers --- src/helpers/optimizers.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl index f930e5a245..d4665a8fe9 100644 --- a/src/helpers/optimizers.jl +++ b/src/helpers/optimizers.jl @@ -47,6 +47,22 @@ function Optimisers.init( return zero(x), zero(x), Utils.convert_eltype.((T,), opt.opt.beta), zero(x) end +# Recurse through chains +function Optimisers.adjust( + chain::ReactantOptimiser{<:Optimisers.OptimiserChain}, + eta::Real +) + results = Optimisers.OptimiserChain([Optimisers.adjust(opt, eta) for opt in chain.opt.opts]...) + return ReactantOptimiser(results) +end +function Optimisers.adjust( + chain:ReactantOptimiser{<:Optimisers.OptimiserChain}; + kw... +) + results = Optimisers.OptimiserChain([Optimisers.adjust(opt; kw...) for opt in chain.opt.opts]...) + return ReactantOptimiser(results) +end + function Optimisers._adjust(opt::ReactantOptimiser, nt::NamedTuple) dev = with_track_numbers(get_device(opt), AbstractFloat) return ReactantOptimiser(Optimisers._adjust(opt.opt, dev(nt))) From ed66c0bbd75ca0f504a056609227bcad276bff96 Mon Sep 17 00:00:00 2001 From: spielman Date: Thu, 8 Jan 2026 09:47:33 -0500 Subject: [PATCH 2/2] Add missing : --- src/helpers/optimizers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/helpers/optimizers.jl b/src/helpers/optimizers.jl index d4665a8fe9..0261e6603b 100644 --- a/src/helpers/optimizers.jl +++ b/src/helpers/optimizers.jl @@ -56,7 +56,7 @@ function Optimisers.adjust( return ReactantOptimiser(results) end function Optimisers.adjust( - chain:ReactantOptimiser{<:Optimisers.OptimiserChain}; + chain::ReactantOptimiser{<:Optimisers.OptimiserChain}; kw... ) results = Optimisers.OptimiserChain([Optimisers.adjust(opt; kw...) for opt in chain.opt.opts]...)