Skip to content

Commit 12b3118

Browse files
authored
Add Concept Relevance Propagation (#146)
* add `CRP` analyzer that wraps `LRP` * add two concept selectors: `TopNConcepts` and `IndexedConcepts` * add `process_batch` argument to `heatmap`
1 parent cbb2906 commit 12b3118

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+528
-187
lines changed

README.md

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using ImageInTerminal # show heatmap in terminal
3333
# Load model
3434
model = VGG(16, pretrain=true).layers
3535
model = strip_softmax(model)
36+
model = canonize(model)
3637

3738
# Load input
3839
url = HTTP.URI("https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
@@ -87,30 +88,29 @@ Check out our talk at JuliaCon 2022 for a demonstration of the package.
8788
## Methods
8889
Currently, the following analyzers are implemented:
8990

90-
```
91-
├── Gradient
92-
├── InputTimesGradient
93-
├── SmoothGrad
94-
├── IntegratedGradients
95-
└── LRP
96-
├── Rules
97-
│ ├── ZeroRule
98-
│ ├── EpsilonRule
99-
│ ├── GammaRule
100-
│ ├── GeneralizedGammaRule
101-
│ ├── WSquareRule
102-
│ ├── FlatRule
103-
│ ├── ZBoxRule
104-
│ ├── ZPlusRule
105-
│ ├── AlphaBetaRule
106-
│ └── PassRule
107-
└── Composite
108-
├── EpsilonGammaBox
109-
├── EpsilonPlus
110-
├── EpsilonPlusFlat
111-
├── EpsilonAlpha2Beta1
112-
└── EpsilonAlpha2Beta1Flat
113-
```
91+
* `Gradient`
92+
* `InputTimesGradient`
93+
* `SmoothGrad`
94+
* `IntegratedGradients`
95+
* `LRP`
96+
* Rules
97+
* `ZeroRule`
98+
* `EpsilonRule`
99+
* `GammaRule`
100+
* `GeneralizedGammaRule`
101+
* `WSquareRule`
102+
* `FlatRule`
103+
* `ZBoxRule`
104+
* `ZPlusRule`
105+
* `AlphaBetaRule`
106+
* `PassRule`
107+
* Composites
108+
* `EpsilonGammaBox`
109+
* `EpsilonPlus`
110+
* `EpsilonPlusFlat`
111+
* `EpsilonAlpha2Beta1`
112+
* `EpsilonAlpha2Beta1Flat`
113+
* `CRP`
114114

115115
One of the design goals of ExplainableAI.jl is extensibility.
116116
Custom [composites][docs-composites] are easily defined

docs/src/literate/augmentations.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,14 @@ heatmap(input, analyzer)
7171
analyzer = IntegratedGradients(model, 50)
7272
heatmap(input, analyzer)
7373

74-
# To select a different reference input, pass it to the `analyze` or `heatmap` function
74+
# To select a different reference input, pass it to the `analyze` function
7575
# using the keyword argument `input_ref`.
7676
# Note that this is an arbitrary example for the sake of demonstration.
7777
matrix_of_ones = ones(Float32, size(input))
7878

7979
analyzer = InterpolationAugmentation(Gradient(model), 50)
80-
heatmap(input, analyzer; input_ref=matrix_of_ones)
80+
expl = analyzer(input; input_ref=matrix_of_ones)
81+
heatmap(expl)
8182

8283
# Once again, `InterpolationAugmentation` can be combined with any analyzer type,
8384
# for example [`LRP`](@ref):

docs/src/lrp/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ LRP_CONFIG.supports_layer
9494
LRP_CONFIG.supports_activation
9595
```
9696

97+
# CRP
98+
```@docs
99+
CRP
100+
TopNConcepts
101+
IndexedConcepts
102+
```
103+
97104
# Index
98105
```@index
99106
```

ext/TullioLRPRulesExt.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ import ExplainableAI: ZeroRule, EpsilonRule, GammaRule, WSquareRule
66

77
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
88
for R in (ZeroRule, EpsilonRule, GammaRule)
9-
@eval function lrp!(Rᵏ, rule::$R, layer::Dense, modified_layer, aᵏ, Rᵏ⁺¹)
10-
layer = isnothing(modified_layer) ? layer : modified_layer
9+
@eval function lrp!(Rᵏ, rule::$R, _layer::Dense, modified_layer, aᵏ, Rᵏ⁺¹)
1110
ãᵏ = modify_input(rule, aᵏ)
12-
z = modify_denominator(rule, layer(ãᵏ))
13-
@tullio Rᵏ[j, b] = layer.weight[i, j] * ãᵏ[j, b] / z[i, b] * Rᵏ⁺¹[i, b]
11+
z = modify_denominator(rule, modified_layer(ãᵏ))
12+
@tullio Rᵏ[j, b] = modified_layer.weight[i, j] * ãᵏ[j, b] / z[i, b] * Rᵏ⁺¹[i, b]
1413
end
1514
end
1615

src/ExplainableAI.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include("lrp/composite.jl")
3030
include("lrp/lrp.jl")
3131
include("lrp/show.jl")
3232
include("lrp/composite_presets.jl") # uses lrp/show.jl
33+
include("lrp/crp.jl")
3334
include("heatmap.jl")
3435
include("preprocessing.jl")
3536
export analyze
@@ -61,6 +62,9 @@ export EpsilonAlpha2Beta1Flat
6162
# Useful type unions
6263
export ConvLayer, PoolingLayer, DropoutLayer, ReshapingLayer, NormalizationLayer
6364

65+
# CRP
66+
export CRP, TopNConcepts, IndexedConcepts
67+
6468
# heatmapping
6569
export heatmap
6670

src/heatmap.jl

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
44
# Analyzer => (colorscheme, reduce, rangescale)
55
:LRP => (ColorSchemes.seismic, :sum, :centered), # attribution
6+
:CRP => (ColorSchemes.seismic, :sum, :centered), # attribution
67
:InputTimesGradient => (ColorSchemes.seismic, :sum, :centered), # attribution
78
:Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
89
)
@@ -36,70 +37,75 @@ See also [`analyze`](@ref).
3637
- `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
3738
- `unpack_singleton::Bool`: When heatmapping a batch with a single sample, setting `unpack_singleton=true`
3839
will return an image instead of an Vector containing a single image.
39-
40-
**Note:** keyword arguments can't be used when calling `heatmap` with an analyzer.
40+
- `process_batch::Bool`: When heatmapping a batch, setting `process_batch=true`
41+
will apply the color channel reduction and normalization to the entire batch
42+
instead of computing it individually for each sample. Defaults to `false`.
4143
"""
4244
function heatmap(
43-
attr::AbstractArray{T,N};
45+
val::AbstractArray{T,N};
4446
cs::ColorScheme=ColorSchemes.seismic,
4547
reduce::Symbol=:sum,
4648
rangescale::Symbol=:centered,
4749
permute::Bool=true,
4850
unpack_singleton::Bool=true,
51+
process_batch::Bool=false,
4952
) where {T,N}
5053
N != 4 && throw(
51-
DomainError(
52-
N,
53-
"""heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
54-
Please reshape your explanation to match this format if your model doesn't adhere to this convention.""",
54+
ArgumentError(
55+
"heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
56+
Please reshape your explanation to match this format if your model doesn't adhere to this convention.",
5557
),
5658
)
57-
if unpack_singleton && size(attr, 4) == 1
58-
return _heatmap(attr[:, :, :, 1], cs, reduce, rangescale, permute)
59+
if unpack_singleton && size(val, 4) == 1
60+
return _heatmap(val[:, :, :, 1], cs, reduce, rangescale, permute)
61+
end
62+
if process_batch
63+
hs = _heatmap(val, cs, reduce, rangescale, permute)
64+
return [hs[:, :, i] for i in axes(hs, 3)]
5965
end
60-
return map(a -> _heatmap(a, cs, reduce, rangescale, permute), eachslice(attr; dims=4))
66+
return [_heatmap(v, cs, reduce, rangescale, permute) for v in eachslice(val; dims=4)]
6167
end
6268

6369
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
64-
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
70+
function heatmap(expl::Explanation; kwargs...)
6571
_cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl.analyzer]
6672
return heatmap(
6773
expl.val;
6874
reduce=get(kwargs, :reduce, _reduce),
6975
rangescale=get(kwargs, :rangescale, _rangescale),
7076
cs=get(kwargs, :cs, _cs),
71-
permute=permute,
77+
kwargs...,
7278
)
7379
end
7480
# Analyze & heatmap in one go
7581
function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
76-
return heatmap(analyze(input, analyzer, args...; kwargs...))
82+
expl = analyze(input, analyzer, args...)
83+
return heatmap(expl; kwargs...)
7784
end
7885

79-
# Lower level function that is mapped along batch dimension
80-
function _heatmap(
81-
attr::AbstractArray{T,3},
82-
cs::ColorScheme,
83-
reduce::Symbol,
84-
rangescale::Symbol,
85-
permute::Bool,
86-
) where {T<:Real}
87-
img = dropdims(_reduce(attr, reduce); dims=3)
88-
permute && (img = permutedims(img))
86+
# Lower level function that can be mapped along batch dimension
87+
function _heatmap(val, cs::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool)
88+
img = dropdims(reduce_color_channel(val, reduce); dims=3)
89+
permute && (img = flip_wh(img))
8990
return ColorSchemes.get(cs, img, rangescale)
9091
end
9192

93+
flip_wh(img::AbstractArray{T,2}) where {T} = permutedims(img, (2, 1))
94+
flip_wh(img::AbstractArray{T,3}) where {T} = permutedims(img, (2, 1, 3))
95+
9296
# Reduce explanations across color channels into a single scalar – assumes WHCN convention
93-
function _reduce(attr::AbstractArray{T,3}, method::Symbol) where {T}
94-
if size(attr, 3) == 1 # nothing to reduce
95-
return attr
97+
function reduce_color_channel(val::AbstractArray, method::Symbol)
98+
init = zero(eltype(val))
99+
if size(val, 3) == 1 # nothing to reduce
100+
return val
96101
elseif method == :sum
97-
return reduce(+, attr; dims=3)
102+
return reduce(+, val; dims=3)
98103
elseif method == :maxabs
99-
return reduce((c...) -> maximum(abs.(c)), attr; dims=3, init=zero(T))
104+
return reduce((c...) -> maximum(abs.(c)), val; dims=3, init=init)
100105
elseif method == :norm
101-
return reduce((c...) -> sqrt(sum(c .^ 2)), attr; dims=3, init=zero(T))
106+
return reduce((c...) -> sqrt(sum(c .^ 2)), val; dims=3, init=init)
102107
end
108+
103109
throw(
104110
ArgumentError(
105111
"Color channel reducer :$method not supported, `reduce` should be :maxabs, :sum or :norm",

0 commit comments

Comments
 (0)