Source code for ehrapy.plot._survival_analysis

from __future__ import annotations

from typing import TYPE_CHECKING

import holoviews as hv
import numpy as np
import pandas as pd
from bokeh.palettes import Category10
from numpy import ndarray

if TYPE_CHECKING:
    from collections.abc import Iterable, Sequence
    from typing import Any
    from xmlrpc.client import Boolean

    from ehrdata import EHRData
    from lifelines import KaplanMeierFitter
    from statsmodels.regression.linear_model import RegressionResults


[docs] def ols( edata: EHRData | None = None, *, x: str | None = None, y: str | None = None, scatter_plot: Boolean | None = True, ols_results: list[RegressionResults] | None = None, ols_color: list[str | None] | None = None, xlabel: str | None = None, ylabel: str | None = None, width: int | None = 600, height: int | None = 400, lines: list[tuple[ndarray | float, ndarray | float]] | None = None, lines_color: list[str | None] | None = None, lines_style: list[str | None] | None = None, lines_label: list[str | None] | None = None, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, title: str | None = None, **kwds, ) -> hv.Scatter | hv.Curve | hv.Overlay | None: """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. Args: edata: Central data object. x: x coordinate, for scatter plotting. y: y coordinate, for scatter plotting. scatter_plot: Whether to show a scatter plot. ols_results: List of RegressionResults from ehrapy.tl.ols. ols_color: List of colors for each ols_results. xlabel: The x-axis label text. ylabel: The y-axis label text. width: Plot width in pixels. height: Plot height in pixels. lines: List of Tuples of (slope, intercept) or (x, y). Plot lines by slope and intercept or data points. Example: plot two lines (y = x + 2 and y = 2*x + 1): [(1, 2), (2, 1)] lines_color: List of colors for each line. lines_style: List of line styles for each line. lines_label: List of line labels for each line. xlim: Set the x-axis view limits. Required for only plotting lines using slope and intercept. ylim: Set the y-axis view limits. Required for only plotting lines using slope and intercept. title: Set the title of the plot. **kwds: Passed to HoloViews Scatter element. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> edata = ed.dt.mimic_2() >>> co2_lm_result = ep.tl.ols( ... edata, var_names=["pco2_first", "tco2_first"], formula="tco2_first ~ pco2_first", missing="drop" ... ).fit() >>> ep.pl.ols( ... edata, ... x="pco2_first", ... y="tco2_first", ... ols_results=[co2_lm_result], ... ols_color=["red"], ... xlabel="PCO2", ... ylabel="TCO2", ... ) .. image:: /_static/docstring_previews/ols_plot.png """ if ols_color is None and ols_results is not None: ols_color = [None] * len(ols_results) if lines_color is None and lines is not None: lines_color = [None] * len(lines) if lines_style is None and lines is not None: lines_style = [None] * len(lines) if lines_label is None and lines is not None: lines_label = [None] * len(lines) plot = None if edata is not None and x is not None and y is not None: x_data = np.array(edata[:, x].X).flatten().astype(float) y_data = np.array(edata[:, y].X).flatten().astype(float) mask = ~(np.isnan(x_data) | np.isnan(y_data)) x_clean = x_data[mask] y_clean = y_data[mask] if scatter_plot: scatter_opts = {**kwds} scatter_opts.setdefault("tools", ["hover"]) plot = hv.Scatter((x_clean, y_clean), kdims=x, vdims=y).opts(**scatter_opts) if ols_results is not None: x_sorted = np.sort(x_clean) for i, ols_result in enumerate(ols_results): y_pred = ols_result.predict(exog={"const": 1, x: x_sorted}) curve_opts: dict[str, Any] = {"tools": ["hover"]} if ols_color[i] is not None: curve_opts["color"] = ols_color[i] ols_curve = hv.Curve((x_sorted, y_pred)).opts(**curve_opts) plot = ols_curve if plot is None else plot * ols_curve if lines is not None: if xlim is None and plot is not None: x_range = plot.range(x if x else 0) xlim = (x_range[0], x_range[1]) if x_range[0] is not None else (0, 1) elif xlim is None: xlim = (0, 1) for i, line in enumerate(lines): a, b = line if np.ndim(a) == 0 and np.ndim(b) == 0: line_x = np.array(xlim) line_y = a * line_x + b else: line_x, line_y = a, b curve_opts = {"tools": ["hover"]} if lines_color[i] is not None: curve_opts["color"] = lines_color[i] if lines_style[i] is not None: style_map = {"-": "solid", "--": "dashed", ":": "dotted", "-.": "dashdot"} curve_opts["line_dash"] = style_map.get(lines_style[i], "solid") curve_kwargs = {"label": lines_label[i]} if lines_label[i] else {} line_curve = hv.Curve((line_x, line_y), **curve_kwargs).opts(**curve_opts) plot = line_curve if plot is None else plot * line_curve if plot is None: return None opts_dict: dict[str, Any] = {} if width is not None: opts_dict["width"] = width if height is not None: opts_dict["height"] = height if xlabel is not None: opts_dict["xlabel"] = xlabel if ylabel is not None: opts_dict["ylabel"] = ylabel if title is not None: opts_dict["title"] = title if xlim is not None: opts_dict["xlim"] = xlim if ylim is not None: opts_dict["ylim"] = ylim plot = plot.opts(**opts_dict) return plot
[docs] def kaplan_meier( kmfs: Sequence[KaplanMeierFitter], *, display_survival_statistics: bool = False, ci_alpha: list[float] | None = None, ci_force_lines: list[Boolean] | None = None, ci_show: list[Boolean] | None = None, ci_legend: list[Boolean] | None = None, at_risk_counts: list[Boolean] | None = None, color: list[str | None] | None = None, grid: Boolean | None = False, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, xlabel: str | None = None, ylabel: str | None = None, width: int | None = 600, height: int | None = 400, title: str | None = None, ) -> hv.Layout | hv.Overlay | hv.Curve | None: """Plots a pretty figure of the Fitted KaplanMeierFitter model. See also: :class:`~lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter`. Args: kmfs: Iterables of fitted KaplanMeierFitter objects. display_survival_statistics: Whether to show survival statistics in a table below the plot. ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list. ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas). If more than one kmfs, this should be a list. ci_show: Show confidence intervals. If more than one kmfs, this should be a list. ci_legend: If ci_force_lines is True, this is a boolean flag to add the lines' labels to the legend. If more than one kmfs, this should be a list. at_risk_counts: Show group sizes at time points. If more than one kmfs, this should be a list. color: List of colors for each kmf. If more than one kmfs, this should be a list. grid: If True, plot grid lines. xlim: Set the x-axis view limits. ylim: Set the y-axis view limits. xlabel: The x-axis label text. ylabel: The y-axis label text. width: Plot width in pixels. height: Plot height in pixels. title: Set the title of the plot. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> import numpy as np >>> edata = ed.dt.mimic_2() >>> edata[:, ["censor_flg"]].X = np.where( ... edata[:, ["censor_flg"]].X == 0, 1, 0 ... ) # MIMIC-II uses 0=death while KaplanMeierFitter expects True=death >>> kmf = ep.tl.kaplan_meier(edata, "mort_day_censored", "censor_flg") >>> ep.pl.kaplan_meier( ... [kmf], color=["r"], xlim=(0, 700), ylim=(0, 1), xlabel="Days", ylabel="Proportion Survived", show=True ... ) .. image:: /_static/docstring_previews/kaplan_meier.png """ if ci_alpha is None: ci_alpha = [0.3] * len(kmfs) if ci_force_lines is None: ci_force_lines = [False] * len(kmfs) if ci_show is None: ci_show = [True] * len(kmfs) if ci_legend is None: ci_legend = [False] * len(kmfs) if at_risk_counts is None: at_risk_counts = [False] * len(kmfs) if color is None: color = [None] * len(kmfs) plot = None for i, kmf in enumerate(kmfs): sf = kmf.survival_function_ times = sf.index.values survival = sf.iloc[:, 0].values label = kmf.label if kmf.label else f"Group {i + 1}" curve_opts: dict[str, Any] = {"tools": ["hover"]} if color[i] is not None: curve_opts["color"] = color[i] curve = hv.Curve((times, survival), kdims="Time", vdims="Survival", label=label).opts(**curve_opts) if ci_show[i] and hasattr(kmf, "confidence_interval_survival_function_"): ci = kmf.confidence_interval_survival_function_ ci_lower = ci.iloc[:, 0].values ci_upper = ci.iloc[:, 1].values if ci_force_lines[i]: ci_lower_curve = hv.Curve((times, ci_lower)).opts( color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_dash="dashed" ) ci_upper_curve = hv.Curve((times, ci_upper)).opts( color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_dash="dashed" ) curve = curve * ci_lower_curve * ci_upper_curve else: ci_area = hv.Area((times, ci_lower, ci_upper), vdims=["y", "y2"]).opts( color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_width=0 ) curve = ci_area * curve plot = curve if plot is None else plot * curve if plot is None: return None opts_dict: dict[str, Any] = {"width": width, "show_grid": grid} if xlabel: opts_dict["xlabel"] = xlabel if ylabel: opts_dict["ylabel"] = ylabel if title: opts_dict["title"] = title if xlim: opts_dict["xlim"] = xlim if ylim: opts_dict["ylim"] = ylim opts_dict["show_legend"] = True plot = plot.opts(**opts_dict) if display_survival_statistics: if xlim: time_points = np.linspace(xlim[0], xlim[1], 10) else: all_times = np.concatenate([kmf.survival_function_.index.values for kmf in kmfs]) time_points = np.linspace(all_times.min(), all_times.max(), 10) # Create table data in wide format (one row per group, columns are time points) table_data: dict[str, list[str]] = {} table_data["Group"] = [] for kmf in kmfs: label = kmf.label if kmf.label else "Group" table_data["Group"].append(label) survival_probs = kmf.survival_function_at_times(time_points).values for _, (t, prob) in enumerate(zip(time_points, survival_probs, strict=False)): col_name = f"{t:.0f}" if col_name not in table_data: table_data[col_name] = [] table_data[col_name].append(f"{prob:.2f}") df = pd.DataFrame(table_data) table = hv.Table(df).opts(width=width, height=int(height * 0.4), fit_columns=True) plot = (plot + table).cols(1) return plot
[docs] def cox_ph_forestplot( edata: EHRData, *, uns_key: str = "cox_ph", labels: Iterable[str] | None = None, width: int = 1200, height: int = 600, ecolor: str = "dimgray", size: int = 3, marker: str = "o", decimal: int = 2, text_size: int = 12, color: str = "k", title: str | None = None, ) -> hv.Overlay | None: """Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model. The `edata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `edata`. The summary table is created when the model is fitted using the :func:`~ehrapy.tools.cox_ph` function. See also: :class:`~lifelines.fitters.coxph_fitter.CoxPHFitter` Args: edata: Data object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tools.cox_ph`. uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`. labels: List of labels for each coefficient, default uses the index of the summary table. width: Plot width in pixels. height: Plot height in pixels. ecolor: Color of the error bars. size: Size of the markers. marker: Marker style. decimal: Number of decimal places to display. text_size: Font size of the text. color: Color of the markers. title: Set the title of the plot. Returns: HoloViews Overlay with forest plot and text annotations. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> edata = ed.dt.mimic_2()[ ... :, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"] ... ] >>> coxph = ep.tl.cox_ph(edata, event_col="censor_flg", duration_col="mort_day_censored") >>> ep.pl.cox_ph_forestplot(edata) .. image:: /_static/docstring_previews/coxph_forestplot.png """ if uns_key not in edata.uns: raise ValueError(f"Key {uns_key} not found in edata.uns. Please provide a valid key.") coxph_fitting_summary = edata.uns[uns_key] auc_col = "coef" if labels is None: labels = list(coxph_fitting_summary.index) coefs = coxph_fitting_summary[auc_col].values lower = coxph_fitting_summary["coef lower 95%"].values upper = coxph_fitting_summary["coef upper 95%"].values y_positions = np.arange(len(coxph_fitting_summary)) x_axis_upper_bound = float(pd.to_numeric(coxph_fitting_summary["coef upper 95%"]).max()) x_axis_lower_bound = float(pd.to_numeric(coxph_fitting_summary["coef lower 95%"]).min()) data_range = x_axis_upper_bound - x_axis_lower_bound plot_padding = data_range * 0.1 text_gap = data_range * 0.15 text_spacing = data_range * 0.4 plot_x_min = x_axis_lower_bound - plot_padding plot_x_max = x_axis_upper_bound + plot_padding text_start_x = plot_x_max + text_gap ci_text_x = text_start_x + text_spacing total_x_max = ci_text_x + (data_range * 0.5) error_data = [] for coef, y, lower_val, upper_val in zip(coefs, y_positions, lower, upper, strict=False): if not np.isnan(coef) and not np.isnan(lower_val) and not np.isnan(upper_val): error_data.append((coef, y, coef - lower_val, upper_val - coef)) error_bars = hv.ErrorBars( error_data, kdims=["Coefficient", "Variable"], vdims=["negative_error", "positive_error"], ).opts(color=ecolor, line_width=2) points = hv.Scatter( [(coef, y) for coef, y in zip(coefs, y_positions, strict=False) if not np.isnan(coef)], kdims=["Coefficient"], vdims=["Variable"], ).opts(color=color, size=size * 5, marker=marker, tools=["hover"]) vline = hv.VLine(1).opts(color="gray", line_width=1) text_labels = [] for coef_val, low_val, upp_val, y_pos in zip(coefs, lower, upper, y_positions, strict=False): if not np.isnan(coef_val): if isinstance(coef_val, float) and isinstance(low_val, float) and isinstance(upp_val, float): coef_text = f"{coef_val:.{decimal}f}" ci_text = f"({low_val:.{decimal}f}, {upp_val:.{decimal}f})" else: coef_text = str(coef_val) ci_text = f"({low_val}, {upp_val})" text_labels.append((text_start_x, y_pos, coef_text)) text_labels.append((ci_text_x, y_pos, ci_text)) labels_overlay = hv.Labels(text_labels, kdims=["x", "y"], vdims=["text"]).opts( text_font_size=f"{text_size}pt", text_align="left", text_color="black" ) header_y = len(coxph_fitting_summary) - 0.7 header_labels = hv.Labels( [ (text_start_x, header_y, "coef"), (ci_text_x, header_y, "95% CI"), ], kdims=["x", "y"], vdims=["text"], ).opts(text_font_size=f"{text_size + 2}pt", text_font_style="bold", text_align="left", text_color="black") forest_plot = (vline * error_bars * points * labels_overlay * header_labels).opts( width=width, height=height, xlim=(plot_x_min, total_x_max), ylim=(len(coxph_fitting_summary) - 0.5, -0.5), invert_yaxis=True, yticks=list(zip(y_positions, labels, strict=False)), xlabel="Coefficient", ylabel="", fontsize={"yticks": text_size}, show_legend=False, ) if title: forest_plot = forest_plot.opts(title=title) return forest_plot
[docs] def cox_ph_adjusted_curves( edata: EHRData, *, uns_key: str = "cox_ph_adjusted_curves", groups: Sequence[str] | None = None, palette: Sequence[str] | None = None, show_ci: bool = True, ci_alpha: float = 0.15, line_width: float = 2.0, width: int = 650, height: int = 450, title: str | None = None, xlabel: str = "Time", ylabel: str = "Adjusted Survival Probability", legend_position: str = "top_right", ) -> hv.Overlay: """Survival curve plot to visualize CoxPH-adjusted survival probabilities stratified by a grouping variable. The `edata` object must first be populated using :func:`~ehrapy.tools.cox_ph_adjusted_curves`. Mirrors the functionality of ggadjustedcurves() from the R survminer package. See also: Therneau, Crowson & Atkinson (2015), 'Adjusted Survival Curves': https://cran.r-project.org/web/packages/survival/vignettes/adjcurve.pdf Args: edata: Data object containing the adjusted survival curves in `uns` after having run :func:`~ehrapy.tools.cox_ph_adjusted_curves`. uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph_adjusted_curves` stored its output. See argument `key_added` in :func:`~ehrapy.tools.cox_ph_adjusted_curves`. groups: Subset of group labels to plot. If None, all groups are plotted. palette: List of hex or named colors, one per group. show_ci: Whether to draw confidence interval bands. Only has an effect when method='average' was used, as method='conditional' does not produce confidence intervals. ci_alpha: Transparency of the CI band fill between 0 and 1. line_width: Line width for the survival curves. width: Plot width in pixels. height: Plot height in pixels. title: Set the title of the plot. xlabel: X-axis label. ylabel: Y-axis label. legend_position: Position of the legend. Returns: HoloViews Overlay object representing the adjusted survival curve plot. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> edata = ed.dt.mimic_2() >>> cph = ep.tl.cox_ph( ... edata, "mort_day_censored", "censor_flg", formula="gender_num + afib_flg + day_icu_intime_num" ... ) >>> ep.tl.cox_ph_adjusted_curves( ... edata, ... cph, ... strata="aline_flg", ... duration_col="mort_day_censored", ... event_col="censor_flg", ... ) >>> ep.pl.cox_ph_adjusted_curves(edata) .. image:: /_static/docstring_previews/cox_ph_adjusted_curves.png """ if uns_key not in edata.uns: raise KeyError(f"No adjusted curves found at edata.uns['{uns_key}']. Run ep.tl.cox_ph_adjusted_curves() first.") data = edata.uns[uns_key] meta = data.get("_meta", {}) strata_label = meta.get("strata", "group") method = meta.get("method", "average") all_groups = [k for k in data if k != "_meta"] plot_groups = groups if groups is not None else all_groups missing = set(plot_groups) - set(all_groups) if missing: raise ValueError(f"Groups {missing} not found in results. Available: {all_groups}") if palette is None: try: palette = Category10[max(3, len(plot_groups))] except KeyError as err: raise ValueError( "Too many groups to assign colors automatically. Please provide a custom palette." ) from err colors = palette[: len(plot_groups)] elements = [] for color, group in zip(colors, plot_groups, strict=False): entry = data[group] times = entry["times"] survival = entry["survival"] ci_lower = entry.get("ci_lower") ci_upper = entry.get("ci_upper") # Survival curve curve = hv.Curve( (times, survival), kdims="time", vdims="survival", label=str(group), ).opts( color=color, line_width=line_width, tools=["hover"], ) elements.append(curve) # CI band if show_ci and ci_lower is not None and ci_upper is not None: band = hv.Area( (times, ci_lower, ci_upper), kdims="time", vdims=["survival_lower", "survival_upper"], ).opts( color=color, alpha=ci_alpha, line_alpha=0, ) elements.append(band) _method_label = "population-averaged" if method == "average" else "conditional" _title = title or f"CoxPH-Adjusted Survival Curves ({_method_label}) stratified by {strata_label}" overlay = ( hv.Overlay(elements) .opts( width=width, height=height, xlabel=xlabel, ylabel=ylabel, ylim=(-0.02, 1.05), xlim=(0, None), legend_position=legend_position, title=_title, toolbar="above", ) .relabel(_title) ) return overlay