ehrapy.tools.x_learner#
- ehrapy.tools.x_learner(edata, treatment, outcome, *, covariates, outcome_model='auto', propensity_model='logistic', cate_model='auto', clip=(0.01, 0.99), key_added=None, layer=None)[source]#
X-learner CATE estimator of Künzel et al. (2019).
First fits group-specific outcome models
μ_0andμ_1(like the T-learner), then imputes individual treatment effects on each group’s own units:D_1 = Y_1 − μ_0(X_1), D_0 = μ_1(X_0) − Y_0
Two CATE models
τ_0andτ_1are fitted on these imputed effects and combined asτ(x) = g(x) τ_0(x) + (1 − g(x)) τ_1(x)wheregis the propensity score. More efficient than the T-learner when treatment groups are imbalanced.- Parameters:
edata (
EHRData) – Central data object.treatment (
str) – Column name of the binary (0/1) treatment variable.outcome (
str) – Column name of the outcome variable.covariates (
Sequence[str]) – Adjustment set used for both the outcome and propensity models. Each entry must refer to a name inedata.var_namesoredata.obs.columns.outcome_model (
str|BaseEstimator, default:'auto') – Outcome model specification for the first-stageμmodels (seeg_computation()for the accepted values).propensity_model (
str|BaseEstimator, default:'logistic') – Propensity model specification (seeiptw()for the accepted values).cate_model (
str|BaseEstimator, default:'auto') – Regressor used for the second-stageτmodels. Accepts'auto'/'linear'/'gradient_boosting'/'random_forest'or any sklearn-compatible regressor.'auto'resolves to linear regression. Classifiers are rejected because the imputed effects are continuous.clip (
tuple[float,float] |None, default:(0.01, 0.99)) –(lo, hi)propensity-score clipping range for the combination weightg. UseNoneto disable clipping.key_added (
str|None, default:None) – Optionaledata.obscolumn name into which the per-observation CATE vector is written. Observations dropped during NaN filtering are filled withNaN.layer (
str|None, default:None) – Layer ofedatato draw the var-side variables from. IfNone,edata.Xis used.
- Return type:
- Returns:
A
CausalEstimatewhosevalueis the average CATE and whoseparams['cate']is the per-observation CATE vector.
Examples
>>> import ehrapy as ep >>> import ehrdata as ed >>> edata = ed.dt.mimic_2_preprocessed() >>> est = ep.tl.x_learner( ... edata, ... "aline_flg", ... "day_28_flg", ... covariates=["age", "sofa_first", "sapsi_first"], ... ) >>> print(f"average CATE: {est.value:+.4f} (n={est.n})") average CATE: -0.0237 (n=1776)