Skip to content

Commit eefe448

Browse files
authored
Merge pull request #465 from mcabbott/patch-1
Splat arguments in gradient `rrule`
2 parents 69812cd + 8fb478c commit eefe448

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/chainrules/chainrules.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
export rrule
22
"""
33
ChainRulesCore.rrule(itp::AbstractInterpolation, x...)
4+
45
ChainRulesCore.jl `rrule` for integration with automatic differentiation libraries.
6+
Note that it gives the gradient only with respect to the evaluation point `x`,
7+
and not the data inside `itp`.
58
"""
69
function ChainRulesCore.rrule(itp::AbstractInterpolation, x...)
710
y = itp(x...)
8-
function pullback(Δy)
9-
(ChainRulesCore.NoTangent(), Δy * Interpolations.gradient(itp, x...))
11+
function interpolate_pullback(Δy)
12+
nope = ChainRulesCore.@not_implemented "`Interpolations.gradient` does not calculate a gradient with respect to the original data, only the evaluation point"
13+
(nope, (Δy * Interpolations.gradient(itp, x...))...)
1014
end
11-
y, pullback
12-
end
15+
y, interpolate_pullback
16+
end

test/chainrules.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ using Zygote
99
itp = interpolate(y,BSpline(Linear()))
1010

1111
if VERSION v"1.3"
12-
@test Zygote.gradient(itp, 1)[1] == Interpolations.gradient(itp, 1)
12+
@test Zygote.gradient(itp, 1) == Tuple(Interpolations.gradient(itp, 1))
1313
else
14-
@test_skip Zygote.gradient(itp, 1)[1] == Interpolations.gradient(itp, 1)
14+
@test_skip Zygote.gradient(itp, 1) == Tuple(Interpolations.gradient(itp, 1))
1515
end
1616

1717
# 2D example
@@ -21,8 +21,8 @@ using Zygote
2121
itp2 = interpolate(y2, BSpline(Cubic(Line(OnGrid()))))
2222

2323
if VERSION v"1.3"
24-
@test Zygote.gradient(itp2,1,2)[1] == Interpolations.gradient(itp2,1,2)
24+
@test Zygote.gradient(itp2,1,2) == Tuple(Interpolations.gradient(itp2,1,2))
2525
else
26-
@test_skip Zygote.gradient(itp2,1,2)[1] == Interpolations.gradient(itp2,1,2)
26+
@test_skip Zygote.gradient(itp2,1,2) == Tuple(Interpolations.gradient(itp2,1,2))
2727
end
28-
end
28+
end

0 commit comments

Comments
 (0)