From 98e86a6cfd89bfd2a940daa9b889783bdc194631 Mon Sep 17 00:00:00 2001 From: Blade He Date: Wed, 18 Sep 2024 17:10:54 -0500 Subject: [PATCH] realize to calculate data extraction metrics. --- core/metrics.py | 237 ++++++++++++++---- .../data_extraction_prompts_config.json | 18 +- main.py | 16 +- prepare_data.py | 90 +++++-- utils/biz_utils.py | 10 + 5 files changed, 305 insertions(+), 66 deletions(-) diff --git a/core/metrics.py b/core/metrics.py index 815333c..be9f939 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -3,6 +3,7 @@ import pandas as pd import time import json from sklearn.metrics import precision_score, recall_score, f1_score +from utils.biz_utils import get_unique_words_text from utils.logger import logger @@ -13,12 +14,14 @@ class Metrics: prediction_file: str, prediction_sheet_name: str = "Sheet1", ground_truth_file: str = None, + ground_truth_sheet_name: str = "Sheet1", output_folder: str = None, ) -> None: self.data_type = data_type self.prediction_file = prediction_file self.prediction_sheet_name = prediction_sheet_name self.ground_truth_file = ground_truth_file + self.ground_truth_sheet_name = ground_truth_sheet_name if output_folder is None or len(output_folder) == 0: output_folder = r"/data/emea_ar/output/metrics/" @@ -49,27 +52,27 @@ class Metrics: metrics_list = [ {"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0} ] - if self.data_type == "page_filter": - missing_error_list, metrics_list = self.get_page_filter_metrics() - elif self.data_type == "datapoint": - missing_error_list, metrics_list = self.get_datapoint_metrics() - else: - logger.error(f"Invalid data type: {self.data_type}") + + missing_error_list, metrics_list = self.get_metrics() missing_error_df = pd.DataFrame(missing_error_list) missing_error_df.reset_index(drop=True, inplace=True) - + metrics_df = pd.DataFrame(metrics_list) metrics_df.reset_index(drop=True, inplace=True) - + with pd.ExcelWriter(self.output_file) as writer: missing_error_df.to_excel(writer, sheet_name="Missing_Error", index=False) metrics_df.to_excel(writer, sheet_name="Metrics", index=False) return missing_error_list, metrics_list, self.output_file - def get_page_filter_metrics(self): - prediction_df = pd.read_excel(self.prediction_file, sheet_name=self.prediction_sheet_name) - ground_truth_df = pd.read_excel(self.ground_truth_file, sheet_name="Sheet1") + def get_metrics(self): + prediction_df = pd.read_excel( + self.prediction_file, sheet_name=self.prediction_sheet_name + ) + ground_truth_df = pd.read_excel( + self.ground_truth_file, sheet_name=self.ground_truth_sheet_name + ) ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1] tor_true = [] @@ -87,27 +90,61 @@ class Metrics: missing_error_list = [] data_point_list = ["tor", "ter", "ogc", "performance_fee"] - for index, row in ground_truth_df.iterrows(): - doc_id = row["doc_id"] - # get first row with the same doc_id - prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[0] - for data_point in data_point_list: - true_data, pred_data, missing_error_data = self.get_true_pred_data( - doc_id, row, prediction_data, data_point - ) - if data_point == "tor": - tor_true.extend(true_data) - tor_pred.extend(pred_data) - elif data_point == "ter": - ter_true.extend(true_data) - ter_pred.extend(pred_data) - elif data_point == "ogc": - ogc_true.extend(true_data) - ogc_pred.extend(pred_data) - elif data_point == "performance_fee": - performance_fee_true.extend(true_data) - performance_fee_pred.extend(pred_data) - missing_error_list.append(missing_error_data) + if self.data_type == "page_filter": + for index, row in ground_truth_df.iterrows(): + doc_id = row["doc_id"] + # get first row with the same doc_id + prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[ + 0 + ] + for data_point in data_point_list: + true_data, pred_data, missing_error_data = ( + self.get_page_filter_true_pred_data( + doc_id, row, prediction_data, data_point + ) + ) + if data_point == "tor": + tor_true.extend(true_data) + tor_pred.extend(pred_data) + elif data_point == "ter": + ter_true.extend(true_data) + ter_pred.extend(pred_data) + elif data_point == "ogc": + ogc_true.extend(true_data) + ogc_pred.extend(pred_data) + elif data_point == "performance_fee": + performance_fee_true.extend(true_data) + performance_fee_pred.extend(pred_data) + missing_error_list.append(missing_error_data) + else: + prediction_doc_id_list = prediction_df["doc_id"].unique().tolist() + ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist() + doc_id_list = list(set(prediction_doc_id_list + ground_truth_doc_id_list)) + # order by doc_id + doc_id_list.sort() + + for doc_id in doc_id_list: + prediction_data = prediction_df[prediction_df["doc_id"] == doc_id] + ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id] + for data_point in data_point_list: + true_data, pred_data, missing_error_data = ( + self.get_data_extraction_true_pred_data( + doc_id, ground_truth_data, prediction_data, data_point + ) + ) + if data_point == "tor": + tor_true.extend(true_data) + tor_pred.extend(pred_data) + elif data_point == "ter": + ter_true.extend(true_data) + ter_pred.extend(pred_data) + elif data_point == "ogc": + ogc_true.extend(true_data) + ogc_pred.extend(pred_data) + elif data_point == "performance_fee": + performance_fee_true.extend(true_data) + performance_fee_pred.extend(pred_data) + missing_error_list.extend(missing_error_data) metrics_list = [] for data_point in data_point_list: @@ -188,14 +225,17 @@ class Metrics: } ) return missing_error_list, metrics_list - + def get_support_number(self, true_data: list): # get the count which true_data is 1 return sum(true_data) - - def get_true_pred_data( - self, doc_id, ground_truth_data: pd.Series, prediction_data: pd.Series, data_point: str + def get_page_filter_true_pred_data( + self, + doc_id, + ground_truth_data: pd.Series, + prediction_data: pd.Series, + data_point: str, ): ground_truth_list = ground_truth_data[data_point] if isinstance(ground_truth_list, str): @@ -206,10 +246,14 @@ class Metrics: true_data = [] pred_data = [] - - - missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": [], "error": []} - + + missing_error_data = { + "doc_id": doc_id, + "data_point": data_point, + "missing": [], + "error": [], + } + missing_data = [] error_data = [] @@ -217,7 +261,7 @@ class Metrics: true_data.append(1) pred_data.append(1) return true_data, pred_data, missing_error_data - + for prediction in prediction_list: if prediction in ground_truth_list: true_data.append(1) @@ -232,8 +276,115 @@ class Metrics: true_data.append(1) pred_data.append(0) missing_data.append(ground_truth) - missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": missing_data, "error": error_data} - + missing_error_data = { + "doc_id": doc_id, + "data_point": data_point, + "missing": missing_data, + "error": error_data, + } + + return true_data, pred_data, missing_error_data + + def get_data_extraction_true_pred_data( + self, + doc_id, + ground_truth_data: pd.DataFrame, + prediction_data: pd.DataFrame, + data_point: str, + ): + dp_ground_truth = ground_truth_data[ + ground_truth_data["datapoint"] == data_point + ] + dp_prediction = prediction_data[prediction_data["datapoint"] == data_point] + + # add new column to store unique words for dp_ground_truth + dp_ground_truth["unique_words"] = dp_ground_truth["raw_name"].apply( + get_unique_words_text + ) + ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist() + # add new column to store unique words for dp_prediction + dp_prediction["unique_words"] = dp_prediction["raw_name"].apply( + get_unique_words_text + ) + pred_unique_words = dp_prediction["unique_words"].unique().tolist() + + true_data = [] + pred_data = [] + + missing_error_data = [] + + if len(dp_ground_truth) == 0 and len(dp_prediction) == 0: + true_data.append(1) + pred_data.append(1) + return true_data, pred_data, missing_error_data + + for index, prediction in dp_prediction.iterrows(): + pred_page_index = prediction["page_index"] + pred_raw_name = prediction["raw_name"] + pred_unique_words = prediction["unique_words"] + pred_data_point_value = prediction["value"] + pred_investment_type = prediction["investment_type"] + + if pred_unique_words in ground_truth_unique_words: + # get the ground truth data with the same unique words + gt_data = dp_ground_truth[ + dp_ground_truth["unique_words"] == pred_unique_words + ].iloc[0] + gt_data_point_value = gt_data["value"] + if pred_data_point_value == gt_data_point_value: + true_data.append(1) + pred_data.append(1) + else: + true_data.append(0) + pred_data.append(1) + error_data = { + "doc_id": doc_id, + "data_point": data_point, + "page_index": pred_page_index, + "pred_raw_name": pred_raw_name, + "investment_type": pred_investment_type, + "error_type": "data value incorrect", + "error_value": pred_data_point_value, + "correct_value": gt_data_point_value, + } + missing_error_data.append(error_data) + else: + true_data.append(0) + pred_data.append(1) + error_data = { + "doc_id": doc_id, + "data_point": data_point, + "page_index": pred_page_index, + "pred_raw_name": pred_raw_name, + "investment_type": pred_investment_type, + "error_type": "raw name incorrect", + "error_value": pred_raw_name, + "correct_value": "", + } + missing_error_data.append(error_data) + + for index, ground_truth in dp_ground_truth.iterrows(): + gt_page_index = ground_truth["page_index"] + gt_raw_name = ground_truth["raw_name"] + gt_unique_words = ground_truth["unique_words"] + gt_data_point_value = ground_truth["value"] + gt_investment_type = ground_truth["investment_type"] + + if gt_unique_words not in pred_unique_words: + true_data.append(1) + pred_data.append(0) + error_data = { + "doc_id": doc_id, + "data_point": data_point, + "page_index": gt_page_index, + "pred_raw_name": "", + "investment_type": gt_investment_type, + "error_type": "raw name missing", + "error_value": pred_data_point_value, + "correct_value": gt_raw_name, + } + missing_error_data.append(error_data) + return true_data, pred_data, missing_error_data def get_specific_metrics(self, true_data: list, pred_data: list): diff --git a/instructions/data_extraction_prompts_config.json b/instructions/data_extraction_prompts_config.json index d5ce07d..094d9b4 100644 --- a/instructions/data_extraction_prompts_config.json +++ b/instructions/data_extraction_prompts_config.json @@ -8,10 +8,20 @@ }, "data_business_features": { "common": [ - "Most of cases, the data is in the table(s) of context.", - "Fund name: a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.\nb. The sub-fund name may be as the first column values in the table.", + "General rules:", + "- Most of cases, the data is in the table(s) of context.", + "- Fund name: ", + "a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.", + "b. The sub-fund name may be as the first column or first row values in the table.", + "b.1 fund name example:", + "- context:", + "Summary information\nCapital International Fund Audited Annual Report 2023 | 15\nFootnotes are on page 17.\nCapital Group Multi-Sector \nIncome Fund (LUX) \n(CGMSILU)\nCapital Group US High Yield \nFund (LUX) (CGUSHYLU)\nCapital Group Emerging \nMarkets Debt Fund (LUX) \n(CGEMDLU)", + "fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)", + "- Only extract the latest data from context:", "If with multiple data values in same row, please extract the latest.", - "Only output the values which with significant reported names.\nPlease exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!", + "- Reported names:", + "Only output the values which with significant reported names.", + "Please exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!", "One fund could be with multiple share classes and relevant share class level data values." ], "investment_level": { @@ -106,7 +116,7 @@ "Only output the data point which with relevant value.", "Don't ignore the data point which with negative value, e.g. -0.12, -1.13", "Don't ignore the data point which with explicit zero value, e.g. 0, 0.00", - "Ignore the data point which with -, N/A, N/A%, N/A %, NONE, etc.", + "Ignore the data point which value with -, *, **, N/A, N/A%, N/A %, NONE, etc.", "Fund level data: (\"fund name\" and \"TOR\") and share level data: (\"fund name\", \"share name\", \"ter\", \"performance fees\", \"ogc\") should be output separately.", "The output should be JSON format, the format is like below example(s):" ], diff --git a/main.py b/main.py index 8659cdb..24d3428 100644 --- a/main.py +++ b/main.py @@ -217,6 +217,7 @@ def batch_start_job( special_doc_id_list: list = None, re_run_extract_data: bool = False, re_run_mapping_data: bool = False, + force_save_total_data: bool = False, ): pdf_files = glob(pdf_folder + "*.pdf") doc_list = [] @@ -249,14 +250,14 @@ def batch_start_job( ) result_extract_data_list.extend(doc_data_from_gpt) result_mapping_data_list.extend(doc_mapping_data_list) - - if special_doc_id_list is None or len(special_doc_id_list) == 0: + + if force_save_total_data or (special_doc_id_list is None or len(special_doc_id_list) == 0): result_extract_data_df = pd.DataFrame(result_extract_data_list) result_extract_data_df.reset_index(drop=True, inplace=True) result_mappingdata_df = pd.DataFrame(result_mapping_data_list) result_mappingdata_df.reset_index(drop=True, inplace=True) - + logger.info(f"Saving extract data to {output_extract_data_total_folder}") os.makedirs(output_extract_data_total_folder, exist_ok=True) time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) @@ -283,7 +284,7 @@ def batch_start_job( result_extract_data_df.to_excel( writer, index=False, sheet_name="extract_data" ) - + def batch_filter_pdf_files( pdf_folder: str, @@ -505,10 +506,14 @@ if __name__ == "__main__": # doc_id = "476492237" # extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run) - special_doc_id_list = ["508854243"] + special_doc_id_list = [ + "525574973", + ] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_mapping_data = True + + force_save_total_data = False batch_start_job( pdf_folder, page_filter_ground_truth_file, @@ -519,4 +524,5 @@ if __name__ == "__main__": special_doc_id_list, re_run_extract_data, re_run_mapping_data, + force_save_total_data=force_save_total_data, ) diff --git a/prepare_data.py b/prepare_data.py index bb45188..95a7cf5 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -113,7 +113,11 @@ def analyze_json_error(): def statistics_document( - pdf_folder: str, doc_mapping_file_path: str, output_folder: str + pdf_folder: str, + doc_mapping_file_path: str, + sheet_name: str = "all_data", + output_folder: str = "/data/emea_ar/basic_information/English/", + output_file: str = "doc_mapping_statistics_data.xlsx" ): if pdf_folder is None or len(pdf_folder) == 0 or not os.path.exists(pdf_folder): logger.error(f"Invalid pdf_folder: {pdf_folder}") @@ -132,7 +136,7 @@ def statistics_document( describe_stat_df_list = [] # statistics document mapping information - doc_mapping_data = pd.read_excel(doc_mapping_file_path, sheet_name="all_data") + doc_mapping_data = pd.read_excel(doc_mapping_file_path, sheet_name=sheet_name) # statistics doc_mapping_data for counting FundId count based on DocumentId logger.info( @@ -172,15 +176,15 @@ def statistics_document( ) describe_stat_df_list.append(doc_share_class_count_stat_df) - # statistics doc_mapping_data for counting FundId count based on ProviderCompanyId and CompanyName + # statistics doc_mapping_data for counting FundId count based on CompanyId and CompanyName logger.info( - "statistics doc_mapping_data for counting FundId count based on ProviderCompanyId and CompanyName" + "statistics doc_mapping_data for counting FundId count based on CompanyId and CompanyName" ) provider_fund_id_df = doc_mapping_data[ - ["ProviderCompanyId", "CompanyName", "FundId"] + ["CompanyId", "CompanyName", "FundId"] ].drop_duplicates() provider_fund_count = ( - provider_fund_id_df.groupby(["ProviderCompanyId", "CompanyName"]) + provider_fund_id_df.groupby(["CompanyId", "CompanyName"]) .size() .reset_index(name="fund_count") ) @@ -194,15 +198,15 @@ def statistics_document( ) describe_stat_df_list.append(provider_fund_count_stat_df) - # statistics doc_mapping_data for counting FundClassId count based on ProviderCompanyId + # statistics doc_mapping_data for counting FundClassId count based on CompanyId logger.info( - "statistics doc_mapping_data for counting FundClassId count based on ProviderCompanyId" + "statistics doc_mapping_data for counting FundClassId count based on CompanyId" ) provider_share_class_id_df = doc_mapping_data[ - ["ProviderCompanyId", "CompanyName", "FundClassId"] + ["CompanyId", "CompanyName", "FundClassId"] ].drop_duplicates() provider_share_class_count = ( - provider_share_class_id_df.groupby(["ProviderCompanyId", "CompanyName"]) + provider_share_class_id_df.groupby(["CompanyId", "CompanyName"]) .size() .reset_index(name="share_class_count") ) @@ -238,13 +242,18 @@ def statistics_document( ) describe_stat_df_list.append(fund_share_class_count_stat_df) - stat_file = os.path.join(output_folder, "doc_mapping_statistics_data.xlsx") + stat_file = os.path.join(output_folder, output_file) + + doc_id_list = [str(docid) for docid in doc_mapping_data["DocumentId"].unique().tolist()] # statistics document page number pdf_files = glob(os.path.join(pdf_folder, "*.pdf")) logger.info(f"Total {len(pdf_files)} pdf files found in {pdf_folder}") logger.info("statistics document page number") doc_page_num_list = [] for pdf_file in tqdm(pdf_files): + pdf_base_name = os.path.basename(pdf_file).replace(".pdf", "") + if pdf_base_name not in doc_id_list: + continue docid = os.path.basename(pdf_file).split(".")[0] doc = fitz.open(pdf_file) page_num = doc.page_count @@ -829,6 +838,46 @@ def pickup_document_from_top_100_providers(): ) +def compare_records_count_by_document_id(): + data_from_document = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" + sheet_name = "mapping_data" + data_from_document_df = pd.read_excel(data_from_document, sheet_name=sheet_name) + data_from_document_df.rename( + columns={"doc_id": "DocumentId"}, inplace=True + ) + # get the count of records by DocumentId + document_records_count = data_from_document_df.groupby("DocumentId").size().reset_index(name="records_count") + + data_from_database = r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx" + sheet_name = "random_small_document_all_data" + data_from_database_df = pd.read_excel(data_from_database, sheet_name=sheet_name) + database_records_count = data_from_database_df.groupby("DocumentId").size().reset_index(name="records_count") + + # merge document_records_count with database_records_count + records_count_compare = pd.merge( + document_records_count, + database_records_count, + on=["DocumentId"], + how="left", + ) + records_count_compare["records_count_diff"] = records_count_compare["records_count_x"] - records_count_compare["records_count_y"] + records_count_compare = records_count_compare.sort_values(by="records_count_diff", ascending=False) + # rename records_count_x to records_count_document, records_count_y to records_count_database + records_count_compare.rename( + columns={"records_count_x": "records_count_document", + "records_count_y": "records_count_database"}, inplace=True + ) + records_count_compare.reset_index(drop=True, inplace=True) + + records_count_compare_file = ( + r"/data/emea_ar/basic_information/English/records_count_compare_between_document_database.xlsx" + ) + with pd.ExcelWriter(records_count_compare_file) as writer: + records_count_compare.to_excel( + writer, sheet_name="records_count_compare", index=False + ) + + if __name__ == "__main__": doc_provider_file_path = ( r"/data/emea_ar/basic_information/English/latest_provider_ar_document.xlsx" @@ -845,22 +894,35 @@ if __name__ == "__main__": output_folder = r"/data/emea_ar/output/" # get_unique_docids_from_doc_provider_data(doc_provider_file_path) # download_pdf(doc_provider_file_path, 'doc_provider_count', pdf_folder) - pdf_folder = r"/data/emea_ar/small_pdf/" + # pdf_folder = r"/data/emea_ar/small_pdf/" output_folder = r"/data/emea_ar/small_pdf_txt/" random_small_document_data_file = ( r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx" ) + + # download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder) # output_pdf_page_text(pdf_folder, output_folder) # extract_pdf_table(pdf_folder, output_folder) # analyze_json_error() - # statistics_document(pdf_folder, doc_mapping_file_path, basic_info_folder) + latest_top_100_provider_ar_data_file = r"/data/emea_ar/basic_information/English/top_100_provider_latest_document_most_mapping/lux_english_ar_from_top_100_provider_latest_document_with_most_mappings.xlsx" + # download_pdf(latest_top_100_provider_ar_data_file, + # 'latest_ar_document_most_mapping', + # pdf_folder) + + output_data_folder = r"/data/emea_ar/basic_information/English/top_100_provider_latest_document_most_mapping/" + statistics_document(pdf_folder=pdf_folder, + doc_mapping_file_path=latest_top_100_provider_ar_data_file, + sheet_name="latest_doc_ar_data", + output_folder=output_data_folder, + output_file="latest_doc_ar_mapping_statistics.xlsx") # statistics_provider_mapping( # provider_mapping_data_file=provider_mapping_data_file, # output_folder=basic_info_folder, # ) # statistics_document_fund_share_count(doc_mapping_from_top_100_provider_file) - pickup_document_from_top_100_providers() + # pickup_document_from_top_100_providers() + # compare_records_count_by_document_id() diff --git a/utils/biz_utils.py b/utils/biz_utils.py index 05d7dd5..073aefe 100644 --- a/utils/biz_utils.py +++ b/utils/biz_utils.py @@ -165,6 +165,16 @@ def remove_special_characters(text): text = text.strip() return text +def get_unique_words_text(text): + text = remove_special_characters(text) + text = text.lower() + text_split = text.split() + text_split = list(set(text_split)) + # sort the list + text_split.sort() + return_text = ' '.join(text_split) + return return_text + def remove_numeric_characters(text): # remove numeric characters