ehrapy.tools.aipw

Contents

ehrapy.tools.aipw#

ehrapy.tools.aipw(edata, treatment, outcome, *, covariates, propensity_model='logistic', outcome_model='auto', clip=(0.01, 0.99), n_bootstrap=0, random_state=None, layer=None)[source]#

Estimate the ATE by the augmented inverse-probability-weighted (AIPW) doubly robust estimator.

AIPW is consistent if either the propensity model or the outcome model is correctly specified. The point estimate is the mean of the influence function:

ψ_i = μ_1(X_i) − μ_0(X_i)
      + (T_i / e_i) (Y_i − μ_1(X_i))
      − ((1 − T_i) / (1 − e_i)) (Y_i − μ_0(X_i))

By default the standard error is computed analytically from the empirical variance of ψ; setting n_bootstrap > 0 switches to a bootstrap SE/CI instead.

Parameters:
  • edata (EHRData) – Central data object.

  • treatment (str) – Column name of the binary (0/1) treatment variable.

  • outcome (str) – Column name of the outcome variable.

  • covariates (Sequence[str]) – Adjustment set used for both nuisance models. Each entry must refer to a name in edata.var_names or edata.obs.columns.

  • propensity_model (str | BaseEstimator, default: 'logistic') – Propensity model specification (see iptw() for the accepted values).

  • outcome_model (str | BaseEstimator, default: 'auto') – Outcome model specification (see g_computation() for the accepted values).

  • clip (tuple[float, float] | None, default: (0.01, 0.99)) – (lo, hi) propensity-score clipping range applied before forming the influence function. Use None to disable clipping.

  • n_bootstrap (int, default: 0) – If positive, use a bootstrap SE/CI instead of the analytic influence-function SE. Set to 0 (the default) to use the influence-function SE.

  • random_state (int | None, default: None) – Seed for the bootstrap resampler.

  • layer (str | None, default: None) – Layer of edata to draw the var-side variables from. If None, edata.X is used.

Return type:

CausalEstimate

Returns:

A CausalEstimate whose params dict contains propensity_scores, mu1, mu0, and the per-observation influence values.

Examples

>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.aipw(
...     edata,
...     "aline_flg",
...     "day_28_flg",
...     covariates=["age", "sofa_first", "sapsi_first"],
... )
>>> print(est.summary())
Causal effect of 'aline_flg' on 'day_28_flg'
  method: aipw
  ATE:    -0.0349
  SE:     0.0365
  95% CI: [-0.1065, 0.0367]
  n:      1776