import os import pandas as pd import time import json from sklearn.metrics import precision_score, recall_score, f1_score from utils.logger import logger class Metrics: def __init__( self, data_type: str, prediction_file: str, prediction_sheet_name: str = "Sheet1", ground_truth_file: str = None, output_folder: str = None, ) -> None: self.data_type = data_type self.prediction_file = prediction_file self.prediction_sheet_name = prediction_sheet_name self.ground_truth_file = ground_truth_file if output_folder is None or len(output_folder) == 0: output_folder = r"/data/emea_ar/output/metrics/" os.makedirs(output_folder, exist_ok=True) time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) self.output_file = os.path.join( output_folder, f"metrics_{data_type}_{time_stamp}.xlsx", ) def get_metrics(self): if ( self.prediction_file is None or len(self.prediction_file) == 0 or not os.path.exists(self.prediction_file) ): logger.error(f"Invalid prediction file: {self.prediction_file}") return [] if ( self.ground_truth_file is None or len(self.ground_truth_file) == 0 or not os.path.exists(self.ground_truth_file) ): logger.error(f"Invalid ground truth file: {self.ground_truth_file}") return [] metrics_list = [ {"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0} ] if self.data_type == "page_filter": missing_error_list, metrics_list = self.get_page_filter_metrics() elif self.data_type == "datapoint": missing_error_list, metrics_list = self.get_datapoint_metrics() else: logger.error(f"Invalid data type: {self.data_type}") missing_error_df = pd.DataFrame(missing_error_list) missing_error_df.reset_index(drop=True, inplace=True) metrics_df = pd.DataFrame(metrics_list) metrics_df.reset_index(drop=True, inplace=True) with pd.ExcelWriter(self.output_file) as writer: missing_error_df.to_excel(writer, sheet_name="Missing_Error", index=False) metrics_df.to_excel(writer, sheet_name="Metrics", index=False) return missing_error_list, metrics_list, self.output_file def get_page_filter_metrics(self): prediction_df = pd.read_excel(self.prediction_file, sheet_name=self.prediction_sheet_name) ground_truth_df = pd.read_excel(self.ground_truth_file, sheet_name="Sheet1") ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1] tor_true = [] tor_pred = [] ter_true = [] ter_pred = [] ogc_true = [] ogc_pred = [] performance_fee_true = [] performance_fee_pred = [] missing_error_list = [] data_point_list = ["tor", "ter", "ogc", "performance_fee"] for index, row in ground_truth_df.iterrows(): doc_id = row["doc_id"] # get first row with the same doc_id prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[0] for data_point in data_point_list: true_data, pred_data, missing_error_data = self.get_true_pred_data( doc_id, row, prediction_data, data_point ) if data_point == "tor": tor_true.extend(true_data) tor_pred.extend(pred_data) elif data_point == "ter": ter_true.extend(true_data) ter_pred.extend(pred_data) elif data_point == "ogc": ogc_true.extend(true_data) ogc_pred.extend(pred_data) elif data_point == "performance_fee": performance_fee_true.extend(true_data) performance_fee_pred.extend(pred_data) missing_error_list.append(missing_error_data) metrics_list = [] for data_point in data_point_list: if data_point == "tor": precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred) tor_support = self.get_support_number(tor_true) metrics_list.append( { "Data_Point": data_point, "Precision": precision, "Recall": recall, "F1": f1, "Support": tor_support, } ) logger.info( f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {tor_support}" ) elif data_point == "ter": precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred) ter_support = self.get_support_number(ter_true) metrics_list.append( { "Data_Point": data_point, "Precision": precision, "Recall": recall, "F1": f1, "Support": ter_support, } ) logger.info( f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ter_support}" ) elif data_point == "ogc": precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred) ogc_support = self.get_support_number(ogc_true) metrics_list.append( { "Data_Point": data_point, "Precision": precision, "Recall": recall, "F1": f1, "Support": ogc_support, } ) logger.info( f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ogc_support}" ) elif data_point == "performance_fee": precision, recall, f1 = self.get_specific_metrics( performance_fee_true, performance_fee_pred ) performance_fee_support = self.get_support_number(performance_fee_true) metrics_list.append( { "Data_Point": data_point, "Precision": precision, "Recall": recall, "F1": f1, "Support": performance_fee_support, } ) logger.info( f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {performance_fee_support}" ) # get average metrics precision_list = [metric["Precision"] for metric in metrics_list] recall_list = [metric["Recall"] for metric in metrics_list] f1_list = [metric["F1"] for metric in metrics_list] metrics_list.append( { "Data_Point": "Average", "Precision": sum(precision_list) / len(precision_list), "Recall": sum(recall_list) / len(recall_list), "F1": sum(f1_list) / len(f1_list), "Support": sum([metric["Support"] for metric in metrics_list]), } ) return missing_error_list, metrics_list def get_support_number(self, true_data: list): # get the count which true_data is 1 return sum(true_data) def get_true_pred_data( self, doc_id, ground_truth_data: pd.Series, prediction_data: pd.Series, data_point: str ): ground_truth_list = ground_truth_data[data_point] if isinstance(ground_truth_list, str): ground_truth_list = json.loads(ground_truth_list) prediction_list = prediction_data[data_point] if isinstance(prediction_list, str): prediction_list = json.loads(prediction_list) true_data = [] pred_data = [] missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": [], "error": []} missing_data = [] error_data = [] if len(ground_truth_list) == 0 and len(prediction_list) == 0: true_data.append(1) pred_data.append(1) return true_data, pred_data, missing_error_data for prediction in prediction_list: if prediction in ground_truth_list: true_data.append(1) pred_data.append(1) else: true_data.append(0) pred_data.append(1) error_data.append(prediction) for ground_truth in ground_truth_list: if ground_truth not in prediction_list: true_data.append(1) pred_data.append(0) missing_data.append(ground_truth) missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": missing_data, "error": error_data} return true_data, pred_data, missing_error_data def get_specific_metrics(self, true_data: list, pred_data: list): precision = precision_score(true_data, pred_data) recall = recall_score(true_data, pred_data) f1 = f1_score(true_data, pred_data) return precision, recall, f1 def get_datapoint_metrics(self): pass