from __future__ import annotations
from typing import TYPE_CHECKING, Any
import holoviews as hv
import numpy as np
import pandas as pd
from fast_array_utils.conv import to_dense
from ehrapy._compat import choose_hv_backend
if TYPE_CHECKING:
from collections.abc import Sequence
from ehrdata import EHRData
[docs]
@choose_hv_backend()
def sankey_diagram(
edata: EHRData,
*,
columns: Sequence[str],
node_width: int | float = 20,
node_padding: int | float = 10,
node_color: str | None = None,
edge_color: str | None = None,
label_position: str | None = "right",
show_values: bool = True,
title: str | None = None,
width: int | None = 600,
height: int | None = 400,
**kwargs,
) -> hv.Sankey:
"""Create a Sankey diagram of relationships across the flat observation table.
Args:
edata: Central data object.
columns: Column names from `edata.obs` to visualize
node_width: Width of the nodes in the Sankey diagram.
node_padding: Padding between nodes in the Sankey diagram.
node_color: Color of the nodes. If None, default coloring is used.
edge_color: Color of the edges. If None, default coloring is used.
label_position: Position of the labels on the nodes. Options are 'left', 'right', 'outer', or 'inner'.
show_values: Whether to display the values on the edges.
title: Title of the Sankey diagram.
width: Width of the Sankey diagram.
height: Height of the Sankey diagram.
**kwargs: Additional styling options passed to :class:`holoviews.element.sankey.Sankey`.
Examples:
>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.diabetes_130_fairlearn(columns_obs_only=["gender", "race"])
>>> ep.pl.sankey_diagram(edata, columns=["gender", "race"])
.. image:: /_static/docstring_previews/sankey.png
"""
missing = [c for c in columns if c not in edata.obs.columns]
if missing:
raise KeyError(f"columns not found in edata.obs: {missing}")
if len(columns) < 2:
raise ValueError("columns must contain at least two obs column names.")
df = edata.obs[columns]
# Build links between consecutive columns
sources, targets, values = [], [], []
source_levels, target_levels = [], []
for i in range(len(columns) - 1):
col_from, col_to = columns[i], columns[i + 1]
flows = df.groupby([col_from, col_to], observed=True).size().reset_index(name="count")
sources.extend(col_from + ": " + flows[col_from].astype("string"))
targets.extend(col_to + ": " + flows[col_to].astype("string"))
values.extend(flows["count"].to_numpy())
source_levels.extend([col_from] * len(flows))
target_levels.extend([col_to] * len(flows))
sankey_df = pd.DataFrame(
{
"source": sources,
"target": targets,
"value": values,
"source_level": source_levels,
"target_level": target_levels,
}
)
sankey = hv.Sankey(sankey_df, kdims=["source", "target"], vdims=["value"])
opts_dict: dict[str, Any] = {}
if hv.Store.current_backend == "bokeh":
if width is not None:
opts_dict["width"] = width
if height is not None:
opts_dict["height"] = height
if node_width is not None:
opts_dict["node_width"] = node_width
if node_padding is not None:
opts_dict["node_padding"] = node_padding
if title is not None:
opts_dict["title"] = title
if node_color is not None:
opts_dict["node_color"] = node_color
if edge_color is not None:
opts_dict["edge_color"] = edge_color
if label_position is not None:
opts_dict["label_position"] = label_position
if show_values is not None:
opts_dict["show_values"] = show_values
opts_dict.update(kwargs)
sankey = sankey.opts(**opts_dict)
return sankey
[docs]
@choose_hv_backend()
def sankey_diagram_time(
edata: EHRData,
*,
var_name: str,
layer: str,
state_labels: dict[int, str] | None = None,
node_width: int | float = 20,
node_padding: int | float = 10,
cmap: str | list[str] | None = None,
node_color: str | None = None,
edge_color: str | None = None,
label_position: str | None = "right",
show_values: bool = True,
title: str | None = None,
width: int | None = 600,
height: int | None = 400,
**kwargs,
) -> hv.Sankey:
"""Create a Sankey diagram showing patient state transitions over time.
Each node represents a state at a specific time point, and flows show the
number of patients transitioning between states.
Visualizes how patients transition between different states
(e.g. disease severity, treatment status) across consecutive time points.
Args:
edata: Central data object.
var_name: Variable name from `edata.var_names` to visualize
layer: Name of the layer in `edata.layers` containing the feature data to visualize.
state_labels: Mapping from numeric state values to readable labels.
If None, state values will be displayed as strings of their numeric codes (e.g., "0", "1", "2").
node_width: Width of the nodes in the Sankey diagram.
node_padding: Padding between nodes in the Sankey diagram.
cmap: Colormap to use for coloring states. Can be a string recognized by Holoviews or a list of color hex codes.
node_color: Color of the nodes. If None, default coloring is used.
edge_color: Color of the edges. If None, default coloring is used.
label_position: Position of the labels on the nodes. Options are 'left', 'right', 'outer', or 'inner'.
show_values: Whether to display the values on the edges.
title: Title of the Sankey diagram.
width: Width of the Sankey diagram.
height: Height of the Sankey diagram.
**kwargs: Additional styling options passed to :class:`holoviews.element.sankey.Sankey`.
Examples:
>>> import ehrapy as ep
>>> import ehrdata as ed
>>> edata = ed.dt.ehrdata_blobs(base_timepoints=5, n_variables=1, n_observations=5, random_state=59)
>>> edata.layers["tem_data"] = edata.layers["tem_data"].astype(int)
>>> state_labels = {-2: "no", -3: "mild", -4: "moderate", -5: "severe", -6: "critical"}
>>> ep.pl.sankey_diagram_time(
... edata,
... var_name="feature_0",
... layer="tem_data",
... state_labels=state_labels,
... )
.. image:: /_static/docstring_previews/sankey_time.png
"""
if var_name not in edata.var_names:
raise KeyError(f"{var_name} not found in edata.var_names.")
if layer not in edata.layers:
raise KeyError(f"{layer} not found in edata.layers.")
flare_data = edata[:, edata.var_names == var_name, :].layers[layer][:, 0, :]
mtx = to_dense(flare_data, to_cpu_memory=True)
time_steps = edata.tem.index.tolist()
if np.issubdtype(mtx.dtype, np.floating):
flat = mtx.ravel()
for x in flat:
if not np.isfinite(x):
continue
rx = np.rint(x)
if not np.isclose(x, rx, rtol=0.0, atol=1e-8):
raise ValueError(
"Sankey requires discrete, binned states. "
f"Found non-integer value {float(x)!r}. "
"Bin first (e.g. with np.digitize) and pass integer codes."
)
states = np.unique(mtx[np.isfinite(mtx)])
else:
states = np.unique(mtx)
observed = {int(s) for s in states}
if state_labels is None:
state_labels = {state: str(state) for state in sorted(observed)}
missing = observed - set(state_labels)
if missing:
raise KeyError(f"state_labels missing keys for states: {missing}")
state_values = sorted(state_labels.keys())
state_names = [state_labels[val] for val in state_values]
if cmap is None:
cmap = hv.plotting.util.process_cmap("Category10", ncolors=len(state_names))
elif isinstance(cmap, str):
cmap = hv.plotting.util.process_cmap(cmap, ncolors=len(state_names))
else:
cmap = list(cmap)
if len(cmap) < len(state_names):
raise ValueError(f"Provided cmap has {len(cmap)} colors but {len(state_names)} are needed.")
# map each state label to a fixed color
state_color_map = {name: cmap[i] for i, name in enumerate(state_names)}
def _color_for_label(label: str) -> str:
"""Helper function to get the bare state name without suffixes."""
state = label.rsplit(" (", 1)[0]
return state_color_map.get(state, "#aaaaaa") # fall back to grey if a state name is not in the map
sources, targets, values = [], [], []
for t in range(len(time_steps) - 1):
for s_from_idx, s_from_val in enumerate(state_values):
for s_to_idx, s_to_val in enumerate(state_values):
count = np.sum((mtx[:, t] == s_from_val) & (mtx[:, t + 1] == s_to_val))
if count > 0:
source_label = f"{state_names[s_from_idx]} ({time_steps[t]})"
target_label = f"{state_names[s_to_idx]} ({time_steps[t + 1]})"
sources.append(source_label)
targets.append(target_label)
values.append(int(count))
sankey_df = pd.DataFrame({"source": sources, "target": targets, "value": values})
sankey_df["edge_color"] = sankey_df["source"].apply(_color_for_label)
# node labels in appearance order
all_labels = list(dict.fromkeys(sources + targets))
# explicit node dataset with "label" as the key dimension
node_df = pd.DataFrame({"label": all_labels})
nodes = hv.Dataset(node_df, kdims=["label"])
# map every node label to its state color
label_color_map = {lbl: _color_for_label(lbl) for lbl in all_labels}
sankey = hv.Sankey(
(sankey_df, nodes),
kdims=["source", "target"],
vdims=["value", "edge_color"],
)
opts_dict: dict[str, Any] = {}
if hv.Store.current_backend == "bokeh":
if width is not None:
opts_dict["width"] = width
if height is not None:
opts_dict["height"] = height
if node_width is not None:
opts_dict["node_width"] = node_width
if node_padding is not None:
opts_dict["node_padding"] = node_padding
if title is not None:
opts_dict["title"] = title
if node_color is None:
opts_dict["node_color"] = "label"
opts_dict["cmap"] = label_color_map
else:
opts_dict["node_color"] = node_color
if edge_color is None:
opts_dict["edge_color"] = "edge_color"
else:
opts_dict["edge_color"] = edge_color
opts_dict["node_size"] = 0
if label_position is not None:
opts_dict["label_position"] = label_position
if show_values is not None:
opts_dict["show_values"] = show_values
opts_dict.update(kwargs)
sankey = sankey.opts(**opts_dict)
return sankey