Skip to content

Commit 0bb2b17

Browse files
authored
Use DifferentiationInterface for gradient-based analyzers (#167)
1 parent aaa3c72 commit 0bb2b17

File tree

5 files changed

+70
-18
lines changed

5 files changed

+70
-18
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# ExplainableAI.jl
2+
## Version `v0.9.0`
3+
- ![Feature][badge-feature] Support selection of AD backend via DifferentiationInterface.jl ([#167])
4+
- `Gradient`, `InputTimesGradient` and `GradCAM` analyzers now have an additional `backend` field and type parameter
5+
- ![Maintenance][badge-maintenance] Update XAIBase interface to v4 ([#166])
6+
27
## Version `v0.8.0`
38
This release removes the automatic reexport of heatmapping functionality.
49
Users are now required to manually load
@@ -210,6 +215,8 @@ Performance improvements:
210215
[VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/
211216
[TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/
212217

218+
[#167]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/167
219+
[#166]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/166
213220
[#162]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/162
214221
[#159]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/159
215222
[#157]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/157

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
name = "ExplainableAI"
22
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
33
authors = ["Adrian Hill <gh@adrianhill.de>"]
4-
version = "0.8.1"
4+
version = "0.9.0-DEV"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
79
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
810
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
911
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -12,6 +14,8 @@ XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
1214
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1315

1416
[compat]
17+
ADTypes = "1"
18+
DifferentiationInterface = "0.5"
1519
Distributions = "0.25"
1620
Random = "<0.0.1, 1"
1721
Reexport = "1"

src/ExplainableAI.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ import XAIBase: call_analyzer
77
using Base.Iterators
88
using Distributions: Distribution, Sampleable, Normal
99
using Random: AbstractRNG, GLOBAL_RNG
10+
11+
# Automatic differentiation
12+
using ADTypes: AbstractADType, AutoZygote
13+
using DifferentiationInterface: value_and_pullback
1014
using Zygote
15+
const DEFAULT_AD_BACKEND = AutoZygote()
1116

1217
include("compat.jl")
1318
include("bibliography.jl")

src/gradcam.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,25 @@ GradCAM is compatible with a wide variety of CNN model-families.
1515
# References
1616
- $REF_SELVARAJU_GRADCAM
1717
"""
18-
struct GradCAM{F,A} <: AbstractXAIMethod
18+
struct GradCAM{F,A,B<:AbstractADType} <: AbstractXAIMethod
1919
feature_layers::F
2020
adaptation_layers::A
21+
backend::B
22+
23+
function GradCAM(
24+
feature_layers::F, adaptation_layers::A, backend::B=DEFAULT_AD_BACKEND
25+
) where {F,A,B<:AbstractADType}
26+
new{F,A,B}(feature_layers, adaptation_layers, backend)
27+
end
2128
end
2229
function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...)
2330
A = analyzer.feature_layers(input) # feature map
2431
feature_map_size = size(A, 1) * size(A, 2)
2532

2633
# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
27-
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
34+
grad, output, output_indices = gradient_wrt_input(
35+
analyzer.adaptation_layers, A, ns, analyzer.backend
36+
)
2837
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
2938
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
3039
return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing)

src/gradient.jl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,45 @@
1-
function gradient_wrt_input(model, input, ns::AbstractOutputSelector)
2-
output, back = Zygote.pullback(model, input)
3-
output_indices = ns(output)
4-
5-
# Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons
6-
v = zero(output)
7-
v[output_indices] .= 1
8-
grad = only(back(v))
9-
return grad, output, output_indices
1+
function forward_with_output_selection(model, input, selector::AbstractOutputSelector)
2+
output = model(input)
3+
sel = selector(output)
4+
return output[sel]
5+
end
6+
7+
function gradient_wrt_input(
8+
model, input, output_selector::AbstractOutputSelector, backend::AbstractADType
9+
)
10+
output = model(input)
11+
return gradient_wrt_input(model, input, output, output_selector, backend)
12+
end
13+
14+
function gradient_wrt_input(
15+
model, input, output, output_selector::AbstractOutputSelector, backend::AbstractADType
16+
)
17+
output_selection = output_selector(output)
18+
dy = zero(output)
19+
dy[output_selection] .= 1
20+
21+
output, grad = value_and_pullback(model, backend, input, dy)
22+
return grad, output, output_selection
1023
end
1124

1225
"""
1326
Gradient(model)
1427
1528
Analyze model by calculating the gradient of a neuron activation with respect to the input.
1629
"""
17-
struct Gradient{M} <: AbstractXAIMethod
30+
struct Gradient{M,B<:AbstractADType} <: AbstractXAIMethod
1831
model::M
19-
Gradient(model) = new{typeof(model)}(model)
32+
backend::B
33+
34+
function Gradient(model::M, backend::B=DEFAULT_AD_BACKEND) where {M,B<:AbstractADType}
35+
new{M,B}(model, backend)
36+
end
2037
end
2138

2239
function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...)
23-
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
40+
grad, output, output_indices = gradient_wrt_input(
41+
analyzer.model, input, ns, analyzer.backend
42+
)
2443
return Explanation(
2544
grad, input, output, output_indices, :Gradient, :sensitivity, nothing
2645
)
@@ -32,15 +51,23 @@ end
3251
Analyze model by calculating the gradient of a neuron activation with respect to the input.
3352
This gradient is then multiplied element-wise with the input.
3453
"""
35-
struct InputTimesGradient{M} <: AbstractXAIMethod
54+
struct InputTimesGradient{M,B<:AbstractADType} <: AbstractXAIMethod
3655
model::M
37-
InputTimesGradient(model) = new{typeof(model)}(model)
56+
backend::B
57+
58+
function InputTimesGradient(
59+
model::M, backend::B=DEFAULT_AD_BACKEND
60+
) where {M,B<:AbstractADType}
61+
new{M,B}(model, backend)
62+
end
3863
end
3964

4065
function call_analyzer(
4166
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
4267
)
43-
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
68+
grad, output, output_indices = gradient_wrt_input(
69+
analyzer.model, input, ns, analyzer.backend
70+
)
4471
attr = input .* grad
4572
return Explanation(
4673
attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing

0 commit comments

Comments
 (0)