ehrapy.tools.s_learner#
- ehrapy.tools.s_learner(edata, treatment, outcome, *, covariates, outcome_model='auto', key_added=None, layer=None)[source]#
Single-model (S-learner) CATE estimator.
Fits one outcome model
μ(T, X)on all data, then predictsτ(X_i) = μ(1, X_i) − μ(0, X_i). Tends to regularise the treatment effect toward zero when the base learner is heavily regularised, so consider a flexible base learner if you suspect heterogeneity.- Parameters:
edata (
EHRData) – Central data object.treatment (
str) – Column name of the binary (0/1) treatment variable.outcome (
str) – Column name of the outcome variable.covariates (
Sequence[str]) – Adjustment set used for the outcome model. Each entry must refer to a name inedata.var_namesoredata.obs.columns.outcome_model (
str|BaseEstimator, default:'auto') – Outcome model specification (seeg_computation()for the accepted values).key_added (
str|None, default:None) – Optionaledata.obscolumn name into which the per-observation CATE vector is written. Observations dropped during NaN filtering are filled withNaN.layer (
str|None, default:None) – Layer ofedatato draw the var-side variables from. IfNone,edata.Xis used.
- Return type:
- Returns:
A
CausalEstimatewhosevalueis the average CATE and whoseparams['cate']is the per-observation CATE vector.
Examples
>>> import ehrapy as ep >>> import ehrdata as ed >>> edata = ed.dt.mimic_2_preprocessed() >>> est = ep.tl.s_learner( ... edata, ... "aline_flg", ... "day_28_flg", ... covariates=["age", "sofa_first", "sapsi_first"], ... ) >>> print(f"average CATE: {est.value:+.4f} (n={est.n})") average CATE: -0.0216 (n=1776)