Skip to content

Commit fad91e8

Browse files
committed
add: random tie break for expected error reduction
1 parent b7815d8 commit fad91e8

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

modAL/expected_error.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
from modAL.models import ActiveLearner
1313
from modAL.utils.data import modALinput, data_vstack
14-
from modAL.utils.selection import multi_argmax
14+
from modAL.utils.selection import multi_argmax, shuffled_argmax
1515
from modAL.uncertainty import _proba_uncertainty, _proba_entropy
1616

1717

1818
def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str = 'binary',
19-
p_subsample: np.float = 1.0, n_instances: int = 1) -> Tuple[np.ndarray, modALinput]:
19+
p_subsample: np.float = 1.0, n_instances: int = 1,
20+
random_tie_break: bool = False) -> Tuple[np.ndarray, modALinput]:
2021
"""
2122
Expected error reduction query strategy.
2223
@@ -32,6 +33,8 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
3233
calculating expected error. Significantly improves runtime
3334
for large sample pools.
3435
n_instances: The number of instances to be sampled.
36+
random_tie_break: If True, shuffles utility scores to randomize the order. This
37+
can be used to break the tie when the highest utility score is not unique.
3538
3639
3740
Returns:
@@ -73,6 +76,9 @@ def expected_error_reduction(learner: ActiveLearner, X: modALinput, loss: str =
7376
else:
7477
expected_error[x_idx] = np.inf
7578

76-
query_idx = multi_argmax(expected_error, n_instances)
79+
if not random_tie_break:
80+
query_idx = multi_argmax(expected_error, n_instances)
81+
else:
82+
query_idx = shuffled_argmax(expected_error, n_instances)
7783

7884
return query_idx, X[query_idx]

tests/core_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def test_eer(self):
462462
X_training=X_training, y_training=y_training)
463463

464464
modAL.expected_error.expected_error_reduction(learner, X_pool)
465+
modAL.expected_error.expected_error_reduction(learner, X_pool, random_tie_break=True)
465466
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1)
466467
modAL.expected_error.expected_error_reduction(learner, X_pool, loss='binary')
467468
modAL.expected_error.expected_error_reduction(learner, X_pool, p_subsample=0.1, loss='log')

0 commit comments

Comments
 (0)