Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/SplineInterpolations/smoothing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,25 @@ function StatsAPI.fit(
else # multidimensional data
Z = eltype(eltype(ys)) # for example, if eltype(ys) <: SVector{D,Z}
@assert Z <: Number
cs_lin = reinterpret(reshape, Z, cs)'
cs_lin = reinterpret(reshape, Z, cs)
cs_lin = reshape(cs_lin, reverse(size(cs_lin))) # produce a <:StridedMatrix
zs = copy(reinterpret(reshape, Z, ys)') # dimensions (N, D)
end
if weights !== nothing
eachindex(weights) == eachindex(xs) || throw(DimensionMismatch("the `weights` vector must have the same length as the data"))
lmul!(Diagonal(weights), zs) # zs = W * ys
end
mul!(cs_lin, A', zs) # cs = A' * (W * ys)
lmul!(2, cs_lin) # cs = 2 * A' * (W * ys)
# specialized sparse matmul requires the destination array (i.e. `cs_lin`) to be
# <:StridedMatrix, otherwise we get a generic matmul
mul!(cs_lin, A', zs, 2, false) # cs = 2 * A' * (W * ys)

cs_lin .= F \ cs_lin # solve linear system (allocates intermediate array)
# solve linear system (allocates intermediate array)
if ndims(Y) == 0
cs_lin .= F \ cs_lin
else
# undo the reshaped reinterpret
cs .= reinterpret(reshape, eltype(ys), (F \ cs_lin)')
end

Spline(R, cs)
end