from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import array_api_compat
import numpy as np
import pandas as pd
from ehrdata.core.constants import FEATURE_TYPE_KEY
from scipy import stats
from statsmodels.stats.multitest import multipletests
from ehrapy._compat import nanmean_array_api
if TYPE_CHECKING:
from collections.abc import Sequence
from ehrdata import EHRData
def _aggregate_variable_values(
edata: EHRData,
layer: str | None = None,
*,
var_names: Sequence[str] | None = None,
agg: Literal["mean", "last", "first"] = "mean",
) -> tuple[np.ndarray, Sequence[str]]:
"""Aggregate variable values from a EHRData layer over time with specified aggregation method."""
if layer is not None:
if layer not in edata.layers:
raise KeyError(f"Layer {layer} not found in edata.layers. Available: {edata.layers.keys()}")
mtx = edata.layers[layer]
else:
mtx = edata.X
xp = array_api_compat.array_namespace(mtx)
# only include numeric or encoded variables
numeric_var_names = {
v
for i, v in enumerate(edata.var_names)
if np.issubdtype(np.array(mtx[:, i] if mtx.ndim == 2 else mtx[:, i, 0]).dtype, np.number)
}
if var_names is None:
var_names = [v for v in edata.var_names if v in numeric_var_names]
else: # when user provides the var_names
available_vars = set(edata.var_names)
missing = set(var_names) - available_vars
if missing:
raise KeyError(f"Variables not found: {missing}, {available_vars}")
non_numeric = set(var_names) - numeric_var_names
if non_numeric:
raise ValueError(f"Non-numeric variables were requested {non_numeric}")
var_name_to_idx = {v: i for i, v in enumerate(edata.var_names)}
var_indices = [var_name_to_idx[v] for v in var_names]
if mtx.ndim == 2:
n_obs, n_var = mtx.shape
mtx_2d = mtx[:, var_indices]
mtx_2d_np = array_api_compat.numpy.asarray(mtx_2d)
else:
n_obs, n_var, n_time = mtx.shape
if agg == "mean":
mtx_3d = xp.astype(mtx[:, var_indices, :], xp.float64)
mtx_2d = nanmean_array_api(xp, mtx_3d, axes=2)
mtx_2d_np = array_api_compat.numpy.asarray(mtx_2d)
elif agg == "last" or agg == "first":
mtx_sub = mtx[:, var_indices, :]
mtx_sub = xp.astype(mtx_sub, xp.float64)
valid_mask = ~xp.isnan(mtx_sub)
if agg == "last":
mtx_sub = xp.flip(mtx_sub, axis=2) # for argmax to find the last valid value
valid_mask = xp.flip(valid_mask, axis=2)
first_valid = xp.argmax(valid_mask, axis=2)
is_valid = xp.any(valid_mask, axis=2)
mtx_sub_np = array_api_compat.numpy.asarray(mtx_sub)
first_valid_np = array_api_compat.numpy.asarray(first_valid)
obs_idx = np.arange(n_obs)[:, None]
var_idx = np.arange(len(var_indices))[None, :]
mtx_2d_np = mtx_sub_np[obs_idx, var_idx, first_valid_np]
is_valid_np = array_api_compat.numpy.asarray(is_valid)
mtx_2d_np = np.where(is_valid_np, mtx_2d_np, np.nan)
else:
raise ValueError(f"Unknown aggregation method: {agg}")
return mtx_2d_np, var_names
[docs]
def variable_correlations(
edata: EHRData,
*,
layer: str | None = None,
var_names: Sequence[str] | None = None,
method: Literal["spearman", "pearson", "kendall"] = "pearson",
agg: Literal["mean", "last", "first"] = "mean",
correction_method: Literal["bonferroni", "fdr_bh", "fdr_tsbh", "holm", "none"] = "bonferroni",
alpha: float = 0.05,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Compute correlation matrix with statistical testing and multiple testing correction.
This function computes pairwise correlations between variables in the given EHRData object,
automatically handling missing values through pairwise deletion.
For 3D time-series data, values are aggregated across time before computing correlations.
Args:
edata: Central data object.
layer: Layer to extract data from. If None, `.X` will be used.
var_names: List of variable names to compute correlation of. If None, uses all numeric variables.
method: Correlation method, "spearman", "kendall" or "pearson".
agg: How to aggregate time dimension: "mean", "last" or "first".
correction_method: Multiple testing correction method:
* `'bonferroni'` conservative Bonferroni correction.
* `'fdr_bh'` Benjamini-Hochberg false discovery rate (FDR) control.
* `'fdr_tsbh'` two-stage Benjamini-Hochberg, better calibrated when many variables are truly correlated.
* `'holm'` Holm-Bonferroni correction.
* `'none'` no multiple-testing correction.
alpha: Significance threshold after correction.
Returns:
Correlation coefficient matrix, raw p-value matrix and boolean significance matrix after correction for each variable pair.
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.ehrdata_blobs(n_variables=10, n_centers=5, n_observations=200, base_timepoints=3)
>>> corr, pval, sig = ep.pp.compute_variable_correlations(
... edata, layer="tem_data", method="pearson", agg="mean", correction_method="fdr_bh", alpha=0.02
... )
"""
arr, var_names = _aggregate_variable_values(edata, layer, var_names=var_names, agg=agg)
if arr.shape[1] < 2:
raise ValueError("For correlation matrix, at least 2 numeric variables are needed.")
n_vars = len(var_names)
if method not in {"spearman", "kendall", "pearson"}:
raise ValueError(f"Unsupported correlation method: {method}")
corr_mtx = np.full((n_vars, n_vars), np.nan)
np.fill_diagonal(corr_mtx, 1.0)
pval_mtx = np.ones((n_vars, n_vars))
np.fill_diagonal(pval_mtx, 0.0)
for i in range(n_vars):
for j in range(i + 1, n_vars):
x = arr[:, i]
y = arr[:, j]
mask = ~(np.isnan(x) | np.isnan(y))
if mask.sum() < 3:
# There should be at least 3 observations that have a value for variables i and j
corr_mtx[i, j] = np.nan
corr_mtx[j, i] = np.nan
pval_mtx[i, j] = 1.0
pval_mtx[j, i] = 1.0
continue
if method == "spearman":
corr_val, pval = stats.spearmanr(x[mask], y[mask])
elif method == "kendall":
corr_val, pval = stats.kendalltau(x[mask], y[mask])
else:
corr_val, pval = stats.pearsonr(x[mask], y[mask])
corr_mtx[i, j] = corr_val
corr_mtx[j, i] = corr_val
pval_mtx[i, j] = pval
pval_mtx[j, i] = pval
corr_df = pd.DataFrame(corr_mtx, index=var_names, columns=var_names)
pval_df = pd.DataFrame(pval_mtx, index=var_names, columns=var_names)
# Multiple testing correction
if correction_method != "none":
indices = np.triu_indices(n_vars, k=1)
pvals_upper = pval_mtx[indices]
_, pval_corrected, _, _ = multipletests(pvals_upper, alpha=alpha, method=correction_method)
sig_mtx = np.zeros((n_vars, n_vars), dtype=bool)
np.fill_diagonal(sig_mtx, True)
for idx, (i, j) in enumerate(zip(*indices, strict=False)):
is_sig = pval_corrected[idx] < alpha
sig_mtx[i, j] = is_sig
sig_mtx[j, i] = is_sig
else:
sig_mtx = pval_mtx < alpha
sig_df = pd.DataFrame(sig_mtx, index=var_names, columns=var_names)
return corr_df, pval_df, sig_df