diff --git a/core/data_extraction.py b/core/data_extraction.py index 18309b4..7b34530 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -191,6 +191,7 @@ class DataExtraction: page_datapoints, need_exclude=False, exclude_data=None, + previous_page_last_fund=previous_page_fund_name ) data_list.append(extract_data) @@ -239,6 +240,7 @@ class DataExtraction: next_datapoints, need_exclude=True, exclude_data=page_data_list, + previous_page_last_fund=previous_page_fund_name ) next_page_data_list = next_page_extract_data.get( "extract_data", {} @@ -373,7 +375,8 @@ class DataExtraction: page_text: str, page_datapoints: list, need_exclude: bool = False, - exclude_data: list = None,) -> dict: + exclude_data: list = None, + previous_page_last_fund: str = None) -> dict: # If can't find numberic value, e.g. 1.25 or 3,88 # apply Vision ChatGPT to extract data numeric_regex = r"\d+(\.|\,)\d+" @@ -383,7 +386,12 @@ class DataExtraction: page_num, page_datapoints, need_exclude, exclude_data) else: return self.extract_data_by_page_text( - page_num, page_text, page_datapoints, need_exclude, exclude_data + page_num=page_num, + page_text=page_text, + page_datapoints=page_datapoints, + need_exclude=need_exclude, + exclude_data=exclude_data, + previous_page_last_fund=previous_page_last_fund ) def extract_data_by_page_text( @@ -393,6 +401,7 @@ class DataExtraction: page_datapoints: list, need_exclude: bool = False, exclude_data: list = None, + previous_page_last_fund: str = None ) -> dict: """ keys are @@ -435,7 +444,7 @@ class DataExtraction: data = json_repair.loads(response) except: data = {"data": []} - data = self.validate_data(data) + data = self.validate_data(data, previous_page_last_fund) data_dict = {"doc_id": self.doc_id} data_dict["page_index"] = page_num @@ -488,7 +497,7 @@ class DataExtraction: data = json_repair.loads(response) except: data = {"data": []} - data = self.validate_data(data) + data = self.validate_data(data, None) data_dict = {"doc_id": self.doc_id} data_dict["page_index"] = page_num @@ -500,7 +509,7 @@ class DataExtraction: data_dict["extract_way"] = "image" return data_dict - def validate_data(self, extract_data_info: dict) -> dict: + def validate_data(self, extract_data_info: dict, previous_page_last_fund: str=None) -> dict: """ Validate data by the rules 1. Each data should be with fund name @@ -511,10 +520,18 @@ class DataExtraction: return extract_data_info remove_list = [] for data in data_list: - fund_name = data.get("fund name", "") + fund_name = data.get("fund name", "").strip() if fund_name == "": remove_list.append(data) + if previous_page_last_fund is not None and len(previous_page_last_fund) > 0: + previous_page_last_fund = previous_page_last_fund.strip() + if fund_name.startswith(previous_page_last_fund) and fund_name != previous_page_last_fund: + modified_fund_name = fund_name.replace(previous_page_last_fund, "").strip() + if len(modified_fund_name.split()) > 1: + fund_name = modified_fund_name + fund_name = self.get_fund_name(fund_name, "Fund") + fund_name = self.get_fund_name(fund_name, "Bond") data["fund name"] = fund_name keys = list(data.keys()) for key in keys: @@ -838,7 +855,25 @@ class DataExtraction: instructions.append(contents) instructions.append("\n\n") instructions.append("\n") - + + # extreme_complex_config_list = special_cases.get("extreme_complex", []) + # if len(extreme_complex_config_list) > 0: + # for extreme_complex_config in extreme_complex_config_list: + # regex = extreme_complex_config.get("regex", "") + # if len(regex) == 0: + # continue + # search = re.search(regex, page_text) + # if search is not None: + # title = extreme_complex_config.get("title", "") + # title = f"{special_cases_number}. {title} " + # special_cases_number += 1 + # instructions.append(title) + # instructions.append("\n") + # contents_list = extreme_complex_config.get("contents", []) + # contents = "\n".join(contents_list) + # instructions.append(contents) + # instructions.append("\n\n") + instructions.append("Output requirement:\n") output_requirement = self.instructions_config.get("output_requirement", {}) output_requirement_common_list = output_requirement.get("common", []) diff --git a/instructions/data_extraction_prompts_config.json b/instructions/data_extraction_prompts_config.json index cc6b238..8da7b39 100644 --- a/instructions/data_extraction_prompts_config.json +++ b/instructions/data_extraction_prompts_config.json @@ -35,16 +35,21 @@ "a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.", "b. The sub-fund name may be as the first column or first row values in the table.", "b.1 fund name example:", - "---- context:", + "---- Example Start ----", "Summary information\nCapital International Fund Audited Annual Report 2023 | 15\nFootnotes are on page 17.\nCapital Group Multi-Sector \nIncome Fund (LUX) \n(CGMSILU)\nCapital Group US High Yield \nFund (LUX) (CGUSHYLU)\nCapital Group Emerging \nMarkets Debt Fund (LUX) \n(CGEMDLU)", - "---- fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)", + "---- Example End ----", + "Fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)", + "\n", "c. If with multiple fund names in context, please retrieve the fund name closest above the numerical value.", "c.1 fund name example:", - "---- context:", + "---- Example Start ----", "AXA World Funds ACT Emerging Markets Bonds\nAXA World Funds \n \nAdditional Unaudited Appendix \n\nƒ$GGLWLRQDO8QDXGLWHG$SSHQGL[$118$/5(3257$;$:RUOG)XQGV\nExpense Ratios (continued) \n \nCalculated TER (1) \nSwiss method \nApplied\nService Fee (2)\nOngoing \nCharges (3) \n \nwith performance \nfees \nwithout performance \nfees \n \nAXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon \nA Capitalisation CHF Hedged \n1.26% \n1.26% \n0.26% \n1.29%", - "---- correct fund name: AXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon", + "---- Example End ----", + "Correct fund name: AXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon", + "\n", "- Only extract the latest data from context:", "If with multiple data values in same row, please extract the latest.", + "\n", "- Reported names:", "Only output the values which with significant reported names.", "Please exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!", @@ -245,6 +250,41 @@ "The performance fees value is Ongoing Charges inkl. Performance-Fee in % **) - Ongoing Charges exkl. Performance-Fee in % **) = 1.20 - 1.15 = 0.05" ] } + ], + "extreme_complex": [ + { + "title": "Complex Data Table Structure", + "regex": "([A-Z]{1,2}\\,\\s?){3,}", + "contents": [ + "Complex Data Table Structure", + "Table structure: the first column is fund name, for each table title, there are a lot of share class names in it.", + "Please split these share class names and extract all of relevant data as fund name, share name, data point and value one by one from the table.", + "-----Example Start-----", + "Charges and expenses (continued) ", + "d) Operating, Administrative and Servicing Expenses / Operating Currency Hedged Share Class Fees (continued)", + "The following table shows the rates of Operating, Administrative and Servicing Expenses:", + "Class A, B, E, ", + "M,O ", + "EQUITY SUB-FUNDS ", + "a) Equity sub-funds ", + "Fund 1", + "0.35", + "Fund 2", + "0.26", + "-----Example End-----", + "The output should be:", + "{\"data\": [{\"fund name\": \"Fund 1\", \"share name\": \"A\", \"ogc\": 0.35},", + "{\"fund name\": \"Fund 1\", \"share name\": \"B\", \"ogc\": 0.35},", + "{\"fund name\": \"Fund 1\", \"share name\": \"E\", \"ogc\": 0.35},", + "{\"fund name\": \"Fund 1\", \"share name\": \"M\", \"ogc\": 0.35},", + "{\"fund name\": \"Fund 1\", \"share name\": \"O\", \"ogc\": 0.35}", + "{\"fund name\": \"Fund 2\", \"share name\": \"A\", \"ogc\": 0.26},", + "{\"fund name\": \"Fund 2\", \"share name\": \"B\", \"ogc\": 0.26},", + "{\"fund name\": \"Fund 2\", \"share name\": \"E\", \"ogc\": 0.26},", + "{\"fund name\": \"Fund 2\", \"share name\": \"M\", \"ogc\": 0.26},", + "{\"fund name\": \"Fund 2\", \"share name\": \"O\", \"ogc\": 0.26}]}" + ] + } ] }, "output_requirement": { diff --git a/main.py b/main.py index 5c58107..f847aec 100644 --- a/main.py +++ b/main.py @@ -855,55 +855,9 @@ def replace_rerun_data(new_data_file: str, original_data_file: str): new_data_in_doc_mapping.to_excel(writer, index=False, sheet_name=data_in_doc_mapping_sheet) new_total_mapping_data.to_excel(writer, index=False, sheet_name=total_mapping_data_sheet) new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet) - -if __name__ == "__main__": - # new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx" - # original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx" - # replace_rerun_data(new_data_file, original_data_file) - # test_calculate_metrics() - # test_replace_abbrevation() - # test_translate_pdf() - # test_mapping_raw_name() - pdf_folder = r"/data/emea_ar/pdf/" - page_filter_ground_truth_file = ( - r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx" - ) - prediction_output_folder = r"/data/emea_ar/output/filter_pages/" - metrics_output_folder = r"/data/emea_ar/output/metrics/" - special_doc_id_list = [] - # batch_filter_pdf_files( - # pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list - # ) - - # data_type = "page_filter" - # prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx" - # missing_error_list, metrics_list, metrics_file = get_metrics( - # data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder - # ) - - # test_auto_generate_instructions() - - output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/" - output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/" - - # batch_extract_data( - # pdf_folder, - # page_filter_ground_truth_file, - # output_extract_data_child_folder, - # output_extract_data_total_folder, - # special_doc_id_list, - # re_run, - # ) - - # doc_id = "476492237" - # extract_way = "image" - # extract_data(doc_id, - # pdf_folder, - # output_extract_data_child_folder, - # extract_way, - # re_run_extract_data) +def batch_run_documents(): # special_doc_id_list = ["505174428", "510326848", "349679479"] # check_mapping_doc_id_list = [ # "327956364", @@ -1197,7 +1151,13 @@ if __name__ == "__main__": "534535767" ] special_doc_id_list = check_db_mapping_doc_id_list - special_doc_id_list = ["532998065"] + special_doc_id_list = ["534535767"] + pdf_folder = r"/data/emea_ar/pdf/" + page_filter_ground_truth_file = ( + r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx" + ) + output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/" + output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_extract_data = True @@ -1206,8 +1166,6 @@ if __name__ == "__main__": calculate_metrics = False extract_ways = ["text"] - # pdf_folder = r"/data/emea_ar/small_pdf/" - pdf_folder = r"/data/emea_ar/pdf/" for extract_way in extract_ways: batch_start_job( pdf_folder, @@ -1223,6 +1181,50 @@ if __name__ == "__main__": force_save_total_data=force_save_total_data, calculate_metrics=calculate_metrics, ) + +if __name__ == "__main__": + batch_run_documents() + + # new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx" + # original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx" + # replace_rerun_data(new_data_file, original_data_file) + # test_calculate_metrics() + # test_replace_abbrevation() + # test_translate_pdf() + # test_mapping_raw_name() # test_data_extraction_metrics() + # batch_filter_pdf_files( + # pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list + # ) + + # data_type = "page_filter" + # prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx" + # missing_error_list, metrics_list, metrics_file = get_metrics( + # data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder + # ) + + # test_auto_generate_instructions() + + # batch_extract_data( + # pdf_folder, + # page_filter_ground_truth_file, + # output_extract_data_child_folder, + # output_extract_data_total_folder, + # special_doc_id_list, + # re_run, + # ) + + # doc_id = "476492237" + # extract_way = "image" + # extract_data(doc_id, + # pdf_folder, + # output_extract_data_child_folder, + # extract_way, + # re_run_extract_data) + + + + + diff --git a/specific_calc_metrics.py b/specific_calc_metrics.py new file mode 100644 index 0000000..735fc71 --- /dev/null +++ b/specific_calc_metrics.py @@ -0,0 +1,143 @@ +from tqdm import tqdm +from glob import glob +import json +import pandas as pd +import os +from traceback import print_exc +from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score + +from utils.logger import logger + + +def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []): + data_df = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping") + # convert doc_id column to string + data_df["doc_id"] = data_df["doc_id"].astype(str) + data_df = data_df[data_df["raw_check"].isin([0, 1])] + + if document_list is not None and len(document_list) > 0: + data_df = data_df[data_df["doc_id"].isin(document_list)] + + data_df.fillna("", inplace=True) + data_df.reset_index(drop=True, inplace=True) + + # tor data + tor_data_df = data_df[data_df["datapoint"] == "tor"] + tor_metrics = get_sub_metrics(tor_data_df, "tor") + logger.info(f"TOR metrics: {tor_metrics}") + + # ter data + ter_data_df = data_df[data_df["datapoint"] == "ter"] + ter_metrics = get_sub_metrics(ter_data_df, "ter") + logger.info(f"TER metrics: {ter_metrics}") + + # ogc data + ogc_data_df = data_df[data_df["datapoint"] == "ogc"] + ogc_metrics = get_sub_metrics(ogc_data_df, "ogc") + logger.info(f"OGC metrics: {ogc_metrics}") + + # performance_fee data + performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"] + performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee") + logger.info(f"Performance fee metrics: {performance_fee_metrics}") + + metrics_df = pd.DataFrame([tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics]) + # add average metrics + avg_metrics = { + "DataPoint": "average", + "F1": metrics_df["F1"].mean(), + "Precision": metrics_df["Precision"].mean(), + "Recall": metrics_df["Recall"].mean(), + "Accuracy": metrics_df["Accuracy"].mean(), + "Support": metrics_df["Support"].sum() + } + + metrics_df = pd.DataFrame([tor_metrics, ter_metrics, + ogc_metrics, performance_fee_metrics, + avg_metrics]) + metrics_df.reset_index(drop=True, inplace=True) + + output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/" + + document_count = len(document_list) \ + if document_list is not None and len(document_list) > 0 \ + else len(data_df["doc_id"].unique()) + + output_metrics_file = os.path.join(output_folder, + f"complex_document_{document_count}_metrics.xlsx") + with pd.ExcelWriter(output_metrics_file) as writer: + metrics_df.to_excel(writer, index=False, sheet_name="metrics") + + +def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict: + data_df_raw_check_1 = data_df[data_df["raw_check"] == 1] + gt_list = [1] * len(data_df_raw_check_1) + pre_list = [1] * len(data_df_raw_check_1) + + data_df_raw_check_0 = data_df[data_df["raw_check"] == 0] + for index, row in data_df_raw_check_0.iterrows(): + if row["raw_check_comment"] == "modify": + gt_list.append(0) + pre_list.append(1) + + gt_list.append(1) + pre_list.append(0) + elif row["raw_check_comment"] == "incorrect": + gt_list.append(0) + pre_list.append(1) + elif row["raw_check_comment"] == "supplement": + gt_list.append(1) + pre_list.append(0) + else: + pass + + # calculate metrics + accuracy = accuracy_score(gt_list, pre_list) + precision = precision_score(gt_list, pre_list) + recall = recall_score(gt_list, pre_list) + f1 = f1_score(gt_list, pre_list) + support = sum(gt_list) + + metrics = { + "DataPoint": data_point, + "F1": f1, + "Precision": precision, + "Recall": recall, + "Accuracy": accuracy, + "Support": support + } + return metrics + + +if __name__ == "__main__": + file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/" + verify_file = "mapping_data_info_31_documents_by_text_first_round.xlsx" + verify_file_path = os.path.join(file_folder, verify_file) + document_list = [ + "334584772", + "337293427", + "337937633", + "404712928", + "406913630", + "407275419", + "422686965", + "422760148", + "422760156", + "422761666", + "423364758", + "423365707", + "423395975", + "423418395", + "423418540", + "425595958", + "451063582", + "451878128", + "466580448", + "481482392", + "508704368", + "532998065", + "536344026", + "540307575" + ] + calculate_complex_document_metrics(verify_file_path=verify_file_path, + document_list=document_list) \ No newline at end of file