56
56
"""
57
57
augment_indices(indices, n)
58
58
59
- Strip batch indices and return inidices for batch augmented by n samples.
59
+ Strip batch indices and return indices for batch augmented by n samples.
60
60
61
61
## Example
62
62
```julia-repl
@@ -83,11 +83,20 @@ function augment_indices(inds::Vector{CartesianIndex{N}}, n) where {N}
83
83
end
84
84
85
85
"""
86
- NoiseAugmentation(analyzer, n, [std=1, rng=GLOBAL_RNG])
87
- NoiseAugmentation(analyzer, n, distribution, [rng=GLOBAL_RNG])
86
+ NoiseAugmentation(analyzer, n)
87
+ NoiseAugmentation(analyzer, n, std::Real)
88
+ NoiseAugmentation(analyzer, n, distribution::Sampleable)
88
89
89
- A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from `distribution`.
90
+ A wrapper around analyzers that augments the input with `n` samples of additive noise sampled from a scalar `distribution`.
90
91
This input augmentation is then averaged to return an `Explanation`.
92
+
93
+ Defaults to the normal distribution `Normal(0, std^2)` with `std=1.0f0`.
94
+ For optimal results, $REF_SMILKOV_SMOOTHGRAD recommends setting `std` between 10% and 20% of the input range of each sample,
95
+ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
96
+
97
+ ## Keyword arguments
98
+ - `rng::AbstractRNG`: Specify the random number generator that is used to sample noise from the `distribution`.
99
+ Defaults to `GLOBAL_RNG`.
91
100
"""
92
101
struct NoiseAugmentation{A<: AbstractXAIMethod ,D<: Sampleable ,R<: AbstractRNG } < :
93
102
AbstractXAIMethod
@@ -96,11 +105,11 @@ struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
96
105
distribution:: D
97
106
rng:: R
98
107
end
99
- function NoiseAugmentation (analyzer, n, distr :: Sampleable , rng= GLOBAL_RNG)
100
- return NoiseAugmentation (analyzer, n, distr :: Sampleable , rng)
108
+ function NoiseAugmentation (analyzer, n, distribution :: Sampleable , rng= GLOBAL_RNG)
109
+ return NoiseAugmentation (analyzer, n, distribution :: Sampleable , rng)
101
110
end
102
- function NoiseAugmentation (analyzer, n, σ :: Real = 0.1f0 , args ... )
103
- return NoiseAugmentation (analyzer, n, Normal (0.0f0 , Float32 (σ) ^ 2 ), args ... )
111
+ function NoiseAugmentation (analyzer, n, std :: T = 1.0f0 , rng = GLOBAL_RNG) where {T <: Real }
112
+ return NoiseAugmentation (analyzer, n, Normal (zero (T), std ^ 2 ), rng )
104
113
end
105
114
106
115
function call_analyzer (input, aug:: NoiseAugmentation , ns:: AbstractOutputSelector ; kwargs... )
0 commit comments