ehrapy.tools.x_learner

Contents

ehrapy.tools.x_learner#

ehrapy.tools.x_learner(edata, treatment, outcome, *, covariates, outcome_model='auto', propensity_model='logistic', cate_model='auto', clip=(0.01, 0.99), key_added=None, layer=None)[source]#

X-learner CATE estimator of Künzel et al. (2019).

First fits group-specific outcome models μ_0 and μ_1 (like the T-learner), then imputes individual treatment effects on each group’s own units:

D_1 = Y_1 − μ_0(X_1),   D_0 = μ_1(X_0) − Y_0

Two CATE models τ_0 and τ_1 are fitted on these imputed effects and combined as τ(x) = g(x) τ_0(x) + (1 g(x)) τ_1(x) where g is the propensity score. More efficient than the T-learner when treatment groups are imbalanced.

Parameters:
  • edata (EHRData) – Central data object.

  • treatment (str) – Column name of the binary (0/1) treatment variable.

  • outcome (str) – Column name of the outcome variable.

  • covariates (Sequence[str]) – Adjustment set used for both the outcome and propensity models. Each entry must refer to a name in edata.var_names or edata.obs.columns.

  • outcome_model (str | BaseEstimator, default: 'auto') – Outcome model specification for the first-stage μ models (see g_computation() for the accepted values).

  • propensity_model (str | BaseEstimator, default: 'logistic') – Propensity model specification (see iptw() for the accepted values).

  • cate_model (str | BaseEstimator, default: 'auto') – Regressor used for the second-stage τ models. Accepts 'auto'/'linear'/'gradient_boosting'/'random_forest' or any sklearn-compatible regressor. 'auto' resolves to linear regression. Classifiers are rejected because the imputed effects are continuous.

  • clip (tuple[float, float] | None, default: (0.01, 0.99)) – (lo, hi) propensity-score clipping range for the combination weight g. Use None to disable clipping.

  • key_added (str | None, default: None) – Optional edata.obs column name into which the per-observation CATE vector is written. Observations dropped during NaN filtering are filled with NaN.

  • layer (str | None, default: None) – Layer of edata to draw the var-side variables from. If None, edata.X is used.

Return type:

CausalEstimate

Returns:

A CausalEstimate whose value is the average CATE and whose params['cate'] is the per-observation CATE vector.

Examples

>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.x_learner(
...     edata,
...     "aline_flg",
...     "day_28_flg",
...     covariates=["age", "sofa_first", "sapsi_first"],
... )
>>> print(f"average CATE: {est.value:+.4f}  (n={est.n})")
average CATE: -0.0237  (n=1776)