Heterogeneous treatment effects on MIMIC-II#
An average treatment effect smooths over patient-level variation: it says ‘on average, A-line placement is associated with a 3% drop in 28-day mortality’ without telling us which patients drive that average. Clinically, the much more useful question is — for which patients does the line actually help, and for which does it not? A continuous monitoring tool that helps sepsis patients but provides no benefit (and some risk) in mild cases would have an ATE that is the patient-weighted average of two qualitatively different effects. If the heterogeneity is real, decision-makers want the conditional average treatment effect (CATE) — the per-patient effect.
ehrapy ships three meta-learners for CATE estimation. All three work on top of any sklearn-compatible regressor; differences between them are mostly statistical (how efficiently they use the data) rather than conceptual. In this notebook we use the same MIMIC-II A-line setup as before and look at how the per-patient estimated effect varies with severity at admission.
import ehrapy as ep
import ehrdata as ed
import holoviews as hv
import numpy as np
import pandas as pd
hv.extension("bokeh")
edata = ed.dt.mimic_2_preprocessed()
treatment = "aline_flg"
outcome = "day_28_flg"
covariates = ["age", "gender_num", "sapsi_first", "sofa_first", "weight_first", "bmi", "icu_los_day", "sepsis_flg"]
Choosing a base learner#
Meta-learners are wrappers; the actual heavy lifting is done by whatever sklearn regressor or classifier you plug in. The choice matters more than it usually does in supervised learning: heterogeneous effects live in the interactions between treatment and covariates, and a base learner that cannot represent those interactions will systematically report a flat CATE.
Why we use a tree-based base learner for HTE
Linear regression regularises the additive part of the model but not its interaction structure.
If you fit a vanilla LinearRegression outcome model and difference its predictions at T = 0 and T = 1, the result is constant in X — you’ll get a flat ‘CATE’ even when one truly exists.
Tree-based learners (gradient boosting, random forest) pick up non-linear treatment-by-covariate interactions naturally, which is what you want when the question is whether the effect varies.
If your base learner is linear, the S-learner is the worst offender (it puts treatment as just another column); the T- and X-learners at least separate the two arms, but their within-arm models still need to capture the relevant non-linearities.
T-learner#
The simplest meta-learner: fit one outcome model on the treated subset, another on the untreated subset, and difference their predictions for every row. Conceptually clean, but statistically inefficient when one arm is much smaller than the other — the smaller-arm model is data-starved exactly where you most need precision.
from sklearn.ensemble import GradientBoostingClassifier
est_t = ep.tl.t_learner(
edata,
treatment,
outcome,
covariates=covariates,
outcome_model=GradientBoostingClassifier(),
key_added="cate_t",
)
print(f"average CATE (T-learner): {est_t.value:+.4f}")
average CATE (T-learner): -0.0106
The T-learner’s average CATE should be close to the ATE we estimated previously (small protective effect, around −0.03).
The interesting object is not that average but the full vector est_t.params["cate"], which contains the model’s best guess at every patient’s individual effect.
The summary value just tells us we haven’t drifted off the average; the heterogeneity is where the action is.
X-learner#
Künzel et al.’s X-learner addresses the T-learner’s imbalance weakness by using each arm’s outcome model to impute counterfactuals on the other arm, then fitting second-stage CATE models on those imputed contrasts, and finally combining the two CATE estimates with propensity weights. More machinery, but worth it when treatment arms are imbalanced — which is exactly the EHR observational setup.
est_x = ep.tl.x_learner(
edata,
treatment,
outcome,
covariates=covariates,
outcome_model=GradientBoostingClassifier(),
cate_model="gradient_boosting",
key_added="cate_x",
)
print(f"average CATE (X-learner): {est_x.value:+.4f}")
average CATE (X-learner): -0.0266
The two meta-learners’ average CATEs should agree to within a fraction of a percentage point. If they disagreed substantially, we’d worry about either the propensity model or the base outcome model — the X-learner relies on more nuisance models than the T-learner does, so its bias profile is different. Their per-patient predictions, however, can differ visibly. The X-learner usually produces smoother CATE surfaces because the imputation step regularises away noise in regions of poor overlap.
How heterogeneous is the effect?#
Histogram the per-patient CATEs from both learners. A flat (delta-function-like) histogram around the ATE means there’s no meaningful heterogeneity; a wide histogram with patients on both sides of zero means subgroup-specific decision-making is potentially actionable.
counts_t, edges_t = np.histogram(est_t.params["cate"], bins=40, density=True)
counts_x, edges_x = np.histogram(est_x.params["cate"], bins=40, density=True)
hist_t = hv.Histogram((edges_t, counts_t), kdims="CATE", vdims="density", label="T-learner").opts(
fill_color="#1f77b4", alpha=0.5, line_alpha=0
)
hist_x = hv.Histogram((edges_x, counts_x), kdims="CATE", vdims="density", label="X-learner").opts(
fill_color="#d62728", alpha=0.5, line_alpha=0
)
zero = hv.VLine(0).opts(color="black", line_width=1, line_dash="dashed")
(hist_t * hist_x * zero).opts(
width=560,
height=320,
xlabel="Per-patient CATE: P(death | line) − P(death | no line)",
ylabel="Density",
show_legend=True,
)
Most patients sit close to a small negative CATE, but the distributions have visible width — some patients have estimated CATEs well below zero (apparent benefit) while others have CATEs near zero or slightly positive (no benefit, possibly some harm). That spread is the signal we wanted to surface: the ATE we estimated in the previous notebook is the patient-weighted average of two different stories.
Does the effect track severity?#
If the heterogeneity is clinically meaningful, it should align with variables a clinician would already think about — severity at admission being the most obvious one. Stratifying the per-patient CATE by SOFA quartile tells us whether the ‘benefit’ is concentrated in sicker patients.
cate = est_x.params["cate"]
df = ed.io.to_pandas(edata)
sofa_bins = pd.qcut(
df.loc[est_x.params["index"], "sofa_first"],
q=4,
labels=["Q1 (least sick)", "Q2", "Q3", "Q4 (sickest)"],
)
pd.DataFrame({"sofa_quartile": sofa_bins.values, "cate": cate}).groupby(
"sofa_quartile", observed=True
)["cate"].agg(["mean", "std", "count"])
| mean | std | count | |
|---|---|---|---|
| sofa_quartile | |||
| Q1 (least sick) | -0.011394 | 0.085449 | 524 |
| Q2 | -0.021327 | 0.115993 | 641 |
| Q3 | -0.020115 | 0.114909 | 246 |
| Q4 (sickest) | -0.062130 | 0.166724 | 365 |
Read the table top-to-bottom. If the mean CATE becomes more negative as you move from Q1 to Q4, the apparent benefit of A-line placement is concentrated in the sicker tail of the cohort — which is what we’d biologically expect (continuous BP monitoring matters most for unstable patients). If Q4 is also where the standard deviation is largest, that’s a sign the model is genuinely uncertain about those patients, and a corresponding subgroup analysis would need to acknowledge that uncertainty.
Important caveats when reading a CATE table
These are exploratory subgroup estimates produced from a single observational dataset. Any specific point estimate — e.g. ‘Q4 benefits by 8%’ — should be treated as a hypothesis, not a clinical claim. Per-patient CATEs are also model-dependent: a different base learner, a different propensity model, or a different adjustment set can shift them visibly, especially in subgroups with sparse data. If a subgroup-specific decision were actually being contemplated — e.g. an A-line guideline that depends on SOFA — it would need to be confirmed in a pre-registered analysis or, ideally, an RCT stratified on the relevant covariate.
Where the CATE is stored#
Each learner wrote its per-patient CATE into edata.obs under the name passed to key_added, with NaNs for patients dropped during NaN filtering.
From there you can combine the CATE with the rest of ehrapy — colour an existing UMAP by predicted benefit, feed it into rank_features_groups to characterise the high-benefit cluster, or feed it into a downstream fairness audit.
edata.obs[["cate_t", "cate_x"]].describe()
| cate_t | cate_x | |
|---|---|---|
| count | 1776.000000 | 1776.000000 |
| mean | -0.010576 | -0.026614 |
| std | 0.179542 | 0.121939 |
| min | -0.851102 | -0.959581 |
| 25% | -0.033686 | -0.052261 |
| 50% | 0.004264 | -0.013188 |
| 75% | 0.022109 | 0.024762 |
| max | 0.718165 | 0.524273 |