11
11
12
12
from modAL .models import ActiveLearner
13
13
from modAL .utils .data import modALinput , data_vstack
14
- from modAL .utils .selection import multi_argmax
14
+ from modAL .utils .selection import multi_argmax , shuffled_argmax
15
15
from modAL .uncertainty import _proba_uncertainty , _proba_entropy
16
16
17
17
18
18
def expected_error_reduction (learner : ActiveLearner , X : modALinput , loss : str = 'binary' ,
19
- p_subsample : np .float = 1.0 , n_instances : int = 1 ) -> Tuple [np .ndarray , modALinput ]:
19
+ p_subsample : np .float = 1.0 , n_instances : int = 1 ,
20
+ random_tie_break : bool = False ) -> Tuple [np .ndarray , modALinput ]:
20
21
"""
21
22
Expected error reduction query strategy.
22
23
@@ -32,6 +33,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
32
33
calculating expected error. Significantly improves runtime
33
34
for large sample pools.
34
35
n_instances: The number of instances to be sampled.
36
+ random_tie_break: If True, shuffles utility scores to randomize the order. This
37
+ can be used to break the tie when the highest utility score is not unique.
35
38
36
39
37
40
Returns:
@@ -73,6 +76,9 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
73
76
else :
74
77
expected_error [x_idx ] = np .inf
75
78
76
- query_idx = multi_argmax (expected_error , n_instances )
79
+ if not random_tie_break :
80
+ query_idx = multi_argmax (expected_error , n_instances )
81
+ else :
82
+ query_idx = shuffled_argmax (expected_error , n_instances )
77
83
78
84
return query_idx , X [query_idx ]
0 commit comments