diff --git a/core/data_extraction.py b/core/data_extraction.py index aab9b7a..3bbc0e6 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -253,7 +253,7 @@ class DataExtraction: exclude_data: list) -> list: """ If occur error, split the context to two parts and try to get data from the two parts - Relevant document: 503194284 + Relevant document: 503194284, page index 147 """ try: logger.info(f"Split context to get data to fix issue which output length is over 4K tokens") diff --git a/core/metrics.py b/core/metrics.py index be9f939..8bbd539 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -53,7 +53,7 @@ class Metrics: {"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0} ] - missing_error_list, metrics_list = self.get_metrics() + missing_error_list, metrics_list = self.calculate_metrics() missing_error_df = pd.DataFrame(missing_error_list) missing_error_df.reset_index(drop=True, inplace=True) @@ -66,7 +66,7 @@ class Metrics: metrics_df.to_excel(writer, sheet_name="Metrics", index=False) return missing_error_list, metrics_list, self.output_file - def get_metrics(self): + def calculate_metrics(self): prediction_df = pd.read_excel( self.prediction_file, sheet_name=self.prediction_sheet_name ) @@ -302,11 +302,13 @@ class Metrics: get_unique_words_text ) ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist() + ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist() # add new column to store unique words for dp_prediction dp_prediction["unique_words"] = dp_prediction["raw_name"].apply( get_unique_words_text ) pred_unique_words = dp_prediction["unique_words"].unique().tolist() + pred_raw_names = dp_prediction["raw_name"].unique().tolist() true_data = [] pred_data = [] @@ -325,11 +327,18 @@ class Metrics: pred_data_point_value = prediction["value"] pred_investment_type = prediction["investment_type"] - if pred_unique_words in ground_truth_unique_words: + find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names + if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name] + if pred_unique_words in ground_truth_unique_words or len(find_raw_name_in_gt) > 0: # get the ground truth data with the same unique words - gt_data = dp_ground_truth[ - dp_ground_truth["unique_words"] == pred_unique_words - ].iloc[0] + if pred_unique_words in ground_truth_unique_words: + gt_data = dp_ground_truth[ + dp_ground_truth["unique_words"] == pred_unique_words + ].iloc[0] + else: + gt_data = dp_ground_truth[ + dp_ground_truth["raw_name"] == find_raw_name_in_gt[0] + ].iloc[0] gt_data_point_value = gt_data["value"] if pred_data_point_value == gt_data_point_value: true_data.append(1) @@ -370,7 +379,11 @@ class Metrics: gt_data_point_value = ground_truth["value"] gt_investment_type = ground_truth["investment_type"] - if gt_unique_words not in pred_unique_words: + find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names + if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name] + + if gt_unique_words not in pred_unique_words and \ + len(find_raw_name_in_pred) == 0: true_data.append(1) pred_data.append(0) error_data = { diff --git a/main.py b/main.py index 24d3428..e71e647 100644 --- a/main.py +++ b/main.py @@ -355,6 +355,7 @@ def get_metrics( prediction_file: str, prediction_sheet_name: str, ground_truth_file: str, + ground_truth_sheet_name: str = None, output_folder: str = None, ) -> None: metrics = Metrics( @@ -362,6 +363,7 @@ def get_metrics( prediction_file=prediction_file, prediction_sheet_name=prediction_sheet_name, ground_truth_file=ground_truth_file, + ground_truth_sheet_name=ground_truth_sheet_name, output_folder=output_folder, ) missing_error_list, metrics_list, metrics_file = metrics.get_metrics() @@ -472,6 +474,22 @@ def test_auto_generate_instructions(): f.write(ogc_ter_performance_fee_instructions_text) +def test_data_extraction_metrics(): + data_type = "data_extraction" + prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx" + prediction_sheet_name = "mapping_data" + ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx" + ground_truth_sheet_name = "mapping_data" + metrics_output_folder = r"/data/emea_ar/output/metrics/" + missing_error_list, metrics_list, metrics_file = get_metrics( + data_type, + prediction_file, + prediction_sheet_name, + ground_truth_file, + ground_truth_sheet_name, + metrics_output_folder + ) + if __name__ == "__main__": pdf_folder = r"/data/emea_ar/small_pdf/" page_filter_ground_truth_file = ( @@ -506,23 +524,23 @@ if __name__ == "__main__": # doc_id = "476492237" # extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run) - special_doc_id_list = [ - "525574973", - ] + special_doc_id_list = [] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_mapping_data = True - force_save_total_data = False - batch_start_job( - pdf_folder, - page_filter_ground_truth_file, - output_extract_data_child_folder, - output_mapping_child_folder, - output_extract_data_total_folder, - output_mapping_total_folder, - special_doc_id_list, - re_run_extract_data, - re_run_mapping_data, - force_save_total_data=force_save_total_data, - ) + force_save_total_data = True + # batch_start_job( + # pdf_folder, + # page_filter_ground_truth_file, + # output_extract_data_child_folder, + # output_mapping_child_folder, + # output_extract_data_total_folder, + # output_mapping_total_folder, + # special_doc_id_list, + # re_run_extract_data, + # re_run_mapping_data, + # force_save_total_data=force_save_total_data, + # ) + + test_data_extraction_metrics()