15
15
16
16
from .utils .resampling import DoubleMLResampling , DoubleMLClusterResampling
17
17
from .utils ._estimation import _rmse , _aggregate_coefs_and_ses , _var_est , _set_external_predictions
18
- from .utils ._checks import _check_in_zero_one , _check_integer , _check_float , _check_bool , _check_is_partition , \
19
- _check_all_smpls , _check_smpl_split , _check_smpl_split_tpl , _check_benchmarks , _check_external_predictions
18
+ from .utils ._checks import _check_in_zero_one , _check_integer , _check_float , _check_bool , \
19
+ _check_benchmarks , _check_external_predictions , _check_sample_splitting
20
20
from .utils ._plots import _sensitivity_contour_plot
21
21
from .utils .gain_statistics import gain_statistics
22
22
@@ -289,11 +289,8 @@ def smpls(self):
289
289
The partition used for cross-fitting.
290
290
"""
291
291
if self ._smpls is None :
292
- if self ._is_cluster_data :
293
- err_msg = 'Sample splitting not specified. Draw samples via .draw_sample splitting().'
294
- else :
295
- err_msg = ('Sample splitting not specified. Either draw samples via .draw_sample splitting() ' +
296
- 'or set external samples via .set_sample_splitting().' )
292
+ err_msg = ('Sample splitting not specified. Either draw samples via .draw_sample splitting() ' +
293
+ 'or set external samples via .set_sample_splitting().' )
297
294
raise ValueError (err_msg )
298
295
return self ._smpls
299
296
@@ -302,9 +299,6 @@ def smpls_cluster(self):
302
299
"""
303
300
The partition of clusters used for cross-fitting.
304
301
"""
305
- if self ._is_cluster_data :
306
- if self ._smpls_cluster is None :
307
- raise ValueError ('Sample splitting not specified. Draw samples via .draw_sample splitting().' )
308
302
return self ._smpls_cluster
309
303
310
304
@property
@@ -1155,7 +1149,7 @@ def draw_sample_splitting(self):
1155
1149
1156
1150
return self
1157
1151
1158
- def set_sample_splitting (self , all_smpls ):
1152
+ def set_sample_splitting (self , all_smpls , all_smpls_cluster = None ):
1159
1153
"""
1160
1154
Set the sample splitting for DoubleML models.
1161
1155
@@ -1177,6 +1171,13 @@ def set_sample_splitting(self, all_smpls):
1177
1171
train_ind and test_ind to np.arange(n_obs), which corresponds to no sample splitting.
1178
1172
``n_folds=1`` and ``n_rep=1`` is always set.
1179
1173
1174
+ all_smpls_cluster : list or None
1175
+ Nested list or ``None``. The first level of nesting corresponds to the number of repetitions. The second level
1176
+ of nesting corresponds to the number of folds. The third level of nesting contains a tuple of training and
1177
+ testing lists. Both training and testing contain an array for each cluster variable, which form a partition of
1178
+ the clusters.
1179
+ Default is ``None``.
1180
+
1180
1181
Returns
1181
1182
-------
1182
1183
self : object
@@ -1194,8 +1195,6 @@ def set_sample_splitting(self, all_smpls):
1194
1195
>>> ml_m = learner
1195
1196
>>> obj_dml_data = make_plr_CCDDHNR2018(n_obs=10, alpha=0.5)
1196
1197
>>> dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m)
1197
- >>> # simple sample splitting with two folds and without cross-fitting
1198
- >>> smpls = ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])
1199
1198
>>> dml_plr_obj.set_sample_splitting(smpls)
1200
1199
>>> # sample splitting with two folds and cross-fitting
1201
1200
>>> smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
@@ -1208,71 +1207,8 @@ def set_sample_splitting(self, all_smpls):
1208
1207
>>> ([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
1209
1208
>>> dml_plr_obj.set_sample_splitting(smpls)
1210
1209
"""
1211
- if self ._is_cluster_data :
1212
- raise NotImplementedError ('Externally setting the sample splitting for DoubleML is '
1213
- 'not yet implemented with clustering.' )
1214
- if isinstance (all_smpls , tuple ):
1215
- if not len (all_smpls ) == 2 :
1216
- raise ValueError ('Invalid partition provided. '
1217
- 'Tuple for train_ind and test_ind must consist of exactly two elements.' )
1218
- all_smpls = _check_smpl_split_tpl (all_smpls , self ._dml_data .n_obs )
1219
- if (_check_is_partition ([all_smpls ], self ._dml_data .n_obs ) &
1220
- _check_is_partition ([(all_smpls [1 ], all_smpls [0 ])], self ._dml_data .n_obs )):
1221
- self ._n_rep = 1
1222
- self ._n_folds = 1
1223
- self ._smpls = [[all_smpls ]]
1224
- else :
1225
- raise ValueError ('Invalid partition provided. '
1226
- 'Tuple provided that doesn\' t form a partition.' )
1227
- else :
1228
- if not isinstance (all_smpls , list ):
1229
- raise TypeError ('all_smpls must be of list or tuple type. '
1230
- f'{ str (all_smpls )} of type { str (type (all_smpls ))} was passed.' )
1231
- all_tuple = all ([isinstance (tpl , tuple ) for tpl in all_smpls ])
1232
- if all_tuple :
1233
- if not all ([len (tpl ) == 2 for tpl in all_smpls ]):
1234
- raise ValueError ('Invalid partition provided. '
1235
- 'All tuples for train_ind and test_ind must consist of exactly two elements.' )
1236
- self ._n_rep = 1
1237
- all_smpls = _check_smpl_split (all_smpls , self ._dml_data .n_obs )
1238
- if _check_is_partition (all_smpls , self ._dml_data .n_obs ):
1239
- if ((len (all_smpls ) == 1 ) &
1240
- _check_is_partition ([(all_smpls [0 ][1 ], all_smpls [0 ][0 ])], self ._dml_data .n_obs )):
1241
- self ._n_folds = 1
1242
- self ._smpls = [all_smpls ]
1243
- else :
1244
- self ._n_folds = len (all_smpls )
1245
- self ._smpls = _check_all_smpls ([all_smpls ], self ._dml_data .n_obs , check_intersect = True )
1246
- else :
1247
- raise ValueError ('Invalid partition provided. '
1248
- 'Tuples provided that don\' t form a partition.' )
1249
- else :
1250
- all_list = all ([isinstance (smpl , list ) for smpl in all_smpls ])
1251
- if not all_list :
1252
- raise ValueError ('Invalid partition provided. '
1253
- 'all_smpls is a list where neither all elements are tuples '
1254
- 'nor all elements are lists.' )
1255
- all_tuple = all ([all ([isinstance (tpl , tuple ) for tpl in smpl ]) for smpl in all_smpls ])
1256
- if not all_tuple :
1257
- raise TypeError ('For repeated sample splitting all_smpls must be list of lists of tuples.' )
1258
- all_pairs = all ([all ([len (tpl ) == 2 for tpl in smpl ]) for smpl in all_smpls ])
1259
- if not all_pairs :
1260
- raise ValueError ('Invalid partition provided. '
1261
- 'All tuples for train_ind and test_ind must consist of exactly two elements.' )
1262
- n_folds_each_smpl = np .array ([len (smpl ) for smpl in all_smpls ])
1263
- if not np .all (n_folds_each_smpl == n_folds_each_smpl [0 ]):
1264
- raise ValueError ('Invalid partition provided. '
1265
- 'Different number of folds for repeated sample splitting.' )
1266
- all_smpls = _check_all_smpls (all_smpls , self ._dml_data .n_obs )
1267
- smpls_are_partitions = [_check_is_partition (smpl , self ._dml_data .n_obs ) for smpl in all_smpls ]
1268
-
1269
- if all (smpls_are_partitions ):
1270
- self ._n_rep = len (all_smpls )
1271
- self ._n_folds = n_folds_each_smpl [0 ]
1272
- self ._smpls = _check_all_smpls (all_smpls , self ._dml_data .n_obs , check_intersect = True )
1273
- else :
1274
- raise ValueError ('Invalid partition provided. '
1275
- 'At least one inner list does not form a partition.' )
1210
+ self ._smpls , self ._smpls_cluster , self ._n_rep , self ._n_folds = _check_sample_splitting (
1211
+ all_smpls , all_smpls_cluster , self ._dml_data , self ._is_cluster_data )
1276
1212
1277
1213
self ._psi , self ._psi_deriv , self ._psi_elements , self ._var_scaling_factors , \
1278
1214
self ._coef , self ._se , self ._all_coef , self ._all_se = self ._initialize_arrays ()
0 commit comments