from __future__ import annotations
import math
from typing import TYPE_CHECKING, Literal
import numpy as np
if TYPE_CHECKING:
from ehrdata import EHRData
_MTTKRP_EINSUM = ["ijk,jr,kr->ir", "ijk,ir,kr->jr", "ijk,ir,jr->kr"]
def _nonneg_cp(
tensor: np.ndarray,
rank: int,
n_iter_max: int = 300,
tol: float = 1e-8,
random_state: int = 0,
) -> tuple[np.ndarray, list[np.ndarray]]:
"""Non-negative CP decomposition via Multiplicative Updates (Lee & Seung).
Uses three optimisations over a naive implementation:
* **einsum MTTKRP** — the matricised-tensor-times-Khatri–Rao-product is computed via ``np.einsum`` without ever forming the Khatri–Rao matrix.
* **Gram-matrix denominator** — ``kr.T @ kr`` is replaced by the Hadamard product of the other factors' Gram matrices (O(R²) instead of O(JK·R²)).
* **Algebraic error** — the reconstruction error is computed from Gram matrices and the cached MTTKRP, avoiding the full-tensor reconstruction.
Args:
tensor: Non-negative input tensor of shape ``(I, J, K)``.
rank: Number of components.
n_iter_max: Maximum number of iterations.
tol: Convergence tolerance on the relative change in reconstruction error.
random_state: Seed for initialisation.
Returns:
weights
Per-component weights of shape ``(rank,)``.
factors
List of non-negative factor matrices, one per tensor mode.
"""
xp = tensor.__array_namespace__()
rng = np.random.default_rng(random_state)
eps = 1e-12
factors = [xp.asarray(np.abs(rng.standard_normal((s, rank))) + 0.1, dtype=tensor.dtype) for s in tensor.shape]
grams = [f.T @ f for f in factors]
tensor_norm_sq = float(xp.sum(tensor * tensor))
prev_error = math.inf
for _ in range(n_iter_max):
for mode in range(3):
other = [i for i in range(3) if i != mode]
numerator = xp.einsum(
_MTTKRP_EINSUM[mode],
tensor,
factors[other[0]],
factors[other[1]],
optimize=True,
)
gram_prod = grams[other[0]] * grams[other[1]]
factors[mode] = factors[mode] * numerator / (factors[mode] @ gram_prod + eps)
grams[mode] = factors[mode].T @ factors[mode]
inner = float(xp.sum(factors[2] * numerator))
gram_all = grams[0] * grams[1] * grams[2]
err_sq = max(tensor_norm_sq - 2 * inner + float(xp.sum(gram_all)), 0.0)
error = math.sqrt(err_sq) / max(math.sqrt(tensor_norm_sq), eps)
if abs(prev_error - error) < tol:
break
prev_error = error
weights = xp.ones(rank, dtype=tensor.dtype)
for i in range(3):
norms = xp.clip(xp.linalg.norm(factors[i], axis=0), min=eps)
weights = weights * norms
factors[i] = factors[i] / norms[None, :]
return weights, factors
[docs]
def ncp(
edata: EHRData,
*,
layer: str,
rank: int = 4,
n_iter_max: int = 300,
sigmoid_transform: bool = False,
key_added: str = "ncp",
random_state: int = 0,
copy: bool = False,
) -> EHRData | None:
r"""Non-negative CP (PARAFAC) decomposition of a 3D temporal EHR layer.
CP (CANDECOMP/PARAFAC) decomposition factorises a 3-way tensor
:math:`X \in \mathbb{R}^{I \times J \times K}` into a sum of ``rank``
outer products:
.. math::
X \approx \sum_r a_r \otimes b_r \otimes c_r
where each triplet :math:`(a_r, b_r, c_r)` is a *component*:
* :math:`a_r \in \mathbb{R}^I` — **patient factor**: how strongly each observation expresses component *r*.
* :math:`b_r \in \mathbb{R}^J` — **variable factor**: which clinical variables are characteristic of component *r*.
* :math:`c_r \in \mathbb{R}^K` — **temporal factor**: how the pattern of component *r* evolves over the time axis.
The *Non-negative* variant (NCP) constrains all factors to be :math:`\geq 0`,
which is natural for count-like or probability data and yields parts-based, interpretable components (analogous to NMF for matrices).
Factors are estimated by Multiplicative Updates (Lee & Seung, 2001).
Each factor matrix is updated in closed form while the others are held fixed, cycling through the three modes until convergence:
.. math::
F_{\text{mode}} \leftarrow F_{\text{mode}} \odot
\frac{\mathcal{X}_{(\text{mode})} \, \mathrm{KR}(F_{-\text{mode}})}
{F_{\text{mode}} \, \mathrm{KR}(F_{-\text{mode}})^\top
\mathrm{KR}(F_{-\text{mode}}) + \varepsilon}
where :math:`\mathcal{X}_{(\text{mode})}` is the mode-*n* matricisation of the tensor and :math:`\mathrm{KR}` denotes the Khatri–Rao product of the remaining factor matrices.
Args:
edata: Central data object.
layer: Key of the 3D layer to decompose (shape ``n_obs × n_vars × n_time``).
All values must be non-negative (use ``sigmoid_transform=True`` for logit layers, or ``np.abs`` / clipping beforehand).
rank: Number of components (rank of the decomposition).
Each component describes one co-occurring patient sub-group, variable signature, and temporal trajectory.
n_iter_max: Maximum number of multiplicative-update iterations.
300 is sufficient for most datasets; increase if the error has not converged (check ``edata.uns[key_added]["params"]``).
sigmoid_transform: If ``True``, apply a sigmoid transformation to the layer before decomposition.
Useful when the layer contains raw logits.
key_added: Key prefix for storing results. Results are stored as:
* ``edata.obsm["X_{key_added}"]`` — patient factors, shape ``(n_obs, rank)``.
* ``edata.varm["{key_added}_loadings"]`` — variable factors, shape ``(n_vars, rank)``.
* ``edata.uns["{key_added}"]["temporal_factors"]`` — temporal factors, shape ``(n_time, rank)``.
random_state: Random seed for the factor initialisation.
copy: Whether to return a copy rather than modifying in place.
Returns:
``None`` if ``copy=False``, else a modified copy of ``edata``.
Examples:
>>> import ehrdata as ed, ehrapy as ep
>>> edata = ed.dt.ehrdata_blobs(n_variables=8, n_centers=3, n_observations=30, base_timepoints=12)
>>> ep.tl.ncp(edata, layer="tem_data", rank=3, sigmoid_transform=True)
>>> edata.obsm["X_ncp"].shape
(30, 3)
>>> edata.varm["ncp_loadings"].shape
(8, 3)
>>> edata.uns["ncp"]["temporal_factors"].shape
(12, 3)
"""
if layer not in edata.layers:
raise KeyError(f"Layer {layer!r} not found in edata.layers. Available: {list(edata.layers)}")
tensor = np.asarray(edata.layers[layer], dtype=np.float64)
if tensor.ndim != 3:
raise ValueError(f"Layer {layer!r} must be 3D (n_obs × n_vars × n_time), got shape {tensor.shape}.")
if sigmoid_transform:
from scipy.special import expit
tensor = expit(tensor)
edata = edata.copy() if copy else edata
weights, factors = _nonneg_cp(tensor, rank=rank, n_iter_max=n_iter_max, random_state=random_state)
A, B, C = factors
A = A * weights[None, :]
edata.obsm[f"X_{key_added}"] = A # (n_obs, rank)
edata.varm[f"{key_added}_loadings"] = B # (n_vars, rank)
edata.uns[key_added] = {
"params": {
"layer": layer,
"rank": rank,
"n_iter_max": n_iter_max,
"sigmoid_transform": sigmoid_transform,
},
"temporal_factors": C, # (n_time, rank)
}
return edata if copy else None