|
3 | 3 | const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
|
4 | 4 | # Analyzer => (colorscheme, reduce, rangescale)
|
5 | 5 | :LRP => (ColorSchemes.seismic, :sum, :centered), # attribution
|
| 6 | + :CRP => (ColorSchemes.seismic, :sum, :centered), # attribution |
6 | 7 | :InputTimesGradient => (ColorSchemes.seismic, :sum, :centered), # attribution
|
7 | 8 | :Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
|
8 | 9 | )
|
@@ -36,70 +37,75 @@ See also [`analyze`](@ref).
|
36 | 37 | - `permute::Bool`: Whether to flip W&H input channels. Default is `true`.
|
37 | 38 | - `unpack_singleton::Bool`: When heatmapping a batch with a single sample, setting `unpack_singleton=true`
|
38 | 39 | will return an image instead of an Vector containing a single image.
|
39 |
| -
|
40 |
| -**Note:** keyword arguments can't be used when calling `heatmap` with an analyzer. |
| 40 | +- `process_batch::Bool`: When heatmapping a batch, setting `process_batch=true` |
| 41 | + will apply the color channel reduction and normalization to the entire batch |
| 42 | + instead of computing it individually for each sample. Defaults to `false`. |
41 | 43 | """
|
42 | 44 | function heatmap(
|
43 |
| - attr::AbstractArray{T,N}; |
| 45 | + val::AbstractArray{T,N}; |
44 | 46 | cs::ColorScheme=ColorSchemes.seismic,
|
45 | 47 | reduce::Symbol=:sum,
|
46 | 48 | rangescale::Symbol=:centered,
|
47 | 49 | permute::Bool=true,
|
48 | 50 | unpack_singleton::Bool=true,
|
| 51 | + process_batch::Bool=false, |
49 | 52 | ) where {T,N}
|
50 | 53 | N != 4 && throw(
|
51 |
| - DomainError( |
52 |
| - N, |
53 |
| - """heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input. |
54 |
| - Please reshape your explanation to match this format if your model doesn't adhere to this convention.""", |
| 54 | + ArgumentError( |
| 55 | + "heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input. |
| 56 | + Please reshape your explanation to match this format if your model doesn't adhere to this convention.", |
55 | 57 | ),
|
56 | 58 | )
|
57 |
| - if unpack_singleton && size(attr, 4) == 1 |
58 |
| - return _heatmap(attr[:, :, :, 1], cs, reduce, rangescale, permute) |
| 59 | + if unpack_singleton && size(val, 4) == 1 |
| 60 | + return _heatmap(val[:, :, :, 1], cs, reduce, rangescale, permute) |
| 61 | + end |
| 62 | + if process_batch |
| 63 | + hs = _heatmap(val, cs, reduce, rangescale, permute) |
| 64 | + return [hs[:, :, i] for i in axes(hs, 3)] |
59 | 65 | end
|
60 |
| - return map(a -> _heatmap(a, cs, reduce, rangescale, permute), eachslice(attr; dims=4)) |
| 66 | + return [_heatmap(v, cs, reduce, rangescale, permute) for v in eachslice(val; dims=4)] |
61 | 67 | end
|
62 | 68 |
|
63 | 69 | # Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
|
64 |
| -function heatmap(expl::Explanation; permute::Bool=true, kwargs...) |
| 70 | +function heatmap(expl::Explanation; kwargs...) |
65 | 71 | _cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl.analyzer]
|
66 | 72 | return heatmap(
|
67 | 73 | expl.val;
|
68 | 74 | reduce=get(kwargs, :reduce, _reduce),
|
69 | 75 | rangescale=get(kwargs, :rangescale, _rangescale),
|
70 | 76 | cs=get(kwargs, :cs, _cs),
|
71 |
| - permute=permute, |
| 77 | + kwargs..., |
72 | 78 | )
|
73 | 79 | end
|
74 | 80 | # Analyze & heatmap in one go
|
75 | 81 | function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
|
76 |
| - return heatmap(analyze(input, analyzer, args...; kwargs...)) |
| 82 | + expl = analyze(input, analyzer, args...) |
| 83 | + return heatmap(expl; kwargs...) |
77 | 84 | end
|
78 | 85 |
|
79 |
| -# Lower level function that is mapped along batch dimension |
80 |
| -function _heatmap( |
81 |
| - attr::AbstractArray{T,3}, |
82 |
| - cs::ColorScheme, |
83 |
| - reduce::Symbol, |
84 |
| - rangescale::Symbol, |
85 |
| - permute::Bool, |
86 |
| -) where {T<:Real} |
87 |
| - img = dropdims(_reduce(attr, reduce); dims=3) |
88 |
| - permute && (img = permutedims(img)) |
| 86 | +# Lower level function that can be mapped along batch dimension |
| 87 | +function _heatmap(val, cs::ColorScheme, reduce::Symbol, rangescale::Symbol, permute::Bool) |
| 88 | + img = dropdims(reduce_color_channel(val, reduce); dims=3) |
| 89 | + permute && (img = flip_wh(img)) |
89 | 90 | return ColorSchemes.get(cs, img, rangescale)
|
90 | 91 | end
|
91 | 92 |
|
| 93 | +flip_wh(img::AbstractArray{T,2}) where {T} = permutedims(img, (2, 1)) |
| 94 | +flip_wh(img::AbstractArray{T,3}) where {T} = permutedims(img, (2, 1, 3)) |
| 95 | + |
92 | 96 | # Reduce explanations across color channels into a single scalar – assumes WHCN convention
|
93 |
| -function _reduce(attr::AbstractArray{T,3}, method::Symbol) where {T} |
94 |
| - if size(attr, 3) == 1 # nothing to reduce |
95 |
| - return attr |
| 97 | +function reduce_color_channel(val::AbstractArray, method::Symbol) |
| 98 | + init = zero(eltype(val)) |
| 99 | + if size(val, 3) == 1 # nothing to reduce |
| 100 | + return val |
96 | 101 | elseif method == :sum
|
97 |
| - return reduce(+, attr; dims=3) |
| 102 | + return reduce(+, val; dims=3) |
98 | 103 | elseif method == :maxabs
|
99 |
| - return reduce((c...) -> maximum(abs.(c)), attr; dims=3, init=zero(T)) |
| 104 | + return reduce((c...) -> maximum(abs.(c)), val; dims=3, init=init) |
100 | 105 | elseif method == :norm
|
101 |
| - return reduce((c...) -> sqrt(sum(c .^ 2)), attr; dims=3, init=zero(T)) |
| 106 | + return reduce((c...) -> sqrt(sum(c .^ 2)), val; dims=3, init=init) |
102 | 107 | end
|
| 108 | + |
103 | 109 | throw(
|
104 | 110 | ArgumentError(
|
105 | 111 | "Color channel reducer :$method not supported, `reduce` should be :maxabs, :sum or :norm",
|
|
0 commit comments