ehrapy.tools.g_computation

Contents

ehrapy.tools.g_computation#

ehrapy.tools.g_computation(edata, treatment, outcome, *, covariates, outcome_model='auto', n_bootstrap=200, random_state=None, layer=None)[source]#

Estimate the ATE by parametric g-computation (a.k.a. the g-formula or standardisation).

Fits an outcome model μ(T, X) on the observed data, then predicts counterfactual outcomes by setting T to 1 and 0 for every row. The ATE is mean(μ(1, X)) mean(μ(0, X)).

Parameters:
  • edata (EHRData) – Central data object.

  • treatment (str) – Column name of the binary (0/1) treatment variable.

  • outcome (str) – Column name of the outcome variable.

  • covariates (Sequence[str]) – Adjustment set used to fit the outcome model. Each entry must refer to a name in edata.var_names or edata.obs.columns.

  • outcome_model (str | BaseEstimator, default: 'auto') – Specification of the outcome model. Accepts one of the strings 'auto', 'linear', 'logistic', 'gradient_boosting', 'random_forest', or any sklearn-compatible regressor/classifier. 'auto' picks logistic regression when the outcome is binary 0/1 and linear regression otherwise.

  • n_bootstrap (int, default: 200) – Number of bootstrap resamples used for the SE and 95% percentile confidence interval. Set to 0 to skip uncertainty estimation.

  • random_state (int | None, default: None) – Seed for the bootstrap resampler.

  • layer (str | None, default: None) – Layer of edata to draw the var-side variables from. If None, edata.X is used.

Return type:

CausalEstimate

Returns:

A CausalEstimate whose params dict contains the counterfactual predictions mu1 and mu0.

Examples

>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.mimic_2_preprocessed()
>>> est = ep.tl.g_computation(
...     edata,
...     "aline_flg",
...     "day_28_flg",
...     covariates=["age", "sofa_first", "sapsi_first"],
...     random_state=0,
... )
>>> print(est.summary())
Causal effect of 'aline_flg' on 'day_28_flg'
  method: g_computation
  ATE:    -0.0216
  SE:     0.0174
  95% CI: [-0.0541, 0.0127]
  n:      1776