Source code for ehrapy.tools.causal._dowhy

from __future__ import annotations

import sys
import warnings
from io import StringIO
from typing import TYPE_CHECKING, Any, Literal

import dowhy
import networkx as nx
import numpy as np
from lamin_utils import logger

if TYPE_CHECKING:
    import anndata

warnings.filterwarnings("ignore")


class capture_output(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self

    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio  # free up some memory
        sys.stdout = self._stdout


[docs] def causal_inference( adata: anndata.AnnData, graph: nx.DiGraph | str, treatment: str, outcome: str, estimation_method: Literal[ "backdoor.propensity_score_matching", "backdoor.propensity_score_stratification", "backdoor.propensity_score_weighting", "backdoor.linear_regression", "backdoor.generalized_linear_model", "iv.instrumental_variable", "iv.regression_discontinuity", "backdoor.econml.linear_model.LinearDML", "backdoor.econml.nonparametric_model.NonParamDML", "backdoor.econml.causal_forest.CausalForestDML", "backdoor.econml.forecast_model.ForestDML", "backdoor.econml.dml.DML", "backdoor.econml.dml.DMLCate", "backdoor.econml.xgboost.XGBTRegressor", "backdoor.econml.xgboost.XGBTEstimator", "backdoor.econml.metalearners.XLearner", ], refute_methods: None | list[str] | ( list[ Literal[ "placebo_treatment_refuter", "random_common_cause", "data_subset_refuter", "add_unobserved_common_cause" ] ] ) = None, print_causal_estimate: bool = False, print_summary: bool = True, return_as: Literal["estimate", "refute", "estimate+refute"] = "estimate", show_graph: bool = False, show_refute_plots: bool | Literal["colormesh", "contour", "line"] | None = None, attempts: int = 10, *, identify_kwargs: dict[str, Any] | None = None, estimate_kwargs: dict[str, Any] | None = None, refute_kwargs: dict[str, Any] | None = None, ) -> tuple[dowhy.CausalEstimate, dict[str, str | dict[str, float]]]: """ Performs causal inference on an AnnData object using the specified causal model and returns a tuple containing the causal estimate and the results of any refutation tests. Args: adata: An AnnData object containing the input data. graph: A str representing the causal graph to use. treatment: A str representing the treatment variable in the causal graph. outcome: A str representing the outcome variable in the causal graph. estimation_method: An optional Literal specifying the estimation method to use. Defaults to "backdoor.propensity_score_stratification". refute_methods: An optional List of Literal specifying the methods to use for refutation tests. Defaults to ["placebo_treatment_refuter", "random_common_cause", "data_subset_refuter"]. print_causal_estimate: Whether to print the causal estimate or not, default is False. print_summary: Whether to print the causal model summary or not, default is True. return_as: An optional Literal specifying the type of output to return. Defaults to "summary". show_graph: Whether to display the graph or not, default is False. show_refute_plots: Whether to display the refutation plots or not, default is False. attempts: Number of attempts to try to generate a valid causal estimate, default is 10. identify_kwargs: Optional keyword arguments for dowhy.CausalModel.identify_effect(). estimate_kwargs: Optional keyword arguments for dowhy.CausalModel.estimate_effect(). refute_kwargs: Optional keyword arguments for dowhy.CausalModel.refute_estimate(). Returns: A tuple containing the causal estimate and a dictionary of the results of any refutation tests. Raises: TypeError: If adata, graph, treatment, outcome, refute_methods, estimation_method, or return_as is not of the expected type. ValueError: If refute_methods or estimation_method contains an unknown value, or if return_as is an unknown value. Examples: >>> data = dowhy.datasets.linear_dataset( ... beta=10, ... num_common_causes=5, ... num_instruments=2, ... num_samples=1000, ... treatment_is_binary=True, ... ) >>> ci = ep.tl.causal_inference( ... adata=anndata.AnnData(data["df"]), ... graph=data["gml_graph"], ... treatment="v0", ... outcome="y", ... estimation_method="backdoor.propensity_score_stratification", ... ) >>> estimate = ep.tl.causal_inference( ... adata=ci.linear_data, ... graph=ci.linear_graph, ... treatment="treatment", ... outcome="outcome", ... estimation_method="backdoor.linear_regression", ... return_as="estimate", ... show_graph=True, ... show_refute_plots=True, ... ) ... ep.tl.plot_causal_effect(estimate) """ if not isinstance(graph, (nx.DiGraph, str)): raise TypeError("Input graph must be a networkx DiGraph or string.") valid_refute_methods = [ "placebo_treatment_refuter", "random_common_cause", "data_subset_refuter", "add_unobserved_common_cause", ] if refute_methods is None: refute_methods = valid_refute_methods if isinstance(refute_methods, str): refute_methods = [refute_methods] if isinstance(refute_methods, list): if not all(isinstance(rm, str) for rm in refute_methods): raise TypeError("When parameter 'refute_methods' is a list, all of them must be strings.") for method in refute_methods: if method not in valid_refute_methods: raise ValueError(f"Unknown refute method {method}") if return_as not in ["estimate", "refute", "estimate+refute"]: raise ValueError(f"Unknown value for return_as '{return_as}': {return_as}") identify_kwargs = identify_kwargs or {} estimate_kwargs = estimate_kwargs or {} refute_kwargs = refute_kwargs or {} if show_refute_plots is None or show_refute_plots is False: refute_kwargs["plotmethod"] = None elif isinstance(show_refute_plots, str): refute_kwargs["plotmethod"] = show_refute_plots elif show_refute_plots is True: refute_kwargs["plotmethod"] = "colormesh" user_gave_num_simulations = "num_simulations" in refute_kwargs user_gave_random_seed = "random_state" in refute_kwargs found_problematic_pvalues = True model = dowhy.CausalModel(data=adata.to_df(), graph=graph, treatment=treatment, outcome=outcome) if show_graph: model.view_model() # For some reason, dowhy sometimes fails to calculate a pval # and spits out NaN or values greater than 1. In that case we just try again. failed_attempts = 0 while found_problematic_pvalues: if not user_gave_num_simulations: refute_kwargs["num_simulations"] = np.random.randint(70, 90) if not user_gave_random_seed: refute_kwargs["random_seed"] = np.random.randint(0, 100) identified_estimand = model.identify_effect(**identify_kwargs) # otherwise prints estimation_method with capture_output() as _: # input validation since `dowhy` does not do it if "." not in estimation_method: raise ValueError(f"Estimation method '{estimation_method}' not supported.") else: if len(estimation_method.split(".")) > 2: if not any(["dowhy" in estimation_method, "_estimator" in estimation_method]): raise ValueError(f"Estimation method '{estimation_method}' not supported.") estimate = model.estimate_effect(identified_estimand, method_name=estimation_method, **estimate_kwargs) refute_results: dict[str, str | dict[str, str]] = {} for method in refute_methods: try: with capture_output() as _: refute = model.refute_estimate( identified_estimand, estimate, method_name=method, verbose=False, **refute_kwargs ) refute_failed = False except ValueError as e: refute_failed = True refute_results[method] = str(e) # type: ignore if refute_failed: logger.warning(f"Refutation '{method}' failed.") else: # only returns dict when pval should be a number if isinstance(refute.refutation_result, dict): if 0 <= refute.refutation_result["p_value"] <= 1: found_problematic_pvalues = False else: failed_attempts += 1 if failed_attempts <= attempts: found_problematic_pvalues = True logger.warning( f"Refutation '{method}' returned invalid pval '{str(refute.refutation_result['p_value'])}', retrying ({failed_attempts}/{attempts})" ) break else: found_problematic_pvalues = False else: found_problematic_pvalues = False if not refute_failed: test_significance = refute.estimated_effect # Try to extract pval, fails for "add_unobserved_common_cause" refuter try: pval = f"{refute.refutation_result['p_value']:.3f}" except TypeError: pval = "Not applicable" # Format effect, can be list when refuter is "add_unobserved_common_cause" if isinstance(refute.new_effect, (list, tuple)): new_effect = ", ".join([str(np.round(x, 2)) for x in refute.new_effect]) else: new_effect = f"{refute.new_effect:.3f}" refute_results[str(refute.refutation_type)] = { "Estimated effect": refute.estimated_effect, "New effect": new_effect, "p-value": pval, "test_significance": test_significance, } # Create the summary string summary = f"Causal inference results for treatment variable '{treatment}' and outcome variable '{outcome}':\n" with capture_output() as output: estimate.interpret(method_name="textual_effect_interpreter") if output is not None: summary += f"└- {''.join(output)}\n" else: summary += f"└- Estimated effect: {estimate.value}\n" summary += "\nRefutation results\n" for idx, (method, results) in enumerate(refute_results.items()): # type: ignore left_char = "|" if (idx + 1) != len(refute_results.keys()) else " " branch_char = "├" if (idx + 1) != len(refute_results.keys()) else "└" if isinstance(results, str): summary += f"├-Refute: {method}\n" summary += f"{left_char} └- {results}\n" else: summary += f"{branch_char}-{method}\n" summary += f"{left_char} ├- Estimated effect: {results['Estimated effect']:.2f}\n" summary += f"{left_char} ├- New effect: {results['New effect']}\n" summary += f"{left_char} ├- p-value: {results['p-value']}\n" summary += f"{left_char} └- Test significance: {results['test_significance']:.2f}\n" if print_causal_estimate: print(estimate) if print_summary: print(summary) if return_as == "estimate": return estimate elif return_as == "refute": return refute_results # type: ignore elif return_as == "estimate+refute": return estimate, refute_results # type: ignore else: raise ValueError(f"Invalid return_as argument: {return_as}")