Skip to content

Commit 5a389f3

Browse files
Move Shapley fix from #167 (#168)
* Update shapley_sensitivity.jl * Update shapley_method.jl
1 parent 394dc74 commit 5a389f3

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/shapley_sensitivity.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ function gsa(f, method::Shapley, input_distribution::SklarDist; batch = false)
203203
sample_complement = rand(
204204
Copulas.subsetdims(input_distribution, idx_minus), n_outer)
205205

206+
if size(sample_complement, 2) == 1
207+
sample_complement = reshape(
208+
sample_complement, (1, length(sample_complement)))
209+
end
210+
206211
for l in 1:n_outer
207212
curr_sample = @view sample_complement[:, l]
208213
# Sampling of the set conditionally to the complementary element

test/shapley_method.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,31 @@ n_perms = -1;
2525
n_var = 10_000;
2626
n_outer = 1000;
2727
n_inner = 3;
28-
dim = 3;
29-
margins = (Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi));
28+
dim = 4;
29+
margins = (Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi), Uniform(-pi, pi));
3030
dependency_matrix = Matrix(4 * I, dim, dim);
3131
C = GaussianCopula(dependency_matrix);
3232
input_distribution = SklarDist(C, margins);
33-
3433
method = Shapley(n_perms = n_perms,
3534
n_var = n_var,
3635
n_outer = n_outer,
3736
n_inner = n_inner);
38-
3937
#---> non batch
4038
@time result = gsa(ishi, method, input_distribution, batch = false)
4139

4240
@test result.shapley_effects[1]0.43813841765976547 atol=1e-1
4341
@test result.shapley_effects[2]0.44673952698721386 atol=1e-1
44-
@test result.shapley_effects[3]0.23144736934254417 atol=1e-1
45-
# @test result.shapley_effects[4]≈0.0 atol=1e-1
42+
@test result.shapley_effects[3]0.11855122481995543 atol=1e-1
43+
@test result.shapley_effects[4]0.0 atol=1e-1
4644
#<---- non batch
4745

4846
#---> batch
4947
result = gsa(ishi_batch, method, input_distribution, batch = true);
5048

5149
@test result.shapley_effects[1]0.44080027198796035 atol=1e-1
5250
@test result.shapley_effects[2]0.43029987176805085 atol=1e-1
53-
@test result.shapley_effects[3]0.23144736934254417 atol=1e-1
54-
# @test result.shapley_effects[4]≈0.0 atol=1e-1
51+
@test result.shapley_effects[3]0.11855122481995543 atol=1e-1
52+
@test result.shapley_effects[4]0.0 atol=1e-1
5553
#<--- batch
5654

5755
d = 3

0 commit comments

Comments
 (0)