Skip to content

Commit 0851e88

Browse files
authored
Merge pull request #212 from DoubleML/o-policy-tree
Add Policy Tree class for learning policies based on IRM
2 parents 8ee9519 + 45e8292 commit 0851e88

9 files changed

+397
-4
lines changed

doubleml/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .double_ml_pq import DoubleMLPQ
1313
from .double_ml_lpq import DoubleMLLPQ
1414
from .double_ml_cvar import DoubleMLCVAR
15+
from .double_ml_policytree import DoubleMLPolicyTree
1516

1617
__all__ = ['DoubleMLPLR',
1718
'DoubleMLPLIV',
@@ -25,6 +26,7 @@
2526
'DoubleMLPQ',
2627
'DoubleMLQTE',
2728
'DoubleMLLPQ',
28-
'DoubleMLCVAR']
29+
'DoubleMLCVAR',
30+
'DoubleMLPolicyTree']
2931

3032
__version__ = get_distribution('doubleml').version

doubleml/double_ml_irm.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from .double_ml import DoubleML
88

99
from .double_ml_blp import DoubleMLBLP
10+
from .double_ml_policytree import DoubleMLPolicyTree
1011
from .double_ml_data import DoubleMLData
1112
from .double_ml_score_mixins import LinearScoreMixin
1213

1314
from ._utils import _dml_cv_predict, _get_cond_smpls, _dml_tune, _trimm, _normalize_ipw
14-
from ._utils_checks import _check_score, _check_trimming, _check_finite_predictions, _check_is_propensity
15+
from ._utils_checks import _check_score, _check_trimming, _check_finite_predictions, _check_is_propensity, _check_integer
1516

1617

1718
class DoubleMLIRM(LinearScoreMixin, DoubleML):
@@ -472,3 +473,49 @@ def gate(self, groups):
472473
model = DoubleMLBLP(orth_signal, basis=groups, is_gate=True).fit()
473474

474475
return model
476+
477+
def policy_tree(self, features, depth=2, **tree_params):
478+
"""
479+
Estimate a decision tree for optimal treatment policy by weighted classification.
480+
481+
Parameters
482+
----------
483+
depth : int
484+
The depth of the estimated decision tree.
485+
Has to be larger than 0. Deeper trees derive a more complex decision policy. Default is ``2``.
486+
487+
features : :class:`pandas.DataFrame`
488+
The covariates on which the policy tree is learned.
489+
Has to be of shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
490+
and ``d`` is the number of covariates to be included.
491+
492+
**tree_params : dict
493+
Parameters that are forwarded to the :class:`sklearn.tree.DecisionTreeClassifier`.
494+
Note that by default we perform minimal pruning by setting the ``ccp_alpha = 0.01`` and
495+
``min_samples_leaf = 8``. This can be adjusted.
496+
497+
Returns
498+
-------
499+
model : :class:`doubleML.DoubleMLPolicyTree`
500+
Policy tree model.
501+
"""
502+
valid_score = ['ATE']
503+
if self.score not in valid_score:
504+
raise ValueError('Invalid score ' + self.score + '. ' +
505+
'Valid score ' + ' or '.join(valid_score) + '.')
506+
507+
if self.n_rep != 1:
508+
raise NotImplementedError('Only implemented for one repetition. ' +
509+
f'Number of repetitions is {str(self.n_rep)}.')
510+
511+
_check_integer(depth, "Depth", 0)
512+
513+
if not isinstance(features, pd.DataFrame):
514+
raise TypeError('Covariates must be of DataFrame type. '
515+
f'Covariates of type {str(type(features))} was passed.')
516+
517+
orth_signal = self.psi_elements['psi_b'].reshape(-1)
518+
519+
model = DoubleMLPolicyTree(orth_signal, depth=depth, features=features, **tree_params).fit()
520+
521+
return model

doubleml/double_ml_policytree.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from sklearn.tree import DecisionTreeClassifier, plot_tree
5+
from sklearn.utils.validation import check_is_fitted
6+
7+
8+
class DoubleMLPolicyTree:
9+
"""Policy Tree fitting for DoubleML.
10+
Currently avaivable for IRM models.
11+
12+
Parameters
13+
----------
14+
orth_signal : :class:`numpy.array`
15+
The orthogonal signal to be predicted. Has to be of shape ``(n_obs,)``,
16+
where ``n_obs`` is the number of observations.
17+
18+
features : :class:`pandas.DataFrame`
19+
The covariates for estimating the policy tree. Has to have the shape ``(n_obs, d)``,
20+
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
21+
22+
depth : int
23+
The depth of the policy tree that will be built. Default is ``2``.
24+
25+
**tree_params : dict
26+
Parameters that are forwarded to the :class:`sklearn.tree.DecisionTreeClassifier`.
27+
Note that by default we perform minimal pruning by setting the ``ccp_alpha = 0.01`` and
28+
``min_samples_leaf = 8``. This can be adjusted.
29+
30+
"""
31+
32+
def __init__(self,
33+
orth_signal,
34+
features,
35+
depth=2,
36+
**tree_params):
37+
38+
if not isinstance(orth_signal, np.ndarray):
39+
raise TypeError('The signal must be of np.ndarray type. '
40+
f'Signal of type {str(type(orth_signal))} was passed.')
41+
42+
if orth_signal.ndim != 1:
43+
raise ValueError('The signal must be of one dimensional. '
44+
f'Signal of dimensions {str(orth_signal.ndim)} was passed.')
45+
46+
if not isinstance(features, pd.DataFrame):
47+
raise TypeError('The features must be of DataFrame type. '
48+
f'Features of type {str(type(features))} was passed.')
49+
50+
if not features.columns.is_unique:
51+
raise ValueError('Invalid pd.DataFrame: '
52+
'Contains duplicate column names.')
53+
54+
self._orth_signal = orth_signal
55+
self._features = features
56+
self._depth = depth
57+
self._tree_params = tree_params
58+
59+
self._tree_params.setdefault("ccp_alpha", .01)
60+
self._tree_params.setdefault("min_samples_leaf", 8)
61+
62+
# initialize tree
63+
self._policy_tree = DecisionTreeClassifier(max_depth=self._depth,
64+
**self._tree_params)
65+
66+
def __str__(self):
67+
class_name = self.__class__.__name__
68+
header = f'================== {class_name} Object ==================\n'
69+
fit_summary = str(self.summary)
70+
res = header + \
71+
'\n------------------ Summary ------------------\n' + fit_summary
72+
return res
73+
74+
@property
75+
def policy_tree(self):
76+
"""
77+
Policy tree model.
78+
"""
79+
return self._policy_tree
80+
81+
@property
82+
def orth_signal(self):
83+
"""
84+
Orthogonal signal.
85+
"""
86+
return self._orth_signal
87+
88+
@property
89+
def features(self):
90+
"""
91+
Covariates.
92+
"""
93+
return self._features
94+
95+
@property
96+
def summary(self):
97+
"""
98+
A summary for the policy tree.
99+
"""
100+
summary = pd.DataFrame({"Decision Variables": self._features.keys(), "Max Depth": self._depth})
101+
return summary
102+
103+
def fit(self):
104+
"""
105+
Estimate DoubleMLPolicyTree models.
106+
107+
Returns
108+
-------
109+
self : object
110+
"""
111+
bin_signal = (np.sign(self._orth_signal) + 1) / 2
112+
abs_signal = np.abs(self._orth_signal)
113+
114+
# fit the tree with target binary score, sample weights absolute score and
115+
# provided feature variables
116+
self._policy_tree.fit(X=self._features, y=bin_signal,
117+
sample_weight=abs_signal)
118+
119+
return self
120+
121+
def plot_tree(self):
122+
"""
123+
Plots the DoubleMLPolicyTree.
124+
125+
Returns
126+
-------
127+
self : object
128+
"""
129+
check_is_fitted(self._policy_tree, msg='Policy Tree not yet fitted. Call fit before plot_tree.')
130+
131+
artists = plot_tree(self.policy_tree, feature_names=list(self._features.keys()), filled=True,
132+
class_names=["No Treatment", "Treatment"], impurity=False)
133+
return artists
134+
135+
def predict(self, features):
136+
"""
137+
Predicts policy based on the DoubleMLPolicyTree.
138+
139+
Parameters
140+
----------
141+
features : :class:`pandas.DataFrame`
142+
The covariates for predicting based on the policy tree. Has to have the shape ``(n_obs, d)``,
143+
where ``n_obs`` is the number of observations and ``d`` is the number of predictors. Has to
144+
have the identical keys as the original covariates.
145+
146+
Returns
147+
-------
148+
self : object
149+
"""
150+
check_is_fitted(self._policy_tree, msg='Policy Tree not yet fitted. Call fit before predict.')
151+
152+
if not isinstance(features, pd.DataFrame):
153+
raise TypeError('The features must be of DataFrame type. '
154+
f'Features of type {str(type(features))} was passed.')
155+
156+
if not set(features.keys()) == set(self._features.keys()):
157+
raise KeyError(f'The features must have the keys {self._features.keys()}. '
158+
f'Features with keys {features.keys()} were passed.')
159+
160+
predictions = self.policy_tree.predict(features)
161+
162+
return features.assign(pred_treatment=predictions.astype(int))

doubleml/tests/_utils_pt_manual.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import numpy as np
2+
from sklearn.tree import DecisionTreeClassifier
3+
4+
5+
def fit_policytree(orth_signal, features, depth):
6+
policytree_model = DecisionTreeClassifier(max_depth=depth,
7+
ccp_alpha=.01,
8+
min_samples_leaf=8).fit(X=features,
9+
y=(np.sign(orth_signal) + 1) / 2,
10+
sample_weight=np.abs(orth_signal))
11+
12+
return policytree_model

doubleml/tests/test_doubleml_exceptions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,3 +1356,49 @@ def eval_fct(y_pred, y_true):
13561356
return np.nan
13571357
with pytest.raises(ValueError, match=msg):
13581358
dml_irm_obj.evaluate_learners(metric=eval_fct)
1359+
1360+
1361+
@pytest.mark.ci
1362+
def test_doubleml_exception_policytree():
1363+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
1364+
ml_g=Lasso(),
1365+
ml_m=LogisticRegression(),
1366+
trimming_threshold=0.05,
1367+
n_folds=5)
1368+
dml_irm_obj.fit()
1369+
1370+
msg = "Covariates must be of DataFrame type. Covariates of type <class 'int'> was passed."
1371+
with pytest.raises(TypeError, match=msg):
1372+
dml_irm_obj.policy_tree(features=2)
1373+
msg = "Depth must be larger or equal to 0. -1 was passed."
1374+
with pytest.raises(ValueError, match=msg):
1375+
dml_irm_obj.policy_tree(features=pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3))),
1376+
depth=-1)
1377+
msg = "Depth must be an integer. 0.1 of type <class 'float'> was passed."
1378+
with pytest.raises(TypeError, match=msg):
1379+
dml_irm_obj.policy_tree(features=pd.DataFrame(np.random.normal(0, 1, size=(dml_data_irm.n_obs, 3))),
1380+
depth=.1)
1381+
1382+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
1383+
ml_g=Lasso(),
1384+
ml_m=LogisticRegression(),
1385+
trimming_threshold=0.05,
1386+
n_folds=5,
1387+
score='ATTE')
1388+
dml_irm_obj.fit()
1389+
1390+
msg = 'Invalid score ATTE. Valid score ATE.'
1391+
with pytest.raises(ValueError, match=msg):
1392+
dml_irm_obj.policy_tree(features=2, depth=1)
1393+
1394+
dml_irm_obj = DoubleMLIRM(dml_data_irm,
1395+
ml_g=Lasso(),
1396+
ml_m=LogisticRegression(),
1397+
trimming_threshold=0.05,
1398+
n_folds=5,
1399+
score='ATE',
1400+
n_rep=2)
1401+
dml_irm_obj.fit()
1402+
msg = 'Only implemented for one repetition. Number of repetitions is 2.'
1403+
with pytest.raises(NotImplementedError, match=msg):
1404+
dml_irm_obj.policy_tree(features=2, depth=1)

doubleml/tests/test_doubleml_model_defaults.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
dml_lpq.bootstrap()
5555
dml_qte.bootstrap()
5656

57+
policy_tree = dml_irm.policy_tree(features=dml_data_irm.data.drop(columns=["y", "d"]))
58+
5759

5860
def _assert_resampling_default_settings(dml_obj):
5961
assert dml_obj.n_folds == 5
@@ -188,3 +190,10 @@ def test_sensitivity_defaults():
188190

189191
dml_plr.sensitivity_analysis()
190192
assert dml_plr._sensitivity_params['input'] == input_dict
193+
194+
195+
@pytest.mark.ci
196+
def test_policytree_defaults():
197+
assert policy_tree.policy_tree.max_depth == 2
198+
assert policy_tree.policy_tree.min_samples_leaf == 8
199+
assert policy_tree.policy_tree.ccp_alpha == 0.01

doubleml/tests/test_doubleml_return_types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import plotly
55

66
from doubleml import DoubleMLPLR, DoubleMLIRM, DoubleMLIIVM, DoubleMLPLIV, DoubleMLData, DoubleMLClusterData, \
7-
DoubleMLCVAR, DoubleMLPQ, DoubleMLLPQ, DoubleMLDID, DoubleMLDIDCS
7+
DoubleMLCVAR, DoubleMLPQ, DoubleMLLPQ, DoubleMLDID, DoubleMLDIDCS, DoubleMLPolicyTree
88
from doubleml.datasets import make_plr_CCDDHNR2018, make_irm_data, make_pliv_CHS2015, make_iivm_data,\
99
make_pliv_multiway_cluster_CKMS2021, make_did_SZ2020
1010

@@ -395,3 +395,13 @@ def test_sensitivity():
395395
assert isinstance(did_cs_dml1._calc_robustness_value(null_hypothesis=0.0, level=0.95, rho=1.0, idx_treatment=0), tuple)
396396
did_cs_benchmark = did_cs_dml1.sensitivity_benchmark(benchmarking_set=['Z1'])
397397
assert isinstance(did_cs_benchmark, pd.DataFrame)
398+
399+
400+
@pytest.mark.ci
401+
def test_policytree():
402+
features = dml_data_irm.data.drop(columns=["y", "d"])
403+
policy_tree = dml_irm.policy_tree(features, depth=1)
404+
assert isinstance(policy_tree, DoubleMLPolicyTree)
405+
assert isinstance(policy_tree.plot_tree(), list)
406+
predict_features = pd.DataFrame(np.random.normal(size=(5, 20)), columns=features.keys())
407+
assert isinstance(policy_tree.predict(predict_features), pd.DataFrame)

0 commit comments

Comments
 (0)