Source code for ehrapy.plot._ncp

from __future__ import annotations

from typing import TYPE_CHECKING

import holoviews as hv
import numpy as np

if TYPE_CHECKING:
    from ehrdata import EHRData

# Bokeh Category10 colours – one per component
_PALETTE = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]


def _require_ncp(edata: EHRData, key: str) -> None:
    missing = []
    if key not in edata.uns:
        missing.append(f"edata.uns[{key!r}]")
    if f"X_{key}" not in edata.obsm:
        missing.append(f"edata.obsm['X_{key}']")
    if f"{key}_loadings" not in edata.varm:
        missing.append(f"edata.varm['{key}_loadings']")
    if missing:
        raise KeyError(f"NCP results not found ({missing}). Run `ep.tl.ncp(edata, ...)` first.")


[docs] def ncp( edata: EHRData, *, key: str = "ncp", n_top: int = 12, width: int = 380, height: int = 280, ) -> hv.Layout: """Plot the factors from a Non-negative CP decomposition. Produces one row of three panels per component, laid out as ``rank × 3`` panels in a fixed three-column grid: **Panel 1 — Temporal profile** (line chart) The normalised temporal factor ``c_r`` for component *r*, plotted against the relative time axis. Each value shows how the collective influence of this component rises or falls at that time point. A rising curve indicates a condition that worsens (or becomes more prevalent) over time; a peaked curve suggests a transient event; a flat curve indicates a time-independent pattern. **Panel 2 — Top variables** (horizontal bar chart) The ``n_top`` clinical variables with the highest normalised loading ``b_r`` for component *r*, sorted by loading magnitude. These are the variables that best characterise the component — i.e. the diseases, measurements, or features that tend to co-occur in the patient sub-group captured by this component. **Panel 3 — Sample loadings** (histogram) Distribution of the patient-level loading ``a_r`` across all observations. A spike near zero with a heavy right tail means the component is *selective* — only a sub-group of patients expresses it. A broad, roughly uniform distribution means the component is *diffuse* — relevant to most patients to varying degrees. All three factor vectors are normalised to ``[0, 1]`` before plotting so that components with different absolute scales are visually comparable. Requires :func:`~ehrapy.tools.ncp` to have been run first. Args: edata: Central data object containing NCP results. key: Key under which NCP results are stored (matches ``key_added`` in :func:`~ehrapy.tools.ncp`). n_top: Number of top-loaded variables to display per component. width: Width of each individual panel in pixels. height: Height of each individual panel in pixels. Returns: HoloViews Layout with ``rank × 3`` panels arranged in three columns. Examples: >>> import ehrdata as ed, ehrapy as ep >>> edata = ed.dt.ehrdata_blobs(n_variables=8, n_centers=3, n_observations=30, base_timepoints=12) >>> ep.tl.ncp(edata, layer="tem_data", rank=3, sigmoid_transform=True) >>> ep.pl.ncp(edata, n_top=5) .. image:: /_static/docstring_previews/ncp.png """ _require_ncp(edata, key) A = np.asarray(edata.obsm[f"X_{key}"]) # (n_obs, rank) B = np.asarray(edata.varm[f"{key}_loadings"]) # (n_vars, rank) C = np.asarray(edata.uns[key]["temporal_factors"]) # (n_time, rank) rank = A.shape[1] var_names = list(edata.var_names) panels = [] for r in range(rank): col = _PALETTE[r % len(_PALETTE)] a_norm = A[:, r] / (A[:, r].max() + 1e-12) b_norm = B[:, r] / (B[:, r].max() + 1e-12) c_norm = C[:, r] / (C[:, r].max() + 1e-12) # ── (1) Temporal profile ────────────────────────────────────────────── t = np.arange(len(c_norm)) curve = hv.Curve( (t, c_norm), kdims=["Year (relative)"], vdims=["Norm. loading"], ).opts( width=width, height=height, title=f"C{r + 1}: Temporal profile", color=col, line_width=2.5, tools=["hover"], ) dots = hv.Scatter( (t, c_norm), kdims=["Year (relative)"], vdims=["Norm. loading"], ).opts(size=7, color=col, tools=["hover"]) panels.append(curve * dots) # ── (2) Top variables (horizontal bars) ─────────────────────────────── top_idx = np.argsort(b_norm)[-n_top:] # lowest → highest labels = [var_names[i][:48] for i in top_idx] vals = b_norm[top_idx] bars = hv.Bars( list(zip(labels, vals, strict=False)), kdims=["Variable"], vdims=["Norm. loading"], ).opts( width=width + 180, height=height, title=f"C{r + 1}: Top {n_top} variables", color=col, invert_axes=True, tools=["hover"], xrotation=0, ) panels.append(bars) # ── (3) Sample-loading histogram ────────────────────────────────────── counts, edges = np.histogram(a_norm, bins=40) hist = hv.Histogram( (edges, counts), kdims=["Norm. loading"], vdims=["# observations"], ).opts( width=width, height=height, title=f"C{r + 1}: Sample loadings", color=col, tools=["hover"], ) panels.append(hist) return hv.Layout(panels).cols(3)
[docs] def ncp_cluster_trajectories( edata: EHRData, *, layer: str, cluster_key: str, key: str = "ncp", n_top_diseases: int = 5, sigmoid_transform: bool = False, width: int = 520, height: int = 300, ) -> hv.Layout: """Plot mean variable trajectories per cluster, guided by NCP loadings. This function bridges unsupervised NCP decomposition and an existing cluster assignment (e.g. from ``sc.tl.leiden`` or a clinical grouping): for each cluster it identifies which NCP component best represents that cluster, selects the top variables of that component, and visualises their mean trajectories over the time axis — all from the raw data, not the low-rank approximation. **What each panel shows** One panel is drawn per unique value in ``edata.obs[cluster_key]``, arranged in two columns. The panel title shows the cluster label, the number of observations, and the dominant NCP component. Within each panel, each line is one variable. The y-axis is the mean value (or mean probability, if ``sigmoid_transform=True``) of that variable across all observations belonging to the cluster, plotted at each time point along the x-axis. Lines therefore reveal: * **Level** — which variables have the highest absolute values for this cluster (higher lines = more pronounced feature). * **Shape** — whether a variable rises, falls, peaks, or stays flat over time within the cluster. * **Co-occurrence** — variables that share a similar trajectory shape are likely driven by the same underlying mechanism. **How variables are chosen per cluster** 1. The mean patient loading ``A[mask].mean(axis=0)`` is computed for the cluster, giving a score per NCP component. 2. The component with the highest score is called the *dominant component*. 3. The ``n_top_diseases`` variables with the highest loading in that component's variable factor ``B[:, dominant]`` are selected. This means each cluster is represented by the clinical variables that the NCP model considers most characteristic of it, providing a direct link between the data-driven decomposition and the cluster structure. Requires :func:`~ehrapy.tools.ncp` to have been run first. Args: edata: Central data object. layer: Key of the 3D layer holding the raw values (shape ``n_obs × n_vars × n_time``). All values must be non-negative (use ``sigmoid_transform=True`` for logit layers, or ``np.abs`` / clipping beforehand). cluster_key: Column in ``edata.obs`` that contains cluster or group labels (any categorical or string column). key: Key under which NCP results are stored (matches ``key_added`` in :func:`~ehrapy.tools.ncp`). n_top_diseases: Number of top-loaded variables to show per cluster. sigmoid_transform: Apply a sigmoid transformation to the layer values before averaging. Set to ``True`` when the layer stores raw logits so that the y-axis represents probabilities in ``(0, 1)``. width: Width of each panel in pixels. height: Height of each panel in pixels. Returns: HoloViews Layout with one panel per cluster, arranged in two columns. Examples: >>> import ehrdata as ed, ehrapy as ep >>> edata = ed.dt.ehrdata_blobs(n_variables=8, n_centers=3, n_observations=30, base_timepoints=12) >>> ep.tl.ncp(edata, layer="tem_data", rank=3, sigmoid_transform=True) >>> ep.pl.ncp_cluster_trajectories(edata, layer="tem_data", cluster_key="cluster") .. image:: /_static/docstring_previews/ncp_cluster_trajectories.png """ _require_ncp(edata, key) if cluster_key not in edata.obs: raise KeyError(f"Cluster key {cluster_key!r} not found in edata.obs.") if layer not in edata.layers: raise KeyError(f"Layer {layer!r} not found in edata.layers.") tensor = np.asarray(edata.layers[layer], dtype=np.float64) if tensor.ndim != 3: raise ValueError(f"Layer {layer!r} must be 3D, got shape {tensor.shape}.") if sigmoid_transform: from scipy.special import expit tensor = expit(tensor) A = np.asarray(edata.obsm[f"X_{key}"]) # (n_obs, rank) B = np.asarray(edata.varm[f"{key}_loadings"]) # (n_vars, rank) var_names = list(edata.var_names) n_time = tensor.shape[2] clusters = edata.obs[cluster_key] panels = [] for cluster_id in sorted(clusters.unique()): mask = (clusters == cluster_id).values cluster_tensor = tensor[mask] # (N_cluster, n_vars, n_time) # dominant NCP component for this cluster avg_loadings = A[mask].mean(axis=0) primary_comp = int(np.argmax(avg_loadings)) _PALETTE[primary_comp % len(_PALETTE)] top_f_idx = np.argsort(B[:, primary_comp])[-n_top_diseases:] curves = [] for f_idx in top_f_idx: avg_risk = cluster_tensor[:, f_idx, :].mean(axis=0) # (n_time,) label = var_names[f_idx][:40] curve = hv.Curve( (np.arange(n_time), avg_risk), kdims=["Year (relative)"], vdims=["Mean probability"], label=label, ).opts(line_width=2, tools=["hover"]) dots = hv.Scatter( (np.arange(n_time), avg_risk), kdims=["Year (relative)"], vdims=["Mean probability"], label=label, ).opts(size=5, tools=["hover"]) curves.append(curve * dots) n_cluster = int(mask.sum()) panel = hv.Overlay(curves).opts( width=width, height=height, legend_position="right", title=(f"Cluster {cluster_id} (n={n_cluster}, dominant component={primary_comp + 1})"), ) panels.append(panel) return hv.Layout(panels).cols(2)