from __future__ import annotations
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal
import numpy as np
import pandas as pd
from lamin_utils import logger
from sklearn.experimental import enable_iterative_imputer # noinspection PyUnresolvedReference
from sklearn.impute import SimpleImputer
from ehrapy import settings
from ehrapy._utils_available import _check_module_importable
from ehrapy._utils_rendering import spinner
from ehrapy.anndata import check_feature_types
from ehrapy.anndata.anndata_ext import get_column_indices
if TYPE_CHECKING:
from anndata import AnnData
[docs]
@spinner("Performing explicit impute")
def explicit_impute(
adata: AnnData,
replacement: (str | int) | (dict[str, str | int]),
*,
impute_empty_strings: bool = True,
warning_threshold: int = 70,
copy: bool = False,
) -> AnnData | None:
"""Replaces all missing values in all columns or a subset of columns specified by the user with the passed replacement value.
There are two scenarios to cover:
1. Replace all missing values with the specified value.
2. Replace all missing values in a subset of columns with a specified value per column.
Args:
adata: :class:`~anndata.AnnData` object containing X to impute values in.
replacement: The value to replace missing values with. If a dictionary is provided, the keys represent column
names and the values represent replacement values for those columns.
impute_empty_strings: If True, empty strings are also replaced.
warning_threshold: Threshold of percentage of missing values to display a warning for.
copy: If True, returns a modified copy of the original AnnData object. If False, modifies the object in place.
Returns:
If copy is True, a modified copy of the original AnnData object with imputed X.
If copy is False, the original AnnData object is modified in place, and None is returned.
Examples:
Replace all missing values in adata with the value 0:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.explicit_impute(adata, replacement=0)
"""
if copy:
adata = adata.copy()
if isinstance(replacement, int) or isinstance(replacement, str):
_warn_imputation_threshold(adata, var_names=list(adata.var_names), threshold=warning_threshold)
else:
_warn_imputation_threshold(adata, var_names=replacement.keys(), threshold=warning_threshold) # type: ignore
# 1: Replace all missing values with the specified value
if isinstance(replacement, int | str):
_replace_explicit(adata.X, replacement, impute_empty_strings)
# 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named
elif isinstance(replacement, dict):
for idx, column_name in enumerate(adata.var_names):
imputation_value = _extract_impute_value(replacement, column_name)
# only replace if an explicit value got passed or could be extracted from replacement
if imputation_value:
_replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings)
else:
logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.")
else:
raise ValueError( # pragma: no cover
f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!"
)
return adata if copy else None
def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None:
"""Replace one column or whole X with a value where missing values are stored."""
if not impute_empty_strings: # pragma: no cover
impute_conditions = pd.isnull(arr)
else:
impute_conditions = np.logical_or(pd.isnull(arr), arr == "")
arr[impute_conditions] = replacement
def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -> str | int | None:
"""Extract the replacement value for a given column in the :class:`~anndata.AnnData` object
Returns:
The value to replace missing values
"""
# try to get a value for the specific column
imputation_value = replacement.get(column_name)
if imputation_value:
return imputation_value
# search for a default value in case no value was specified for that column
imputation_value = replacement.get("default")
if imputation_value: # pragma: no cover
return imputation_value
else:
return None
[docs]
@spinner("Performing simple impute")
def simple_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
*,
strategy: Literal["mean", "median", "most_frequent"] = "mean",
copy: bool = False,
warning_threshold: int = 70,
) -> AnnData | None:
"""Impute missing values in numerical data using mean/median/most frequent imputation.
If required and using mean or median strategy, the data needs to be properly encoded as this imputation requires
numerical data only.
Args:
adata: The annotated data matrix to impute missing values on.
var_names: A list of column names to apply imputation on (if None, impute all columns).
strategy: Imputation strategy to use. One of {'mean', 'median', 'most_frequent'}.
warning_threshold: Display a warning message if percentage of missing values exceeds this threshold.
copy:Whether to return a copy of `adata` or modify it inplace.
Returns:
If copy is True, a modified copy of the original AnnData object with imputed X.
If copy is False, the original AnnData object is modified in place, and None is returned.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.simple_impute(adata, strategy="median")
"""
if copy:
adata = adata.copy()
_warn_imputation_threshold(adata, var_names, threshold=warning_threshold)
if strategy in {"median", "mean"}:
try:
_simple_impute(adata, var_names, strategy)
except ValueError:
raise ValueError(
f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation "
"to certain columns using var_names parameter or use a different mode."
) from None
# most_frequent imputation works with non-numerical data as well
elif strategy == "most_frequent":
_simple_impute(adata, var_names, strategy)
else:
raise ValueError(
f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent."
) from None
return adata if copy else None
def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: str) -> None:
imputer = SimpleImputer(strategy=strategy)
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
else:
adata.X = imputer.fit_transform(adata.X)
[docs]
@spinner("Performing KNN impute")
@check_feature_types
def knn_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
*,
n_neighbors: int = 5,
copy: bool = False,
backend: Literal["scikit-learn", "faiss"] = "faiss",
warning_threshold: int = 70,
backend_kwargs: dict | None = None,
**kwargs,
) -> AnnData:
"""Imputes missing values in the input AnnData object using K-nearest neighbor imputation.
If required, the data needs to be properly encoded as this imputation requires numerical data only.
.. warning::
Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors.
However, in future versions, only `n_neighbors` will be supported. Please update your code accordingly.
Args:
adata: An annotated data matrix containing EHR data.
var_names: A list of variable names indicating which columns to impute.
If `None`, all columns are imputed. Default is `None`.
n_neighbors: Number of neighbors to use when performing the imputation.
copy: Whether to perform the imputation on a copy of the original `AnnData` object.
If `True`, the original object remains unmodified.
backend: The implementation to use for the KNN imputation.
'scikit-learn' is very slow but uses an exact KNN algorithm, whereas 'faiss'
is drastically faster but uses an approximation for the KNN graph.
In practice, 'faiss' is close enough to the 'scikit-learn' results.
warning_threshold: Percentage of missing values above which a warning is issued.
backend_kwargs: Passed to the backend.
Pass "mean", "median", or "weighted" for 'strategy' to set the imputation strategy for faiss.
See `sklearn.impute.KNNImputer <https://scikit-learn.org/stable/modules/generated/sklearn.impute.KNNImputer.html>`_ for more information on the 'scikit-learn' backend.
See `fknni.faiss.FaissImputer <https://fknni.readthedocs.io/en/latest/>`_ for more information on the 'faiss' backend.
kwargs: Gathering keyword arguments of earlier ehrapy versions for backwards compatibility. It is encouraged to use the here listed, current arguments.
Returns:
If copy is True, a modified copy of the original AnnData object with imputed X.
If copy is False, the original AnnData object is modified in place, and None is returned.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.knn_impute(adata)
"""
if copy:
adata = adata.copy()
_warn_imputation_threshold(adata, var_names, threshold=warning_threshold)
if backend not in {"scikit-learn", "faiss"}:
raise ValueError(f"Unknown backend '{backend}' for KNN imputation. Choose between 'scikit-learn' and 'faiss'.")
if backend_kwargs is None:
backend_kwargs = {}
valid_kwargs = {"n_neighbours"}
unexpected_kwargs = set(kwargs.keys()) - valid_kwargs
if unexpected_kwargs:
raise ValueError(f"Unexpected keyword arguments: {unexpected_kwargs}.")
if "n_neighbours" in kwargs.keys():
n_neighbors = kwargs["n_neighbours"]
warnings.warn(
"ehrapy will use 'n_neighbors' instead of 'n_neighbours'. Please update your code.",
DeprecationWarning,
stacklevel=1,
)
if _check_module_importable("sklearnex"): # pragma: no cover
from sklearnex import patch_sklearn, unpatch_sklearn
patch_sklearn()
try:
if np.issubdtype(adata.X.dtype, np.number):
_knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs)
else:
# Raise exception since non-numerical data can not be imputed using KNN Imputation
raise ValueError(
"Can only impute numerical data. Try to restrict imputation to certain columns using "
"var_names parameter or perform an encoding of your data."
)
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
logger.error("Check that your matrix does not contain any NaN only columns!")
raise
if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()
return adata if copy else None
def _knn_impute(
adata: AnnData,
var_names: Iterable[str] | None,
n_neighbors: int,
backend: Literal["scikit-learn", "faiss"],
**kwargs,
) -> None:
if backend == "scikit-learn":
from sklearn.impute import KNNImputer
imputer = KNNImputer(n_neighbors=n_neighbors, **kwargs)
else:
from fknni import FaissImputer
imputer = FaissImputer(n_neighbors=n_neighbors, **kwargs)
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
column_indices = get_column_indices(adata, var_names)
adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices])
# this is required since X dtype has to be numerical in order to correctly round floats
adata.X = adata.X.astype("float64")
else:
adata.X = imputer.fit_transform(adata.X)
[docs]
@spinner("Performing miss-forest impute")
def miss_forest_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
*,
num_initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean",
max_iter: int = 3,
n_estimators: int = 100,
random_state: int = 0,
warning_threshold: int = 70,
copy: bool = False,
) -> AnnData | None:
"""Impute data using the MissForest strategy.
This function uses the MissForest strategy to impute missing values in the data matrix of an AnnData object.
The strategy works by fitting a random forest model on each feature containing missing values,
and using the trained model to predict the missing values.
See https://academic.oup.com/bioinformatics/article/28/1/112/219101.
If required, the data needs to be properly encoded as this imputation requires numerical data only.
Args:
adata: The AnnData object to use MissForest Imputation on.
var_names: Iterable of columns to impute
num_initial_strategy: The initial strategy to replace all missing numerical values with.
max_iter: The maximum number of iterations if the stop criterion has not been met yet.
n_estimators: The number of trees to fit for every missing variable. Has a big effect on the run time.
Decrease for faster computations.
random_state: The random seed for the initialization.
warning_threshold: Threshold of percentage of missing values to display a warning for.
copy: Whether to return a copy or act in place.
Returns:
If copy is True, a modified copy of the original AnnData object with imputed X.
If copy is False, the original AnnData object is modified in place, and None is returned.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.miss_forest_impute(adata)
"""
if copy:
adata = adata.copy()
if var_names is None:
_warn_imputation_threshold(adata, list(adata.var_names), threshold=warning_threshold)
elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
_warn_imputation_threshold(adata, var_names, threshold=warning_threshold)
if _check_module_importable("sklearnex"): # pragma: no cover
from sklearnex import patch_sklearn, unpatch_sklearn
patch_sklearn()
from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier
from sklearn.impute import IterativeImputer
try:
imp_num = IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
)
# initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data
IterativeImputer(
estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy="most_frequent",
max_iter=max_iter,
random_state=random_state,
)
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): # type: ignore
num_indices = get_column_indices(adata, var_names)
else:
num_indices = get_column_indices(adata, adata.var_names)
if set(num_indices).issubset(_get_non_numerical_column_indices(adata.X)):
raise ValueError(
"Can only impute numerical data. Try to restrict imputation to certain columns using "
"var_names parameter."
)
# this step is the most expensive one and might extremely slow down the impute process
if num_indices:
adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices])
else:
raise ValueError("Cannot find any feature to perform imputation")
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
logger.error("Check that your matrix does not contain any NaN only columns!")
raise
if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()
return adata if copy else None
[docs]
@spinner("Performing mice-forest impute")
@check_feature_types
def mice_forest_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
*,
warning_threshold: int = 70,
save_all_iterations_data: bool = True,
random_state: int | None = None,
inplace: bool = False,
iterations: int = 5,
variable_parameters: dict | None = None,
verbose: bool = False,
copy: bool = False,
) -> AnnData | None:
"""Impute data using the miceforest.
See https://github.com/AnotherSamWilson/miceforest
Fast, memory efficient Multiple Imputation by Chained Equations (MICE) with lightgbm.
If required, the data needs to be properly encoded as this imputation requires numerical data only.
Args:
adata: The AnnData object containing the data to impute.
var_names: A list of variable names to impute. If None, impute all variables.
warning_threshold: Threshold of percentage of missing values to display a warning for.
save_all_iterations_data: Whether to save all imputed values from all iterations or just the latest.
Saving all iterations allows for additional plotting, but may take more memory.
random_state: The random state ensures script reproducibility.
inplace: If True, modify the input AnnData object in-place and return None.
If False, return a copy of the modified AnnData object. Default is False.
iterations: The number of iterations to run.
variable_parameters: Model parameters can be specified by variable here.
Keys should be variable names or indices, and values should be a dict of parameter which should apply to that variable only.
verbose: Whether to print information about the imputation process.
copy: Whether to return a copy of the AnnData object or modify it in-place.
Returns:
If copy is True, a modified copy of the original AnnData object with imputed X.
If copy is False, the original AnnData object is modified in place, and None is returned.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.mice_forest_impute(adata)
"""
if copy:
adata = adata.copy()
_warn_imputation_threshold(adata, var_names, threshold=warning_threshold)
try:
if np.issubdtype(adata.X.dtype, np.number):
_miceforest_impute(
adata,
var_names,
save_all_iterations_data,
random_state,
inplace,
iterations,
variable_parameters,
verbose,
)
else:
raise ValueError(
"Can only impute numerical data. Try to restrict imputation to certain columns using "
"var_names parameter."
)
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
logger.warning("Check that your matrix does not contain any NaN only columns!")
raise
return adata if copy else None
def _miceforest_impute(
adata, var_names, save_all_iterations_data, random_state, inplace, iterations, variable_parameters, verbose
) -> None:
import miceforest as mf
data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
data_df = data_df.apply(pd.to_numeric, errors="coerce")
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):
column_indices = get_column_indices(adata, var_names)
selected_columns = data_df.iloc[:, column_indices]
selected_columns = selected_columns.reset_index(drop=True)
kernel = mf.ImputationKernel(
selected_columns,
num_datasets=1,
save_all_iterations_data=save_all_iterations_data,
random_state=random_state,
)
kernel.mice(iterations=iterations, variable_parameters=variable_parameters or {}, verbose=verbose)
data_df.iloc[:, column_indices] = kernel.complete_data(dataset=0, inplace=inplace)
else:
data_df = data_df.reset_index(drop=True)
kernel = mf.ImputationKernel(
data_df, num_datasets=1, save_all_iterations_data=save_all_iterations_data, random_state=random_state
)
kernel.mice(iterations=iterations, variable_parameters=variable_parameters or {}, verbose=verbose)
data_df = kernel.complete_data(dataset=0, inplace=inplace)
adata.X = data_df.values
def _warn_imputation_threshold(adata: AnnData, var_names: Iterable[str] | None, threshold: int = 75) -> dict[str, int]:
"""Warns the user if the more than $threshold percent had to be imputed.
Args:
adata: The AnnData object to check
var_names: The var names which were imputed.
threshold: A percentage value from 0 to 100 used as minimum.
"""
try:
adata.var["missing_values_pct"]
except KeyError:
from ehrapy.preprocessing import qc_metrics
qc_metrics(adata)
used_var_names = set(adata.var_names) if var_names is None else set(var_names)
thresholded_var_names = set(adata.var[adata.var["missing_values_pct"] > threshold].index) & set(used_var_names)
var_name_to_pct: dict[str, int] = {}
for var in thresholded_var_names:
var_name_to_pct[var] = adata.var["missing_values_pct"].loc[var]
logger.warning(f"Feature '{var}' had more than {var_name_to_pct[var]:.2f}% missing values!")
return var_name_to_pct
def _get_non_numerical_column_indices(arr: np.ndarray) -> set:
"""Return indices of columns, that contain at least one non-numerical value that is not "Nan"."""
def _is_float_or_nan(val) -> bool: # pragma: no cover
"""Check whether a given item is a float or np.nan"""
try:
_ = float(val)
return not isinstance(val, bool)
except (ValueError, TypeError):
return False
def _is_float_or_nan_row(row) -> list[bool]: # pragma: no cover
return [_is_float_or_nan(val) for val in row]
mask = np.apply_along_axis(_is_float_or_nan_row, 0, arr)
_, column_indices = np.where(~mask)
return set(column_indices)