diff --git a/core/metrics.py b/core/metrics.py index a6a3146..be31215 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -4,6 +4,7 @@ import time import json from sklearn.metrics import precision_score, recall_score, f1_score from utils.biz_utils import get_unique_words_text, get_beginning_common_words, remove_special_characters +from utils.sql_query_util import query_document_fund_mapping from utils.logger import logger @@ -33,7 +34,7 @@ class Metrics: f"metrics_{data_type}_{time_stamp}.xlsx", ) - def get_metrics(self): + def get_metrics(self, strict_model: bool = False): if ( self.prediction_file is None or len(self.prediction_file) == 0 @@ -53,7 +54,7 @@ class Metrics: {"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0} ] - missing_error_list, metrics_list = self.calculate_metrics() + missing_error_list, metrics_list = self.calculate_metrics(strict_model=strict_model) missing_error_df = pd.DataFrame(missing_error_list) missing_error_df.reset_index(drop=True, inplace=True) @@ -66,7 +67,7 @@ class Metrics: metrics_df.to_excel(writer, sheet_name="Metrics", index=False) return missing_error_list, metrics_list, self.output_file - def calculate_metrics(self): + def calculate_metrics(self, strict_model: bool = False): prediction_df = pd.read_excel( self.prediction_file, sheet_name=self.prediction_sheet_name ) @@ -77,7 +78,7 @@ class Metrics: ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1] elif self.data_type == "data_extraction": ground_truth_df = ground_truth_df[ground_truth_df["rawname_checked"] == 1] - elif self.data_type == "investment_mapping": + elif self.data_type in ["investment_mapping", "document_mapping_in_db"]: ground_truth_df = ground_truth_df[ground_truth_df["mapping_checked"] == 1] else: logger.error(f"Invalid data type: {self.data_type}") @@ -179,9 +180,28 @@ class Metrics: investment_mapping_true.extend(true_data) investment_mapping_pred.extend(pred_data) missing_error_list.extend(missing_error_data) + elif self.data_type == "document_mapping_in_db": + prediction_doc_id_list = prediction_df["doc_id"].unique().tolist() + ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist() + # get intersection of doc_id_list + doc_id_list = list( + set(prediction_doc_id_list) & set(ground_truth_doc_id_list) + ) + # order by doc_id + doc_id_list.sort() + + for doc_id in doc_id_list: + prediction_data = prediction_df[prediction_df["doc_id"] == doc_id] + ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id] + true_data, pred_data, missing_error_data = self.get_document_mapping_in_db_true_pred_data( + doc_id, ground_truth_data, prediction_data, strict_mode=strict_model + ) + investment_mapping_true.extend(true_data) + investment_mapping_pred.extend(pred_data) + missing_error_list.extend(missing_error_data) metrics_list = [] - if self.data_type == "investment_mapping": + if self.data_type in ["investment_mapping", "document_mapping_in_db"]: if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0: investment_mapping_true.append(1) investment_mapping_pred.append(1) @@ -669,6 +689,208 @@ class Metrics: missing_error_data.append(error_data) return true_data, pred_data, missing_error_data + def get_document_mapping_in_db_true_pred_data( + self, + doc_id, + ground_truth_data: pd.DataFrame, + prediction_data: pd.DataFrame, + strict_mode: bool = False, + ): + """ + EMEA AR Mapping Metrics based on document mapping in DB + 1. Make ground truth manually + According to fund name/ share name in document mapping, + Find relevant data in document data extraction, input mapping id + 2. Metrics calculation + Recall: + Based on each document: + a. Ground truth data + According to the mapping id in document mapping, + filter relevant document data extraction records. + b. Prediction data + Get document mapping by fund/ share raw name from PDF document. + mapping correct: true 1 pred 1 + mapping error + mapping is empty: + true 1 pred 0 --- hurt recall + mapping is incorrect: other fund/ share id: + true 1 pred 0 --- hurt recall + if incorrect mapping in document mapping: + true 0 pred 1 --- hurt precision + """ + document_mapping_data = query_document_fund_mapping(doc_id) + if len(document_mapping_data) == 0: + return [1], [1], [] + fund_id_list = document_mapping_data["FundId"].unique().tolist() + share_id_list = document_mapping_data["SecId"].unique().tolist() + id_list = fund_id_list + share_id_list + + # get dp_ground_truth which investment_id in id_list + dp_ground_truth = ground_truth_data[ + ground_truth_data["investment_id"].isin(id_list) + ] + + dp_ground_truth = self.modify_data(dp_ground_truth) + # only get the columns: doc_id, raw_name, simple_raw_name, simple_name_unique_words, + # investment_type, investment_id, investment_name + # from dp_ground_truth + dp_ground_truth = dp_ground_truth[["doc_id", "page_index", "raw_name", "simple_raw_name", + "simple_name_unique_words", "investment_type", + "investment_id", "investment_name"]] + dp_ground_truth.drop_duplicates(inplace=True) + dp_ground_truth.reset_index(drop=True, inplace=True) + + # fillnan for dp_prediction investment_id to be "" if it is nan + prediction_data["investment_id"].fillna("", inplace=True) + prediction_data["investment_name"].fillna("", inplace=True) + dp_prediction = self.modify_data(prediction_data) + dp_prediction = dp_prediction[["doc_id", "page_index", "raw_name", "simple_raw_name", + "simple_name_unique_words", "investment_type", + "investment_id", "investment_name"]] + dp_prediction.drop_duplicates(inplace=True) + dp_prediction.reset_index(drop=True, inplace=True) + # pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist() + # pred_simple_name_unique_words_list = ( + # dp_prediction["simple_name_unique_words"].unique().tolist() + # ) + + compare_data_list = [] + gt_investment_id_list = [] + for index, ground_truth in dp_ground_truth.iterrows(): + gt_page_index = ground_truth["page_index"] + gt_raw_name = ground_truth["raw_name"] + gt_simple_raw_name = ground_truth["simple_raw_name"] + gt_simple_name_unique_words = ground_truth["simple_name_unique_words"] + gt_investment_type = ground_truth["investment_type"] + gt_investment_id = ground_truth["investment_id"] + gt_investment_name = ground_truth["investment_name"] + + # get pred_simple_raw_names by gt_page_index + pred_page_data = dp_prediction[dp_prediction["page_index"] == gt_page_index] + if len(pred_page_data) > 0: + pred_simple_raw_names = pred_page_data["simple_raw_name"].unique().tolist() + pred_simple_name_unique_words_list = ( + pred_page_data["simple_name_unique_words"].unique().tolist() + ) + else: + pred_simple_raw_names = [] + pred_simple_name_unique_words_list = [] + + if gt_investment_id in gt_investment_id_list: + continue + find_raw_name_in_pred = [ + pred_raw_name + for pred_raw_name in pred_simple_raw_names + if ( + gt_simple_raw_name in pred_raw_name + or pred_raw_name in gt_simple_raw_name + ) + and pred_raw_name.endswith(gt_simple_raw_name.split()[-1]) + ] + + if ( + gt_simple_name_unique_words in pred_simple_name_unique_words_list + or len(find_raw_name_in_pred) > 0 + ): + # get the ground truth data with the same unique words + if gt_simple_name_unique_words in pred_simple_name_unique_words_list: + pred_data_df = dp_prediction[ + dp_prediction["simple_name_unique_words"] + == gt_simple_name_unique_words + ] + if len(pred_data_df) > 1: + if ( + len(pred_data_df[pred_data_df["page_index"] == gt_page_index]) + == 0 + ): + pred_data = pred_data_df.iloc[0] + else: + pred_data = pred_data_df[ + pred_data_df["page_index"] == gt_page_index + ].iloc[0] + elif len(pred_data_df) == 1: + pred_data = pred_data_df.iloc[0] + else: + pred_data = None + else: + pred_data_df = dp_prediction[ + dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0] + ] + if len(pred_data_df) > 1: + if ( + len(pred_data_df[pred_data_df["page_index"] == gt_page_index]) + == 0 + ): + pred_data = pred_data_df.iloc[0] + else: + pred_data = pred_data_df[ + pred_data_df["page_index"] == gt_page_index + ].iloc[0] + elif len(pred_data_df) == 1: + pred_data = pred_data_df.iloc[0] + else: + pred_data = None + if pred_data is not None: + compare_data = {"raw_name": gt_raw_name, + "investment_type": gt_investment_type, + "gt_investment_id": gt_investment_id, + "gt_investment_name": gt_investment_name, + "pred_investment_id": pred_data["investment_id"], + "pred_investment_name": pred_data["investment_name"]} + gt_investment_id_list.append(gt_investment_id) + compare_data_list.append(compare_data) + else: + if strict_mode: + compare_data = {"raw_name": gt_raw_name, + "investment_type": gt_investment_type, + "gt_investment_id": gt_investment_id, + "gt_investment_name": gt_investment_name, + "pred_investment_id": "", + "pred_investment_name": ""} + compare_data_list.append(compare_data) + + true_data = [] + pred_data = [] + missing_error_data = [] + + for compare_data in compare_data_list: + gt_investment_id = compare_data["gt_investment_id"] + pred_investment_id = compare_data["pred_investment_id"] + if gt_investment_id == pred_investment_id: + true_data.append(1) + pred_data.append(1) + else: + true_data.append(1) + pred_data.append(0) + if pred_investment_id is not None and len(pred_investment_id) > 0: + if pred_investment_id in id_list: + true_data.append(0) + pred_data.append(1) + error_data = { + "doc_id": doc_id, + "raw_name": compare_data["raw_name"], + "investment_type": compare_data["investment_type"], + "error_type": "mapping incorrect", + "error_id": pred_investment_id, + "error_name": compare_data["pred_investment_name"], + "correct_id": compare_data["gt_investment_id"], + "correct_name": compare_data["gt_investment_name"] + } + else: + error_data = { + "doc_id": doc_id, + "raw_name": compare_data["raw_name"], + "investment_type": compare_data["investment_type"], + "error_type": "mapping missing", + "error_id": "", + "error_name": "", + "correct_id": compare_data["gt_investment_id"], + "correct_name": compare_data["gt_investment_name"] + } + missing_error_data.append(error_data) + + return true_data, pred_data, missing_error_data + def modify_data(self, data: pd.DataFrame): data["simple_raw_name"] = "" data["simple_name_unique_words"] = "" diff --git a/main.py b/main.py index d3891f2..7e0caaf 100644 --- a/main.py +++ b/main.py @@ -345,9 +345,19 @@ def batch_start_job( metrics_output_folder, ) - logger.info(f"Calculating metrics for investment mapping") + # logger.info(f"Calculating metrics for investment mapping by actual document mapping") + # missing_error_list, metrics_list, metrics_file = get_metrics( + # "investment_mapping", + # output_file, + # prediction_sheet_name, + # ground_truth_file, + # ground_truth_sheet_name, + # metrics_output_folder, + # ) + + logger.info(f"Calculating metrics for investment mapping by database document mapping") missing_error_list, metrics_list, metrics_file = get_metrics( - "investment_mapping", + "document_mapping_in_db", output_file, prediction_sheet_name, ground_truth_file, @@ -436,7 +446,7 @@ def get_metrics( ground_truth_sheet_name=ground_truth_sheet_name, output_folder=output_folder, ) - missing_error_list, metrics_list, metrics_file = metrics.get_metrics() + missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=True) return missing_error_list, metrics_list, metrics_file @@ -657,13 +667,38 @@ if __name__ == "__main__": "479793787", "471641628", ] - special_doc_id_list = check_mapping_doc_id_list - special_doc_id_list = ["402113224"] + check_db_mapping_doc_id_list = [ + "292989214", + "316237292", + "321733631", + "323390570", + "327956364", + "332223498", + "333207452", + "334718372", + "344636875", + "349679479", + "362246081", + "366179419", + "380945052", + "382366116", + "387202452", + "389171486", + "391456740", + "391736837", + "394778487", + "401684600", + "402113224", + "402181770" + ] + # special_doc_id_list = check_mapping_doc_id_list + special_doc_id_list = check_db_mapping_doc_id_list + # special_doc_id_list = ["382366116"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_extract_data = False - re_run_mapping_data = True - force_save_total_data = False + re_run_mapping_data = False + force_save_total_data = True extract_ways = ["text"] for extract_way in extract_ways: diff --git a/utils/biz_utils.py b/utils/biz_utils.py index 346fd1d..789cbb2 100644 --- a/utils/biz_utils.py +++ b/utils/biz_utils.py @@ -119,10 +119,15 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list for i in range(len(copy_name_list)): temp_splits = copy_name_list[i].split() copy_name_list[i] = ' '.join([split for split in temp_splits - if remove_special_characters(split).lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']]) + if remove_special_characters(split).lower() + not in ['fund', "funds", 'portfolio', + 'class', 'classes', + 'share', 'shares']]) final_splits = [] for split in new_splits: - if split.lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']: + if split.lower() not in ['fund', "funds", 'portfolio', + 'class', 'classes', + 'share', 'shares']: final_splits.append(split) text = ' '.join(final_splits)