Source code for ehrapy.tools.nlp._medcat

from __future__ import annotations

import pandas as pd
from anndata import AnnData
from thefuzz import process

from ehrapy.core._tool_available import _check_module_importable

try:
    from medcat.cat import CAT
    from medcat.cdb import CDB
    from medcat.cdb_maker import CDBMaker
    from medcat.config import Config
    from medcat.vocab import Vocab
except ModuleNotFoundError:
    pass
from rich import box, print
from rich.console import Console
from rich.table import Table


[docs]class MedCAT: """Wrapper class for Medcat. This class will hold references to the current AnnData object, which holds the data, the current model (with vocab and concept database) and should be passed to all functions exposed to the ehrapy nlp API when required. """ def __init__(self, anndata: AnnData, vocabulary: Vocab = None, concept_db: CDB = None, model_pack_path=None): if not _check_module_importable("medcat"): raise RuntimeError("Package medcat is not importable. Please install via pip install medcat") self.anndata = anndata self.vocabulary = vocabulary self.concept_db = concept_db if self.vocabulary is not None and self.concept_db is not None: self.cat = CAT(cdb=concept_db, config=concept_db.config, vocab=vocabulary) elif model_pack_path is not None: self.cat = CAT.load_model_pack(model_pack_path) # will be initialized as None, but will get updated when running annotate_text self.annotated_results = None
[docs] def update_cat(self, vocabulary: Vocab = None, concept_db: CDB = None): """Updates the current MedCAT instance with new Vocabularies and Concept Databases. Args: vocabulary: Vocabulary to update to. concept_db: Concept Database to update to. """ self.cat = CAT(cdb=concept_db, config=concept_db.config, vocab=vocabulary)
[docs] def update_cat_config(self, concept_db_config: Config) -> None: """Updates the MedCAT configuration. Args: concept_db_config: Concept to update to. """ self.concept_db.config = concept_db_config
[docs] def set_filter_by_tui(self, tuis: list[str]) -> None: """Restrict results of annotation step to certain tui's (type unique identifiers). Note that this will change the MedCat object by updating the concept database config. In every annotation process that will be run afterwards, entities are shown, only if they fall into the tui's type. A full list of tui's can be found at: https://lhncbc.nlm.nih.gov/ii/tools/MetaMap/Docs/SemanticTypes_2018AB.txt As an example: Setting tuis=["T047", "T048"] will only annotate concepts (identified by a CUI (concept unique identifier)) in UMLS that are either diseases or syndroms (T047) or mental/behavioural dysfunctions (T048). Args: tuis: list of TUI's (default is """ # the filtered cui's that fall into the type of the filter tui's cui_filters = set() for type_id in tuis: cui_filters.update(self.cat.cdb.addl_info["type_id2cuis"][type_id]) self.cat.cdb.config.linking["filters"]["cuis"] = cui_filters
[docs] @staticmethod def create_vocabulary(vocabulary_data: str, replace: bool = True) -> Vocab: """Creates a MedCAT Vocab and sets it for the MedCAT object. Args: vocabulary_data: Path to the vocabulary data. It is a tsv file and must look like: <token>\t<word_count>\t<vector_embedding_separated_by_spaces> house 34444 0.3232 0.123213 1.231231 replace: Whether to replace existing words in the vocabulary. Returns: Instance of a MedCAT Vocab """ vocabulary = Vocab() vocabulary.add_words(vocabulary_data, replace=replace) return vocabulary
[docs] @staticmethod def create_concept_db(csv_path: list[str], config: Config = None) -> CDB: """Creates a MedCAT concept database and sets it for the MedCAT object. Args: csv_path: List of paths to one or more csv files containing all concepts. The concept csvs must look like: cui,name 1,kidney failure 7,coronavirus config: Optional MedCAT concept database configuration. If not provided a default configuration with config.general['spacy_model'] = 'en_core_sci_md' is created. Returns: Instance of a MedCAT CDB concept database """ if config is None: config = Config() config.general["spacy_model"] = "en_core_sci_md" maker = CDBMaker(config) concept_db = maker.prepare_csvs(csv_path, full_build=True) return concept_db
[docs] @staticmethod def save_vocabulary(vocab: Vocab, output_path: str) -> None: """Saves a vocabulary. Args: vocab: The vocabulary object output_path: Path to write the vocabulary to. """ vocab.save(output_path)
[docs] @staticmethod def load_vocabulary(vocabulary_path) -> Vocab: """Loads a vocabulary. Args: vocabulary_path: Path to load the vocabulary from. """ return Vocab.load(vocabulary_path)
[docs] @staticmethod def save_concept_db(cdb, output_path: str) -> None: """Saves a concept database. Args: cdb: the concept database object output_path: Path to save the concept database to. """ cdb.save(output_path)
[docs] @staticmethod def load_concept_db(concept_db_path) -> CDB: """Loads the concept database. Args: concept_db_path: Path to load the concept database from. """ return CDB.load(concept_db_path)
[docs] def save_model_pack(self, model_pack_dir: str = ".", name: str = "ehrapy_medcat_model_pack") -> None: """Saves a MedCAT model pack. Args: model_pack_dir: Path to save the model to (defaults to current working directory). name: Name of the new model pack """ _ = self.cat.create_model_pack(name)
class EhrapyMedcat: """Wrapper class to perform feature extraction from free text data using MedCAT with ehrapy. This can be simply called by `ep.tl.mc`. This class is not supposed to be instantiated at any time, it just serves as a wrapper for import. """ @staticmethod def run_unsupervised_training( medcat_obj: MedCAT, text: pd.Series, progress_print: int = 100, print_statistics: bool = False ) -> None: """Performs MedCAT unsupervised training on a provided text column. Args: medcat_obj: ehrapy's custom MedCAT object, that keeps track of the vocab, concept database and the (annotated) results text: Pandas Series of (free) text to annotate. progress_print: print progress after that many training documents print_statistics: Whether to print training statistics after training. """ print(f"[bold blue]Running unsupervised training using {len(text)} documents.") medcat_obj.cat.train(text.values, progress_print=progress_print) if print_statistics: medcat_obj.cat.cdb.print_stats() @staticmethod def annotate_text(medcat_obj: MedCAT, text_column: str, n_proc: int = 2, batch_size_chars: int = 500000) -> None: """Annotate the original free text data. Note this will only annotate non null rows. The result will be a DataFrame. It will be set as the annotated_results attribute for the passed MedCat object. This dataframe will be the base for all further analyses, for example coloring umaps by specific diseases. Args: medcat_obj: Ehrapy's custom MedCAT object. The annotated_results attribute will be set here. text_column: Name of the column that should be annotated n_proc: Number of processors to use batch_size_chars: batch size to control for the variability between document sizes """ non_null_text = EhrapyMedcat._filter_null_values(medcat_obj.anndata.obs, text_column) formatted_text_column = EhrapyMedcat._format_df_column(non_null_text, text_column) results = medcat_obj.cat.multiprocessing(formatted_text_column, batch_size_chars=batch_size_chars, nproc=n_proc) flattened_res = EhrapyMedcat._flatten_annotated_results(results) # sort for row number in ascending order and reset index to keep index updated medcat_obj.annotated_results = ( EhrapyMedcat._annotated_results_to_df(flattened_res).sort_values(by=["row_nr"]).reset_index(drop=True) ) @staticmethod def get_annotation_overview( medcat_obj: MedCAT, n: int = 10, status: str = "Affirmed", save_to_csv: bool = False, save_path: str = "." ) -> None: """Provide an overview for the annotation results. An overview will look like the following: cui (the CUI), nsubjects (from how many rows this one got extracted), type_ids (TUIs), name(name of the entitiy), perc_subjects (how many rows relative to absolute number of rows) Args: medcat_obj: The current MedCAT object which holds all infos on NLP analysis with MedCAT and ehrapy. n: Basically the parameter for head() of pandas Dataframe. How many of the most common entities should be shown? status: One of "Affirmed" (default), "Other" or "Both". Displays stats for either only affirmed entities, negated ones or both. save_to_csv: Whether to save the overview dataframe to a local .csv file in the current working directory or not. save_path: Path to save the overview as .csv file. Defaults to current working directory. Returns: A Pandas DataFrame with the overview stats. """ df = EhrapyMedcat._filter_df_by_status(medcat_obj.annotated_results, status) # group by CUI as this is a unique identifier per entity grouped = df.groupby("cui") # get absolute number of rows with this entity # note for overview, only one TUI and type is shown (there shouldn't be much situations were multiple are even possible or useful) res = grouped.agg( { "pretty_name": (lambda x: next(iter(set(x)))), "type_ids": (lambda x: next(iter(x))[0]), "types": (lambda x: next(iter(x))[0]), "row_nr": "nunique", } ) res = res.rename(columns={"row_nr": "n_patient_visit"}) # relative amount of patient visits with the specific entity to all patient visits (or rows in the original data) res["n_patient_visit_percent"] = (res["n_patient_visit"] / df["row_nr"].nunique()) * 100 res.round({"n_patient_visit_percent": 1}) # save to csv if desired if save_to_csv: res.to_csv(save_path) overview_table = EhrapyMedcat._df_to_rich_table(res.nlargest(n, "n_patient_visit")) console = Console() console.print(overview_table) @staticmethod def add_binary_column_to_obs( medcat_obj: MedCAT, adata: AnnData, name: str, all_names: list[str], add_cols: list[str] | None ) -> None: """Adds a binary column to obs (temporarily) for plotting infos extracted from freetext. Indicates whether the specific entity to color by has been found in that row or not. """ # only extract affirmed entities df = EhrapyMedcat._filter_df_by_status(medcat_obj.annotated_results, "Affirmed") # check whether the name is in the extracted entities to handle possible typos to a certain extend # currently, only the pretty_name column is supported # _list_replace(color, colored_column, colored_column_tmp) if name not in df["pretty_name"].values: new_name, _ = process.extractOne(query=name, choices=df["pretty_name"].unique(), score_cutoff=50) if new_name: print( f"[bold yellow]Did not find [blue]{name} [yellow]in MedCAT's extracted entities. " f"Will use best match {new_name}!" ) def _list_replace(lst, old: str, new: str): """replace list elements (inplace)""" i = -1 try: while True: i = lst.index(old, i + 1) lst[i] = new except ValueError: pass _list_replace(all_names, name, new_name) name = new_name else: raise EntitiyNotFoundError( f"Did not find {name} in MedCAT's extracted entities and could not determine a best matching equivalent." ) # add column to additional to remove it later on if add_cols is not None: add_cols.append(name) adata.obs[name] = ( df.groupby("row_nr").agg({"pretty_name": (lambda x: int(any(x.isin([name]))))}).astype("category") ) adata.obs = adata.obs.replace({name: {1.0: "yes", 0.0: "no"}}) adata.obs[name] = adata.obs[name].fillna("no").astype("category") @staticmethod def _annotated_results_to_df(flattened_results: dict) -> pd.DataFrame: """Turn the flattened annotated results into a pandas DataFrame and remove duplicates.""" df = pd.DataFrame.from_dict(flattened_results, orient="index") # remove duplicate entries; for example when a single entity like a disease is mentioned multiple times without any meaningful context changes # Example: The patient suffers from Diabetes. Cause of the Diabetes, he receives drug X. df.drop_duplicates(subset=["cui", "row_nr", "meta_anns"]) return df @staticmethod def _flatten_annotated_results(annotation_results: dict) -> dict: """Flattens the nested set (usually 5 level nested) of annotation results. annotation_results is just a simple flattened dict with infos on all entities found """ flattened_annotated_dict = {} entry_nr = 0 # row numbers where the text column is located in the original data for row_id in annotation_results.keys(): # all entities extracted from a given row entities = annotation_results[row_id]["entities"] for entity_id in entities.keys(): # tokens are currently ignored, as they will not appear with the current basic model used by ehrapy from MedCAT if entity_id != "tokens": single_entity = {"row_nr": row_id} entity = entities[entity_id] # iterate over all info attributes of a single entity found in a specific row for entity_key in entity.keys(): if entity_key in ["pretty_name", "cui", "type_ids", "types"]: single_entity[entity_key] = entities[entity_id][entity_key] elif entity_key == "meta_anns": single_entity[entity_key] = entities[entity_id][entity_key]["Status"]["value"] flattened_annotated_dict[entry_nr] = single_entity entry_nr += 1 return flattened_annotated_dict @staticmethod def _format_df_column(df: pd.DataFrame, column_name: str) -> list[tuple[int, str]]: """Format the df to match: formatted_data = [(row_id, row_text), (row_id, row_text), ...] as this is required by MedCAT's multiprocessing annotation step """ formatted_data = [] for id, row in df.iterrows(): text = row[column_name] formatted_data.append((id, text)) return formatted_data @staticmethod def _filter_null_values(df: pd.DataFrame, column: str) -> pd.DataFrame: """Filter null values of a given column and return that column without the null values""" return pd.DataFrame(df[column][~df[column].isnull()]) @staticmethod def _filter_df_by_status(df: pd.DataFrame, status: str) -> pd.DataFrame: """Util function to filter passed dataframe by status.""" df_res = df if status != "Both": if status not in {"Affirmed", "Other"}: raise StatusNotSupportedError(f"{status} is not available. Please use either Affirmed, Other or Both!") mask = df["meta_anns"].values == status df_res = df[mask] return df_res @staticmethod def _df_to_rich_table(df: pd.DataFrame) -> Table: """Convert a pandas dataframe to a rich Table""" table = Table(show_header=True, header_style="bold magenta") for column in df.columns: table.add_column(str(column)) for _, value_list in enumerate(df.values.tolist()): row = [] row += [str(x) for x in value_list] table.add_row(*row) # Update the style of the table table.row_styles = ["none", "dim"] table.box = box.SIMPLE_HEAD return table class StatusNotSupportedError(Exception): pass class EntitiyNotFoundError(Exception): pass