import dowhy
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
[docs]def causal_effect(estimate: dowhy.causal_estimator.CausalEstimate, precision: int = 3) -> matplotlib.axes:
"""Plot the causal effect estimate.
This function plots the causal effect of treatment on outcome, assuming a
linear relationship between the two. It uses the data, treatment name,
outcome name, and estimate object to determine the data to plot. It then
creates a plot with the treatment on the x-axis and the outcome on the
y-axis. The observed data is plotted as gray dots, and the causal variation
is plotted as a black line. The function then returns the plot.
Args:
estimate: The causal effect estimate to plot.
precision: The number of decimal places to round the estimate to in the plot title.
Defaults to 3.
Returns:
matplotlib.axes.Axes: The matplotlib Axes object containing the plot.
Raises:
TypeError: If the `estimate` parameter is not an instance of `dowhy.causal_estimator.CausalEstimate`.
ValueError: If the estimation method in `estimate` is not supported for this plot type.
"""
if not isinstance(estimate, dowhy.causal_estimator.CausalEstimate):
raise TypeError("Parameter 'estimate' must be a dowhy.causal_estimator.CausalEstimate object")
if "LinearRegressionEstimator" not in str(estimate.params["estimator_class"]):
raise ValueError(f"Estimation method {estimate.params['estimator_class']} is not supported for this plot type.")
treatment_name = estimate.estimator._target_estimand.treatment_variable[0]
outcome_name = estimate.estimator._target_estimand.outcome_variable[0]
data = estimate._data
treatment = data[treatment_name].values
outcome = data[outcome_name]
_, ax = plt.subplots()
x_min = 0
x_max = max(treatment)
if isinstance(x_max, np.ndarray) and len(x_max) == 1:
x_max = x_max[0]
y_min = estimate.params["intercept"]
y_max = y_min + estimate.value * (x_max - x_min)
if isinstance(y_max, np.ndarray) and len(y_max) == 1:
y_max = y_max[0]
ax.scatter(treatment, outcome, c="gray", marker="o", label="Observed data")
ax.plot([x_min, x_max], [y_min, y_max], c="black", ls="solid", lw=4, label="Causal variation")
ax.set_ylim(0, max(outcome))
ax.set_xlim(0, x_max)
ax.set_title(r"DoWhy estimate $\rho$ (slope) = " + str(round(estimate.value, precision)))
ax.legend(loc="upper left")
ax.set_xlabel(treatment_name)
ax.set_ylabel(outcome_name)
plt.tight_layout()
return ax