"""Average treatment effect (ATE) estimators for binary treatments.
All estimators in this module accept an :class:`~ehrdata.EHRData` object and return a :class:`~ehrapy.tools.CausalEstimate`.
They share a common interface: a ``treatment`` and ``outcome`` column name plus a list of ``covariates`` (the adjustment set) that may come from either ``edata.var_names`` or ``edata.obs.columns``.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import array_api_compat
import numpy as np
from ehrapy._compat import function_2D_only
from ehrapy.tools.causal._design import assert_binary_treatment, build_design
from ehrapy.tools.causal._estimate import CausalEstimate
from ehrapy.tools.causal._models import fit_propensity, predict_mean, resolve_outcome_model
if TYPE_CHECKING:
from collections.abc import Sequence
from ehrdata import EHRData
from sklearn.base import BaseEstimator
_DEFAULT_CLIP = (0.01, 0.99)
def _bootstrap_ate(
estimator_fn,
*,
n: int,
n_bootstrap: int,
random_state: int | None,
) -> tuple[float, float, float]:
"""Bootstrap an estimator function to obtain (SE, ci_lower, ci_upper).
``estimator_fn`` is a callable that takes a boolean index array of length ``n`` and returns a scalar ATE.
"""
rng = np.random.default_rng(random_state)
values = np.empty(n_bootstrap, dtype=float)
for b in range(n_bootstrap):
idx = rng.integers(0, n, size=n)
values[b] = estimator_fn(idx)
se = float(np.std(values, ddof=1))
ci_lower, ci_upper = (float(x) for x in np.quantile(values, [0.025, 0.975]))
return se, ci_lower, ci_upper
[docs]
@function_2D_only()
def iptw(
edata: EHRData,
treatment: str,
outcome: str,
*,
covariates: Sequence[str],
propensity_model: str | BaseEstimator = "logistic",
stabilized: bool = True,
clip: tuple[float, float] | None = _DEFAULT_CLIP,
n_bootstrap: int = 200,
random_state: int | None = None,
layer: str | None = None,
) -> CausalEstimate:
"""Estimate the average treatment effect by inverse probability of treatment weighting (IPTW).
Fits a propensity model ``e(X) = P(T=1 | X)`` and forms weights ``w_i = T_i / e_i + (1-T_i) / (1-e_i)``.
With ``stabilized=True`` the weights are multiplied by the marginal treatment probabilities, which typically reduces variance with negligible bias.
The ATE is the difference of weighted means of ``Y`` between treated and untreated groups.
Args:
edata: Central data object.
treatment: Column name of the binary (0/1) treatment variable.
outcome: Column name of the outcome variable.
covariates: Adjustment set used to fit the propensity model.
Each entry must refer to a name in ``edata.var_names`` or ``edata.obs.columns``.
propensity_model: Specification of the propensity model.
Accepts one of the strings ``'logistic'``, ``'gradient_boosting'``, ``'random_forest'``, or any sklearn-compatible classifier (it must implement ``predict_proba``).
stabilized: Whether to use stabilized weights instead of the basic inverse-probability weights.
clip: ``(lo, hi)`` propensity-score clipping range applied before forming weights.
Use ``None`` to disable clipping.
n_bootstrap: Number of bootstrap resamples used for the SE and 95% percentile confidence interval.
Set to ``0`` to skip uncertainty estimation.
random_state: Seed for the bootstrap resampler.
layer: Layer of ``edata`` to draw the var-side variables from.
If ``None``, ``edata.X`` is used.
Returns:
A :class:`~ehrapy.tools.CausalEstimate` whose ``params`` dict contains the fitted ``propensity_scores`` and the IPTW ``weights``.
Examples:
>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.iptw(
... edata,
... "aline_flg",
... "day_28_flg",
... covariates=["age", "sofa_first", "sapsi_first"],
... random_state=0,
... )
>>> print(est.summary())
Causal effect of 'aline_flg' on 'day_28_flg'
method: iptw_stabilized
ATE: -0.0644
SE: 0.0332
95% CI: [-0.1313, -0.0089]
n: 1776
"""
design = build_design(edata, treatment=treatment, outcome=outcome, covariates=covariates, layer=layer)
assert_binary_treatment(design.T, treatment)
ps, _ = fit_propensity(propensity_model, design.X, design.T, clip=clip)
weights = _iptw_weights(design.T, ps, stabilized=stabilized)
ate = _weighted_diff_in_means(design.Y, design.T, weights)
se: float | None = None
ci_lower: float | None = None
ci_upper: float | None = None
if n_bootstrap > 0:
def _refit(idx: np.ndarray) -> float:
X_b, T_b, Y_b = design.X[idx], design.T[idx], design.Y[idx]
if len(np.unique(T_b)) < 2:
return np.nan
ps_b, _ = fit_propensity(propensity_model, X_b, T_b, clip=clip)
w_b = _iptw_weights(T_b, ps_b, stabilized=stabilized)
return _weighted_diff_in_means(Y_b, T_b, w_b)
se, ci_lower, ci_upper = _bootstrap_ate(
_refit, n=len(design.T), n_bootstrap=n_bootstrap, random_state=random_state
)
return CausalEstimate(
method="iptw" + ("_stabilized" if stabilized else ""),
treatment=treatment,
outcome=outcome,
value=float(ate),
se=se,
ci_lower=ci_lower,
ci_upper=ci_upper,
n=int(len(design.T)),
params={
"propensity_scores": ps,
"weights": weights,
"index": design.index,
"feature_names": design.feature_names,
},
)
[docs]
@function_2D_only()
def g_computation(
edata: EHRData,
treatment: str,
outcome: str,
*,
covariates: Sequence[str],
outcome_model: str | BaseEstimator = "auto",
n_bootstrap: int = 200,
random_state: int | None = None,
layer: str | None = None,
) -> CausalEstimate:
"""Estimate the ATE by parametric g-computation (a.k.a. the g-formula or standardisation).
Fits an outcome model ``μ(T, X)`` on the observed data, then predicts counterfactual outcomes by setting ``T`` to 1 and 0 for every row.
The ATE is ``mean(μ(1, X)) − mean(μ(0, X))``.
Args:
edata: Central data object.
treatment: Column name of the binary (0/1) treatment variable.
outcome: Column name of the outcome variable.
covariates: Adjustment set used to fit the outcome model.
Each entry must refer to a name in ``edata.var_names`` or ``edata.obs.columns``.
outcome_model: Specification of the outcome model.
Accepts one of the strings ``'auto'``, ``'linear'``, ``'logistic'``, ``'gradient_boosting'``, ``'random_forest'``, or any sklearn-compatible regressor/classifier.
``'auto'`` picks logistic regression when the outcome is binary 0/1 and linear regression otherwise.
n_bootstrap: Number of bootstrap resamples used for the SE and 95% percentile confidence interval.
Set to ``0`` to skip uncertainty estimation.
random_state: Seed for the bootstrap resampler.
layer: Layer of ``edata`` to draw the var-side variables from.
If ``None``, ``edata.X`` is used.
Returns:
A :class:`~ehrapy.tools.CausalEstimate` whose ``params`` dict contains the counterfactual predictions ``mu1`` and ``mu0``.
Examples:
>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.g_computation(
... edata,
... "aline_flg",
... "day_28_flg",
... covariates=["age", "sofa_first", "sapsi_first"],
... random_state=0,
... )
>>> print(est.summary())
Causal effect of 'aline_flg' on 'day_28_flg'
method: g_computation
ATE: -0.0216
SE: 0.0174
95% CI: [-0.0541, 0.0127]
n: 1776
"""
design = build_design(edata, treatment=treatment, outcome=outcome, covariates=covariates, layer=layer)
assert_binary_treatment(design.T, treatment)
mu1, mu0 = _g_predict(design.X, design.T, design.Y, outcome_model)
ate = float(np.mean(mu1) - np.mean(mu0))
se: float | None = None
ci_lower: float | None = None
ci_upper: float | None = None
if n_bootstrap > 0:
def _refit(idx: np.ndarray) -> float:
X_b, T_b, Y_b = design.X[idx], design.T[idx], design.Y[idx]
if len(np.unique(T_b)) < 2:
return np.nan
mu1_b, mu0_b = _g_predict(X_b, T_b, Y_b, outcome_model)
return float(np.mean(mu1_b) - np.mean(mu0_b))
se, ci_lower, ci_upper = _bootstrap_ate(
_refit, n=len(design.T), n_bootstrap=n_bootstrap, random_state=random_state
)
return CausalEstimate(
method="g_computation",
treatment=treatment,
outcome=outcome,
value=ate,
se=se,
ci_lower=ci_lower,
ci_upper=ci_upper,
n=int(len(design.T)),
params={"mu1": mu1, "mu0": mu0, "index": design.index, "feature_names": design.feature_names},
)
[docs]
@function_2D_only()
def aipw(
edata: EHRData,
treatment: str,
outcome: str,
*,
covariates: Sequence[str],
propensity_model: str | BaseEstimator = "logistic",
outcome_model: str | BaseEstimator = "auto",
clip: tuple[float, float] | None = _DEFAULT_CLIP,
n_bootstrap: int = 0,
random_state: int | None = None,
layer: str | None = None,
) -> CausalEstimate:
"""Estimate the ATE by the augmented inverse-probability-weighted (AIPW) doubly robust estimator.
AIPW is consistent if *either* the propensity model *or* the outcome model is correctly specified.
The point estimate is the mean of the influence function::
ψ_i = μ_1(X_i) − μ_0(X_i)
+ (T_i / e_i) (Y_i − μ_1(X_i))
− ((1 − T_i) / (1 − e_i)) (Y_i − μ_0(X_i))
By default the standard error is computed analytically from the empirical variance of ψ; setting ``n_bootstrap > 0`` switches to a bootstrap SE/CI instead.
Args:
edata: Central data object.
treatment: Column name of the binary (0/1) treatment variable.
outcome: Column name of the outcome variable.
covariates: Adjustment set used for both nuisance models.
Each entry must refer to a name in ``edata.var_names`` or ``edata.obs.columns``.
propensity_model: Propensity model specification (see :func:`iptw` for the accepted values).
outcome_model: Outcome model specification (see :func:`g_computation` for the accepted values).
clip: ``(lo, hi)`` propensity-score clipping range applied before forming the influence function.
Use ``None`` to disable clipping.
n_bootstrap: If positive, use a bootstrap SE/CI instead of the analytic influence-function SE.
Set to ``0`` (the default) to use the influence-function SE.
random_state: Seed for the bootstrap resampler.
layer: Layer of ``edata`` to draw the var-side variables from.
If ``None``, ``edata.X`` is used.
Returns:
A :class:`~ehrapy.tools.CausalEstimate` whose ``params`` dict contains ``propensity_scores``, ``mu1``, ``mu0``, and the per-observation ``influence`` values.
Examples:
>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.aipw(
... edata,
... "aline_flg",
... "day_28_flg",
... covariates=["age", "sofa_first", "sapsi_first"],
... )
>>> print(est.summary())
Causal effect of 'aline_flg' on 'day_28_flg'
method: aipw
ATE: -0.0349
SE: 0.0365
95% CI: [-0.1065, 0.0367]
n: 1776
"""
design = build_design(edata, treatment=treatment, outcome=outcome, covariates=covariates, layer=layer)
assert_binary_treatment(design.T, treatment)
ps, _ = fit_propensity(propensity_model, design.X, design.T, clip=clip)
mu1, mu0 = _g_predict(design.X, design.T, design.Y, outcome_model)
psi = _aipw_influence(design.T, design.Y, ps, mu1, mu0)
ate = float(np.mean(psi))
se: float | None
ci_lower: float | None
ci_upper: float | None
if n_bootstrap > 0:
def _refit(idx: np.ndarray) -> float:
X_b, T_b, Y_b = design.X[idx], design.T[idx], design.Y[idx]
if len(np.unique(T_b)) < 2:
return np.nan
ps_b, _ = fit_propensity(propensity_model, X_b, T_b, clip=clip)
mu1_b, mu0_b = _g_predict(X_b, T_b, Y_b, outcome_model)
psi_b = _aipw_influence(T_b, Y_b, ps_b, mu1_b, mu0_b)
return float(np.mean(psi_b))
se, ci_lower, ci_upper = _bootstrap_ate(
_refit, n=len(design.T), n_bootstrap=n_bootstrap, random_state=random_state
)
else:
se = float(np.std(psi, ddof=1) / np.sqrt(len(psi)))
ci_lower = float(ate - 1.96 * se)
ci_upper = float(ate + 1.96 * se)
return CausalEstimate(
method="aipw",
treatment=treatment,
outcome=outcome,
value=ate,
se=se,
ci_lower=ci_lower,
ci_upper=ci_upper,
n=int(len(design.T)),
params={
"propensity_scores": ps,
"mu1": mu1,
"mu0": mu0,
"influence": psi,
"index": design.index,
"feature_names": design.feature_names,
},
)
[docs]
@function_2D_only()
def propensity_score_matching(
edata: EHRData,
treatment: str,
outcome: str,
*,
covariates: Sequence[str],
propensity_model: str | BaseEstimator = "logistic",
k: int = 1,
caliper: float | None = 0.2,
replacement: bool = True,
target: str = "att",
n_bootstrap: int = 200,
random_state: int | None = None,
layer: str | None = None,
) -> CausalEstimate:
"""Estimate the treatment effect by 1-to-:math:`k` propensity score matching on the logit scale.
For each treated unit, the :math:`k` nearest control units in logit-propensity space are selected as matches (and vice versa when ``target='ate'``).
With ``caliper`` set, candidate matches with logit-propensity distance above ``caliper * SD(logit(e))`` are discarded; treated units with no valid match are dropped from the estimate.
Args:
edata: Central data object.
treatment: Column name of the binary (0/1) treatment variable.
outcome: Column name of the outcome variable.
covariates: Adjustment set used to fit the propensity model.
Each entry must refer to a name in ``edata.var_names`` or ``edata.obs.columns``.
propensity_model: Propensity model specification (see :func:`iptw` for the accepted values).
k: Number of matches per unit.
caliper: Maximum logit-propensity distance for a valid match, in units of ``SD(logit(e))``.
Use ``None`` to disable the caliper.
replacement: Whether matching is performed with replacement.
target: ``'att'`` for the average treatment effect on the treated, or ``'ate'`` for the average treatment effect.
n_bootstrap: Number of bootstrap resamples used for the SE and 95% percentile confidence interval.
Set to ``0`` to skip uncertainty estimation.
random_state: Seed for the bootstrap resampler.
layer: Layer of ``edata`` to draw the var-side variables from.
If ``None``, ``edata.X`` is used.
Returns:
A :class:`~ehrapy.tools.CausalEstimate` whose ``params`` dict contains the propensity scores and the matched-pair indices.
Examples:
>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.propensity_score_matching(
... edata,
... "aline_flg",
... "day_28_flg",
... covariates=["age", "sofa_first", "sapsi_first"],
... random_state=0,
... )
>>> print(est.summary())
Causal effect of 'aline_flg' on 'day_28_flg'
method: propensity_score_matching_att
ATE: -0.0511
SE: 0.0337
95% CI: [-0.1209, 0.0051]
n: 1776
"""
if target not in {"att", "ate"}:
raise ValueError(f"target must be 'att' or 'ate'; got {target!r}.")
design = build_design(edata, treatment=treatment, outcome=outcome, covariates=covariates, layer=layer)
assert_binary_treatment(design.T, treatment)
ps, _ = fit_propensity(propensity_model, design.X, design.T, clip=(1e-6, 1 - 1e-6))
ate, match_info = _ps_match_effect(
design.T, design.Y, ps, k=k, caliper=caliper, replacement=replacement, target=target
)
se: float | None = None
ci_lower: float | None = None
ci_upper: float | None = None
if n_bootstrap > 0:
def _refit(idx: np.ndarray) -> float:
X_b, T_b, Y_b = design.X[idx], design.T[idx], design.Y[idx]
if len(np.unique(T_b)) < 2:
return np.nan
ps_b, _ = fit_propensity(propensity_model, X_b, T_b, clip=(1e-6, 1 - 1e-6))
ate_b, _ = _ps_match_effect(T_b, Y_b, ps_b, k=k, caliper=caliper, replacement=replacement, target=target)
return float(ate_b)
se, ci_lower, ci_upper = _bootstrap_ate(
_refit, n=len(design.T), n_bootstrap=n_bootstrap, random_state=random_state
)
return CausalEstimate(
method=f"propensity_score_matching_{target}",
treatment=treatment,
outcome=outcome,
value=float(ate),
se=se,
ci_lower=ci_lower,
ci_upper=ci_upper,
n=int(len(design.T)),
params={
"propensity_scores": ps,
"matches": match_info,
"index": design.index,
"feature_names": design.feature_names,
},
)
def _iptw_weights(T, ps, *, stabilized: bool):
"""Compute IPTW weights from a treatment vector and propensity scores."""
xp = array_api_compat.array_namespace(T, ps)
if stabilized:
p_t = xp.mean(T)
return T * p_t / ps + (1 - T) * (1 - p_t) / (1 - ps)
return T / ps + (1 - T) / (1 - ps)
def _weighted_diff_in_means(Y, T, w) -> float:
"""Weighted difference in means (Hájek estimator)."""
xp = array_api_compat.array_namespace(Y, T, w)
treated = T == 1
untreated = ~treated
mu1 = xp.sum(w[treated] * Y[treated]) / xp.sum(w[treated])
mu0 = xp.sum(w[untreated] * Y[untreated]) / xp.sum(w[untreated])
return float(mu1 - mu0)
def _aipw_influence(T, Y, ps, mu1, mu0):
"""Compute the AIPW influence-function values; backend-agnostic."""
return mu1 - mu0 + (T / ps) * (Y - mu1) - ((1 - T) / (1 - ps)) * (Y - mu0)
def _g_predict(X: np.ndarray, T: np.ndarray, Y: np.ndarray, outcome_model_spec) -> tuple[np.ndarray, np.ndarray]:
"""Fit μ(T, X) and return (μ(1, X), μ(0, X)) for every row of X.
sklearn currently mandates numpy at the fit boundary, so this helper materialises to numpy.
"""
model = resolve_outcome_model(outcome_model_spec, y=Y)
XT = np.column_stack([T, X])
model.fit(XT, Y if not hasattr(model, "predict_proba") else Y.astype(int))
X1 = np.column_stack([np.ones_like(T), X])
X0 = np.column_stack([np.zeros_like(T), X])
return predict_mean(model, X1), predict_mean(model, X0)
def _logit(p):
"""Logit transform; backend-agnostic."""
xp = array_api_compat.array_namespace(p)
p = xp.clip(p, 1e-12, 1 - 1e-12)
return xp.log(p / (1 - p))
def _ps_match_effect(
T: np.ndarray,
Y: np.ndarray,
ps: np.ndarray,
*,
k: int,
caliper: float | None,
replacement: bool,
target: str,
) -> tuple[float, dict]:
"""Compute the matched effect and return (ATE, match diagnostics)."""
logit_ps = _logit(ps)
sd = float(np.std(logit_ps, ddof=1))
caliper_dist = caliper * sd if caliper is not None else None
treated_idx = np.flatnonzero(T == 1)
control_idx = np.flatnonzero(T == 0)
contribs: list[float] = []
used_treated: list[int] = []
used_control: list[int] = []
matched_pairs: list[tuple[int, list[int]]] = []
dropped = 0
def _match_one(src: int, pool: np.ndarray) -> list[int] | None:
dist = np.abs(logit_ps[pool] - logit_ps[src])
order = np.argsort(dist)
picks: list[int] = []
for j in order:
cand = int(pool[j])
if caliper_dist is not None and dist[j] > caliper_dist:
break
if (not replacement) and cand in picks:
continue
picks.append(cand)
if len(picks) == k:
break
return picks if picks else None
available_controls = list(control_idx)
for i in treated_idx:
pool = np.array(available_controls, dtype=int) if not replacement else control_idx
if len(pool) == 0:
dropped += 1
continue
matches = _match_one(int(i), pool)
if matches is None:
dropped += 1
continue
contribs.append(float(Y[i] - np.mean(Y[matches])))
used_treated.append(int(i))
used_control.extend(matches)
matched_pairs.append((int(i), matches))
if not replacement:
available_controls = [c for c in available_controls if c not in matches]
if target == "att":
ate = float(np.mean(contribs)) if contribs else float("nan")
else:
# Also match controls to treated for the ATE
available_treated = list(treated_idx)
for j in control_idx:
pool = np.array(available_treated, dtype=int) if not replacement else treated_idx
if len(pool) == 0:
dropped += 1
continue
matches = _match_one(int(j), pool)
if matches is None:
dropped += 1
continue
contribs.append(float(np.mean(Y[matches]) - Y[j]))
used_control.append(int(j))
used_treated.extend(matches)
matched_pairs.append((int(j), matches))
if not replacement:
available_treated = [t for t in available_treated if t not in matches]
ate = float(np.mean(contribs)) if contribs else float("nan")
return ate, {
"n_matched_pairs": len(matched_pairs),
"n_dropped": dropped,
"pairs": matched_pairs,
"caliper_distance": caliper_dist,
}