ehrapy.tools.t_learner

Contents

ehrapy.tools.t_learner#

ehrapy.tools.t_learner(edata, treatment, outcome, *, covariates, outcome_model='auto', key_added=None, layer=None)[source]#

Two-model (T-learner) CATE estimator.

Fits a separate outcome model on the treated subset (μ_1) and on the untreated subset (μ_0), then computes τ(X_i) = μ_1(X_i) μ_0(X_i) for every row. Simple but statistically inefficient when the two groups are imbalanced.

Parameters:
  • edata (EHRData) – Central data object.

  • treatment (str) – Column name of the binary (0/1) treatment variable.

  • outcome (str) – Column name of the outcome variable.

  • covariates (Sequence[str]) – Adjustment set used for the outcome models. Each entry must refer to a name in edata.var_names or edata.obs.columns.

  • outcome_model (str | BaseEstimator, default: 'auto') – Outcome model specification (see g_computation() for the accepted values).

  • key_added (str | None, default: None) – Optional edata.obs column name into which the per-observation CATE vector is written. Observations dropped during NaN filtering are filled with NaN.

  • layer (str | None, default: None) – Layer of edata to draw the var-side variables from. If None, edata.X is used.

Return type:

CausalEstimate

Returns:

A CausalEstimate whose value is the average CATE and whose params['cate'] is the per-observation CATE vector.

Examples

>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.t_learner(
...     edata,
...     "aline_flg",
...     "day_28_flg",
...     covariates=["age", "sofa_first", "sapsi_first"],
... )
>>> print(f"average CATE: {est.value:+.4f}  (n={est.n})")
average CATE: -0.0256  (n=1776)