Skip to content

Commit bd0e400

Browse files
authored
Fix default noise level for NoiseAugmentation (#179)
1 parent 61cf3f4 commit bd0e400

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ExplainableAI"
22
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
33
authors = ["Adrian Hill <gh@adrianhill.de>"]
4-
version = "0.9.0"
4+
version = "0.10.0-DEV"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ function call_analyzer(
7575
end
7676

7777
"""
78-
SmoothGrad(analyzer, [n=50, std=0.1, rng=GLOBAL_RNG])
79-
SmoothGrad(analyzer, [n=50, distribution=Normal(0, σ²=0.01), rng=GLOBAL_RNG])
78+
SmoothGrad(analyzer, [n=50, std=1.0f0, rng=GLOBAL_RNG])
79+
SmoothGrad(analyzer, [n=50, distribution=Normal(0.0f0, 1.0f0), rng=GLOBAL_RNG])
8080
8181
Analyze model by calculating a smoothed sensitivity map.
8282
This is done by averaging sensitivity maps of a `Gradient` analyzer over random samples

src/input_augmentation.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656
"""
5757
augment_indices(indices, n)
5858
59-
Strip batch indices and return inidices for batch augmented by n samples.
59+
Strip batch indices and return indices for batch augmented by n samples.
6060
6161
## Example
6262
```julia-repl
@@ -83,11 +83,20 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
8383
end
8484

8585
"""
86-
NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
87-
NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
86+
NoiseAugmentation(analyzer, n)
87+
NoiseAugmentation(analyzer, n, std::Real)
88+
NoiseAugmentation(analyzer, n, distribution::Sampleable)
8889
89-
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
90+
A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from a scalar `distribution`.
9091
This input augmentation is then averaged to return an `Explanation`.
92+
93+
Defaults to the normal distribution `Normal(0, std^2)` with `std=1.0f0`.
94+
For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of each sample,
95+
e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
96+
97+
## Keyword arguments
98+
- `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
99+
Defaults to `GLOBAL_RNG`.
91100
"""
92101
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
93102
AbstractXAIMethod
@@ -96,11 +105,11 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
96105
distribution::D
97106
rng::R
98107
end
99-
function NoiseAugmentation(analyzer, n, distr::Sampleable, rng=GLOBAL_RNG)
100-
return NoiseAugmentation(analyzer, n, distr::Sampleable, rng)
108+
function NoiseAugmentation(analyzer, n, distribution::Sampleable, rng=GLOBAL_RNG)
109+
return NoiseAugmentation(analyzer, n, distribution::Sampleable, rng)
101110
end
102-
function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
103-
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
111+
function NoiseAugmentation(analyzer, n, std::T=1.0f0, rng=GLOBAL_RNG) where {T<:Real}
112+
return NoiseAugmentation(analyzer, n, Normal(zero(T), std^2), rng)
104113
end
105114

106115
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)

0 commit comments

Comments
 (0)