Skip to content

Commit abe02c0

Browse files
authored
Add SmoothGrad and InputAugmentation (#50)
1 parent 55c6e8e commit abe02c0

18 files changed

+261
-18
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@ version = "0.3.1"
55

66
[deps]
77
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
8+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1213
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1315
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1416
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1517
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1618

1719
[compat]
1820
ColorSchemes = "3"
21+
Distributions = "0.25"
1922
Flux = "0.12"
2023
ImageCore = "0.8, 0.9"
2124
PrettyTables = "1"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Currently, the following analyzers are implemented:
5454
```
5555
├── Gradient
5656
├── InputTimesGradient
57+
├── SmoothGrad
5758
└── LRP
5859
├── LRPZero
5960
├── LRPEpsilon
@@ -65,7 +66,6 @@ Individual LRP rules like `ZeroRule`, `EpsilonRule`, `GammaRule` and `ZBoxRule`
6566

6667
## Roadmap
6768
In the future, we would like to include:
68-
- [SmoothGrad](https://arxiv.org/abs/1706.03825)
6969
- [Integrated Gradients](https://arxiv.org/abs/1703.01365)
7070
- [PatternNet](https://arxiv.org/abs/1705.05598)
7171
- [DeepLift](https://arxiv.org/abs/1704.02685)

benchmark/benchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ algs = Dict(
2121
"InputTimesGradient" => InputTimesGradient,
2222
"LRPZero" => LRPZero,
2323
"LRPCustom" => LRPCustom, #modifies weights
24+
"SmoothGrad" => model -> SmoothGrad(model, 10),
2425
)
2526

2627
# Define benchmark

docs/literate/example.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ mosaic(heatmap(batch, analyzer, 1); nrow=10)
105105
# ```
106106
# ├── Gradient
107107
# ├── InputTimesGradient
108+
# ├── SmoothGrad
108109
# └── LRP
109110
# ├── LRPZero
110111
# ├── LRPEpsilon

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ heatmap
1010
LRP
1111
Gradient
1212
InputTimesGradient
13+
SmoothGrad
14+
```
15+
16+
`SmoothGrad` is a special case of `InputAugmentation`, which can be applied as a wrapper to any analyzer:
17+
```@doc
18+
InputAugmentation
1319
```
1420

1521
# LRP

src/ExplainableAI.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module ExplainableAI
22

33
using Base.Iterators
44
using LinearAlgebra
5+
using Distributions
6+
using Random: AbstractRNG, GLOBAL_RNG
57
using Flux
68
using Zygote
79
using Tullio
@@ -18,6 +20,7 @@ include("neuron_selection.jl")
1820
include("analyze_api.jl")
1921
include("flux.jl")
2022
include("utils.jl")
23+
include("input_augmentation.jl")
2124
include("gradient.jl")
2225
include("lrp_checks.jl")
2326
include("lrp_rules.jl")
@@ -29,6 +32,7 @@ export analyze
2932
# Analyzers
3033
export AbstractXAIMethod
3134
export Gradient, InputTimesGradient
35+
export InputAugmentation, SmoothGrad
3236
export LRP, LRPZero, LRPEpsilon, LRPGamma
3337

3438
# LRP rules

src/gradient.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function gradient_wrt_input(model, input::T, output_indices) where {T}
1+
function gradient_wrt_input(model, input, output_indices)
22
return only(gradient((in) -> model(in)[output_indices], input))
33
end
44

@@ -8,7 +8,7 @@ function gradients_wrt_batch(model, input::AbstractArray{T,N}, output_indices) w
88
return mapreduce(
99
(gs...) -> cat(gs...; dims=N), zip(eachslice(input; dims=N), output_indices)
1010
) do (in, idx)
11-
gradient_wrt_input(model, batch_dim_view(in), drop_batch_dim(idx))
11+
gradient_wrt_input(model, batch_dim_view(in), drop_batch_index(idx))
1212
end
1313
end
1414

@@ -46,3 +46,16 @@ function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
4646
attr = input .* gradients_wrt_batch(analyzer.model, input, output_indices)
4747
return Explanation(attr, output, output_indices, :InputTimesGradient, Nothing)
4848
end
49+
50+
"""
51+
SmoothGrad(analyzer, [n=50, std=0.1, rng=GLOBAL_RNG])
52+
SmoothGrad(analyzer, [n=50, distribution=Normal(0, σ²=0.01), rng=GLOBAL_RNG])
53+
54+
Analyze model by calculating a smoothed sensitivity map.
55+
This is done by averaging sensitivity maps of a `Gradient` analyzer over random samples
56+
in a neighborhood of the input, typically by adding Gaussian noise with mean 0.
57+
58+
# References
59+
[1] Smilkov et al., SmoothGrad: removing noise by adding noise
60+
"""
61+
SmoothGrad(model, n=50, args...) = InputAugmentation(Gradient(model), n, args...)

src/input_augmentation.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
InputAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
3+
InputAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
4+
5+
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
6+
This input augmentation is then averaged to return an `Explanation`.
7+
"""
8+
struct InputAugmentation{A<:AbstractXAIMethod,D<:Distribution,R<:AbstractRNG} <:
9+
AbstractXAIMethod
10+
analyzer::A
11+
n::Integer
12+
distribution::D
13+
rng::R
14+
end
15+
function InputAugmentation(analyzer, n, distr, rng=GLOBAL_RNG)
16+
return InputAugmentation(analyzer, n, distr, rng)
17+
end
18+
function InputAugmentation(analyzer, n, σ::Real=0.1f0, args...)
19+
return InputAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
20+
end
21+
22+
function (aug::InputAugmentation)(input, ns::AbstractNeuronSelector)
23+
# Regular forward pass of model
24+
output = aug.analyzer.model(input)
25+
output_indices = ns(output)
26+
27+
# Call regular analyzer on augmented batch
28+
augmented_input = add_noise(augment_batch_dim(input, aug.n), aug.distribution, aug.rng)
29+
augmented_indices = augment_indices(output_indices, aug.n)
30+
augmented_expl = aug.analyzer(augmented_input, AugmentationSelector(augmented_indices))
31+
32+
# Average explanation
33+
return Explanation(
34+
reduce_augmentation(augmented_expl.attribution, aug.n),
35+
output,
36+
output_indices,
37+
augmented_expl.analyzer,
38+
Nothing,
39+
)
40+
end
41+
42+
function add_noise(A::AbstractArray{T}, distr::Distribution, rng::AbstractRNG) where {T}
43+
return A + T.(rand(rng, distr, size(A)))
44+
end
45+
46+
"""
47+
augment_batch_dim(input, n)
48+
49+
Repeat each sample in input batch n-times along batch dimension.
50+
This turns arrays of size `(..., B)` into arrays of size `(..., B*n)`.
51+
52+
## Example
53+
```julia-repl
54+
julia> A = [1 2; 3 4]
55+
2×2 Matrix{Int64}:
56+
1 2
57+
3 4
58+
59+
julia> augment_batch_dim(A, 3)
60+
2×6 Matrix{Int64}:
61+
1 1 1 2 2 2
62+
3 3 3 4 4 4
63+
```
64+
"""
65+
function augment_batch_dim(input::AbstractArray{T,N}, n) where {T,N}
66+
return repeat(input; inner=(ntuple(_ -> 1, Val(N - 1))..., n))
67+
end
68+
69+
"""
70+
reduce_augmentation(augmented_input, n)
71+
72+
Reduce augmented input batch by averaging the explanation for each augmented sample.
73+
"""
74+
function reduce_augmentation(input::AbstractArray{T,N}, n) where {T<:AbstractFloat,N}
75+
return cat(
76+
(
77+
Iterators.map(1:n:size(input, N)) do i
78+
augmentation_range = ntuple(_ -> :, Val(N - 1))..., i:(i + n - 1)
79+
sum(view(input, augmentation_range...); dims=N) / n
80+
end
81+
)...; dims=N
82+
)::Array{T,N}
83+
end
84+
"""
85+
augment_indices(indices, n)
86+
87+
Strip batch indices and return inidices for batch augmented by n samples.
88+
89+
## Example
90+
```julia-repl
91+
julia> inds = [CartesianIndex(5,1), CartesianIndex(3,2)]
92+
2-element Vector{CartesianIndex{2}}:
93+
CartesianIndex(5, 1)
94+
CartesianIndex(3, 2)
95+
96+
julia> augment_indices(inds, 3)
97+
6-element Vector{CartesianIndex{2}}:
98+
CartesianIndex(5, 1)
99+
CartesianIndex(5, 2)
100+
CartesianIndex(5, 3)
101+
CartesianIndex(3, 4)
102+
CartesianIndex(3, 5)
103+
CartesianIndex(3, 6)
104+
```
105+
"""
106+
function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
107+
indices_wo_batch = [i.I[1:(end - 1)] for i in inds]
108+
return map(enumerate(repeat(indices_wo_batch; inner=n))) do (i, idx)
109+
CartesianIndex{N}(idx..., i)
110+
end
111+
end

src/neuron_selection.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,13 @@ function (s::IndexSelector{I})(out::AbstractArray{T,N}) where {I,T,N}
2727
N < 2 && throw(BATCHDIM_MISSING)
2828
return CartesianIndex{N}.(s.index..., 1:size(out, N))
2929
end
30+
31+
"""
32+
AugmentationSelector(index)
33+
34+
Neuron selector that passes through an augmented neuron selection.
35+
"""
36+
struct AugmentationSelector{I} <: AbstractNeuronSelector
37+
indices::I
38+
end
39+
(s::AugmentationSelector)(out) = s.indices

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ julia> batch_dim_view(A)
4242
batch_dim_view(A::AbstractArray{T,N}) where {T,N} = view(A, ntuple(_ -> :, Val(N + 1))...)
4343

4444
"""
45-
drop_batch_dim(I)
45+
drop_batch_index(I)
4646
4747
Drop batch dimension index (last value) from CartesianIndex.
4848
4949
## Example
50-
julia> drop_batch_dim(CartesianIndex(5,3,2))
50+
julia> drop_batch_index(CartesianIndex(5,3,2))
5151
CartesianIndex(5, 3)
5252
"""
53-
drop_batch_dim(C::CartesianIndex) = CartesianIndex(C.I[1:(end - 1)])
53+
drop_batch_index(C::CartesianIndex) = CartesianIndex(C.I[1:(end - 1)])
5454

5555
# Utils for printing model check summary using PrettyTable.jl
5656
_print_name(layer) = "$layer"

0 commit comments

Comments
 (0)