Source code for ehrapy.plot._survival_analysis

from __future__ import annotations

from typing import TYPE_CHECKING

import holoviews as hv
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from numpy import ndarray

from ehrapy._compat import use_ehrdata

if TYPE_CHECKING:
    from collections.abc import Iterable, Sequence
    from typing import Any
    from xmlrpc.client import Boolean

    from anndata import AnnData
    from ehrdata import EHRData
    from lifelines import KaplanMeierFitter
    from statsmodels.regression.linear_model import RegressionResults


[docs] @use_ehrdata(deprecated_after="1.0.0", edata_None_allowed=True) def ols( edata: EHRData | AnnData | None = None, *, x: str | None = None, y: str | None = None, scatter_plot: Boolean | None = True, ols_results: list[RegressionResults] | None = None, ols_color: list[str | None] | None = None, xlabel: str | None = None, ylabel: str | None = None, width: int | None = 600, height: int | None = 400, lines: list[tuple[ndarray | float, ndarray | float]] | None = None, lines_color: list[str | None] | None = None, lines_style: list[str | None] | None = None, lines_label: list[str | None] | None = None, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, title: str | None = None, **kwds, ) -> hv.Scatter | hv.Curve | hv.Overlay | None: """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. Args: edata: Central data object. x: x coordinate, for scatter plotting. y: y coordinate, for scatter plotting. scatter_plot: Whether to show a scatter plot. ols_results: List of RegressionResults from ehrapy.tl.ols. ols_color: List of colors for each ols_results. xlabel: The x-axis label text. ylabel: The y-axis label text. width: Plot width in pixels. height: Plot height in pixels. lines: List of Tuples of (slope, intercept) or (x, y). Plot lines by slope and intercept or data points. Example: plot two lines (y = x + 2 and y = 2*x + 1): [(1, 2), (2, 1)] lines_color: List of colors for each line. lines_style: List of line styles for each line. lines_label: List of line labels for each line. xlim: Set the x-axis view limits. Required for only plotting lines using slope and intercept. ylim: Set the y-axis view limits. Required for only plotting lines using slope and intercept. title: Set the title of the plot. **kwds: Passed to HoloViews Scatter element. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> edata = ed.dt.mimic_2() >>> co2_lm_result = ep.tl.ols( ... edata, var_names=["pco2_first", "tco2_first"], formula="tco2_first ~ pco2_first", missing="drop" ... ).fit() >>> ep.pl.ols( ... edata, ... x="pco2_first", ... y="tco2_first", ... ols_results=[co2_lm_result], ... ols_color=["red"], ... xlabel="PCO2", ... ylabel="TCO2", ... ) .. image:: /_static/docstring_previews/ols_plot.png """ if ols_color is None and ols_results is not None: ols_color = [None] * len(ols_results) if lines_color is None and lines is not None: lines_color = [None] * len(lines) if lines_style is None and lines is not None: lines_style = [None] * len(lines) if lines_label is None and lines is not None: lines_label = [None] * len(lines) plot = None if edata is not None and x is not None and y is not None: x_data = np.array(edata[:, x].X).flatten().astype(float) y_data = np.array(edata[:, y].X).flatten().astype(float) mask = ~(np.isnan(x_data) | np.isnan(y_data)) x_clean = x_data[mask] y_clean = y_data[mask] if scatter_plot: scatter_opts = {**kwds} scatter_opts.setdefault("tools", ["hover"]) plot = hv.Scatter((x_clean, y_clean), kdims=x, vdims=y).opts(**scatter_opts) if ols_results is not None: x_sorted = np.sort(x_clean) for i, ols_result in enumerate(ols_results): y_pred = ols_result.predict(exog={"const": 1, x: x_sorted}) curve_opts: dict[str, Any] = {"tools": ["hover"]} if ols_color[i] is not None: curve_opts["color"] = ols_color[i] ols_curve = hv.Curve((x_sorted, y_pred)).opts(**curve_opts) plot = ols_curve if plot is None else plot * ols_curve if lines is not None: if xlim is None and plot is not None: x_range = plot.range(x if x else 0) xlim = (x_range[0], x_range[1]) if x_range[0] is not None else (0, 1) elif xlim is None: xlim = (0, 1) for i, line in enumerate(lines): a, b = line if np.ndim(a) == 0 and np.ndim(b) == 0: line_x = np.array(xlim) line_y = a * line_x + b else: line_x, line_y = a, b curve_opts = {"tools": ["hover"]} if lines_color[i] is not None: curve_opts["color"] = lines_color[i] if lines_style[i] is not None: style_map = {"-": "solid", "--": "dashed", ":": "dotted", "-.": "dashdot"} curve_opts["line_dash"] = style_map.get(lines_style[i], "solid") curve_kwargs = {"label": lines_label[i]} if lines_label[i] else {} line_curve = hv.Curve((line_x, line_y), **curve_kwargs).opts(**curve_opts) plot = line_curve if plot is None else plot * line_curve if plot is None: return None opts_dict: dict[str, Any] = {} if width is not None: opts_dict["width"] = width if height is not None: opts_dict["height"] = height if xlabel is not None: opts_dict["xlabel"] = xlabel if ylabel is not None: opts_dict["ylabel"] = ylabel if title is not None: opts_dict["title"] = title if xlim is not None: opts_dict["xlim"] = xlim if ylim is not None: opts_dict["ylim"] = ylim plot = plot.opts(**opts_dict) return plot
[docs] def kaplan_meier( kmfs: Sequence[KaplanMeierFitter], *, display_survival_statistics: bool = False, ci_alpha: list[float] | None = None, ci_force_lines: list[Boolean] | None = None, ci_show: list[Boolean] | None = None, ci_legend: list[Boolean] | None = None, at_risk_counts: list[Boolean] | None = None, color: list[str | None] | None = None, grid: Boolean | None = False, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, xlabel: str | None = None, ylabel: str | None = None, width: int | None = 600, height: int | None = 400, title: str | None = None, ) -> hv.Layout | hv.Overlay | hv.Curve | None: """Plots a pretty figure of the Fitted KaplanMeierFitter model. See also: :class:`~lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter`. Args: kmfs: Iterables of fitted KaplanMeierFitter objects. display_survival_statistics: Whether to show survival statistics in a table below the plot. ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list. ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas). If more than one kmfs, this should be a list. ci_show: Show confidence intervals. If more than one kmfs, this should be a list. ci_legend: If ci_force_lines is True, this is a boolean flag to add the lines' labels to the legend. If more than one kmfs, this should be a list. at_risk_counts: Show group sizes at time points. If more than one kmfs, this should be a list. color: List of colors for each kmf. If more than one kmfs, this should be a list. grid: If True, plot grid lines. xlim: Set the x-axis view limits. ylim: Set the y-axis view limits. xlabel: The x-axis label text. ylabel: The y-axis label text. width: Plot width in pixels. height: Plot height in pixels. title: Set the title of the plot. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> import numpy as np >>> edata = ed.dt.mimic_2() >>> edata[:, ["censor_flg"]].X = np.where( ... edata[:, ["censor_flg"]].X == 0, 1, 0 ... ) # MIMIC-II uses 0=death while KaplanMeierFitter expects True=death >>> kmf = ep.tl.kaplan_meier(edata, "mort_day_censored", "censor_flg") >>> ep.pl.kaplan_meier( ... [kmf], color=["r"], xlim=(0, 700), ylim=(0, 1), xlabel="Days", ylabel="Proportion Survived", show=True ... ) .. image:: /_static/docstring_previews/kaplan_meier.png """ if ci_alpha is None: ci_alpha = [0.3] * len(kmfs) if ci_force_lines is None: ci_force_lines = [False] * len(kmfs) if ci_show is None: ci_show = [True] * len(kmfs) if ci_legend is None: ci_legend = [False] * len(kmfs) if at_risk_counts is None: at_risk_counts = [False] * len(kmfs) if color is None: color = [None] * len(kmfs) plot = None for i, kmf in enumerate(kmfs): sf = kmf.survival_function_ times = sf.index.values survival = sf.iloc[:, 0].values label = kmf.label if kmf.label else f"Group {i + 1}" curve_opts: dict[str, Any] = {"tools": ["hover"]} if color[i] is not None: curve_opts["color"] = color[i] curve = hv.Curve((times, survival), kdims="Time", vdims="Survival", label=label).opts(**curve_opts) if ci_show[i] and hasattr(kmf, "confidence_interval_survival_function_"): ci = kmf.confidence_interval_survival_function_ ci_lower = ci.iloc[:, 0].values ci_upper = ci.iloc[:, 1].values if ci_force_lines[i]: ci_lower_curve = hv.Curve((times, ci_lower)).opts( color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_dash="dashed" ) ci_upper_curve = hv.Curve((times, ci_upper)).opts( color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_dash="dashed" ) curve = curve * ci_lower_curve * ci_upper_curve else: ci_area = hv.Area((times, ci_lower, ci_upper), vdims=["y", "y2"]).opts( color=color[i] if color[i] else "gray", alpha=ci_alpha[i], line_width=0 ) curve = ci_area * curve plot = curve if plot is None else plot * curve if plot is None: return None opts_dict: dict[str, Any] = {"width": width, "show_grid": grid} if xlabel: opts_dict["xlabel"] = xlabel if ylabel: opts_dict["ylabel"] = ylabel if title: opts_dict["title"] = title if xlim: opts_dict["xlim"] = xlim if ylim: opts_dict["ylim"] = ylim opts_dict["show_legend"] = True plot = plot.opts(**opts_dict) if display_survival_statistics: if xlim: time_points = np.linspace(xlim[0], xlim[1], 10) else: all_times = np.concatenate([kmf.survival_function_.index.values for kmf in kmfs]) time_points = np.linspace(all_times.min(), all_times.max(), 10) # Create table data in wide format (one row per group, columns are time points) table_data: dict[str, list[str]] = {} table_data["Group"] = [] for kmf in kmfs: label = kmf.label if kmf.label else "Group" table_data["Group"].append(label) survival_probs = kmf.survival_function_at_times(time_points).values for _, (t, prob) in enumerate(zip(time_points, survival_probs, strict=False)): col_name = f"{t:.0f}" if col_name not in table_data: table_data[col_name] = [] table_data[col_name].append(f"{prob:.2f}") df = pd.DataFrame(table_data) table = hv.Table(df).opts(width=width, height=int(height * 0.4), fit_columns=True) plot = (plot + table).cols(1) return plot
[docs] @use_ehrdata(deprecated_after="1.0.0") def cox_ph_forestplot( edata: EHRData | AnnData, *, uns_key: str = "cox_ph", labels: Iterable[str] | None = None, width: int = 1200, height: int = 600, ecolor: str = "dimgray", size: int = 3, marker: str = "o", decimal: int = 2, text_size: int = 12, color: str = "k", title: str | None = None, ) -> hv.Overlay | None: """Generates a forest plot to visualize the coefficients and confidence intervals of a Cox Proportional Hazards model. The `edata` object must first be populated using the :func:`~ehrapy.tools.cox_ph` function. This function stores the summary table of the `CoxPHFitter` in the `.uns` attribute of `edata`. The summary table is created when the model is fitted using the :func:`~ehrapy.tools.cox_ph` function. See also: :class:`~lifelines.fitters.coxph_fitter.CoxPHFitter` Args: edata: Data object containing the summary table from the CoxPHFitter. This is stored in the `.uns` attribute, after fitting the model using :func:`~ehrapy.tools.cox_ph`. uns_key: Key in `.uns` where :func:`~ehrapy.tools.cox_ph` function stored the summary table. See argument `uns_key` in :func:`~ehrapy.tools.cox_ph`. labels: List of labels for each coefficient, default uses the index of the summary table. width: Plot width in pixels. height: Plot height in pixels. ecolor: Color of the error bars. size: Size of the markers. marker: Marker style. decimal: Number of decimal places to display. text_size: Font size of the text. color: Color of the markers. title: Set the title of the plot. Returns: HoloViews Overlay with forest plot and text annotations. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> edata = ed.dt.mimic_2()[ ... :, ["mort_day_censored", "censor_flg", "gender_num", "afib_flg", "day_icu_intime_num"] ... ] >>> coxph = ep.tl.cox_ph(edata, event_col="censor_flg", duration_col="mort_day_censored") >>> ep.pl.cox_ph_forestplot(edata) .. image:: /_static/docstring_previews/coxph_forestplot.png """ if uns_key not in edata.uns: raise ValueError(f"Key {uns_key} not found in edata.uns. Please provide a valid key.") coxph_fitting_summary = edata.uns[uns_key] auc_col = "coef" if labels is None: labels = list(coxph_fitting_summary.index) coefs = coxph_fitting_summary[auc_col].values lower = coxph_fitting_summary["coef lower 95%"].values upper = coxph_fitting_summary["coef upper 95%"].values y_positions = np.arange(len(coxph_fitting_summary)) x_axis_upper_bound = float(pd.to_numeric(coxph_fitting_summary["coef upper 95%"]).max()) x_axis_lower_bound = float(pd.to_numeric(coxph_fitting_summary["coef lower 95%"]).min()) data_range = x_axis_upper_bound - x_axis_lower_bound plot_padding = data_range * 0.1 text_gap = data_range * 0.15 text_spacing = data_range * 0.4 plot_x_min = x_axis_lower_bound - plot_padding plot_x_max = x_axis_upper_bound + plot_padding text_start_x = plot_x_max + text_gap ci_text_x = text_start_x + text_spacing total_x_max = ci_text_x + (data_range * 0.5) error_data = [] for coef, y, lower_val, upper_val in zip(coefs, y_positions, lower, upper, strict=False): if not np.isnan(coef) and not np.isnan(lower_val) and not np.isnan(upper_val): error_data.append((coef, y, coef - lower_val, upper_val - coef)) error_bars = hv.ErrorBars( error_data, kdims=["Coefficient", "Variable"], vdims=["negative_error", "positive_error"], ).opts(color=ecolor, line_width=2) points = hv.Scatter( [(coef, y) for coef, y in zip(coefs, y_positions, strict=False) if not np.isnan(coef)], kdims=["Coefficient"], vdims=["Variable"], ).opts(color=color, size=size * 5, marker=marker, tools=["hover"]) vline = hv.VLine(1).opts(color="gray", line_width=1) text_labels = [] for coef_val, low_val, upp_val, y_pos in zip(coefs, lower, upper, y_positions, strict=False): if not np.isnan(coef_val): if isinstance(coef_val, float) and isinstance(low_val, float) and isinstance(upp_val, float): coef_text = f"{coef_val:.{decimal}f}" ci_text = f"({low_val:.{decimal}f}, {upp_val:.{decimal}f})" else: coef_text = str(coef_val) ci_text = f"({low_val}, {upp_val})" text_labels.append((text_start_x, y_pos, coef_text)) text_labels.append((ci_text_x, y_pos, ci_text)) labels_overlay = hv.Labels(text_labels, kdims=["x", "y"], vdims=["text"]).opts( text_font_size=f"{text_size}pt", text_align="left", text_color="black" ) header_y = len(coxph_fitting_summary) - 0.7 header_labels = hv.Labels( [ (text_start_x, header_y, "coef"), (ci_text_x, header_y, "95% CI"), ], kdims=["x", "y"], vdims=["text"], ).opts(text_font_size=f"{text_size + 2}pt", text_font_style="bold", text_align="left", text_color="black") forest_plot = (vline * error_bars * points * labels_overlay * header_labels).opts( width=width, height=height, xlim=(plot_x_min, total_x_max), ylim=(len(coxph_fitting_summary) - 0.5, -0.5), invert_yaxis=True, yticks=list(zip(y_positions, labels, strict=False)), xlabel="Coefficient", ylabel="", fontsize={"yticks": text_size}, show_legend=False, ) if title: forest_plot = forest_plot.opts(title=title) return forest_plot