From 8ec0d9ecc863e9b043c19b6a7ba6e1aa3d4b4bd9 Mon Sep 17 00:00:00 2001 From: Abdelghani Belgaid <72890326+abdelghanibelgaid@users.noreply.github.com> Date: Thu, 28 Sep 2023 00:42:08 +0100 Subject: [PATCH] Update evaluation.py: Added a function to evaluate the performance of EM models on validation data using a specified metric. --- causalnex/evaluation/evaluation.py | 50 ++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/causalnex/evaluation/evaluation.py b/causalnex/evaluation/evaluation.py index e0dd2aa..b2b1a03 100644 --- a/causalnex/evaluation/evaluation.py +++ b/causalnex/evaluation/evaluation.py @@ -32,6 +32,7 @@ import pandas as pd from sklearn import metrics +from causalnex.estimator import EMSingleLatentVariable from causalnex.network import BayesianNetwork @@ -201,3 +202,52 @@ def classification_report(bn: BayesianNetwork, data: pd.DataFrame, node: str) -> ) return report + +def evaluate_em_model( + em_model: EMSingleLatentVariable, + validation_data: pd.DataFrame, + metric: callable, + lv_column: str = None, +): + """ + Evaluate the performance of the learned EM model on validation data using a specified metric. + + Args: + em_model (EMSingleLatentVariable): The trained EM model. + validation_data (pd.DataFrame): Validation dataset with the same structure as the training data. + metric (callable): A callable function or metric that takes two arguments: ground truth and predicted values. + lv_column (str): Name of the latent variable column in the dataset, if different from the one used during training. + + Returns: + float: The computed evaluation metric value. + + Example: + >>> em = EMSingleLatentVariable(sm=sm, data=train_data, lv_name=lv_name, node_states=node_states) + >>> em.run() # Train the model + >>> validation_metric = evaluate_em_model(em, validation_data, custom_metric_function) + """ + if validation_data.empty: + raise ValueError("Validation dataset is empty.") + + if lv_column is None: + lv_column = em_model.lv_name + + # Extract relevant columns from validation data + relevant_columns = [lv_column] + list(em_model.sm.successors(lv_column)) + validation_data = validation_data[relevant_columns] + + # Check if validation_data is still empty after column extraction + if validation_data.empty: + raise ValueError("Validation dataset does not contain relevant columns.") + + # Compute the likelihood for each record in the validation data + likelihoods = [] + for _, record in validation_data.iterrows(): + likelihood = em_model.compute_likelihood(record.to_dict()) + likelihoods.append(likelihood) + + # Compute the evaluation metric by comparing the likelihoods with the true values + true_values = validation_data[lv_column] + evaluation_score = metric(true_values, likelihoods) + + return evaluation_score