ehrapy.tools.s_learner

Contents

ehrapy.tools.s_learner#

ehrapy.tools.s_learner(edata, treatment, outcome, *, covariates, outcome_model='auto', key_added=None, layer=None)[source]#

Single-model (S-learner) CATE estimator.

Fits one outcome model μ(T, X) on all data, then predicts τ(X_i) = μ(1, X_i) μ(0, X_i). Tends to regularise the treatment effect toward zero when the base learner is heavily regularised, so consider a flexible base learner if you suspect heterogeneity.

Parameters:
  • edata (EHRData) – Central data object.

  • treatment (str) – Column name of the binary (0/1) treatment variable.

  • outcome (str) – Column name of the outcome variable.

  • covariates (Sequence[str]) – Adjustment set used for the outcome model. Each entry must refer to a name in edata.var_names or edata.obs.columns.

  • outcome_model (str | BaseEstimator, default: 'auto') – Outcome model specification (see g_computation() for the accepted values).

  • key_added (str | None, default: None) – Optional edata.obs column name into which the per-observation CATE vector is written. Observations dropped during NaN filtering are filled with NaN.

  • layer (str | None, default: None) – Layer of edata to draw the var-side variables from. If None, edata.X is used.

Return type:

CausalEstimate

Returns:

A CausalEstimate whose value is the average CATE and whose params['cate'] is the per-observation CATE vector.

Examples

>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.s_learner(
...     edata,
...     "aline_flg",
...     "day_28_flg",
...     covariates=["age", "sofa_first", "sapsi_first"],
... )
>>> print(f"average CATE: {est.value:+.4f}  (n={est.n})")
average CATE: -0.0216  (n=1776)