Skip to content

Commit 3954711

Browse files
committed
Merge branch 'master' of github.com:DoubleML/doubleml-for-py into 0.2.X
2 parents 17221d0 + ce44e84 commit 3954711

10 files changed

+449
-50
lines changed

README.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,21 @@ Detailed [installation instructions](https://docs.doubleml.org/stable/intro/inst
9090

9191
If you use the DoubleML package a citation is highly appreciated:
9292

93-
Bach, P., Chernozhukov, V., Kurz, M. S., and Spindler, M. (2020),
94-
DoubleML - Double Machine Learning in Python.
95-
URL: [https://github.com/DoubleML/doubleml-for-py](https://github.com/DoubleML/doubleml-for-py),
96-
Python-Package version 0.2.0.
93+
Bach, P., Chernozhukov, V., Kurz, M. S., and Spindler, M. (2021), DoubleML - An
94+
Object-Oriented Implementation of Double Machine Learning in Python,
95+
arXiv:[2104.03220](https://arxiv.org/abs/2104.03220).
9796

9897
Bibtex-entry:
9998

10099
```
101-
@Manual{DoubleML2020,
102-
title = {DoubleML - Double Machine Learning in Python},
103-
author = {Bach, P., Chernozhukov, V., Kurz, M. S., and Spindler, M.},
104-
year = {2020},
105-
note = {URL: \url{https://github.com/DoubleML/doubleml-for-py}, Python-Package version 0.2.0}
100+
@misc{DoubleML2021,
101+
title={{DoubleML} -- {A}n Object-Oriented Implementation of Double Machine Learning in {P}ython},
102+
author={Philipp Bach and Victor Chernozhukov and Malte S. Kurz and Martin Spindler},
103+
year={2021},
104+
eprint={2104.03220},
105+
archivePrefix={arXiv},
106+
primaryClass={stat.ML},
107+
note={arXiv:\href{https://arxiv.org/abs/2104.03220}{2104.03220} [stat.ML]}
106108
}
107109
```
108110

doubleml/double_ml.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,9 @@ def set_ml_nuisance_params(self, learner, treat_var, params):
869869
raise ValueError('Invalid treatment variable ' + treat_var + '. ' +
870870
'Valid treatment variable ' + ' or '.join(self._dml_data.d_cols) + '.')
871871

872-
if isinstance(params, dict):
872+
if params is None:
873+
all_params = [None] * self.n_rep
874+
elif isinstance(params, dict):
873875
if self.apply_cross_fitting:
874876
all_params = [[params] * self.n_folds] * self.n_rep
875877
else:

doubleml/double_ml_data.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,19 @@ def __init__(self,
5656
x_cols=None,
5757
z_cols=None,
5858
use_other_treat_as_covariate=True):
59+
if not isinstance(data, pd.DataFrame):
60+
raise TypeError('data must be of pd.DataFrame type. '
61+
f'{str(data)} of type {str(type(data))} was passed.')
62+
if not data.columns.is_unique:
63+
raise ValueError('Invalid pd.DataFrame: '
64+
'Contains duplicate column names.')
5965
self._data = data
6066

6167
self.y_col = y_col
6268
self.d_cols = d_cols
6369
self.z_cols = z_cols
6470
self.x_cols = x_cols
71+
self._check_disjoint_sets()
6572
self.use_other_treat_as_covariate = use_other_treat_as_covariate
6673
self._binary_treats = self._check_binary_treats()
6774
self._set_y_z()
@@ -245,6 +252,9 @@ def x_cols(self, value):
245252
if not isinstance(value, list):
246253
raise TypeError('The covariates x_cols must be of str or list type (or None). '
247254
f'{str(value)} of type {str(type(value))} was passed.')
255+
if not len(set(value)) == len(value):
256+
raise ValueError('Invalid covariates x_cols: '
257+
'Contains duplicate values.')
248258
if not set(value).issubset(set(self.all_variables)):
249259
raise ValueError('Invalid covariates x_cols. '
250260
'At least one covariate is no data column.')
@@ -253,13 +263,14 @@ def x_cols(self, value):
253263
else:
254264
# x_cols defaults to all columns but y_col, d_cols and z_cols
255265
if self.z_cols is not None:
256-
y_d_z = set.union(set(self.y_col), set(self.d_cols), set(self.z_cols))
266+
y_d_z = set.union({self.y_col}, set(self.d_cols), set(self.z_cols))
257267
x_cols = [col for col in self.data.columns if col not in y_d_z]
258268
else:
259-
y_d = set.union(set(self.y_col), set(self.d_cols))
269+
y_d = set.union({self.y_col}, set(self.d_cols))
260270
x_cols = [col for col in self.data.columns if col not in y_d]
261271
self._x_cols = x_cols
262272
if reset_value:
273+
self._check_disjoint_sets()
263274
# by default, we initialize to the first treatment variable
264275
self.set_x_d(self.d_cols[0])
265276

@@ -278,11 +289,15 @@ def d_cols(self, value):
278289
if not isinstance(value, list):
279290
raise TypeError('The treatment variable(s) d_cols must be of str or list type. '
280291
f'{str(value)} of type {str(type(value))} was passed.')
292+
if not len(set(value)) == len(value):
293+
raise ValueError('Invalid treatment variable(s) d_cols: '
294+
'Contains duplicate values.')
281295
if not set(value).issubset(set(self.all_variables)):
282296
raise ValueError('Invalid treatment variable(s) d_cols. '
283297
'At least one treatment variable is no data column.')
284298
self._d_cols = value
285299
if reset_value:
300+
self._check_disjoint_sets()
286301
# by default, we initialize to the first treatment variable
287302
self.set_x_d(self.d_cols[0])
288303

@@ -304,6 +319,7 @@ def y_col(self, value):
304319
f'{value} is no data column.')
305320
self._y_col = value
306321
if reset_value:
322+
self._check_disjoint_sets()
307323
self._set_y_z()
308324

309325
@property
@@ -322,13 +338,17 @@ def z_cols(self, value):
322338
if not isinstance(value, list):
323339
raise TypeError('The instrumental variable(s) z_cols must be of str or list type (or None). '
324340
f'{str(value)} of type {str(type(value))} was passed.')
341+
if not len(set(value)) == len(value):
342+
raise ValueError('Invalid instrumental variable(s) z_cols: '
343+
'Contains duplicate values.')
325344
if not set(value).issubset(set(self.all_variables)):
326345
raise ValueError('Invalid instrumental variable(s) z_cols. '
327346
'At least one instrumental variable is no data column.')
328347
self._z_cols = value
329348
else:
330349
self._z_cols = None
331350
if reset_value:
351+
self._check_disjoint_sets()
332352
self._set_y_z()
333353

334354
@property
@@ -368,6 +388,8 @@ def set_x_d(self, treatment_var):
368388
raise ValueError('Invalid treatment_var. '
369389
f'{treatment_var} is not in d_cols.')
370390
if self.use_other_treat_as_covariate:
391+
# note that the following line needs to be adapted in case an intersection of x_cols and d_cols as allowed
392+
# (see https://github.com/DoubleML/doubleml-for-py/issues/83)
371393
xd_list = self.x_cols + self.d_cols
372394
xd_list.remove(treatment_var)
373395
else:
@@ -383,3 +405,32 @@ def _check_binary_treats(self):
383405
zero_one_treat = np.all((np.power(this_d, 2) - this_d) == 0)
384406
is_binary[treatment_var] = (binary_treat & zero_one_treat)
385407
return is_binary
408+
409+
def _check_disjoint_sets(self):
410+
y_col_set = {self.y_col}
411+
x_cols_set = set(self.x_cols)
412+
d_cols_set = set(self.d_cols)
413+
414+
if not y_col_set.isdisjoint(x_cols_set):
415+
raise ValueError(f'{str(self.y_col)} cannot be set as outcome variable ``y_col`` and covariate in '
416+
'``x_cols``.')
417+
if not y_col_set.isdisjoint(d_cols_set):
418+
raise ValueError(f'{str(self.y_col)} cannot be set as outcome variable ``y_col`` and treatment variable in '
419+
'``d_cols``.')
420+
# note that the line xd_list = self.x_cols + self.d_cols in method set_x_d needs adaption if an intersection of
421+
# x_cols and d_cols as allowed (see https://github.com/DoubleML/doubleml-for-py/issues/83)
422+
if not d_cols_set.isdisjoint(x_cols_set):
423+
raise ValueError('At least one variable/column is set as treatment variable (``d_cols``) and as covariate'
424+
'(``x_cols``). Consider using parameter ``use_other_treat_as_covariate``.')
425+
426+
if self.z_cols is not None:
427+
z_cols_set = set(self.z_cols)
428+
if not y_col_set.isdisjoint(z_cols_set):
429+
raise ValueError(f'{str(self.y_col)} cannot be set as outcome variable ``y_col`` and instrumental '
430+
'variable in ``z_cols``.')
431+
if not d_cols_set.isdisjoint(z_cols_set):
432+
raise ValueError('At least one variable/column is set as treatment variable (``d_cols``) and '
433+
'instrumental variable in ``z_cols``.')
434+
if not x_cols_set.isdisjoint(z_cols_set):
435+
raise ValueError('At least one variable/column is set as covariate (``x_cols``) and instrumental '
436+
'variable in ``z_cols``.')

doubleml/double_ml_iivm.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class DoubleMLIIVM(DoubleML):
4040
``psi_a, psi_b = score(y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls)``.
4141
Default is ``'LATE'``.
4242
43+
subgroups: dict or None
44+
Dictionary with options to adapt to cases with and without the subgroups of always-takers and never-takes. The
45+
logical item ``always_takers`` speficies whether there are always takers in the sample. The logical item
46+
``never_takers`` speficies whether there are never takers in the sample.
47+
Default is ``{'always_takers': True, 'never_takers': True}``.
48+
4349
dml_procedure : str
4450
A str (``'dml1'`` or ``'dml2'``) specifying the double machine learning algorithm.
4551
Default is ``'dml2'``.
@@ -115,6 +121,7 @@ def __init__(self,
115121
n_folds=5,
116122
n_rep=1,
117123
score='LATE',
124+
subgroups=None,
118125
dml_procedure='dml2',
119126
trimming_rule='truncate',
120127
trimming_threshold=1e-12,
@@ -138,6 +145,25 @@ def __init__(self,
138145
if trimming_rule not in valid_trimming_rule:
139146
raise ValueError('Invalid trimming_rule ' + trimming_rule + '. ' +
140147
'Valid trimming_rule ' + ' or '.join(valid_trimming_rule) + '.')
148+
149+
if subgroups is None:
150+
# this is the default for subgroups; via None to prevent a mutable default argument
151+
subgroups = {'always_takers': True, 'never_takers': True}
152+
else:
153+
if not isinstance(subgroups, dict):
154+
raise TypeError('Invalid subgroups ' + str(subgroups) + '. ' +
155+
'subgroups must be of type dictionary.')
156+
if (not all(k in subgroups for k in ['always_takers', 'never_takers']))\
157+
| (not all(k in ['always_takers', 'never_takers'] for k in subgroups)):
158+
raise ValueError('Invalid subgroups ' + str(subgroups) + '. ' +
159+
'subgroups must be a dictionary with keys always_takers and never_takers.')
160+
if not isinstance(subgroups['always_takers'], bool):
161+
raise TypeError("subgroups['always_takers'] must be True or False. "
162+
f'Got {str(subgroups["always_takers"])}.')
163+
if not isinstance(subgroups['never_takers'], bool):
164+
raise TypeError("subgroups['never_takers'] must be True or False. "
165+
f'Got {str(subgroups["never_takers"])}.')
166+
self.subgroups = subgroups
141167
self.trimming_rule = trimming_rule
142168
self.trimming_threshold = trimming_threshold
143169

@@ -196,10 +222,16 @@ def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv):
196222
est_params=self._get_params('ml_m'), method=self._predict_method['ml_m'])
197223

198224
# nuisance r
199-
r_hat0 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z0, n_jobs=n_jobs_cv,
200-
est_params=self._get_params('ml_r0'), method=self._predict_method['ml_r'])
201-
r_hat1 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z1, n_jobs=n_jobs_cv,
202-
est_params=self._get_params('ml_r1'), method=self._predict_method['ml_r'])
225+
if self.subgroups['always_takers']:
226+
r_hat0 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z0, n_jobs=n_jobs_cv,
227+
est_params=self._get_params('ml_r0'), method=self._predict_method['ml_r'])
228+
else:
229+
r_hat0 = np.zeros_like(d)
230+
if self.subgroups['never_takers']:
231+
r_hat1 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z1, n_jobs=n_jobs_cv,
232+
est_params=self._get_params('ml_r1'), method=self._predict_method['ml_r'])
233+
else:
234+
r_hat1 = np.ones_like(d)
203235

204236
psi_a, psi_b = self._score_elements(y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls)
205237
preds = {'ml_g0': g_hat0,
@@ -262,18 +294,27 @@ def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune,
262294
m_tune_res = _dml_tune(z, x, train_inds,
263295
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
264296
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
265-
r0_tune_res = _dml_tune(d, x, train_inds_z0,
266-
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
267-
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
268-
r1_tune_res = _dml_tune(d, x, train_inds_z1,
269-
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
270-
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
297+
298+
if self.subgroups['always_takers']:
299+
r0_tune_res = _dml_tune(d, x, train_inds_z0,
300+
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
301+
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
302+
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
303+
else:
304+
r0_tune_res = None
305+
r0_best_params = [None] * len(smpls)
306+
if self.subgroups['never_takers']:
307+
r1_tune_res = _dml_tune(d, x, train_inds_z1,
308+
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
309+
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
310+
r1_best_params = [xx.best_params_ for xx in r1_tune_res]
311+
else:
312+
r1_tune_res = None
313+
r1_best_params = [None] * len(smpls)
271314

272315
g0_best_params = [xx.best_params_ for xx in g0_tune_res]
273316
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
274317
m_best_params = [xx.best_params_ for xx in m_tune_res]
275-
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
276-
r1_best_params = [xx.best_params_ for xx in r1_tune_res]
277318

278319
params = {'ml_g0': g0_best_params,
279320
'ml_g1': g1_best_params,

doubleml/tests/_utils_iivm_manual.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def fit_nuisance_iivm(y, x, d, z, learner_m, learner_g, learner_r, smpls,
99
g0_params=None, g1_params=None, m_params=None, r0_params=None, r1_params=None,
10-
trimming_threshold=1e-12):
10+
trimming_threshold=1e-12, always_takers=True, never_takers=True):
1111
ml_g0 = clone(learner_g)
1212
g_hat0 = []
1313
for idx, (train_index, test_index) in enumerate(smpls):
@@ -41,21 +41,28 @@ def fit_nuisance_iivm(y, x, d, z, learner_m, learner_g, learner_r, smpls,
4141
if r0_params is not None:
4242
ml_r0.set_params(**r0_params[idx])
4343
train_index0 = np.intersect1d(np.where(z == 0)[0], train_index)
44-
r_hat0.append(ml_r0.fit(x[train_index0], d[train_index0]).predict_proba(x[test_index])[:, 1])
44+
if always_takers:
45+
r_hat0.append(ml_r0.fit(x[train_index0], d[train_index0]).predict_proba(x[test_index])[:, 1])
46+
else:
47+
r_hat0.append(np.zeros_like(d[test_index]))
4548

4649
ml_r1 = clone(learner_r)
4750
r_hat1 = []
4851
for idx, (train_index, test_index) in enumerate(smpls):
4952
if r1_params is not None:
5053
ml_r1.set_params(**r1_params[idx])
5154
train_index1 = np.intersect1d(np.where(z == 1)[0], train_index)
52-
r_hat1.append(ml_r1.fit(x[train_index1], d[train_index1]).predict_proba(x[test_index])[:, 1])
55+
if never_takers:
56+
r_hat1.append(ml_r1.fit(x[train_index1], d[train_index1]).predict_proba(x[test_index])[:, 1])
57+
else:
58+
r_hat1.append(np.ones_like(d[test_index]))
5359

5460
return g_hat0, g_hat1, m_hat, r_hat0, r_hat1
5561

5662

5763
def tune_nuisance_iivm(y, x, d, z, ml_m, ml_g, ml_r, smpls, n_folds_tune,
58-
param_grid_g, param_grid_m, param_grid_r):
64+
param_grid_g, param_grid_m, param_grid_r,
65+
always_takers=True, never_takers=True):
5966
g0_tune_res = [None] * len(smpls)
6067
for idx, (train_index, _) in enumerate(smpls):
6168
g0_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
@@ -79,27 +86,33 @@ def tune_nuisance_iivm(y, x, d, z, ml_m, ml_g, ml_r, smpls, n_folds_tune,
7986
cv=m_tune_resampling)
8087
m_tune_res[idx] = m_grid_search.fit(x[train_index, :], z[train_index])
8188

82-
r0_tune_res = [None] * len(smpls)
83-
for idx, (train_index, _) in enumerate(smpls):
84-
r0_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
85-
r0_grid_search = GridSearchCV(ml_r, param_grid_r,
86-
cv=r0_tune_resampling)
87-
train_index0 = np.intersect1d(np.where(z == 0)[0], train_index)
88-
r0_tune_res[idx] = r0_grid_search.fit(x[train_index0, :], d[train_index0])
89-
90-
r1_tune_res = [None] * len(smpls)
91-
for idx, (train_index, _) in enumerate(smpls):
92-
r1_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
93-
r1_grid_search = GridSearchCV(ml_r, param_grid_r,
94-
cv=r1_tune_resampling)
95-
train_index1 = np.intersect1d(np.where(z == 1)[0], train_index)
96-
r1_tune_res[idx] = r1_grid_search.fit(x[train_index1, :], d[train_index1])
89+
if always_takers:
90+
r0_tune_res = [None] * len(smpls)
91+
for idx, (train_index, _) in enumerate(smpls):
92+
r0_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
93+
r0_grid_search = GridSearchCV(ml_r, param_grid_r,
94+
cv=r0_tune_resampling)
95+
train_index0 = np.intersect1d(np.where(z == 0)[0], train_index)
96+
r0_tune_res[idx] = r0_grid_search.fit(x[train_index0, :], d[train_index0])
97+
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
98+
else:
99+
r0_best_params = None
100+
101+
if never_takers:
102+
r1_tune_res = [None] * len(smpls)
103+
for idx, (train_index, _) in enumerate(smpls):
104+
r1_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
105+
r1_grid_search = GridSearchCV(ml_r, param_grid_r,
106+
cv=r1_tune_resampling)
107+
train_index1 = np.intersect1d(np.where(z == 1)[0], train_index)
108+
r1_tune_res[idx] = r1_grid_search.fit(x[train_index1, :], d[train_index1])
109+
r1_best_params = [xx.best_params_ for xx in r1_tune_res]
110+
else:
111+
r1_best_params = None
97112

98113
g0_best_params = [xx.best_params_ for xx in g0_tune_res]
99114
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
100115
m_best_params = [xx.best_params_ for xx in m_tune_res]
101-
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
102-
r1_best_params = [xx.best_params_ for xx in r1_tune_res]
103116

104117
return g0_best_params, g1_best_params, m_best_params, r0_best_params, r1_best_params
105118

0 commit comments

Comments
 (0)