Source code for ehrapy.plot.feature_ranking._feature_importances

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

if TYPE_CHECKING:
    from ehrdata import EHRData
    from matplotlib.axes import Axes


[docs] def rank_features_supervised( edata: EHRData, key: str = "feature_importances", n_features: int = 10, ax: Axes | None = None, show: bool = True, save: str | None = None, **kwargs, ) -> Axes | None: """Plot features with greatest absolute importances as a barplot. Args: edata: Central data object. A key in edata.var should contain the feature importances, calculated beforehand. key: The key in edata.var to use for feature importances. n_features: The number of features to plot. ax: A matplotlib axes object to plot on. If `None`, a new figure will be created. show: If `True`, show the figure. If `False`, return the axes object. save: Path to save the figure. If `None`, the figure will not be saved. **kwargs: Additional arguments passed to `seaborn.barplot`. Returns: If `show == False` a `matplotlib.axes.Axes` object, else `None`. Examples: >>> import ehrdata as ed >>> import ehrapy as ep >>> edata = ed.dt.mimic_2() >>> ep.pp.knn_impute(edata, n_neighbors=5) >>> input_features = [ ... feat for feat in edata.var_names if feat not in {"service_unit", "day_icu_intime", "tco2_first"} ... ] >>> ep.tl.rank_features_supervised(edata, "tco2_first", "rf", input_features=input_features) >>> ep.pl.rank_features_supervised(edata) .. image:: /_static/docstring_previews/feature_importances.png """ if key not in edata.var.keys(): raise ValueError( f"Key {key} not found in edata.var. Make sure to calculate feature importances first with ep.tl.feature_importances." ) df = pd.DataFrame({"importance": edata.var[key]}, index=edata.var_names) df["absolute_importance"] = df["importance"].abs() df = df.sort_values("absolute_importance", ascending=False) if ax is None: fig, ax = plt.subplots() ax = sns.barplot(x=df["importance"][:n_features], y=df.index[:n_features], orient="h", ax=ax, **kwargs) plt.ylabel("Feature") plt.xlabel("Importance") plt.tight_layout() if save: plt.savefig(save, bbox_inches="tight") if show: plt.show() return None else: return ax