Skip to content

Commit a5e05d9

Browse files
authored
Update DifferentiationInterface dependency to v0.6 (#181)
1 parent efc29f8 commit a5e05d9

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
1414

1515
[compat]
1616
ADTypes = "1"
17-
DifferentiationInterface = "0.5"
17+
DifferentiationInterface = "0.6"
1818
Distributions = "0.25"
1919
Random = "<0.0.1, 1"
2020
Reexport = "1"

src/gradient.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ function gradient_wrt_input(
1818
dy = zero(output)
1919
dy[output_selection] .= 1
2020

21-
output, grad = value_and_pullback(model, backend, input, dy)
21+
output, pbs = value_and_pullback(model, backend, input, tuple(dy))
22+
grad = only(pbs)
2223
return grad, output, output_selection
2324
end
2425

0 commit comments

Comments
 (0)