Fix issue for "The last fund name of previous PDF page" logic:

If current page fund name starts with "The last fund name of previous PDF page" and with more contents below, then remove "The last fund name of previous PDF page".
This commit is contained in:
Blade He 2024-12-04 16:57:52 -06:00
parent 36fbaa946e
commit 70362b554f
4 changed files with 281 additions and 61 deletions

View File

@ -191,6 +191,7 @@ class DataExtraction:
page_datapoints,
need_exclude=False,
exclude_data=None,
previous_page_last_fund=previous_page_fund_name
)
data_list.append(extract_data)
@ -239,6 +240,7 @@ class DataExtraction:
next_datapoints,
need_exclude=True,
exclude_data=page_data_list,
previous_page_last_fund=previous_page_fund_name
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
@ -373,7 +375,8 @@ class DataExtraction:
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,) -> dict:
exclude_data: list = None,
previous_page_last_fund: str = None) -> dict:
# If can't find numberic value, e.g. 1.25 or 3,88
# apply Vision ChatGPT to extract data
numeric_regex = r"\d+(\.|\,)\d+"
@ -383,7 +386,12 @@ class DataExtraction:
page_num, page_datapoints, need_exclude, exclude_data)
else:
return self.extract_data_by_page_text(
page_num, page_text, page_datapoints, need_exclude, exclude_data
page_num=page_num,
page_text=page_text,
page_datapoints=page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
previous_page_last_fund=previous_page_last_fund
)
def extract_data_by_page_text(
@ -393,6 +401,7 @@ class DataExtraction:
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
previous_page_last_fund: str = None
) -> dict:
"""
keys are
@ -435,7 +444,7 @@ class DataExtraction:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data)
data = self.validate_data(data, previous_page_last_fund)
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
@ -488,7 +497,7 @@ class DataExtraction:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data)
data = self.validate_data(data, None)
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
@ -500,7 +509,7 @@ class DataExtraction:
data_dict["extract_way"] = "image"
return data_dict
def validate_data(self, extract_data_info: dict) -> dict:
def validate_data(self, extract_data_info: dict, previous_page_last_fund: str=None) -> dict:
"""
Validate data by the rules
1. Each data should be with fund name
@ -511,10 +520,18 @@ class DataExtraction:
return extract_data_info
remove_list = []
for data in data_list:
fund_name = data.get("fund name", "")
fund_name = data.get("fund name", "").strip()
if fund_name == "":
remove_list.append(data)
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
previous_page_last_fund = previous_page_last_fund.strip()
if fund_name.startswith(previous_page_last_fund) and fund_name != previous_page_last_fund:
modified_fund_name = fund_name.replace(previous_page_last_fund, "").strip()
if len(modified_fund_name.split()) > 1:
fund_name = modified_fund_name
fund_name = self.get_fund_name(fund_name, "Fund")
fund_name = self.get_fund_name(fund_name, "Bond")
data["fund name"] = fund_name
keys = list(data.keys())
for key in keys:
@ -839,6 +856,24 @@ class DataExtraction:
instructions.append("\n\n")
instructions.append("\n")
# extreme_complex_config_list = special_cases.get("extreme_complex", [])
# if len(extreme_complex_config_list) > 0:
# for extreme_complex_config in extreme_complex_config_list:
# regex = extreme_complex_config.get("regex", "")
# if len(regex) == 0:
# continue
# search = re.search(regex, page_text)
# if search is not None:
# title = extreme_complex_config.get("title", "")
# title = f"{special_cases_number}. {title} "
# special_cases_number += 1
# instructions.append(title)
# instructions.append("\n")
# contents_list = extreme_complex_config.get("contents", [])
# contents = "\n".join(contents_list)
# instructions.append(contents)
# instructions.append("\n\n")
instructions.append("Output requirement:\n")
output_requirement = self.instructions_config.get("output_requirement", {})
output_requirement_common_list = output_requirement.get("common", [])

View File

@ -35,16 +35,21 @@
"a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.",
"b. The sub-fund name may be as the first column or first row values in the table.",
"b.1 fund name example:",
"---- context:",
"---- Example Start ----",
"Summary information\nCapital International Fund Audited Annual Report 2023 | 15\nFootnotes are on page 17.\nCapital Group Multi-Sector \nIncome Fund (LUX) \n(CGMSILU)\nCapital Group US High Yield \nFund (LUX) (CGUSHYLU)\nCapital Group Emerging \nMarkets Debt Fund (LUX) \n(CGEMDLU)",
"---- fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)",
"---- Example End ----",
"Fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)",
"\n",
"c. If with multiple fund names in context, please retrieve the fund name closest above the numerical value.",
"c.1 fund name example:",
"---- context:",
"---- Example Start ----",
"AXA World Funds ACT Emerging Markets Bonds\nAXA World Funds \n \nAdditional Unaudited Appendix \n\nƒ$GGLWLRQDO8QDXGLWHG$SSHQGL[$118$/5(3257$;$:RUOG)XQGV\nExpense Ratios (continued) \n \nCalculated TER (1) \nSwiss method \nApplied\nService Fee (2)\nOngoing \nCharges (3) \n \nwith performance \nfees \nwithout performance \nfees \n \nAXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon \nA Capitalisation CHF Hedged \n1.26% \n1.26% \n0.26% \n1.29%",
"---- correct fund name: AXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon",
"---- Example End ----",
"Correct fund name: AXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon",
"\n",
"- Only extract the latest data from context:",
"If with multiple data values in same row, please extract the latest.",
"\n",
"- Reported names:",
"Only output the values which with significant reported names.",
"Please exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!",
@ -245,6 +250,41 @@
"The performance fees value is Ongoing Charges inkl. Performance-Fee in % **) - Ongoing Charges exkl. Performance-Fee in % **) = 1.20 - 1.15 = 0.05"
]
}
],
"extreme_complex": [
{
"title": "Complex Data Table Structure",
"regex": "([A-Z]{1,2}\\,\\s?){3,}",
"contents": [
"Complex Data Table Structure",
"Table structure: the first column is fund name, for each table title, there are a lot of share class names in it.",
"Please split these share class names and extract all of relevant data as fund name, share name, data point and value one by one from the table.",
"-----Example Start-----",
"Charges and expenses (continued) ",
"d) Operating, Administrative and Servicing Expenses / Operating Currency Hedged Share Class Fees (continued)",
"The following table shows the rates of Operating, Administrative and Servicing Expenses:",
"Class A, B, E, ",
"M,O ",
"EQUITY SUB-FUNDS ",
"a) Equity sub-funds ",
"Fund 1",
"0.35",
"Fund 2",
"0.26",
"-----Example End-----",
"The output should be:",
"{\"data\": [{\"fund name\": \"Fund 1\", \"share name\": \"A\", \"ogc\": 0.35},",
"{\"fund name\": \"Fund 1\", \"share name\": \"B\", \"ogc\": 0.35},",
"{\"fund name\": \"Fund 1\", \"share name\": \"E\", \"ogc\": 0.35},",
"{\"fund name\": \"Fund 1\", \"share name\": \"M\", \"ogc\": 0.35},",
"{\"fund name\": \"Fund 1\", \"share name\": \"O\", \"ogc\": 0.35}",
"{\"fund name\": \"Fund 2\", \"share name\": \"A\", \"ogc\": 0.26},",
"{\"fund name\": \"Fund 2\", \"share name\": \"B\", \"ogc\": 0.26},",
"{\"fund name\": \"Fund 2\", \"share name\": \"E\", \"ogc\": 0.26},",
"{\"fund name\": \"Fund 2\", \"share name\": \"M\", \"ogc\": 0.26},",
"{\"fund name\": \"Fund 2\", \"share name\": \"O\", \"ogc\": 0.26}]}"
]
}
]
},
"output_requirement": {

102
main.py
View File

@ -857,53 +857,7 @@ def replace_rerun_data(new_data_file: str, original_data_file: str):
new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet)
if __name__ == "__main__":
# new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
# replace_rerun_data(new_data_file, original_data_file)
# test_calculate_metrics()
# test_replace_abbrevation()
# test_translate_pdf()
# test_mapping_raw_name()
pdf_folder = r"/data/emea_ar/pdf/"
page_filter_ground_truth_file = (
r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx"
)
prediction_output_folder = r"/data/emea_ar/output/filter_pages/"
metrics_output_folder = r"/data/emea_ar/output/metrics/"
special_doc_id_list = []
# batch_filter_pdf_files(
# pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list
# )
# data_type = "page_filter"
# prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx"
# missing_error_list, metrics_list, metrics_file = get_metrics(
# data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder
# )
# test_auto_generate_instructions()
output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/"
output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/"
# batch_extract_data(
# pdf_folder,
# page_filter_ground_truth_file,
# output_extract_data_child_folder,
# output_extract_data_total_folder,
# special_doc_id_list,
# re_run,
# )
# doc_id = "476492237"
# extract_way = "image"
# extract_data(doc_id,
# pdf_folder,
# output_extract_data_child_folder,
# extract_way,
# re_run_extract_data)
def batch_run_documents():
# special_doc_id_list = ["505174428", "510326848", "349679479"]
# check_mapping_doc_id_list = [
# "327956364",
@ -1197,7 +1151,13 @@ if __name__ == "__main__":
"534535767"
]
special_doc_id_list = check_db_mapping_doc_id_list
special_doc_id_list = ["532998065"]
special_doc_id_list = ["534535767"]
pdf_folder = r"/data/emea_ar/pdf/"
page_filter_ground_truth_file = (
r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx"
)
output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/"
output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/"
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_extract_data = True
@ -1206,8 +1166,6 @@ if __name__ == "__main__":
calculate_metrics = False
extract_ways = ["text"]
# pdf_folder = r"/data/emea_ar/small_pdf/"
pdf_folder = r"/data/emea_ar/pdf/"
for extract_way in extract_ways:
batch_start_job(
pdf_folder,
@ -1224,5 +1182,49 @@ if __name__ == "__main__":
calculate_metrics=calculate_metrics,
)
if __name__ == "__main__":
batch_run_documents()
# new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
# replace_rerun_data(new_data_file, original_data_file)
# test_calculate_metrics()
# test_replace_abbrevation()
# test_translate_pdf()
# test_mapping_raw_name()
# test_data_extraction_metrics()
# batch_filter_pdf_files(
# pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list
# )
# data_type = "page_filter"
# prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx"
# missing_error_list, metrics_list, metrics_file = get_metrics(
# data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder
# )
# test_auto_generate_instructions()
# batch_extract_data(
# pdf_folder,
# page_filter_ground_truth_file,
# output_extract_data_child_folder,
# output_extract_data_total_folder,
# special_doc_id_list,
# re_run,
# )
# doc_id = "476492237"
# extract_way = "image"
# extract_data(doc_id,
# pdf_folder,
# output_extract_data_child_folder,
# extract_way,
# re_run_extract_data)

143
specific_calc_metrics.py Normal file
View File

@ -0,0 +1,143 @@
from tqdm import tqdm
from glob import glob
import json
import pandas as pd
import os
from traceback import print_exc
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from utils.logger import logger
def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []):
data_df = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping")
# convert doc_id column to string
data_df["doc_id"] = data_df["doc_id"].astype(str)
data_df = data_df[data_df["raw_check"].isin([0, 1])]
if document_list is not None and len(document_list) > 0:
data_df = data_df[data_df["doc_id"].isin(document_list)]
data_df.fillna("", inplace=True)
data_df.reset_index(drop=True, inplace=True)
# tor data
tor_data_df = data_df[data_df["datapoint"] == "tor"]
tor_metrics = get_sub_metrics(tor_data_df, "tor")
logger.info(f"TOR metrics: {tor_metrics}")
# ter data
ter_data_df = data_df[data_df["datapoint"] == "ter"]
ter_metrics = get_sub_metrics(ter_data_df, "ter")
logger.info(f"TER metrics: {ter_metrics}")
# ogc data
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
logger.info(f"OGC metrics: {ogc_metrics}")
# performance_fee data
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
metrics_df = pd.DataFrame([tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics])
# add average metrics
avg_metrics = {
"DataPoint": "average",
"F1": metrics_df["F1"].mean(),
"Precision": metrics_df["Precision"].mean(),
"Recall": metrics_df["Recall"].mean(),
"Accuracy": metrics_df["Accuracy"].mean(),
"Support": metrics_df["Support"].sum()
}
metrics_df = pd.DataFrame([tor_metrics, ter_metrics,
ogc_metrics, performance_fee_metrics,
avg_metrics])
metrics_df.reset_index(drop=True, inplace=True)
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
document_count = len(document_list) \
if document_list is not None and len(document_list) > 0 \
else len(data_df["doc_id"].unique())
output_metrics_file = os.path.join(output_folder,
f"complex_document_{document_count}_metrics.xlsx")
with pd.ExcelWriter(output_metrics_file) as writer:
metrics_df.to_excel(writer, index=False, sheet_name="metrics")
def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
data_df_raw_check_1 = data_df[data_df["raw_check"] == 1]
gt_list = [1] * len(data_df_raw_check_1)
pre_list = [1] * len(data_df_raw_check_1)
data_df_raw_check_0 = data_df[data_df["raw_check"] == 0]
for index, row in data_df_raw_check_0.iterrows():
if row["raw_check_comment"] == "modify":
gt_list.append(0)
pre_list.append(1)
gt_list.append(1)
pre_list.append(0)
elif row["raw_check_comment"] == "incorrect":
gt_list.append(0)
pre_list.append(1)
elif row["raw_check_comment"] == "supplement":
gt_list.append(1)
pre_list.append(0)
else:
pass
# calculate metrics
accuracy = accuracy_score(gt_list, pre_list)
precision = precision_score(gt_list, pre_list)
recall = recall_score(gt_list, pre_list)
f1 = f1_score(gt_list, pre_list)
support = sum(gt_list)
metrics = {
"DataPoint": data_point,
"F1": f1,
"Precision": precision,
"Recall": recall,
"Accuracy": accuracy,
"Support": support
}
return metrics
if __name__ == "__main__":
file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
verify_file = "mapping_data_info_31_documents_by_text_first_round.xlsx"
verify_file_path = os.path.join(file_folder, verify_file)
document_list = [
"334584772",
"337293427",
"337937633",
"404712928",
"406913630",
"407275419",
"422686965",
"422760148",
"422760156",
"422761666",
"423364758",
"423365707",
"423395975",
"423418395",
"423418540",
"425595958",
"451063582",
"451878128",
"466580448",
"481482392",
"508704368",
"532998065",
"536344026",
"540307575"
]
calculate_complex_document_metrics(verify_file_path=verify_file_path,
document_list=document_list)