ehrapy.tools.t_learner#
- ehrapy.tools.t_learner(edata, treatment, outcome, *, covariates, outcome_model='auto', key_added=None, layer=None)[source]#
Two-model (T-learner) CATE estimator.
Fits a separate outcome model on the treated subset (
μ_1) and on the untreated subset (μ_0), then computesτ(X_i) = μ_1(X_i) − μ_0(X_i)for every row. Simple but statistically inefficient when the two groups are imbalanced.- Parameters:
edata (
EHRData) – Central data object.treatment (
str) – Column name of the binary (0/1) treatment variable.outcome (
str) – Column name of the outcome variable.covariates (
Sequence[str]) – Adjustment set used for the outcome models. Each entry must refer to a name inedata.var_namesoredata.obs.columns.outcome_model (
str|BaseEstimator, default:'auto') – Outcome model specification (seeg_computation()for the accepted values).key_added (
str|None, default:None) – Optionaledata.obscolumn name into which the per-observation CATE vector is written. Observations dropped during NaN filtering are filled withNaN.layer (
str|None, default:None) – Layer ofedatato draw the var-side variables from. IfNone,edata.Xis used.
- Return type:
- Returns:
A
CausalEstimatewhosevalueis the average CATE and whoseparams['cate']is the per-observation CATE vector.
Examples
>>> import ehrapy as ep >>> import ehrdata as ed >>> edata = ed.dt.mimic_2_preprocessed() >>> est = ep.tl.t_learner( ... edata, ... "aline_flg", ... "day_28_flg", ... covariates=["age", "sofa_first", "sapsi_first"], ... ) >>> print(f"average CATE: {est.value:+.4f} (n={est.n})") average CATE: -0.0256 (n=1776)