from __future__ import annotations
from typing import TYPE_CHECKING
import holoviews as hv
import numpy as np
import pandas as pd
from bokeh.palettes import Category10
from numpy import ndarray
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing import Any
from xmlrpc.client import Boolean
from ehrdata import EHRData
from lifelines import KaplanMeierFitter
from statsmodels.regression.linear_model import RegressionResults
[docs]
def ols(
edata: EHRData | None = None,
*,
x: str | None = None,
y: str | None = None,
scatter_plot: Boolean | None = True,
ols_results: list[RegressionResults] | None = None,
ols_color: list[str | None] | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
width: int | None = 600,
height: int | None = 400,
lines: list[tuple[ndarray | float, ndarray | float]] | None = None,
lines_color: list[str | None] | None = None,
lines_style: list[str | None] | None = None,
lines_label: list[str | None] | None = None,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
title: str | None = None,
**kwds,
) -> hv.Scatter | hv.Curve | hv.Overlay | None:
"""Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot.
Args:
edata: Central data object.
x: x coordinate, for scatter plotting.
y: y coordinate, for scatter plotting.
scatter_plot: Whether to show a scatter plot.
ols_results: List of RegressionResults from ehrapy.tl.ols.
ols_color: List of colors for each ols_results.
xlabel: The x-axis label text.
ylabel: The y-axis label text.
width: Plot width in pixels.
height: Plot height in pixels.
lines: List of Tuples of (slope, intercept) or (x, y). Plot lines by slope and intercept or data points.
Example: plot two lines (y = x + 2 and y = 2*x + 1): [(1, 2), (2, 1)]
lines_color: List of colors for each line.
lines_style: List of line styles for each line.
lines_label: List of line labels for each line.
xlim: Set the x-axis view limits. Required for only plotting lines using slope and intercept.
ylim: Set the y-axis view limits. Required for only plotting lines using slope and intercept.
title: Set the title of the plot.
**kwds: Passed to HoloViews Scatter element.
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> co2_lm_result = ep.tl.ols(
... edata, var_names=["pco2_first", "tco2_first"], formula="tco2_first ~ pco2_first", missing="drop"
... ).fit()
>>> ep.pl.ols(
... edata,
... x="pco2_first",
... y="tco2_first",
... ols_results=[co2_lm_result],
... ols_color=["red"],
... xlabel="PCO2",
... ylabel="TCO2",
... )
.. image:: /_static/docstring_previews/ols_plot.png
"""
if ols_color is None and ols_results is not None:
ols_color = [None] * len(ols_results)
if lines_color is None and lines is not None:
lines_color = [None] * len(lines)
if lines_style is None and lines is not None:
lines_style = [None] * len(lines)
if lines_label is None and lines is not None:
lines_label = [None] * len(lines)
plot = None
if edata is not None and x is not None and y is not None:
x_data = np.array(edata[:, x].X).flatten().astype(float)
y_data = np.array(edata[:, y].X).flatten().astype(float)
mask = ~(np.isnan(x_data) | np.isnan(y_data))
x_clean = x_data[mask]
y_clean = y_data[mask]
if scatter_plot:
scatter_opts = {**kwds}
scatter_opts.setdefault("tools", ["hover"])
plot = hv.Scatter((x_clean, y_clean), kdims=x, vdims=y).opts(**scatter_opts)
if ols_results is not None:
x_sorted = np.sort(x_clean)
for i, ols_result in enumerate(ols_results):
y_pred = ols_result.predict(exog={"const": 1, x: x_sorted})
curve_opts: dict[str, Any] = {"tools": ["hover"]}
if ols_color[i] is not None:
curve_opts["color"] = ols_color[i]
ols_curve = hv.Curve((x_sorted, y_pred)).opts(**curve_opts)
plot = ols_curve if plot is None else plot * ols_curve
if lines is not None:
if xlim is None and plot is not None:
x_range = plot.range(x if x else 0)
xlim = (x_range[0], x_range[1]) if x_range[0] is not None else (0, 1)
elif xlim is None:
xlim = (0, 1)
for i, line in enumerate(lines):
a, b = line
if np.ndim(a) == 0 and np.ndim(b) == 0:
line_x = np.array(xlim)
line_y = a * line_x + b
else:
line_x, line_y = a, b
curve_opts = {"tools": ["hover"]}
if lines_color[i] is not None:
curve_opts["color"] = lines_color[i]
if lines_style[i] is not None:
style_map = {"-": "solid", "--": "dashed", ":": "dotted", "-.": "dashdot"}
curve_opts["line_dash"] = style_map.get(lines_style[i], "solid")
curve_kwargs = {"label": lines_label[i]} if lines_label[i] else {}
line_curve = hv.Curve((line_x, line_y), **curve_kwargs).opts(**curve_opts)
plot = line_curve if plot is None else plot * line_curve
if plot is None:
return None
opts_dict: dict[str, Any] = {}
if width is not None:
opts_dict["width"] = width
if height is not None:
opts_dict["height"] = height
if xlabel is not None:
opts_dict["xlabel"] = xlabel
if ylabel is not None:
opts_dict["ylabel"] = ylabel
if title is not None:
opts_dict["title"] = title
if xlim is not None:
opts_dict["xlim"] = xlim
if ylim is not None:
opts_dict["ylim"] = ylim
plot = plot.opts(**opts_dict)
return plot
[docs]
def kaplan_meier(
kmfs: Sequence[KaplanMeierFitter],
*,
display_survival_statistics: bool = False,
ci_alpha: list[float] | None = None,
ci_force_lines: list[Boolean] | None = None,
ci_show: list[Boolean] | None = None,
ci_legend: list[Boolean] | None = None,
at_risk_counts: list[Boolean] | None = None,
color: list[str | None] | None = None,
grid: Boolean | None = False,
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
width: int | None = 600,
height: int | None = 400,
title: str | None = None,
) -> hv.Layout | hv.Overlay | hv.Curve | None:
"""Plots a pretty figure of the Fitted KaplanMeierFitter model.
See also: :class:`~lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter`.
Args:
kmfs: Iterables of fitted KaplanMeierFitter objects.
display_survival_statistics: Whether to show survival statistics in a table below the plot.
ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list.
ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas).
If more than one kmfs, this should be a list.
ci_show: Show confidence intervals. If more than one kmfs, this should be a list.
ci_legend: If ci_force_lines is True, this is a boolean flag to add the lines' labels to the legend.
If more than one kmfs, this should be a list.
at_risk_counts: Show group sizes at time points. If more than one kmfs, this should be a list.
color: List of colors for each kmf. If more than one kmfs, this should be a list.
grid: If True, plot grid lines.
xlim: Set the x-axis view limits.
ylim: Set the y-axis view limits.
xlabel: The x-axis label text.
ylabel: The y-axis label text.
width: Plot width in pixels.
height: Plot height in pixels.
title: Set the title of the plot.
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> import numpy as np
>>> edata = ed.dt.mimic_2()
>>> edata[:, ["censor_flg"]].X = np.where(
... edata[:, ["censor_flg"]].X == 0, 1, 0
... ) # MIMIC-II uses 0=death while KaplanMeierFitter expects True=death
>>> kmf = ep.tl.kaplan_meier(edata, "mort_day_censored", "censor_flg")
>>> ep.pl.kaplan_meier(
... [kmf], color=["r"], xlim=(0, 700), ylim=(0, 1), xlabel="Days", ylabel="Proportion Survived", show=True
... )
.. image:: /_static/docstring_previews/kaplan_meier.png
"""
if ci_alpha is None:
ci_alpha = [0.3] * len(kmfs)
if ci_force_lines is None:
ci_force_lines = [False] * len(kmfs)
if ci_show is None:
ci_show = [True] * len(kmfs)
if ci_legend is None:
ci_legend = [False] * len(kmfs)
if at_risk_counts is None:
at_risk_counts = [False] * len(kmfs)
if color is None:
color = [None] * len(kmfs)
plot = None
for i, kmf in enumerate(kmfs):
sf = kmf.survival_function_
times = sf.index.values
survival = sf.iloc[:, 0].values
label = kmf.label if kmf.label else f"Group {i + 1}"
curve_opts: dict[str, Any] = {"tools": ["hover"]}
if color[i] is not None:
curve_opts["color"] = color[i]
curve = hv.Curve((times, survival), kdims="Time", vdims="Survival", label=label).opts(**curve_opts)
if ci_show[i] and hasattr(kmf, "confidence_interval_survival_function_"):
ci = kmf.confidence_interval_survival_function_
ci_lower = ci.iloc[:, 0].values
ci_upper = ci.iloc[:, 1].values
if ci_force_lines[i]:
ci_lower_curve = hv.Curve((times, ci_lower)).opts(
color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_dash="dashed"
)
ci_upper_curve = hv.Curve((times, ci_upper)).opts(
color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_dash="dashed"
)
curve = curve * ci_lower_curve * ci_upper_curve
else:
ci_area = hv.Area((times, ci_lower, ci_upper), vdims=["y", "y2"]).opts(
color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_width=0
)
curve = ci_area * curve
plot = curve if plot is None else plot * curve
if plot is None:
return None
opts_dict: dict[str, Any] = {"width": width, "show_grid": grid}
if xlabel:
opts_dict["xlabel"] = xlabel
if ylabel:
opts_dict["ylabel"] = ylabel
if title:
opts_dict["title"] = title
if xlim:
opts_dict["xlim"] = xlim
if ylim:
opts_dict["ylim"] = ylim
opts_dict["show_legend"] = True
plot = plot.opts(**opts_dict)
if display_survival_statistics:
if xlim:
time_points = np.linspace(xlim[0], xlim[1], 10)
else:
all_times = np.concatenate([kmf.survival_function_.index.values for kmf in kmfs])
time_points = np.linspace(all_times.min(), all_times.max(), 10)
# Create table data in wide format (one row per group, columns are time points)
table_data: dict[str, list[str]] = {}
table_data["Group"] = []
for kmf in kmfs:
label = kmf.label if kmf.label else "Group"
table_data["Group"].append(label)
survival_probs = kmf.survival_function_at_times(time_points).values
for _, (t, prob) in enumerate(zip(time_points, survival_probs, strict=False)):
col_name = f"{t:.0f}"
if col_name not in table_data:
table_data[col_name] = []
table_data[col_name].append(f"{prob:.2f}")
df = pd.DataFrame(table_data)
table = hv.Table(df).opts(width=width, height=int(height * 0.4), fit_columns=True)
plot = (plot + table).cols(1)
return plot
[docs]
def cox_ph_forestplot(
edata: EHRData,
*,
uns_key: str = "cox_ph",
labels: Iterable[str] | None = None,
width: int = 1200,
height: int = 600,
ecolor: str = "dimgray",
size: int = 3,
marker: str = "o",
decimal: int = 2,
text_size: int = 12,
color: str = "k",
title: str | None = None,
) -> hv.Overlay | None:
"""Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model.
The `edata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function.
This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `edata`.
The summary table is created when the model is fitted using the :func:`~ehrapy.tools.cox_ph` function.
See also: :class:`~lifelines.fitters.coxph_fitter.CoxPHFitter`
Args:
edata: Data object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tools.cox_ph`.
uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`.
labels: List of labels for each coefficient, default uses the index of the summary table.
width: Plot width in pixels.
height: Plot height in pixels.
ecolor: Color of the error bars.
size: Size of the markers.
marker: Marker style.
decimal: Number of decimal places to display.
text_size: Font size of the text.
color: Color of the markers.
title: Set the title of the plot.
Returns:
HoloViews Overlay with forest plot and text annotations.
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()[
... :, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"]
... ]
>>> coxph = ep.tl.cox_ph(edata, event_col="censor_flg", duration_col="mort_day_censored")
>>> ep.pl.cox_ph_forestplot(edata)
.. image:: /_static/docstring_previews/coxph_forestplot.png
"""
if uns_key not in edata.uns:
raise ValueError(f"Key {uns_key} not found in edata.uns. Please provide a valid key.")
coxph_fitting_summary = edata.uns[uns_key]
auc_col = "coef"
if labels is None:
labels = list(coxph_fitting_summary.index)
coefs = coxph_fitting_summary[auc_col].values
lower = coxph_fitting_summary["coef lower 95%"].values
upper = coxph_fitting_summary["coef upper 95%"].values
y_positions = np.arange(len(coxph_fitting_summary))
x_axis_upper_bound = float(pd.to_numeric(coxph_fitting_summary["coef upper 95%"]).max())
x_axis_lower_bound = float(pd.to_numeric(coxph_fitting_summary["coef lower 95%"]).min())
data_range = x_axis_upper_bound - x_axis_lower_bound
plot_padding = data_range * 0.1
text_gap = data_range * 0.15
text_spacing = data_range * 0.4
plot_x_min = x_axis_lower_bound - plot_padding
plot_x_max = x_axis_upper_bound + plot_padding
text_start_x = plot_x_max + text_gap
ci_text_x = text_start_x + text_spacing
total_x_max = ci_text_x + (data_range * 0.5)
error_data = []
for coef, y, lower_val, upper_val in zip(coefs, y_positions, lower, upper, strict=False):
if not np.isnan(coef) and not np.isnan(lower_val) and not np.isnan(upper_val):
error_data.append((coef, y, coef - lower_val, upper_val - coef))
error_bars = hv.ErrorBars(
error_data,
kdims=["Coefficient", "Variable"],
vdims=["negative_error", "positive_error"],
).opts(color=ecolor, line_width=2)
points = hv.Scatter(
[(coef, y) for coef, y in zip(coefs, y_positions, strict=False) if not np.isnan(coef)],
kdims=["Coefficient"],
vdims=["Variable"],
).opts(color=color, size=size * 5, marker=marker, tools=["hover"])
vline = hv.VLine(1).opts(color="gray", line_width=1)
text_labels = []
for coef_val, low_val, upp_val, y_pos in zip(coefs, lower, upper, y_positions, strict=False):
if not np.isnan(coef_val):
if isinstance(coef_val, float) and isinstance(low_val, float) and isinstance(upp_val, float):
coef_text = f"{coef_val:.{decimal}f}"
ci_text = f"({low_val:.{decimal}f}, {upp_val:.{decimal}f})"
else:
coef_text = str(coef_val)
ci_text = f"({low_val}, {upp_val})"
text_labels.append((text_start_x, y_pos, coef_text))
text_labels.append((ci_text_x, y_pos, ci_text))
labels_overlay = hv.Labels(text_labels, kdims=["x", "y"], vdims=["text"]).opts(
text_font_size=f"{text_size}pt", text_align="left", text_color="black"
)
header_y = len(coxph_fitting_summary) - 0.7
header_labels = hv.Labels(
[
(text_start_x, header_y, "coef"),
(ci_text_x, header_y, "95% CI"),
],
kdims=["x", "y"],
vdims=["text"],
).opts(text_font_size=f"{text_size + 2}pt", text_font_style="bold", text_align="left", text_color="black")
forest_plot = (vline * error_bars * points * labels_overlay * header_labels).opts(
width=width,
height=height,
xlim=(plot_x_min, total_x_max),
ylim=(len(coxph_fitting_summary) - 0.5, -0.5),
invert_yaxis=True,
yticks=list(zip(y_positions, labels, strict=False)),
xlabel="Coefficient",
ylabel="",
fontsize={"yticks": text_size},
show_legend=False,
)
if title:
forest_plot = forest_plot.opts(title=title)
return forest_plot
[docs]
def cox_ph_adjusted_curves(
edata: EHRData,
*,
uns_key: str = "cox_ph_adjusted_curves",
groups: Sequence[str] | None = None,
palette: Sequence[str] | None = None,
show_ci: bool = True,
ci_alpha: float = 0.15,
line_width: float = 2.0,
width: int = 650,
height: int = 450,
title: str | None = None,
xlabel: str = "Time",
ylabel: str = "Adjusted Survival Probability",
legend_position: str = "top_right",
) -> hv.Overlay:
"""Survival curve plot to visualize CoxPH-adjusted survival probabilities stratified by a grouping variable.
The `edata` object must first be populated using :func:`~ehrapy.tools.cox_ph_adjusted_curves`.
Mirrors the functionality of ggadjustedcurves() from the R survminer package.
See also: Therneau, Crowson & Atkinson (2015), 'Adjusted Survival Curves':
https://cran.r-project.org/web/packages/survival/vignettes/adjcurve.pdf
Args:
edata: Data object containing the adjusted survival curves in `uns` after having run :func:`~ehrapy.tools.cox_ph_adjusted_curves`.
uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph_adjusted_curves` stored its output.
See argument `key_added` in :func:`~ehrapy.tools.cox_ph_adjusted_curves`.
groups: Subset of group labels to plot.
If None, all groups are plotted.
palette: List of hex or named colors, one per group.
show_ci: Whether to draw confidence interval bands.
Only has an effect when method='average' was used, as method='conditional' does not produce confidence intervals.
ci_alpha: Transparency of the CI band fill between 0 and 1.
line_width: Line width for the survival curves.
width: Plot width in pixels.
height: Plot height in pixels.
title: Set the title of the plot.
xlabel: X-axis label.
ylabel: Y-axis label.
legend_position: Position of the legend.
Returns:
HoloViews Overlay object representing the adjusted survival curve plot.
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> cph = ep.tl.cox_ph(
... edata, "mort_day_censored", "censor_flg", formula="gender_num + afib_flg + day_icu_intime_num"
... )
>>> ep.tl.cox_ph_adjusted_curves(
... edata,
... cph,
... strata="aline_flg",
... duration_col="mort_day_censored",
... event_col="censor_flg",
... )
>>> ep.pl.cox_ph_adjusted_curves(edata)
.. image:: /_static/docstring_previews/cox_ph_adjusted_curves.png
"""
if uns_key not in edata.uns:
raise KeyError(f"No adjusted curves found at edata.uns['{uns_key}']. Run ep.tl.cox_ph_adjusted_curves() first.")
data = edata.uns[uns_key]
meta = data.get("_meta", {})
strata_label = meta.get("strata", "group")
method = meta.get("method", "average")
all_groups = [k for k in data if k != "_meta"]
plot_groups = groups if groups is not None else all_groups
missing = set(plot_groups) - set(all_groups)
if missing:
raise ValueError(f"Groups {missing} not found in results. Available: {all_groups}")
if palette is None:
try:
palette = Category10[max(3, len(plot_groups))]
except KeyError as err:
raise ValueError(
"Too many groups to assign colors automatically. Please provide a custom palette."
) from err
colors = palette[: len(plot_groups)]
elements = []
for color, group in zip(colors, plot_groups, strict=False):
entry = data[group]
times = entry["times"]
survival = entry["survival"]
ci_lower = entry.get("ci_lower")
ci_upper = entry.get("ci_upper")
# Survival curve
curve = hv.Curve(
(times, survival),
kdims="time",
vdims="survival",
label=str(group),
).opts(
color=color,
line_width=line_width,
tools=["hover"],
)
elements.append(curve)
# CI band
if show_ci and ci_lower is not None and ci_upper is not None:
band = hv.Area(
(times, ci_lower, ci_upper),
kdims="time",
vdims=["survival_lower", "survival_upper"],
).opts(
color=color,
alpha=ci_alpha,
line_alpha=0,
)
elements.append(band)
_method_label = "population-averaged" if method == "average" else "conditional"
_title = title or f"CoxPH-Adjusted Survival Curves ({_method_label}) stratified by {strata_label}"
overlay = (
hv.Overlay(elements)
.opts(
width=width,
height=height,
xlabel=xlabel,
ylabel=ylabel,
ylim=(-0.02, 1.05),
xlim=(0, None),
legend_position=legend_position,
title=_title,
toolbar="above",
)
.relabel(_title)
)
return overlay