Source code for ehrapy.anndata._feature_specifications

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from dateutil.parser import isoparse  # type: ignore
from lamin_utils import logger
from rich import print
from rich.tree import Tree

from ehrapy.anndata._constants import CATEGORICAL_TAG, DATE_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG

if TYPE_CHECKING:
    from collections.abc import Iterable


def _detect_feature_type(col: pd.Series) -> tuple[Literal["date", "categorical", "numeric"], bool]:
    """Detect the feature type of a column in a pandas DataFrame.

    Args:
        col: The column to detect the feature type for.
        verbose: Whether to print warnings for uncertain feature types.

    Returns:
        The detected feature type (one of 'date', 'categorical', or 'numeric') and a boolean, which is True if the feature type is uncertain.
    """
    n_elements = len(col)
    col = col.dropna()
    if len(col) == 0:
        raise ValueError(
            f"Feature '{col.name}' contains only NaN values. Please drop this feature to infer the feature type."
        )
    majority_type = col.apply(type).value_counts().idxmax()

    if majority_type == pd.Timestamp:
        return DATE_TAG, False  # type: ignore

    if majority_type is str:
        try:
            col.apply(isoparse)
            return DATE_TAG, False  # type: ignore
        except ValueError:
            try:
                col = pd.to_numeric(col, errors="raise")  # Could be an encoded categorical or a numeric feature
                majority_type = float
            except ValueError:
                # Features stored as Strings that cannot be converted to float are assumed to be categorical
                return CATEGORICAL_TAG, False  # type: ignore

    if majority_type not in [int, float]:
        return CATEGORICAL_TAG, False  # type: ignore

    # Guess categorical if the feature is an integer and the values are 0/1 to n-1/n with no gaps
    if (
        (majority_type is int or (np.all(i.is_integer() for i in col)))
        and (n_elements != col.nunique())
        and (
            (col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique())))
            or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1)))
        )
    ):
        return CATEGORICAL_TAG, True  # type: ignore

    return NUMERIC_TAG, False  # type: ignore


[docs] def infer_feature_types( adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree", verbose: bool = True ): """Infer feature types from AnnData object. For each feature in adata.var_names, the method infers one of the following types: 'date', 'categorical', or 'numeric'. The inferred types are stored in adata.var['feature_type']. Please check the inferred types and adjust if necessary using adata.var['feature_type']['feature1']='corrected_type'. Be aware that not all features stored numerically are of 'numeric' type, as categorical features might be stored in a numerically encoded format. For example, a feature with values [0, 1, 2] might be a categorical feature with three categories. This is accounted for in the method, but it is recommended to check the inferred types. Args: adata: :class:`~anndata.AnnData` object storing the EHR data. layer: The layer to use from the AnnData object. If None, the X layer is used. output: The output format. Choose between 'tree', 'dataframe', or None. If 'tree', the feature types will be printed to the console in a tree format. If 'dataframe', a pandas DataFrame with the feature types will be returned. If None, nothing will be returned. verbose: Whether to print warnings for uncertain feature types. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=False) >>> ep.ad.infer_feature_types(adata) """ from ehrapy.anndata.anndata_ext import anndata_to_df feature_types = {} uncertain_features = [] df = anndata_to_df(adata, layer=layer) for feature in adata.var_names: if ( FEATURE_TYPE_KEY in adata.var.keys() and adata.var[FEATURE_TYPE_KEY][feature] is not None and not pd.isna(adata.var[FEATURE_TYPE_KEY][feature]) ): feature_types[feature] = adata.var[FEATURE_TYPE_KEY][feature] else: feature_types[feature], raise_warning = _detect_feature_type(df[feature]) if raise_warning: uncertain_features.append(feature) adata.var[FEATURE_TYPE_KEY] = pd.Series(feature_types)[adata.var_names] if verbose: logger.warning( f"{'Features' if len(uncertain_features) > 1 else 'Feature'} {str(uncertain_features)[1:-1]} {'were' if len(uncertain_features) > 1 else 'was'} detected as categorical features stored numerically." f"Please verify and correct using `ep.ad.replace_feature_types` if necessary." ) logger.info( f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']." f" Please verify and adjust if necessary using `ep.ad.replace_feature_types`." ) if output == "tree": feature_type_overview(adata) elif output == "dataframe": return adata.var[FEATURE_TYPE_KEY] elif output is not None: raise ValueError(f"Output format {output} not recognized. Choose between 'tree', 'dataframe', or None.")
def check_feature_types(func): @wraps(func) def wrapper(adata, *args, **kwargs): # Account for class methods that pass self as first argument _self = None if not isinstance(adata, AnnData) and len(args) > 0 and isinstance(args[0], AnnData): _self = adata adata = args[0] args = args[1:] if FEATURE_TYPE_KEY not in adata.var.keys(): infer_feature_types(adata, output=None) logger.warning( f"Feature types were inferred and stored in adata.var[{FEATURE_TYPE_KEY}]. Please verify using `ep.ad.feature_type_overview` and adjust if necessary using `ep.ad.replace_feature_types`." ) for feature in adata.var_names: feature_type = adata.var[FEATURE_TYPE_KEY][feature] if ( feature_type is not None and (not pd.isna(feature_type)) and feature_type not in [CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG] ): logger.warning( f"Feature '{feature}' has an invalid feature type '{feature_type}'. Please correct using `ep.ad.replace_feature_types`." ) if _self is not None: return func(_self, adata, *args, **kwargs) return func(adata, *args, **kwargs) return wrapper
[docs] @check_feature_types def feature_type_overview(adata: AnnData): """Print an overview of the feature types and encoding modes in the AnnData object. Args: adata: The AnnData object storing the EHR data. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.ad.feature_type_overview(adata) """ from ehrapy.anndata.anndata_ext import anndata_to_df tree = Tree( f"[b] Detected feature types for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars", guide_style="underline2", ) branch = tree.add("📅[b] Date features") for date in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == DATE_TAG]): branch.add(date) branch = tree.add("📐[b] Numerical features") for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]): branch.add(numeric) branch = tree.add("🗂️[b] Categorical features") cat_features = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG] df = anndata_to_df(adata[:, cat_features]) if "encoding_mode" in adata.var.keys(): unencoded_vars = adata.var.loc[cat_features, "unencoded_var_names"].unique().tolist() for unencoded in sorted(unencoded_vars): if unencoded in adata.var_names: branch.add(f"{unencoded} ({df.loc[:, unencoded].nunique()} categories)") else: enc_mode = adata.var.loc[adata.var["unencoded_var_names"] == unencoded, "encoding_mode"].values[0] branch.add(f"{unencoded} ({adata.obs[unencoded].nunique()} categories); {enc_mode} encoded") else: for categorical in sorted(cat_features): branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)") print(tree)
[docs] def replace_feature_types(adata, features: Iterable[str], corrected_type: str): """Correct the feature types for a list of features inplace. Args: adata: :class:`~anndata.AnnData` object storing the EHR data. features: The features to correct. corrected_type: The corrected feature type. One of 'date', 'categorical', or 'numeric'. Examples: >>> import ehrapy as ep >>> adata = ep.dt.diabetes_130_fairlearn() >>> ep.ad.infer_feature_types(adata) >>> ep.ad.replace_feature_types(adata, ["time_in_hospital", "number_diagnoses", "num_procedures"], "numeric") """ if corrected_type not in [CATEGORICAL_TAG, NUMERIC_TAG, DATE_TAG]: raise ValueError( f"Corrected type {corrected_type} not recognized. Choose between '{DATE_TAG}', '{CATEGORICAL_TAG}', or '{NUMERIC_TAG}'." ) if FEATURE_TYPE_KEY not in adata.var.keys(): raise ValueError( "Feature types were not inferred. Please infer feature types using 'infer_feature_types' before correcting." ) if isinstance(features, str): features = [features] adata.var.loc[features, FEATURE_TYPE_KEY] = corrected_type