from __future__ import annotations
import sys
import warnings
from io import StringIO
from typing import TYPE_CHECKING, Any, Literal
import dowhy
import networkx as nx
import numpy as np
from lamin_utils import logger
if TYPE_CHECKING:
import anndata
warnings.filterwarnings("ignore")
class capture_output(list):
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = StringIO()
return self
def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout
[docs]
def causal_inference(
adata: anndata.AnnData,
graph: nx.DiGraph | str,
treatment: str,
outcome: str,
estimation_method: Literal[
"backdoor.propensity_score_matching",
"backdoor.propensity_score_stratification",
"backdoor.propensity_score_weighting",
"backdoor.linear_regression",
"backdoor.generalized_linear_model",
"iv.instrumental_variable",
"iv.regression_discontinuity",
"backdoor.econml.linear_model.LinearDML",
"backdoor.econml.nonparametric_model.NonParamDML",
"backdoor.econml.causal_forest.CausalForestDML",
"backdoor.econml.forecast_model.ForestDML",
"backdoor.econml.dml.DML",
"backdoor.econml.dml.DMLCate",
"backdoor.econml.xgboost.XGBTRegressor",
"backdoor.econml.xgboost.XGBTEstimator",
"backdoor.econml.metalearners.XLearner",
],
refute_methods: None
| list[str]
| (
list[
Literal[
"placebo_treatment_refuter", "random_common_cause", "data_subset_refuter", "add_unobserved_common_cause"
]
]
) = None,
print_causal_estimate: bool = False,
print_summary: bool = True,
return_as: Literal["estimate", "refute", "estimate+refute"] = "estimate",
show_graph: bool = False,
show_refute_plots: bool | Literal["colormesh", "contour", "line"] | None = None,
attempts: int = 10,
*,
identify_kwargs: dict[str, Any] | None = None,
estimate_kwargs: dict[str, Any] | None = None,
refute_kwargs: dict[str, Any] | None = None,
) -> tuple[dowhy.CausalEstimate, dict[str, str | dict[str, float]]]:
"""
Performs causal inference on an AnnData object using the specified causal model and returns a tuple containing the causal estimate and the results of any refutation tests.
Args:
adata: An AnnData object containing the input data.
graph: A str representing the causal graph to use.
treatment: A str representing the treatment variable in the causal graph.
outcome: A str representing the outcome variable in the causal graph.
estimation_method: An optional Literal specifying the estimation method to use. Defaults to "backdoor.propensity_score_stratification".
refute_methods: An optional List of Literal specifying the methods to use for refutation tests. Defaults to ["placebo_treatment_refuter", "random_common_cause", "data_subset_refuter"].
print_causal_estimate: Whether to print the causal estimate or not, default is False.
print_summary: Whether to print the causal model summary or not, default is True.
return_as: An optional Literal specifying the type of output to return. Defaults to "summary".
show_graph: Whether to display the graph or not, default is False.
show_refute_plots: Whether to display the refutation plots or not, default is False.
attempts: Number of attempts to try to generate a valid causal estimate, default is 10.
identify_kwargs: Optional keyword arguments for dowhy.CausalModel.identify_effect().
estimate_kwargs: Optional keyword arguments for dowhy.CausalModel.estimate_effect().
refute_kwargs: Optional keyword arguments for dowhy.CausalModel.refute_estimate().
Returns:
A tuple containing the causal estimate and a dictionary of the results of any refutation tests.
Raises:
TypeError: If adata, graph, treatment, outcome, refute_methods, estimation_method, or return_as is not of the expected type.
ValueError: If refute_methods or estimation_method contains an unknown value, or if return_as is an unknown value.
Examples:
>>> data = dowhy.datasets.linear_dataset(
... beta=10,
... num_common_causes=5,
... num_instruments=2,
... num_samples=1000,
... treatment_is_binary=True,
... )
>>> ci = ep.tl.causal_inference(
... adata=anndata.AnnData(data["df"]),
... graph=data["gml_graph"],
... treatment="v0",
... outcome="y",
... estimation_method="backdoor.propensity_score_stratification",
... )
>>> estimate = ep.tl.causal_inference(
... adata=ci.linear_data,
... graph=ci.linear_graph,
... treatment="treatment",
... outcome="outcome",
... estimation_method="backdoor.linear_regression",
... return_as="estimate",
... show_graph=True,
... show_refute_plots=True,
... )
... ep.tl.plot_causal_effect(estimate)
"""
if not isinstance(graph, (nx.DiGraph, str)):
raise TypeError("Input graph must be a networkx DiGraph or string.")
valid_refute_methods = [
"placebo_treatment_refuter",
"random_common_cause",
"data_subset_refuter",
"add_unobserved_common_cause",
]
if refute_methods is None:
refute_methods = valid_refute_methods
if isinstance(refute_methods, str):
refute_methods = [refute_methods]
if isinstance(refute_methods, list):
if not all(isinstance(rm, str) for rm in refute_methods):
raise TypeError("When parameter 'refute_methods' is a list, all of them must be strings.")
for method in refute_methods:
if method not in valid_refute_methods:
raise ValueError(f"Unknown refute method {method}")
if return_as not in ["estimate", "refute", "estimate+refute"]:
raise ValueError(f"Unknown value for return_as '{return_as}': {return_as}")
identify_kwargs = identify_kwargs or {}
estimate_kwargs = estimate_kwargs or {}
refute_kwargs = refute_kwargs or {}
if show_refute_plots is None or show_refute_plots is False:
refute_kwargs["plotmethod"] = None
elif isinstance(show_refute_plots, str):
refute_kwargs["plotmethod"] = show_refute_plots
elif show_refute_plots is True:
refute_kwargs["plotmethod"] = "colormesh"
user_gave_num_simulations = "num_simulations" in refute_kwargs
user_gave_random_seed = "random_state" in refute_kwargs
found_problematic_pvalues = True
model = dowhy.CausalModel(data=adata.to_df(), graph=graph, treatment=treatment, outcome=outcome)
if show_graph:
model.view_model()
# For some reason, dowhy sometimes fails to calculate a pval
# and spits out NaN or values greater than 1. In that case we just try again.
failed_attempts = 0
while found_problematic_pvalues:
if not user_gave_num_simulations:
refute_kwargs["num_simulations"] = np.random.randint(70, 90)
if not user_gave_random_seed:
refute_kwargs["random_seed"] = np.random.randint(0, 100)
identified_estimand = model.identify_effect(**identify_kwargs)
# otherwise prints estimation_method
with capture_output() as _:
# input validation since `dowhy` does not do it
if "." not in estimation_method:
raise ValueError(f"Estimation method '{estimation_method}' not supported.")
else:
if len(estimation_method.split(".")) > 2:
if not any(["dowhy" in estimation_method, "_estimator" in estimation_method]):
raise ValueError(f"Estimation method '{estimation_method}' not supported.")
estimate = model.estimate_effect(identified_estimand, method_name=estimation_method, **estimate_kwargs)
refute_results: dict[str, str | dict[str, str]] = {}
for method in refute_methods:
try:
with capture_output() as _:
refute = model.refute_estimate(
identified_estimand, estimate, method_name=method, verbose=False, **refute_kwargs
)
refute_failed = False
except ValueError as e:
refute_failed = True
refute_results[method] = str(e) # type: ignore
if refute_failed:
logger.warning(f"Refutation '{method}' failed.")
else:
# only returns dict when pval should be a number
if isinstance(refute.refutation_result, dict):
if 0 <= refute.refutation_result["p_value"] <= 1:
found_problematic_pvalues = False
else:
failed_attempts += 1
if failed_attempts <= attempts:
found_problematic_pvalues = True
logger.warning(
f"Refutation '{method}' returned invalid pval '{str(refute.refutation_result['p_value'])}', retrying ({failed_attempts}/{attempts})"
)
break
else:
found_problematic_pvalues = False
else:
found_problematic_pvalues = False
if not refute_failed:
test_significance = refute.estimated_effect
# Try to extract pval, fails for "add_unobserved_common_cause" refuter
try:
pval = f"{refute.refutation_result['p_value']:.3f}"
except TypeError:
pval = "Not applicable"
# Format effect, can be list when refuter is "add_unobserved_common_cause"
if isinstance(refute.new_effect, (list, tuple)):
new_effect = ", ".join([str(np.round(x, 2)) for x in refute.new_effect])
else:
new_effect = f"{refute.new_effect:.3f}"
refute_results[str(refute.refutation_type)] = {
"Estimated effect": refute.estimated_effect,
"New effect": new_effect,
"p-value": pval,
"test_significance": test_significance,
}
# Create the summary string
summary = f"Causal inference results for treatment variable '{treatment}' and outcome variable '{outcome}':\n"
with capture_output() as output:
estimate.interpret(method_name="textual_effect_interpreter")
if output is not None:
summary += f"└- {''.join(output)}\n"
else:
summary += f"└- Estimated effect: {estimate.value}\n"
summary += "\nRefutation results\n"
for idx, (method, results) in enumerate(refute_results.items()): # type: ignore
left_char = "|" if (idx + 1) != len(refute_results.keys()) else " "
branch_char = "├" if (idx + 1) != len(refute_results.keys()) else "└"
if isinstance(results, str):
summary += f"├-Refute: {method}\n"
summary += f"{left_char} └- {results}\n"
else:
summary += f"{branch_char}-{method}\n"
summary += f"{left_char} ├- Estimated effect: {results['Estimated effect']:.2f}\n"
summary += f"{left_char} ├- New effect: {results['New effect']}\n"
summary += f"{left_char} ├- p-value: {results['p-value']}\n"
summary += f"{left_char} └- Test significance: {results['test_significance']:.2f}\n"
if print_causal_estimate:
print(estimate)
if print_summary:
print(summary)
if return_as == "estimate":
return estimate
elif return_as == "refute":
return refute_results # type: ignore
elif return_as == "estimate+refute":
return estimate, refute_results # type: ignore
else:
raise ValueError(f"Invalid return_as argument: {return_as}")