Source code for ehrapy.preprocessing._imputation

from __future__ import annotations

import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
from lamin_utils import logger
from rich import print
from rich.progress import Progress, SpinnerColumn
from sklearn.experimental import enable_iterative_imputer  # required to enable IterativeImputer (experimental feature)
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OrdinalEncoder

from ehrapy import settings
from ehrapy.anndata import check_feature_types
from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY
from ehrapy.anndata.anndata_ext import _get_column_indices
from ehrapy.core._tool_available import _check_module_importable

if TYPE_CHECKING:
    from anndata import AnnData


[docs] def explicit_impute( adata: AnnData, replacement: (str | int) | (dict[str, str | int]), *, impute_empty_strings: bool = True, warning_threshold: int = 70, copy: bool = False, ) -> AnnData: """Replaces all missing values in all columns or a subset of columns specified by the user with the passed replacement value. There are two scenarios to cover: 1. Replace all missing values with the specified value. 2. Replace all missing values in a subset of columns with a specified value per column. Args: adata: :class:`~anndata.AnnData` object containing X to impute values in. replacement: The value to replace missing values with. If a dictionary is provided, the keys represent column names and the values represent replacement values for those columns. impute_empty_strings: If True, empty strings are also replaced. warning_threshold: Threshold of percentage of missing values to display a warning for. copy: If True, returns a modified copy of the original AnnData object. If False, modifies the object in place. Returns: If copy is True, a modified copy of the original AnnData object with imputed X. If copy is False, the original AnnData object is modified in place. Examples: Replace all missing values in adata with the value 0: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.explicit_impute(adata, replacement=0) """ if copy: # pragma: no cover adata = adata.copy() if isinstance(replacement, int) or isinstance(replacement, str): _warn_imputation_threshold(adata, var_names=list(adata.var_names), threshold=warning_threshold) else: _warn_imputation_threshold(adata, var_names=replacement.keys(), threshold=warning_threshold) # type: ignore with Progress( "[progress.description]{task.description}", SpinnerColumn(), refresh_per_second=1500, ) as progress: progress.add_task("[blue]Running explicit imputation", total=1) # 1: Replace all missing values with the specified value if isinstance(replacement, (int, str)): _replace_explicit(adata.X, replacement, impute_empty_strings) # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named elif isinstance(replacement, dict): for idx, column_name in enumerate(adata.var_names): imputation_value = _extract_impute_value(replacement, column_name) # only replace if an explicit value got passed or could be extracted from replacement if imputation_value: _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) else: logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") else: raise ValueError( # pragma: no cover f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" ) if copy: return adata
def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None: """Replace one column or whole X with a value where missing values are stored.""" if not impute_empty_strings: # pragma: no cover impute_conditions = pd.isnull(arr) else: impute_conditions = np.logical_or(pd.isnull(arr), arr == "") arr[impute_conditions] = replacement def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -> str | int | None: """Extract the replacement value for a given column in the :class:`~anndata.AnnData` object Returns: The value to replace missing values """ # try to get a value for the specific column imputation_value = replacement.get(column_name) if imputation_value: return imputation_value # search for a default value in case no value was specified for that column imputation_value = replacement.get("default") if imputation_value: # pragma: no cover return imputation_value else: return None
[docs] def simple_impute( adata: AnnData, var_names: Iterable[str] | None = None, *, strategy: Literal["mean", "median", "most_frequent"] = "mean", copy: bool = False, warning_threshold: int = 70, ) -> AnnData: """Impute missing values in numerical data using mean/median/most frequent imputation. Args: adata: The annotated data matrix to impute missing values on. var_names: A list of column names to apply imputation on (if None, impute all columns). strategy: Imputation strategy to use. One of {'mean', 'median', 'most_frequent'}. warning_threshold: Display a warning message if percentage of missing values exceeds this threshold. copy:Whether to return a copy of `adata` or modify it inplace. Returns: An updated AnnData object with imputed values. Raises: ValueError: If the selected imputation strategy is not applicable to the data. ValueError: If an unknown imputation strategy is provided. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.simple_impute(adata, strategy="median") """ if copy: adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) with Progress( "[progress.description]{task.description}", SpinnerColumn(), refresh_per_second=1500, ) as progress: progress.add_task(f"[blue]Running simple imputation with {strategy}", total=1) # Imputation using median and mean strategy works with numerical data only if strategy in {"median", "mean"}: try: _simple_impute(adata, var_names, strategy) except ValueError: raise ValueError( f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation" "to certain columns using var_names parameter or use a different mode." ) from None # most_frequent imputation works with non-numerical data as well elif strategy == "most_frequent": _simple_impute(adata, var_names, strategy) # unknown simple imputation strategy else: raise ValueError( # pragma: no cover f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." ) from None if copy: return adata
def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: str) -> None: imputer = SimpleImputer(strategy=strategy) if isinstance(var_names, Iterable): column_indices = _get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) else: adata.X = imputer.fit_transform(adata.X)
[docs] @check_feature_types def knn_impute( adata: AnnData, var_names: Iterable[str] | None = None, *, n_neighbors: int = 5, copy: bool = False, backend: Literal["scikit-learn", "faiss"] = "faiss", warning_threshold: int = 70, backend_kwargs: dict | None = None, **kwargs, ) -> AnnData: """Imputes missing values in the input AnnData object using K-nearest neighbor imputation. When using KNN Imputation with mixed data (non-numerical and numerical), encoding using ordinal encoding is required since KNN Imputation can only work on numerical data. The encoding itself is just a utility and will be undone once imputation ran successfully. .. warning:: Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors. However, in future versions, only `n_neighbors` will be supported. Please update your code accordingly. Args: adata: An annotated data matrix containing EHR data. var_names: A list of variable names indicating which columns to impute. If `None`, all columns are imputed. Default is `None`. n_neighbors: Number of neighbors to use when performing the imputation. copy: Whether to perform the imputation on a copy of the original `AnnData` object. If `True`, the original object remains unmodified. backend: The implementation to use for the KNN imputation. 'scikit-learn' is very slow but uses an exact KNN algorithm, whereas 'faiss' is drastically faster but uses an approximation for the KNN graph. In practice, 'faiss' is close enough to the 'scikit-learn' results. warning_threshold: Percentage of missing values above which a warning is issued. backend_kwargs: Passed to the backend. Pass "mean", "median", or "weighted" for 'strategy' to set the imputation strategy for faiss. See `sklearn.impute.KNNImputer <https://scikit-learn.org/stable/modules/generated/sklearn.impute.KNNImputer.html>`_ for more information on the 'scikit-learn' backend. See `fknni.faiss.FaissImputer <https://fknni.readthedocs.io/en/latest/>`_ for more information on the 'faiss' backend. kwargs: Gathering keyword arguments of earlier ehrapy versions for backwards compatibility. It is encouraged to use the here listed, current arguments. Returns: An updated AnnData object with imputed values. Raises: ValueError: If the input data matrix contains only categorical (non-numeric) values. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.ad.infer_feature_types(adata) >>> ep.pp.knn_impute(adata) """ if copy: adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) if backend not in {"scikit-learn", "faiss"}: raise ValueError(f"Unknown backend '{backend}' for KNN imputation. Choose between 'scikit-learn' and 'faiss'.") if backend_kwargs is None: backend_kwargs = {} valid_kwargs = {"n_neighbours"} unexpected_kwargs = set(kwargs.keys()) - valid_kwargs if unexpected_kwargs: raise ValueError(f"Unexpected keyword arguments: {unexpected_kwargs}.") if "n_neighbours" in kwargs.keys(): n_neighbors = kwargs["n_neighbours"] warnings.warn( "ehrapy will use 'n_neighbors' instead of 'n_neighbours'. Please update your code.", DeprecationWarning, stacklevel=1, ) if _check_module_importable("sklearnex"): # pragma: no cover from sklearnex import patch_sklearn, unpatch_sklearn patch_sklearn() try: with Progress( "[progress.description]{task.description}", SpinnerColumn(), refresh_per_second=1500, ) as progress: progress.add_task("[blue]Running KNN imputation", total=1) # numerical only data needs no encoding since KNN Imputation can be applied directly if np.issubdtype(adata.X.dtype, np.number): _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) else: # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation enc = OrdinalEncoder() column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using KNN imputation _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) # imputing on encoded columns might result in float numbers; those can not be decoded # cast them to int to ensure they can be decoded adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int) # knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding adata.X = adata.X.astype("object") # decode ordinal encoding to obtain imputed original data adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() if copy: return adata
def _knn_impute( adata: AnnData, var_names: Iterable[str] | None, n_neighbors: int, backend: Literal["scikit-learn", "faiss"], **kwargs, ) -> None: if backend == "scikit-learn": from sklearn.impute import KNNImputer imputer = KNNImputer(n_neighbors=n_neighbors, **kwargs) else: from fknni import FaissImputer imputer = FaissImputer(n_neighbors=n_neighbors, **kwargs) if isinstance(var_names, Iterable): column_indices = _get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) # this is required since X dtype has to be numerical in order to correctly round floats adata.X = adata.X.astype("float64") else: adata.X = imputer.fit_transform(adata.X)
[docs] def miss_forest_impute( adata: AnnData, var_names: dict[str, list[str]] | list[str] | None = None, *, num_initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", max_iter: int = 3, n_estimators=100, random_state: int = 0, warning_threshold: int = 70, copy: bool = False, ) -> AnnData: """Impute data using the MissForest strategy. This function uses the MissForest strategy to impute missing values in the data matrix of an AnnData object. The strategy works by fitting a random forest model on each feature containing missing values, and using the trained model to predict the missing values. See https://academic.oup.com/bioinformatics/article/28/1/112/219101. This requires the computation of which columns in X contain numerical only (including NaNs) and which contain non-numerical data. Args: adata: The AnnData object to use MissForest Imputation on. var_names: List of columns to impute or a dict with two keys ('numerical' and 'non_numerical') indicating which var contain mixed data and which numerical data only. num_initial_strategy: The initial strategy to replace all missing numerical values with. max_iter: The maximum number of iterations if the stop criterion has not been met yet. n_estimators: The number of trees to fit for every missing variable. Has a big effect on the run time. Decrease for faster computations. random_state: The random seed for the initialization. warning_threshold: Threshold of percentage of missing values to display a warning for. copy: Whether to return a copy or act in place. Returns: The imputed (but unencoded) AnnData object. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.miss_forest_impute(adata) """ if copy: # pragma: no cover adata = adata.copy() if var_names is None: _warn_imputation_threshold(adata, list(adata.var_names), threshold=warning_threshold) elif isinstance(var_names, dict): _warn_imputation_threshold(adata, var_names.keys(), threshold=warning_threshold) # type: ignore elif isinstance(var_names, list): _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) if _check_module_importable("sklearnex"): # pragma: no cover from sklearnex import patch_sklearn, unpatch_sklearn patch_sklearn() from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier from sklearn.impute import IterativeImputer try: with Progress( "[progress.description]{task.description}", SpinnerColumn(), refresh_per_second=1500, ) as progress: progress.add_task("[blue]Running MissForest imputation", total=1) if settings.n_jobs == 1: # pragma: no cover logger.warning("The number of jobs is only 1. To decrease the runtime set ep.settings.n_jobs=-1.") imp_num = IterativeImputer( estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), initial_strategy=num_initial_strategy, max_iter=max_iter, random_state=random_state, ) # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data imp_cat = IterativeImputer( estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), initial_strategy="most_frequent", max_iter=max_iter, random_state=random_state, ) if isinstance(var_names, list): var_indices = _get_column_indices(adata, var_names) # type: ignore adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices]) elif isinstance(var_names, dict) or var_names is None: if var_names: try: non_num_vars = var_names["non_numerical"] num_vars = var_names["numerical"] except KeyError: # pragma: no cover raise ValueError( "One or both of your keys provided for var_names are unknown. Only " "numerical and non_numerical are available!" ) from None non_num_indices = _get_column_indices(adata, non_num_vars) num_indices = _get_column_indices(adata, num_vars) # infer non numerical and numerical indices automatically else: non_num_indices_set = _get_non_numerical_column_indices(adata.X) num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set] non_num_indices = list(non_num_indices_set) # encode all non numerical columns if non_num_indices: enc = OrdinalEncoder() adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices]) # this step is the most expensive one and might extremely slow down the impute process if num_indices: adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) if non_num_indices: adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices]) adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices]) except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() if copy: return adata
[docs] @check_feature_types def mice_forest_impute( adata: AnnData, var_names: Iterable[str] | None = None, *, warning_threshold: int = 70, save_all_iterations: bool = True, random_state: int | None = None, inplace: bool = False, iterations: int = 5, variable_parameters: dict | None = None, verbose: bool = False, copy: bool = False, ) -> AnnData: """Impute data using the miceforest. See https://github.com/AnotherSamWilson/miceforest Fast, memory efficient Multiple Imputation by Chained Equations (MICE) with lightgbm. Args: adata: The AnnData object containing the data to impute. var_names: A list of variable names to impute. If None, impute all variables. warning_threshold: Threshold of percentage of missing values to display a warning for. save_all_iterations: Whether to save all imputed values from all iterations or just the latest. Saving all iterations allows for additional plotting, but may take more memory. random_state: The random state ensures script reproducibility. inplace: If True, modify the input AnnData object in-place and return None. If False, return a copy of the modified AnnData object. Default is False. iterations: The number of iterations to run. variable_parameters: Model parameters can be specified by variable here. Keys should be variable names or indices, and values should be a dict of parameter which should apply to that variable only. verbose: Whether to print information about the imputation process. copy: Whether to return a copy of the AnnData object or modify it in-place. Returns: The imputed AnnData object. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.ad.infer_feature_types(adata) >>> ep.pp.mice_forest_impute(adata) """ if copy: adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) try: with Progress( "[progress.description]{task.description}", SpinnerColumn(), refresh_per_second=1500, ) as progress: progress.add_task("[blue]Running miceforest", total=1) if np.issubdtype(adata.X.dtype, np.number): _miceforest_impute( adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose, ) else: # ordinal encoding is used since non-numerical data can not be imputed using miceforest enc = OrdinalEncoder() column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) # impute the data using miceforest _miceforest_impute( adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose, ) adata.X = adata.X.astype("object") # decode ordinal encoding to obtain imputed original data adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.warning("Check that your matrix does not contain any NaN only columns!") raise return adata
def _miceforest_impute( adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose ) -> None: import miceforest as mf if isinstance(var_names, Iterable): column_indices = _get_column_indices(adata, var_names) # Create kernel. kernel = mf.ImputationKernel( adata.X[::, column_indices], datasets=1, save_all_iterations=save_all_iterations, random_state=random_state ) kernel.mice(iterations=iterations, variable_parameters=variable_parameters, verbose=verbose) adata.X[::, column_indices] = kernel.complete_data(dataset=0, inplace=inplace) else: # Create kernel. kernel = mf.ImputationKernel( adata.X, datasets=1, save_all_iterations=save_all_iterations, random_state=random_state ) kernel.mice(iterations=iterations, variable_parameters=variable_parameters, verbose=verbose) adata.X = kernel.complete_data(dataset=0, inplace=inplace) def _warn_imputation_threshold(adata: AnnData, var_names: Iterable[str] | None, threshold: int = 75) -> dict[str, int]: """Warns the user if the more than $threshold percent had to be imputed. Args: adata: The AnnData object to check var_names: The var names which were imputed. threshold: A percentage value from 0 to 100 used as minimum. """ try: adata.var["missing_values_pct"] except KeyError: from ehrapy.preprocessing import qc_metrics qc_metrics(adata) used_var_names = set(adata.var_names) if var_names is None else set(var_names) thresholded_var_names = set(adata.var[adata.var["missing_values_pct"] > threshold].index) & set(used_var_names) var_name_to_pct: dict[str, int] = {} for var in thresholded_var_names: var_name_to_pct[var] = adata.var["missing_values_pct"].loc[var] logger.warning(f"Feature '{var}' had more than {var_name_to_pct[var]:.2f}% missing values!") return var_name_to_pct def _get_non_numerical_column_indices(X: np.ndarray) -> set: """Return indices of columns, that contain at least one non-numerical value that is not "Nan".""" def _is_float_or_nan(val): # pragma: no cover """Check whether a given item is a float or np.nan""" try: float(val) except ValueError: if val is np.nan: return True return False else: if not isinstance(val, bool): return True else: return False is_numeric_numpy = np.vectorize(_is_float_or_nan, otypes=[bool]) mask = np.apply_along_axis(is_numeric_numpy, 0, X) _, column_indices = np.where(~mask) non_num_indices = set(column_indices) return non_num_indices