Source code for ehrapy.preprocessing._balanced_sampling
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from anndata import AnnData
from anndata import AnnData
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
[docs]
def balanced_sample(
adata: AnnData,
*,
key: str,
random_state: int = 0,
method: Literal["RandomUnderSampler", "RandomOverSampler"] = "RandomUnderSampler",
sampler_kwargs: dict = None,
copy: bool = False,
) -> AnnData:
"""Balancing groups in the dataset.
Balancing groups in the dataset based on group members in `.obs[key]` using the `imbalanced-learn <https://imbalanced-learn.org/stable/index.html>`_ package.
Currently, supports `RandomUnderSampler` and `RandomOverSampler`.
Note that `RandomOverSampler <https://imbalanced-learn.org/stable/references/generated/imblearn.over_sampling.RandomOverSampler.html>`_
only replicates observations of the minority groups, which distorts several downstream analyses, very prominently neighborhood calculations and downstream analyses depending on that.
The `RandomUnderSampler <https://imbalanced-learn.org/stable/references/generated/imblearn.under_sampling.RandomUnderSampler.html>`_
by default undersamples the majority group without replacement, not causing this issues of replicated observations.
Args:
adata: The annotated data matrix of shape `n_obs` × `n_vars`.
key: The key in `adata.obs` that contains the group information.
random_state: Random seed.
method: The method to use for balancing.
sampler_kwargs: Keyword arguments for the sampler, see the `imbalanced-learn` documentation for options.
copy: If True, return a copy of the balanced data.
Returns:
A new `AnnData` object, with the balanced groups.
Examples:
>>> import ehrapy as ep
>>> adata = ep.data.diabetes_130_fairlearn(columns_obs_only=["age"])
>>> adata.obs.age.value_counts()
age
'Over 60 years' 68541
'30-60 years' 30716
'30 years or younger' 2509
>>> adata_balanced = ep.pp.sample(adata, key="age")
>>> adata_balanced.obs.age.value_counts()
age
'30 years or younger' 2509
'30-60 years' 2509
'Over 60 years' 2509
"""
if not isinstance(adata, AnnData):
raise ValueError(f"Input data is not an AnnData object: type of {adata}, is {type(adata)}")
if sampler_kwargs is None:
sampler_kwargs = {"random_state": random_state}
else:
sampler_kwargs["random_state"] = random_state
if method == "RandomUnderSampler":
sampler = RandomUnderSampler(**sampler_kwargs)
elif method == "RandomOverSampler":
sampler = RandomOverSampler(**sampler_kwargs)
else:
raise ValueError(f"Unknown sampling method: {method}")
if key in adata.obs.keys():
use_label = adata.obs[key]
else:
raise ValueError(f"key not in adata.obs: {key}")
sampler.fit_resample(adata.X, use_label)
if copy:
return adata[sampler.sample_indices_].copy()
else:
adata._inplace_subset_obs(sampler.sample_indices_)