from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import holoviews as hv
import numpy as np
import pandas as pd
import ehrapy as ep
if TYPE_CHECKING:
from collections.abc import Sequence
from ehrdata import EHRData
[docs]
def variable_correlations(
edata: EHRData,
*,
layer: str,
var_names: Sequence[str] | None = None,
method: Literal["spearman", "pearson", "kendall"] = "pearson",
agg: Literal["mean", "last", "first"] = "mean",
correction_method: Literal["bonferroni", "fdr_bh", "fdr_tsbh", "holm", "none"] = "bonferroni",
alpha: float = 0.05,
width: int = 600,
height: int = 600,
cmap: str = "RdBu_r",
show_values: bool = True,
title: str | None = None,
) -> hv.HeatMap | hv.Overlay:
"""Plot variable correlations as heatmap.
Computes a correlation matrix (Pearson or Spearman) for the selected variables from the given layer.
If the layer contains a time dimension, values are first aggregated per variable across time.
Cells are annotated with the correlation coefficient.
An asterisk marks statistically significant correlations after correction.
Args:
edata: Central data object.
layer: Layer to extract data from.
var_names: List of variable names to compute correlation of. If None, uses all numeric variables.
method: Correlation method: "spearman", "kendall" or "pearson".
agg: How to aggregate time dimension: "mean", "last" or "first".
correction_method: Multiple testing correction method:
* `'bonferroni'` conservative Bonferroni correction.
* `'fdr_bh'` Benjamini-Hochberg false discovery rate (FDR) control.
* `'fdr_tsbh'` two-stage Benjamini-Hochberg, better calibrated when many variables are truly correlated.
* `'holm'` Holm-Bonferroni correction.
* `'none'` no multiple-testing correction.
alpha: Significance threshold after correction.
width: Plot width in pixels.
height: Plot height in pixels.
cmap: Colormap for the heatmap.
show_values: If True, display correlation values on cells.
title: Set the title of the plot.
Returns:
:class:`holoviews.element.HeatMap` (if show_values=False) or :class:`holoviews.core.overlay.Overlay` (if show_values=True).
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.ehrdata_blobs(n_variables=10, n_centers=5, n_observations=200, base_timepoints=3)
>>> ep.pl.variable_correlations(
... edata, layer="tem_data", method="pearson", agg="mean", correction_method="fdr_bh", width=700
... )
.. image:: /_static/docstring_previews/variable_correlations_heatmap.png
"""
corr_df, _, sig_df = ep.pp.variable_correlations(
edata=edata,
layer=layer,
var_names=var_names,
method=method,
agg=agg,
correction_method=correction_method,
alpha=alpha,
)
corr_long = corr_df.stack(future_stack=True).rename("correlation")
sig_long = sig_df.stack(future_stack=True).rename("significant")
heatmap_df = pd.concat([corr_long, sig_long], axis=1).reset_index()
heatmap_df.columns = ["variable1", "variable2", "correlation", "significant"]
is_nan = heatmap_df["correlation"].isna()
is_diag = heatmap_df["variable1"] == heatmap_df["variable2"]
heatmap_df["label"] = np.where(
is_nan,
"N/A",
heatmap_df["correlation"].map("{:.2f}".format) + np.where(heatmap_df["significant"] & ~is_diag, "*", ""),
)
# for NaN correlations the neutral color will be shown on the colorscale
heatmap_df["correlation"] = heatmap_df["correlation"].fillna(0)
if title is None:
title = f"{method.capitalize()} Correlation Matrix "
if correction_method != "none":
title += f"(correction method: {correction_method}, alpha={alpha})"
heatmap = hv.HeatMap(heatmap_df, kdims=["variable1", "variable2"], vdims=["correlation", "label"])
heatmap = heatmap.opts(
width=width,
height=height,
cmap=cmap,
clim=(-1, 1),
colorbar=True,
title=title,
xrotation=45,
toolbar="above",
fontscale=1.2,
xlabel="",
ylabel="",
)
if show_values:
labels = hv.Labels(heatmap_df, kdims=["variable1", "variable2"], vdims="label").opts(
text_font_size="10pt",
text_color="black",
text_align="center",
)
overlay = heatmap * labels
return overlay
return heatmap
[docs]
def variable_dependencies(
edata: EHRData,
*,
layer: str,
var_names: Sequence[str] | None = None,
method: Literal["spearman", "pearson", "kendall"] = "pearson",
agg: Literal["mean", "last", "first"] = "mean",
correction_method: Literal["bonferroni", "fdr_bh", "fdr_tsbh", "holm", "none"] = "bonferroni",
alpha: float = 0.05,
abs_correlation_threshold: float = 0.3,
only_significant: bool = True,
width: int = 600,
height: int = 600,
cmap: str = "RdBu_r",
title: str | None = None,
) -> hv.Chord:
"""Plot correlation dependencies as a chord diagram.
Computes pairwise correlations between selected variables from layer and visualizes them as a chord diagram.
If the layer contains a time dimension, values are aggregated per variable before correlation is computed.
Args:
edata: Central data object.
layer: Layer to extract data from.
var_names: List of variable names to compute correlation of. If None, uses all numeric variables.
method: Correlation method: "spearman", "kendall" or "pearson".
agg: How to aggregate time dimension: "mean", "last" or "first".
correction_method: Multiple testing correction method:
* `'bonferroni'` conservative Bonferroni correction.
* `'fdr_bh'` Benjamini-Hochberg false discovery rate (FDR) control.
* `'fdr_tsbh'` two-stage Benjamini-Hochberg, better calibrated when many variables are truly correlated.
* `'holm'` Holm-Bonferroni correction.
* `'none'` no multiple-testing correction.
alpha: Significance threshold after correction.
abs_correlation_threshold: Minimum absolute correlation to show a chord.
only_significant: If True, only show significant correlations.
width: Plot width in pixels.
height: Plot height in pixels.
cmap: Colormap for the chord diagram.
title: Set the title of the plot.
Returns:
:class:`holoviews.element.Chord` object.
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.ehrdata_blobs(n_variables=10, n_centers=5, n_observations=200, base_timepoints=3)
>>> ep.pl.variable_dependencies(
... edata, layer="tem_data", method="pearson", agg="mean", correction_method="fdr_bh"
... )
.. image:: /_static/docstring_previews/variable_dependencies_chord.png
"""
if not 0 <= abs_correlation_threshold <= 1:
raise ValueError(f"min_correlation must be between 0 and 1, got {abs_correlation_threshold}")
corr_df, _, sig_df = ep.pp.variable_correlations(
edata=edata,
layer=layer,
var_names=var_names,
method=method,
agg=agg,
correction_method=correction_method,
alpha=alpha,
)
corr_long = corr_df.stack(future_stack=True).rename("correlation")
sig_long = sig_df.stack(future_stack=True).rename("significant")
edges_df = pd.concat([corr_long, sig_long], axis=1).reset_index()
edges_df.columns = ["variable1", "variable2", "correlation", "significant"]
variables = corr_df.columns.to_list()
var_to_idx = {var: idx for idx, var in enumerate(variables)}
edges_df["source"] = edges_df["variable1"].map(var_to_idx)
edges_df["target"] = edges_df["variable2"].map(var_to_idx)
edges_df = edges_df[edges_df["source"] < edges_df["target"]]
edges_df = edges_df.dropna(subset="correlation")
edges_df["value"] = edges_df["correlation"].abs()
if only_significant:
edges_df = edges_df[edges_df["significant"]]
edges_df = edges_df[edges_df["value"] >= abs_correlation_threshold]
edges_df = edges_df[["source", "target", "value", "correlation"]].reset_index(drop=True)
if len(edges_df) == 0:
raise ValueError(
f"No correlations meet criteria (minimum absolute correlation to plot = {abs_correlation_threshold})."
f"\nTry lowering abs_correlation_threshold or setting only_significant=False."
)
nodes_df = pd.DataFrame({"index": range(len(variables)), "name": variables})
if title is None:
title = f"{method.capitalize()} Correlation Chord Diagram "
if correction_method != "none":
title += f"({correction_method}, alpha={alpha})"
chord = hv.Chord((edges_df, hv.Dataset(nodes_df, "index")))
chord = chord.opts(
width=width,
height=height,
node_color="index",
edge_color="correlation",
colorbar=True,
labels="name",
node_size=15,
clim=(-1, 1),
title=title,
cmap=cmap,
)
return chord