Source code for ehrapy.plot._survival_analysis

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from numpy import ndarray

from ehrapy.plot import scatter

if TYPE_CHECKING:
    from collections.abc import Sequence
    from xmlrpc.client import Boolean

    from anndata import AnnData
    from lifelines import KaplanMeierFitter
    from matplotlib.axes import Axes
    from statsmodels.regression.linear_model import RegressionResults


[docs] def ols( adata: AnnData | None = None, x: str | None = None, y: str | None = None, scatter_plot: Boolean | None = True, ols_results: list[RegressionResults] | None = None, ols_color: list[str] | None | None = None, xlabel: str | None = None, ylabel: str | None = None, figsize: tuple[float, float] | None = None, lines: list[tuple[ndarray | float, ndarray | float]] | None = None, lines_color: list[str] | None | None = None, lines_style: list[str] | None | None = None, lines_label: list[str] | None | None = None, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, show: bool | None = None, ax: Axes | None = None, title: str | None = None, **kwds, ): """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. Args: adata: :class:`~anndata.AnnData` object containing all observations. x: x coordinate, for scatter plotting. y: y coordinate, for scatter plotting. scatter_plot: If True, show scatter plot. Defaults to True. ols_results: List of RegressionResults from ehrapy.tl.ols. Example: [result_1, result_2] ols_color: List of colors for each ols_results. Example: ['red', 'blue']. xlabel: The x-axis label text. ylabel: The y-axis label text. figsize: Width, height in inches. Defaults to None. lines: List of Tuples of (slope, intercept) or (x, y). Plot lines by slope and intercept or data points. Example: plot two lines (y = x + 2 and y = 2*x + 1): [(1, 2), (2, 1)] lines_color: List of colors for each line. Example: ['red', 'blue'] lines_style: List of line styles for each line. Example: ['-', '--'] lines_label: List of line labels for each line. Example: ['Line1', 'Line2'] xlim: Set the x-axis view limits. Required for only plotting lines using slope and intercept. ylim: Set the y-axis view limits. Required for only plotting lines using slope and intercept. show: Show the plot, do not return axis. ax: A matplotlib axes object. Only works if plotting a single component. title: Set the title of the plot. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=False) >>> co2_lm_result = ep.tl.ols( ... adata, var_names=["pco2_first", "tco2_first"], formula="tco2_first ~ pco2_first", missing="drop" ... ).fit() >>> ep.pl.ols( ... adata, ... x="pco2_first", ... y="tco2_first", ... ols_results=[co2_lm_result], ... ols_color=["red"], ... xlabel="PCO2", ... ylabel="TCO2", ... ) .. image:: /_static/docstring_previews/ols_plot_1.png >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=False) >>> ep.pl.ols(adata, x='pco2_first', y='tco2_first', lines=[(0.25, 10), (0.3, 20)], >>> lines_color=['red', 'blue'], lines_style=['-', ':'], lines_label=['Line1', 'Line2']) .. image:: /_static/docstring_previews/ols_plot_2.png >>> import ehrapy as ep >>> ep.pl.ols(lines=[(0.25, 10), (0.3, 20)], lines_color=['red', 'blue'], lines_style=['-', ':'], >>> lines_label=['Line1', 'Line2'], xlim=(0, 150), ylim=(0, 50)) .. image:: /_static/docstring_previews/ols_plot_3.png """ if ax is None: _, ax = plt.subplots(figsize=figsize) if xlim is not None: plt.xlim(xlim) if ylim is not None: plt.ylim(ylim) if ols_color is None and ols_results is not None: ols_color = [None] * len(ols_results) if lines_color is None and lines is not None: lines_color = [None] * len(lines) if lines_style is None and lines is not None: lines_style = [None] * len(lines) if lines_label is None and lines is not None: lines_label = [None] * len(lines) if adata is not None and x is not None and y is not None: x_processed = np.array(adata[:, x].X).astype(float) x_processed = x_processed[~np.isnan(x_processed)] if scatter_plot is True: ax = scatter(adata, x=x, y=y, show=False, ax=ax, **kwds) if ols_results is not None: for i, ols_result in enumerate(ols_results): ax.plot(x_processed, ols_result.predict(), color=ols_color[i]) if lines is not None: for i, line in enumerate(lines): a, b = line if np.ndim(a) == 0 and np.ndim(b) == 0: line_x = np.array(ax.get_xlim()) line_y = a * line_x + b ax.plot(line_x, line_y, linestyle=lines_style[i], color=lines_color[i], label=lines_label[i]) else: ax.plot(a, b, lines_style[i], color=lines_color[i], label=lines_label[i]) plt.xlabel(xlabel) plt.ylabel(ylabel) if title: plt.title(title) if lines_label is not None and lines_label[0] is not None: plt.legend() if not show: return ax
[docs] def kmf( kmfs: Sequence[KaplanMeierFitter], ci_alpha: list[float] | None = None, ci_force_lines: list[Boolean] | None = None, ci_show: list[Boolean] | None = None, ci_legend: list[Boolean] | None = None, at_risk_counts: list[Boolean] | None = None, color: list[str] | None | None = None, grid: Boolean | None = False, xlim: tuple[float, float] | None = None, ylim: tuple[float, float] | None = None, xlabel: str | None = None, ylabel: str | None = None, figsize: tuple[float, float] | None = None, show: bool | None = None, title: str | None = None, ): """Plots a pretty figure of the Fitted KaplanMeierFitter model See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html Args: kmfs: Iterables of fitted KaplanMeierFitter objects. ci_alpha: The transparency level of the confidence interval. If more than one kmfs, this should be a list. Defaults to 0.3. ci_force_lines: Force the confidence intervals to be line plots (versus default shaded areas). If more than one kmfs, this should be a list. Defaults to False . ci_show: Show confidence intervals. If more than one kmfs, this should be a list. Defaults to True . ci_legend: If ci_force_lines is True, this is a boolean flag to add the lines' labels to the legend. If more than one kmfs, this should be a list. Defaults to False . at_risk_counts: Show group sizes at time points. If more than one kmfs, this should be a list. Defaults to False. color: List of colors for each kmf. If more than one kmfs, this should be a list. grid: If True, plot grid lines. xlim: Set the x-axis view limits. ylim: Set the y-axis view limits. xlabel: The x-axis label text. ylabel: The y-axis label text. figsize: Width, height in inches. Defaults to None . show: Show the plot, do not return axis. title: Set the title of the plot. Examples: >>> import ehrapy as ep >>> import numpy as np >>> adata = ep.dt.mimic_2(encoded=False) # Because in MIMIC-II database, `censor_fl` is censored or death (binary: 0 = death, 1 = censored). # While in KaplanMeierFitter, `event_observed` is True if the the death was observed, False if the event was lost (right-censored). # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) >>> ep.pl.kmf( ... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True ... ) .. image:: /_static/docstring_previews/kmf_plot_1.png >>> T = adata[:, ["mort_day_censored"]].X >>> E = adata[:, ["censor_flg"]].X >>> groups = adata[:, ["service_unit"]].X >>> ix1 = groups == "FICU" >>> ix2 = groups == "MICU" >>> ix3 = groups == "SICU" >>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU") >>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU") >>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU") >>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], >>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived") .. image:: /_static/docstring_previews/kmf_plot_2.png """ if ci_alpha is None: ci_alpha = [0.3] * len(kmfs) if ci_force_lines is None: ci_force_lines = [False] * len(kmfs) if ci_show is None: ci_show = [True] * len(kmfs) if ci_legend is None: ci_legend = [False] * len(kmfs) if at_risk_counts is None: at_risk_counts = [False] * len(kmfs) if color is None: color = [None] * len(kmfs) plt.figure(figsize=figsize) for i, kmf in enumerate(kmfs): if i == 0: ax = kmf.plot_survival_function( ci_alpha=ci_alpha[i], ci_force_lines=ci_force_lines[i], ci_show=ci_show[i], ci_legend=ci_legend[i], at_risk_counts=at_risk_counts[i], color=color[i], ) else: ax = kmf.plot_survival_function( ax=ax, ci_alpha=ci_alpha[i], ci_force_lines=ci_force_lines[i], ci_show=ci_show[i], ci_legend=ci_legend[i], at_risk_counts=at_risk_counts[i], color=color[i], ) ax.grid(grid) plt.xlim(xlim) plt.ylim(ylim) plt.xlabel(xlabel) plt.ylabel(ylabel) if title: plt.title(title) if not show: return ax