Skip to content

Commit 0511f6b

Browse files
authored
Merge pull request #255 from DoubleML/s-set-sample-splitting
Extend set_sample_splitting for cluster data
2 parents 296114d + 5388568 commit 0511f6b

File tree

6 files changed

+345
-109
lines changed

6 files changed

+345
-109
lines changed

doubleml/double_ml.py

Lines changed: 14 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
from .utils.resampling import DoubleMLResampling, DoubleMLClusterResampling
1717
from .utils._estimation import _rmse, _aggregate_coefs_and_ses, _var_est, _set_external_predictions
18-
from .utils._checks import _check_in_zero_one, _check_integer, _check_float, _check_bool, _check_is_partition, \
19-
_check_all_smpls, _check_smpl_split, _check_smpl_split_tpl, _check_benchmarks, _check_external_predictions
18+
from .utils._checks import _check_in_zero_one, _check_integer, _check_float, _check_bool, \
19+
_check_benchmarks, _check_external_predictions, _check_sample_splitting
2020
from .utils._plots import _sensitivity_contour_plot
2121
from .utils.gain_statistics import gain_statistics
2222

@@ -289,11 +289,8 @@ def smpls(self):
289289
The partition used for cross-fitting.
290290
"""
291291
if self._smpls is None:
292-
if self._is_cluster_data:
293-
err_msg = 'Sample splitting not specified. Draw samples via .draw_sample splitting().'
294-
else:
295-
err_msg = ('Sample splitting not specified. Either draw samples via .draw_sample splitting() ' +
296-
'or set external samples via .set_sample_splitting().')
292+
err_msg = ('Sample splitting not specified. Either draw samples via .draw_sample splitting() ' +
293+
'or set external samples via .set_sample_splitting().')
297294
raise ValueError(err_msg)
298295
return self._smpls
299296

@@ -302,9 +299,6 @@ def smpls_cluster(self):
302299
"""
303300
The partition of clusters used for cross-fitting.
304301
"""
305-
if self._is_cluster_data:
306-
if self._smpls_cluster is None:
307-
raise ValueError('Sample splitting not specified. Draw samples via .draw_sample splitting().')
308302
return self._smpls_cluster
309303

310304
@property
@@ -1155,7 +1149,7 @@ def draw_sample_splitting(self):
11551149

11561150
return self
11571151

1158-
def set_sample_splitting(self, all_smpls):
1152+
def set_sample_splitting(self, all_smpls, all_smpls_cluster=None):
11591153
"""
11601154
Set the sample splitting for DoubleML models.
11611155
@@ -1177,6 +1171,13 @@ def set_sample_splitting(self, all_smpls):
11771171
train_ind and test_ind to np.arange(n_obs), which corresponds to no sample splitting.
11781172
``n_folds=1`` and ``n_rep=1`` is always set.
11791173
1174+
all_smpls_cluster : list or None
1175+
Nested list or ``None``. The first level of nesting corresponds to the number of repetitions. The second level
1176+
of nesting corresponds to the number of folds. The third level of nesting contains a tuple of training and
1177+
testing lists. Both training and testing contain an array for each cluster variable, which form a partition of
1178+
the clusters.
1179+
Default is ``None``.
1180+
11801181
Returns
11811182
-------
11821183
self : object
@@ -1194,8 +1195,6 @@ def set_sample_splitting(self, all_smpls):
11941195
>>> ml_m = learner
11951196
>>> obj_dml_data = make_plr_CCDDHNR2018(n_obs=10, alpha=0.5)
11961197
>>> dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m)
1197-
>>> # simple sample splitting with two folds and without cross-fitting
1198-
>>> smpls = ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])
11991198
>>> dml_plr_obj.set_sample_splitting(smpls)
12001199
>>> # sample splitting with two folds and cross-fitting
12011200
>>> smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
@@ -1208,71 +1207,8 @@ def set_sample_splitting(self, all_smpls):
12081207
>>> ([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
12091208
>>> dml_plr_obj.set_sample_splitting(smpls)
12101209
"""
1211-
if self._is_cluster_data:
1212-
raise NotImplementedError('Externally setting the sample splitting for DoubleML is '
1213-
'not yet implemented with clustering.')
1214-
if isinstance(all_smpls, tuple):
1215-
if not len(all_smpls) == 2:
1216-
raise ValueError('Invalid partition provided. '
1217-
'Tuple for train_ind and test_ind must consist of exactly two elements.')
1218-
all_smpls = _check_smpl_split_tpl(all_smpls, self._dml_data.n_obs)
1219-
if (_check_is_partition([all_smpls], self._dml_data.n_obs) &
1220-
_check_is_partition([(all_smpls[1], all_smpls[0])], self._dml_data.n_obs)):
1221-
self._n_rep = 1
1222-
self._n_folds = 1
1223-
self._smpls = [[all_smpls]]
1224-
else:
1225-
raise ValueError('Invalid partition provided. '
1226-
'Tuple provided that doesn\'t form a partition.')
1227-
else:
1228-
if not isinstance(all_smpls, list):
1229-
raise TypeError('all_smpls must be of list or tuple type. '
1230-
f'{str(all_smpls)} of type {str(type(all_smpls))} was passed.')
1231-
all_tuple = all([isinstance(tpl, tuple) for tpl in all_smpls])
1232-
if all_tuple:
1233-
if not all([len(tpl) == 2 for tpl in all_smpls]):
1234-
raise ValueError('Invalid partition provided. '
1235-
'All tuples for train_ind and test_ind must consist of exactly two elements.')
1236-
self._n_rep = 1
1237-
all_smpls = _check_smpl_split(all_smpls, self._dml_data.n_obs)
1238-
if _check_is_partition(all_smpls, self._dml_data.n_obs):
1239-
if ((len(all_smpls) == 1) &
1240-
_check_is_partition([(all_smpls[0][1], all_smpls[0][0])], self._dml_data.n_obs)):
1241-
self._n_folds = 1
1242-
self._smpls = [all_smpls]
1243-
else:
1244-
self._n_folds = len(all_smpls)
1245-
self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs, check_intersect=True)
1246-
else:
1247-
raise ValueError('Invalid partition provided. '
1248-
'Tuples provided that don\'t form a partition.')
1249-
else:
1250-
all_list = all([isinstance(smpl, list) for smpl in all_smpls])
1251-
if not all_list:
1252-
raise ValueError('Invalid partition provided. '
1253-
'all_smpls is a list where neither all elements are tuples '
1254-
'nor all elements are lists.')
1255-
all_tuple = all([all([isinstance(tpl, tuple) for tpl in smpl]) for smpl in all_smpls])
1256-
if not all_tuple:
1257-
raise TypeError('For repeated sample splitting all_smpls must be list of lists of tuples.')
1258-
all_pairs = all([all([len(tpl) == 2 for tpl in smpl]) for smpl in all_smpls])
1259-
if not all_pairs:
1260-
raise ValueError('Invalid partition provided. '
1261-
'All tuples for train_ind and test_ind must consist of exactly two elements.')
1262-
n_folds_each_smpl = np.array([len(smpl) for smpl in all_smpls])
1263-
if not np.all(n_folds_each_smpl == n_folds_each_smpl[0]):
1264-
raise ValueError('Invalid partition provided. '
1265-
'Different number of folds for repeated sample splitting.')
1266-
all_smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs)
1267-
smpls_are_partitions = [_check_is_partition(smpl, self._dml_data.n_obs) for smpl in all_smpls]
1268-
1269-
if all(smpls_are_partitions):
1270-
self._n_rep = len(all_smpls)
1271-
self._n_folds = n_folds_each_smpl[0]
1272-
self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs, check_intersect=True)
1273-
else:
1274-
raise ValueError('Invalid partition provided. '
1275-
'At least one inner list does not form a partition.')
1210+
self._smpls, self._smpls_cluster, self._n_rep, self._n_folds = _check_sample_splitting(
1211+
all_smpls, all_smpls_cluster, self._dml_data, self._is_cluster_data)
12761212

12771213
self._psi, self._psi_deriv, self._psi_elements, self._var_scaling_factors, \
12781214
self._coef, self._se, self._all_coef, self._all_se = self._initialize_arrays()

doubleml/irm/qte.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from ..utils._estimation import _default_kde
1515
from ..utils.resampling import DoubleMLResampling
16-
from ..utils._checks import _check_score, _check_trimming, _check_zero_one_treatment
16+
from ..utils._checks import _check_score, _check_trimming, _check_zero_one_treatment, _check_sample_splitting
1717

1818

1919
class DoubleMLQTE:
@@ -143,16 +143,15 @@ def __init__(self,
143143
raise TypeError('Normalization indicator has to be boolean. ' +
144144
f'Object of type {str(type(self.normalize_ipw))} passed.')
145145

146+
self._learner = {'ml_g': clone(ml_g), 'ml_m': clone(ml_m)}
147+
self._predict_method = {'ml_g': 'predict_proba', 'ml_m': 'predict_proba'}
148+
146149
# perform sample splitting
147150
self._smpls = None
148151
if draw_sample_splitting:
149152
self.draw_sample_splitting()
150-
151-
self._learner = {'ml_g': clone(ml_g), 'ml_m': clone(ml_m)}
152-
self._predict_method = {'ml_g': 'predict_proba', 'ml_m': 'predict_proba'}
153-
154-
# initialize all models
155-
self._modellist_0, self._modellist_1 = self._initialize_models()
153+
# initialize all models
154+
self._modellist_0, self._modellist_1 = self._initialize_models()
156155

157156
def __str__(self):
158157
class_name = self.__class__.__name__
@@ -204,8 +203,8 @@ def smpls(self):
204203
The partition used for cross-fitting.
205204
"""
206205
if self._smpls is None:
207-
err_msg = ('Sample splitting not specified. Draw samples via .draw_sample splitting(). ' +
208-
'External samples not implemented yet.')
206+
err_msg = ('Sample splitting not specified. Either draw samples via .draw_sample splitting() ' +
207+
'or set external samples via .set_sample_splitting().')
209208
raise ValueError(err_msg)
210209
return self._smpls
211210

@@ -465,6 +464,74 @@ def draw_sample_splitting(self):
465464
n_obs=self._dml_data.n_obs,
466465
stratify=self._dml_data.d)
467466
self._smpls = obj_dml_resampling.split_samples()
467+
# initialize all models
468+
self._modellist_0, self._modellist_1 = self._initialize_models()
469+
470+
return self
471+
472+
def set_sample_splitting(self, all_smpls, all_smpls_cluster=None):
473+
"""
474+
Set the sample splitting for DoubleML models.
475+
476+
The attributes ``n_folds`` and ``n_rep`` are derived from the provided partition.
477+
478+
Parameters
479+
----------
480+
all_smpls : list or tuple
481+
If nested list of lists of tuples:
482+
The outer list needs to provide an entry per repeated sample splitting (length of list is set as
483+
``n_rep``).
484+
The inner list needs to provide a tuple (train_ind, test_ind) per fold (length of list is set as
485+
``n_folds``). test_ind must form a partition for each inner list.
486+
If list of tuples:
487+
The list needs to provide a tuple (train_ind, test_ind) per fold (length of list is set as
488+
``n_folds``). test_ind must form a partition. ``n_rep=1`` is always set.
489+
If tuple:
490+
Must be a tuple with two elements train_ind and test_ind. Only viable option is to set
491+
train_ind and test_ind to np.arange(n_obs), which corresponds to no sample splitting.
492+
``n_folds=1`` and ``n_rep=1`` is always set.
493+
494+
all_smpls_cluster : list or None
495+
Nested list or ``None``. The first level of nesting corresponds to the number of repetitions. The second level
496+
of nesting corresponds to the number of folds. The third level of nesting contains a tuple of training and
497+
testing lists. Both training and testing contain an array for each cluster variable, which form a partition of
498+
the clusters.
499+
Default is ``None``.
500+
501+
Returns
502+
-------
503+
self : object
504+
505+
Examples
506+
--------
507+
>>> import numpy as np
508+
>>> import doubleml as dml
509+
>>> from doubleml.datasets import make_plr_CCDDHNR2018
510+
>>> from sklearn.ensemble import RandomForestRegressor
511+
>>> from sklearn.base import clone
512+
>>> np.random.seed(3141)
513+
>>> learner = RandomForestRegressor(max_depth=2, n_estimators=10)
514+
>>> ml_g = learner
515+
>>> ml_m = learner
516+
>>> obj_dml_data = make_plr_CCDDHNR2018(n_obs=10, alpha=0.5)
517+
>>> dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m)
518+
>>> dml_plr_obj.set_sample_splitting(smpls)
519+
>>> # sample splitting with two folds and cross-fitting
520+
>>> smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
521+
>>> ([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])]
522+
>>> dml_plr_obj.set_sample_splitting(smpls)
523+
>>> # sample splitting with two folds and repeated cross-fitting with n_rep = 2
524+
>>> smpls = [[([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
525+
>>> ([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
526+
>>> [([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]),
527+
>>> ([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
528+
>>> dml_plr_obj.set_sample_splitting(smpls)
529+
"""
530+
self._smpls, self._smpls_cluster, self._n_rep, self._n_folds = _check_sample_splitting(
531+
all_smpls, all_smpls_cluster, self._dml_data, self._is_cluster_data)
532+
533+
# initialize all models
534+
self._modellist_0, self._modellist_1 = self._initialize_models()
468535

469536
return self
470537

doubleml/irm/tests/test_qte.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,36 @@ def dml_qte_fixture(generate_data_quantiles, learner, normalize_ipw, kde):
5454
ml_g = clone(learner)
5555
ml_m = clone(learner)
5656

57+
input_args = {
58+
"quantiles": quantiles,
59+
"n_folds": n_folds,
60+
"n_rep": n_rep,
61+
"normalize_ipw": normalize_ipw,
62+
"trimming_threshold": 1e-12,
63+
"kde": kde
64+
}
65+
5766
np.random.seed(42)
58-
dml_qte_obj = dml.DoubleMLQTE(obj_dml_data,
59-
ml_g, ml_m,
60-
quantiles=quantiles,
61-
n_folds=n_folds,
62-
n_rep=n_rep,
63-
normalize_ipw=normalize_ipw,
64-
trimming_threshold=1e-12,
65-
kde=kde)
67+
dml_qte_obj = dml.DoubleMLQTE(
68+
obj_dml_data,
69+
ml_g, ml_m,
70+
**input_args
71+
)
6672
unfitted_qte_model = copy.copy(dml_qte_obj)
73+
np.random.seed(42)
6774
dml_qte_obj.fit()
6875

76+
np.random.seed(42)
77+
dml_qte_obj_ext_smpls = dml.DoubleMLQTE(
78+
obj_dml_data,
79+
ml_g, ml_m,
80+
draw_sample_splitting=False,
81+
**input_args
82+
)
83+
dml_qte_obj_ext_smpls.set_sample_splitting(dml_qte_obj.smpls)
84+
np.random.seed(42)
85+
dml_qte_obj_ext_smpls.fit()
86+
6987
np.random.seed(42)
7088
n_obs = len(y)
7189
all_smpls = draw_smpls(n_obs, n_folds, n_rep=1, groups=d)
@@ -80,8 +98,10 @@ def dml_qte_fixture(generate_data_quantiles, learner, normalize_ipw, kde):
8098
boot_t_stat=None, joint=False, level=0.95)
8199
res_dict = {'coef': dml_qte_obj.coef,
82100
'coef_manual': res_manual['qte'],
101+
'coef_ext_smpls': dml_qte_obj_ext_smpls.coef,
83102
'se': dml_qte_obj.se,
84103
'se_manual': res_manual['se'],
104+
'se_ext_smpls': dml_qte_obj_ext_smpls.se,
85105
'boot_methods': boot_methods,
86106
'ci': ci.to_numpy(),
87107
'ci_manual': ci_manual.to_numpy(),
@@ -112,13 +132,19 @@ def test_dml_qte_coef(dml_qte_fixture):
112132
assert np.allclose(dml_qte_fixture['coef'],
113133
dml_qte_fixture['coef_manual'],
114134
rtol=1e-9, atol=1e-4)
135+
assert np.allclose(dml_qte_fixture['coef'],
136+
dml_qte_fixture['coef_ext_smpls'],
137+
rtol=1e-9, atol=1e-4)
115138

116139

117140
@pytest.mark.ci
118141
def test_dml_qte_se(dml_qte_fixture):
119142
assert np.allclose(dml_qte_fixture['se'],
120143
dml_qte_fixture['se_manual'],
121144
rtol=1e-9, atol=1e-4)
145+
assert np.allclose(dml_qte_fixture['se'],
146+
dml_qte_fixture['se_ext_smpls'],
147+
rtol=1e-9, atol=1e-4)
122148

123149

124150
@pytest.mark.ci
@@ -148,8 +174,8 @@ def test_doubleml_qte_exceptions():
148174
ml_g = RandomForestClassifier(n_estimators=20)
149175
ml_m = RandomForestClassifier(n_estimators=20)
150176

151-
msg = r'Sample splitting not specified. Draw samples via .draw_sample splitting\(\). ' \
152-
'External samples not implemented yet.'
177+
msg = ('Sample splitting not specified. '
178+
r'Either draw samples via .draw_sample splitting\(\) or set external samples via .set_sample_splitting\(\).')
153179
with pytest.raises(ValueError, match=msg):
154180
dml_obj = dml.DoubleMLQTE(obj_dml_data, ml_g, ml_m, draw_sample_splitting=False)
155181
_ = dml_obj.smpls

0 commit comments

Comments
 (0)