Skip to content

Commit 67ab1c9

Browse files
authored
Document CRP (#148)
* Document use of CRP * Document `heatmap` kwarg `process_batch` * Define `show` on `IndexedConcepts`
1 parent 3ede2af commit 67ab1c9

File tree

4 files changed

+86
-8
lines changed

4 files changed

+86
-8
lines changed

docs/make.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ makedocs(;
3535
"Input augmentations" => "generated/augmentations.md",
3636
],
3737
"LRP" => Any[
38-
"Basic usage" => "generated/lrp/basics.md",
39-
"Assigning rules to layers" => "generated/lrp/composites.md",
40-
"Supporting new layer types" => "generated/lrp/custom_layer.md",
41-
"Custom LRP rules" => "generated/lrp/custom_rules.md",
42-
"Developer documentation" => "lrp/developer.md"
38+
"Basic usage" => "generated/lrp/basics.md",
39+
"Assigning rules to layers" => "generated/lrp/composites.md",
40+
"Supporting new layer types" => "generated/lrp/custom_layer.md",
41+
"Custom LRP rules" => "generated/lrp/custom_rules.md",
42+
"Concept Relevance Propagation" => "generated/lrp/crp.md",
43+
"Developer documentation" => "lrp/developer.md"
4344
],
4445
"API Reference" => Any[
4546
"General" => "api.md",
@@ -48,7 +49,10 @@ makedocs(;
4849
],
4950
#! format: on
5051
linkcheck=true,
51-
linkcheck_ignore=[r"https://link.springer.com/chapter/10.1007/978-3-030-28954-6_10"],
52+
linkcheck_ignore=[
53+
r"https://link.springer.com/chapter/10.1007/978-3-030-28954-6_10",
54+
r"https://www.nature.com/articles/s42256-023-00711-8",
55+
],
5256
checkdocs=:exports, # only check docstrings in API reference if they are exported
5357
)
5458

docs/src/literate/heatmapping.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ heatmap(expl; reduce=:norm)
6464
#-
6565
heatmap(expl; reduce=:maxabs)
6666

67-
# Since MNIST only has a single color channel, there is no need for reduction
68-
# and heatmaps look identical.
67+
# In this example, the heatmaps look identical.
68+
# Since MNIST only has a single color channel, there is no need for color channel reduction.
6969

7070
# ### [Mapping explanations onto the color scheme](@id docs-heatmap-rangescale)
7171
# To map a [color-channel-reduced](@ref docs-heatmap-reduce) explanation onto a color scheme,
@@ -103,6 +103,20 @@ heatmaps = heatmap(batch, analyzer)
103103
# Image.jl's `mosaic` function can used to display them in a grid:
104104
mosaic(heatmaps; nrow=10)
105105

106+
# When heatmapping batches, the mapping to the color scheme is applied per sample.
107+
# For example, `rangescale=:extrema` will normalize each heatmap
108+
# to the minimum and maximum value of each sample in the batch.
109+
# This ensures that heatmaps don't depend on other samples in the batch.
110+
#
111+
# If this bevahior is not desired,
112+
# `heatmap` can be called with the keyword-argument `process_batch=true`:
113+
heatmaps = heatmap(batch, analyzer; process_batch=true)
114+
mosaic(heatmaps; nrow=10)
115+
116+
# This can be useful when comparing heatmaps for fixed output neurons:
117+
heatmaps = heatmap(batch, analyzer, 7; process_batch=true) # heatmaps for digit "6"
118+
mosaic(heatmaps; nrow=10)
119+
106120
#md # !!! note "Output type consistency"
107121
#md #
108122
#md # To obtain a singleton `Vector` containing a single heatmap for non-batched inputs,

docs/src/literate/lrp/crp.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# # [Concept relevance propagation](@id docs-crp)
2+
# In [*From attribution maps to human-understandable explanations through Concept Relevance Propagation*](https://www.nature.com/articles/s42256-023-00711-8) (CRP),
3+
# Achtibat et al. propose the conditioning of LRP relevances on individual features of a model.
4+
#
5+
# This example builds on the basics shown in the [*Getting started*](@ref docs-getting-started) section.
6+
# We start out by loading the same pre-trained LeNet5 model and MNIST input data:
7+
using ExplainableAI
8+
using Flux
9+
10+
using BSON # hide
11+
model = BSON.load("../../model.bson", @__MODULE__)[:model] # hide
12+
model
13+
#-
14+
using MLDatasets
15+
using ImageCore, ImageIO, ImageShow
16+
17+
index = 10
18+
x, y = MNIST(Float32, :test)[10]
19+
input = reshape(x, 28, 28, 1, :)
20+
21+
convert2image(MNIST, x)
22+
23+
# ## Step 1: Create LRP analyzer
24+
# To create a CRP analyzer, first define an LRP analyzer with your desired rules:
25+
composite = EpsilonPlusFlat()
26+
lrp_analyzer = LRP(model, composite)
27+
28+
# ## Step 2: Define concepts
29+
# Then, specify the index of the layer on the outputs of which you want to condition the explanation.
30+
# In this example, we are interested in the outputs of the last convolutional layer, layer 3:
31+
concept_layer = 3 # index of relevant layer in model
32+
model[concept_layer] # show layer
33+
34+
# Then, specify the concepts you are interested in.
35+
# To automatically select the $n$ most relevant concepts, use [`TopNConcepts`](@ref).
36+
#
37+
# Note that for convolutional layers,
38+
# a feature corresponds to an entire output channel of the layer.
39+
concepts = TopNConcepts(5)
40+
41+
# To manually specify features, use [`IndexedConcepts`](@ref).
42+
concepts = IndexedConcepts(1, 2, 10)
43+
44+
# ## Step 3: Use CRP analyzer
45+
# We can now create a [`CRP`](@ref) analyzer
46+
# and use it like any other analyzer from ExplainableAI.jl:
47+
analyzer = CRP(lrp_analyzer, concept_layer, concepts)
48+
heatmap(input, analyzer)
49+
50+
# ## Using CRP on input batches
51+
# Note that `CRP` uses the batch dimension to return explanations.
52+
# When using CRP on batches, the explanations are first sorted by concepts, then inputs,
53+
# e.g. `[c1_i1, c1_i2, c2_i1, c2_i2, c3_i1, c3_i2]` in the following example:
54+
x, y = MNIST(Float32, :test)[10:11]
55+
batch = reshape(x, 28, 28, 1, :)
56+
57+
heatmap(batch, analyzer)

src/lrp/crp.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ IndexedConcepts(args...) = IndexedConcepts(tuple(args...))
130130

131131
number_of_concepts(c::IndexedConcepts) = length(c.inds)
132132

133+
# Pretty printing
134+
Base.show(io::IO, c::IndexedConcepts) = print(io, "IndexedConcepts$(c.inds)")
135+
133136
# Index concepts on 2D arrays, e.g. Dense layers with batch dimension
134137
function (c::IndexedConcepts)(A::AbstractMatrix)
135138
batchsize = size(A, 2)

0 commit comments

Comments
 (0)