Skip to content

Commit 974a934

Browse files
authored
Add GradCAM analyzer (#155)
1 parent ac1356d commit 974a934

File tree

9 files changed

+39
-5
lines changed

9 files changed

+39
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Currently, the following analyzers are implemented:
101101
* `InputTimesGradient`
102102
* `SmoothGrad`
103103
* `IntegratedGradients`
104+
* `GradCAM`
104105

105106
One of the design goals of the [Julia-XAI ecosystem][juliaxai-docs] is extensibility.
106107
To implement an XAI method, take a look at the [common interface

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Gradient
1212
InputTimesGradient
1313
SmoothGrad
1414
IntegratedGradients
15+
GradCAM
1516
```
1617

1718
# Input augmentations

src/ExplainableAI.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ include("compat.jl")
1212
include("bibliography.jl")
1313
include("input_augmentation.jl")
1414
include("gradient.jl")
15+
include("gradcam.jl")
1516

1617
export Gradient, InputTimesGradient
1718
export NoiseAugmentation, SmoothGrad
1819
export InterpolationAugmentation, IntegratedGradients
20+
export GradCAM
1921

2022
end # module

src/bibliography.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Gradient methods:
22
const REF_SMILKOV_SMOOTHGRAD = "Smilkov et al., *SmoothGrad: removing noise by adding noise*"
33
const REF_SUNDARARAJAN_AXIOMATIC = "Sundararajan et al., *Axiomatic Attribution for Deep Networks*"
4+
const REF_SELVARAJU_GRADCAM = "Selvaraju et al., *Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization*"

src/gradcam.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""
2+
GradCAM(feature_layers, adaptation_layers)
3+
4+
Calculates the Gradient-weighted Class Activation Map (GradCAM).
5+
GradCAM provides a visual explanation of the regions with significant neuron importance for the model's classification decision.
6+
7+
# Parameters
8+
- `feature_layers`: The layers of a convolutional neural network (CNN) responsible for extracting feature maps.
9+
- `adaptation_layers`: The layers of the CNN used for adaptation and classification.
10+
11+
# Note
12+
Flux is not required for GradCAM.
13+
GradCAM is compatible with a wide variety of CNN model-families.
14+
15+
# References
16+
- $REF_SELVARAJU_GRADCAM
17+
"""
18+
struct GradCAM{F,A} <: AbstractXAIMethod
19+
feature_layers::F
20+
adaptation_layers::A
21+
end
22+
function (analyzer::GradCAM)(input, ns::AbstractNeuronSelector)
23+
A = analyzer.feature_layers(input) # feature map
24+
feature_map_size = size(A, 1) * size(A, 2)
25+
26+
# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
27+
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
28+
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
29+
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
30+
return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing)
31+
end

test/references/cnn/GradCAM_max.jld2

1.03 KB
Binary file not shown.

test/references/cnn/GradCAM_ns1.jld2

1.03 KB
Binary file not shown.

test/test_batches.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ ins = 20
88
outs = 10
99
batchsize = 15
1010

11-
model = Chain(Dense(ins, outs, relu; init=pseudorand))
11+
model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand))
1212

1313
# Input 1 w/o batch dimension
1414
input1_no_bd = rand(MersenneTwister(1), Float32, ins)
@@ -24,6 +24,7 @@ ANALYZERS = Dict(
2424
"InputTimesGradient" => InputTimesGradient,
2525
"SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)),
2626
"IntegratedGradients" => m -> IntegratedGradients(m, 5),
27+
"GradCAM" => m -> GradCAM(m[1], m[2]),
2728
)
2829

2930
for (name, method) in ANALYZERS

test/test_cnn.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ const GRADIENT_ANALYZERS = Dict(
66
"InputTimesGradient" => InputTimesGradient,
77
"SmoothGrad" => m -> SmoothGrad(m, 5, 0.1, MersenneTwister(123)),
88
"IntegratedGradients" => m -> IntegratedGradients(m, 5),
9+
"GradCAM" => m -> GradCAM(m[1], m[2]),
910
)
1011

1112
input_size = (32, 32, 3, 1)
@@ -67,17 +68,13 @@ function test_cnn(name, method)
6768
println("Timing $name...")
6869
print("cold:")
6970
@time expl = analyze(input, analyzer)
70-
71-
@test size(expl.val) == size(input)
7271
@test_reference "references/cnn/$(name)_max.jld2" Dict("expl" => expl.val) by =
7372
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
7473
end
7574
@testset "Neuron selection" begin
7675
analyzer = method(model)
7776
print("warm:")
7877
@time expl = analyze(input, analyzer, 1)
79-
80-
@test size(expl.val) == size(input)
8178
@test_reference "references/cnn/$(name)_ns1.jld2" Dict("expl" => expl.val) by =
8279
(r, a) -> isapprox(r["expl"], a["expl"]; rtol=0.05)
8380
end

0 commit comments

Comments
 (0)