Fix issue for "The last fund name of previous PDF page" logic:
If current page fund name starts with "The last fund name of previous PDF page" and with more contents below, then remove "The last fund name of previous PDF page".
This commit is contained in:
parent
36fbaa946e
commit
70362b554f
|
|
@ -191,6 +191,7 @@ class DataExtraction:
|
||||||
page_datapoints,
|
page_datapoints,
|
||||||
need_exclude=False,
|
need_exclude=False,
|
||||||
exclude_data=None,
|
exclude_data=None,
|
||||||
|
previous_page_last_fund=previous_page_fund_name
|
||||||
)
|
)
|
||||||
data_list.append(extract_data)
|
data_list.append(extract_data)
|
||||||
|
|
||||||
|
|
@ -239,6 +240,7 @@ class DataExtraction:
|
||||||
next_datapoints,
|
next_datapoints,
|
||||||
need_exclude=True,
|
need_exclude=True,
|
||||||
exclude_data=page_data_list,
|
exclude_data=page_data_list,
|
||||||
|
previous_page_last_fund=previous_page_fund_name
|
||||||
)
|
)
|
||||||
next_page_data_list = next_page_extract_data.get(
|
next_page_data_list = next_page_extract_data.get(
|
||||||
"extract_data", {}
|
"extract_data", {}
|
||||||
|
|
@ -373,7 +375,8 @@ class DataExtraction:
|
||||||
page_text: str,
|
page_text: str,
|
||||||
page_datapoints: list,
|
page_datapoints: list,
|
||||||
need_exclude: bool = False,
|
need_exclude: bool = False,
|
||||||
exclude_data: list = None,) -> dict:
|
exclude_data: list = None,
|
||||||
|
previous_page_last_fund: str = None) -> dict:
|
||||||
# If can't find numberic value, e.g. 1.25 or 3,88
|
# If can't find numberic value, e.g. 1.25 or 3,88
|
||||||
# apply Vision ChatGPT to extract data
|
# apply Vision ChatGPT to extract data
|
||||||
numeric_regex = r"\d+(\.|\,)\d+"
|
numeric_regex = r"\d+(\.|\,)\d+"
|
||||||
|
|
@ -383,7 +386,12 @@ class DataExtraction:
|
||||||
page_num, page_datapoints, need_exclude, exclude_data)
|
page_num, page_datapoints, need_exclude, exclude_data)
|
||||||
else:
|
else:
|
||||||
return self.extract_data_by_page_text(
|
return self.extract_data_by_page_text(
|
||||||
page_num, page_text, page_datapoints, need_exclude, exclude_data
|
page_num=page_num,
|
||||||
|
page_text=page_text,
|
||||||
|
page_datapoints=page_datapoints,
|
||||||
|
need_exclude=need_exclude,
|
||||||
|
exclude_data=exclude_data,
|
||||||
|
previous_page_last_fund=previous_page_last_fund
|
||||||
)
|
)
|
||||||
|
|
||||||
def extract_data_by_page_text(
|
def extract_data_by_page_text(
|
||||||
|
|
@ -393,6 +401,7 @@ class DataExtraction:
|
||||||
page_datapoints: list,
|
page_datapoints: list,
|
||||||
need_exclude: bool = False,
|
need_exclude: bool = False,
|
||||||
exclude_data: list = None,
|
exclude_data: list = None,
|
||||||
|
previous_page_last_fund: str = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
keys are
|
keys are
|
||||||
|
|
@ -435,7 +444,7 @@ class DataExtraction:
|
||||||
data = json_repair.loads(response)
|
data = json_repair.loads(response)
|
||||||
except:
|
except:
|
||||||
data = {"data": []}
|
data = {"data": []}
|
||||||
data = self.validate_data(data)
|
data = self.validate_data(data, previous_page_last_fund)
|
||||||
|
|
||||||
data_dict = {"doc_id": self.doc_id}
|
data_dict = {"doc_id": self.doc_id}
|
||||||
data_dict["page_index"] = page_num
|
data_dict["page_index"] = page_num
|
||||||
|
|
@ -488,7 +497,7 @@ class DataExtraction:
|
||||||
data = json_repair.loads(response)
|
data = json_repair.loads(response)
|
||||||
except:
|
except:
|
||||||
data = {"data": []}
|
data = {"data": []}
|
||||||
data = self.validate_data(data)
|
data = self.validate_data(data, None)
|
||||||
|
|
||||||
data_dict = {"doc_id": self.doc_id}
|
data_dict = {"doc_id": self.doc_id}
|
||||||
data_dict["page_index"] = page_num
|
data_dict["page_index"] = page_num
|
||||||
|
|
@ -500,7 +509,7 @@ class DataExtraction:
|
||||||
data_dict["extract_way"] = "image"
|
data_dict["extract_way"] = "image"
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def validate_data(self, extract_data_info: dict) -> dict:
|
def validate_data(self, extract_data_info: dict, previous_page_last_fund: str=None) -> dict:
|
||||||
"""
|
"""
|
||||||
Validate data by the rules
|
Validate data by the rules
|
||||||
1. Each data should be with fund name
|
1. Each data should be with fund name
|
||||||
|
|
@ -511,10 +520,18 @@ class DataExtraction:
|
||||||
return extract_data_info
|
return extract_data_info
|
||||||
remove_list = []
|
remove_list = []
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
fund_name = data.get("fund name", "")
|
fund_name = data.get("fund name", "").strip()
|
||||||
if fund_name == "":
|
if fund_name == "":
|
||||||
remove_list.append(data)
|
remove_list.append(data)
|
||||||
|
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
|
||||||
|
previous_page_last_fund = previous_page_last_fund.strip()
|
||||||
|
if fund_name.startswith(previous_page_last_fund) and fund_name != previous_page_last_fund:
|
||||||
|
modified_fund_name = fund_name.replace(previous_page_last_fund, "").strip()
|
||||||
|
if len(modified_fund_name.split()) > 1:
|
||||||
|
fund_name = modified_fund_name
|
||||||
|
|
||||||
fund_name = self.get_fund_name(fund_name, "Fund")
|
fund_name = self.get_fund_name(fund_name, "Fund")
|
||||||
|
fund_name = self.get_fund_name(fund_name, "Bond")
|
||||||
data["fund name"] = fund_name
|
data["fund name"] = fund_name
|
||||||
keys = list(data.keys())
|
keys = list(data.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
|
@ -839,6 +856,24 @@ class DataExtraction:
|
||||||
instructions.append("\n\n")
|
instructions.append("\n\n")
|
||||||
instructions.append("\n")
|
instructions.append("\n")
|
||||||
|
|
||||||
|
# extreme_complex_config_list = special_cases.get("extreme_complex", [])
|
||||||
|
# if len(extreme_complex_config_list) > 0:
|
||||||
|
# for extreme_complex_config in extreme_complex_config_list:
|
||||||
|
# regex = extreme_complex_config.get("regex", "")
|
||||||
|
# if len(regex) == 0:
|
||||||
|
# continue
|
||||||
|
# search = re.search(regex, page_text)
|
||||||
|
# if search is not None:
|
||||||
|
# title = extreme_complex_config.get("title", "")
|
||||||
|
# title = f"{special_cases_number}. {title} "
|
||||||
|
# special_cases_number += 1
|
||||||
|
# instructions.append(title)
|
||||||
|
# instructions.append("\n")
|
||||||
|
# contents_list = extreme_complex_config.get("contents", [])
|
||||||
|
# contents = "\n".join(contents_list)
|
||||||
|
# instructions.append(contents)
|
||||||
|
# instructions.append("\n\n")
|
||||||
|
|
||||||
instructions.append("Output requirement:\n")
|
instructions.append("Output requirement:\n")
|
||||||
output_requirement = self.instructions_config.get("output_requirement", {})
|
output_requirement = self.instructions_config.get("output_requirement", {})
|
||||||
output_requirement_common_list = output_requirement.get("common", [])
|
output_requirement_common_list = output_requirement.get("common", [])
|
||||||
|
|
|
||||||
|
|
@ -35,16 +35,21 @@
|
||||||
"a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.",
|
"a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.",
|
||||||
"b. The sub-fund name may be as the first column or first row values in the table.",
|
"b. The sub-fund name may be as the first column or first row values in the table.",
|
||||||
"b.1 fund name example:",
|
"b.1 fund name example:",
|
||||||
"---- context:",
|
"---- Example Start ----",
|
||||||
"Summary information\nCapital International Fund Audited Annual Report 2023 | 15\nFootnotes are on page 17.\nCapital Group Multi-Sector \nIncome Fund (LUX) \n(CGMSILU)\nCapital Group US High Yield \nFund (LUX) (CGUSHYLU)\nCapital Group Emerging \nMarkets Debt Fund (LUX) \n(CGEMDLU)",
|
"Summary information\nCapital International Fund Audited Annual Report 2023 | 15\nFootnotes are on page 17.\nCapital Group Multi-Sector \nIncome Fund (LUX) \n(CGMSILU)\nCapital Group US High Yield \nFund (LUX) (CGUSHYLU)\nCapital Group Emerging \nMarkets Debt Fund (LUX) \n(CGEMDLU)",
|
||||||
"---- fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)",
|
"---- Example End ----",
|
||||||
|
"Fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)",
|
||||||
|
"\n",
|
||||||
"c. If with multiple fund names in context, please retrieve the fund name closest above the numerical value.",
|
"c. If with multiple fund names in context, please retrieve the fund name closest above the numerical value.",
|
||||||
"c.1 fund name example:",
|
"c.1 fund name example:",
|
||||||
"---- context:",
|
"---- Example Start ----",
|
||||||
"AXA World Funds ACT Emerging Markets Bonds\nAXA World Funds \n \nAdditional Unaudited Appendix \n\nƒ$GGLWLRQDO8QDXGLWHG$SSHQGL[$118$/5(3257$;$:RUOG)XQGV\nExpense Ratios (continued) \n \nCalculated TER (1) \nSwiss method \nApplied\nService Fee (2)\nOngoing \nCharges (3) \n \nwith performance \nfees \nwithout performance \nfees \n \nAXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon \nA Capitalisation CHF Hedged \n1.26% \n1.26% \n0.26% \n1.29%",
|
"AXA World Funds ACT Emerging Markets Bonds\nAXA World Funds \n \nAdditional Unaudited Appendix \n\nƒ$GGLWLRQDO8QDXGLWHG$SSHQGL[$118$/5(3257$;$:RUOG)XQGV\nExpense Ratios (continued) \n \nCalculated TER (1) \nSwiss method \nApplied\nService Fee (2)\nOngoing \nCharges (3) \n \nwith performance \nfees \nwithout performance \nfees \n \nAXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon \nA Capitalisation CHF Hedged \n1.26% \n1.26% \n0.26% \n1.29%",
|
||||||
"---- correct fund name: AXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon",
|
"---- Example End ----",
|
||||||
|
"Correct fund name: AXA World Funds - ACT Emerging Markets Short Duration Bonds Low Carbon",
|
||||||
|
"\n",
|
||||||
"- Only extract the latest data from context:",
|
"- Only extract the latest data from context:",
|
||||||
"If with multiple data values in same row, please extract the latest.",
|
"If with multiple data values in same row, please extract the latest.",
|
||||||
|
"\n",
|
||||||
"- Reported names:",
|
"- Reported names:",
|
||||||
"Only output the values which with significant reported names.",
|
"Only output the values which with significant reported names.",
|
||||||
"Please exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!",
|
"Please exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!",
|
||||||
|
|
@ -245,6 +250,41 @@
|
||||||
"The performance fees value is Ongoing Charges inkl. Performance-Fee in % **) - Ongoing Charges exkl. Performance-Fee in % **) = 1.20 - 1.15 = 0.05"
|
"The performance fees value is Ongoing Charges inkl. Performance-Fee in % **) - Ongoing Charges exkl. Performance-Fee in % **) = 1.20 - 1.15 = 0.05"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"extreme_complex": [
|
||||||
|
{
|
||||||
|
"title": "Complex Data Table Structure",
|
||||||
|
"regex": "([A-Z]{1,2}\\,\\s?){3,}",
|
||||||
|
"contents": [
|
||||||
|
"Complex Data Table Structure",
|
||||||
|
"Table structure: the first column is fund name, for each table title, there are a lot of share class names in it.",
|
||||||
|
"Please split these share class names and extract all of relevant data as fund name, share name, data point and value one by one from the table.",
|
||||||
|
"-----Example Start-----",
|
||||||
|
"Charges and expenses (continued) ",
|
||||||
|
"d) Operating, Administrative and Servicing Expenses / Operating Currency Hedged Share Class Fees (continued)",
|
||||||
|
"The following table shows the rates of Operating, Administrative and Servicing Expenses:",
|
||||||
|
"Class A, B, E, ",
|
||||||
|
"M,O ",
|
||||||
|
"EQUITY SUB-FUNDS ",
|
||||||
|
"a) Equity sub-funds ",
|
||||||
|
"Fund 1",
|
||||||
|
"0.35",
|
||||||
|
"Fund 2",
|
||||||
|
"0.26",
|
||||||
|
"-----Example End-----",
|
||||||
|
"The output should be:",
|
||||||
|
"{\"data\": [{\"fund name\": \"Fund 1\", \"share name\": \"A\", \"ogc\": 0.35},",
|
||||||
|
"{\"fund name\": \"Fund 1\", \"share name\": \"B\", \"ogc\": 0.35},",
|
||||||
|
"{\"fund name\": \"Fund 1\", \"share name\": \"E\", \"ogc\": 0.35},",
|
||||||
|
"{\"fund name\": \"Fund 1\", \"share name\": \"M\", \"ogc\": 0.35},",
|
||||||
|
"{\"fund name\": \"Fund 1\", \"share name\": \"O\", \"ogc\": 0.35}",
|
||||||
|
"{\"fund name\": \"Fund 2\", \"share name\": \"A\", \"ogc\": 0.26},",
|
||||||
|
"{\"fund name\": \"Fund 2\", \"share name\": \"B\", \"ogc\": 0.26},",
|
||||||
|
"{\"fund name\": \"Fund 2\", \"share name\": \"E\", \"ogc\": 0.26},",
|
||||||
|
"{\"fund name\": \"Fund 2\", \"share name\": \"M\", \"ogc\": 0.26},",
|
||||||
|
"{\"fund name\": \"Fund 2\", \"share name\": \"O\", \"ogc\": 0.26}]}"
|
||||||
|
]
|
||||||
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"output_requirement": {
|
"output_requirement": {
|
||||||
|
|
|
||||||
102
main.py
102
main.py
|
|
@ -857,53 +857,7 @@ def replace_rerun_data(new_data_file: str, original_data_file: str):
|
||||||
new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet)
|
new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def batch_run_documents():
|
||||||
# new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
|
|
||||||
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
|
||||||
# replace_rerun_data(new_data_file, original_data_file)
|
|
||||||
# test_calculate_metrics()
|
|
||||||
# test_replace_abbrevation()
|
|
||||||
# test_translate_pdf()
|
|
||||||
# test_mapping_raw_name()
|
|
||||||
pdf_folder = r"/data/emea_ar/pdf/"
|
|
||||||
page_filter_ground_truth_file = (
|
|
||||||
r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx"
|
|
||||||
)
|
|
||||||
prediction_output_folder = r"/data/emea_ar/output/filter_pages/"
|
|
||||||
metrics_output_folder = r"/data/emea_ar/output/metrics/"
|
|
||||||
special_doc_id_list = []
|
|
||||||
# batch_filter_pdf_files(
|
|
||||||
# pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list
|
|
||||||
# )
|
|
||||||
|
|
||||||
# data_type = "page_filter"
|
|
||||||
# prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx"
|
|
||||||
# missing_error_list, metrics_list, metrics_file = get_metrics(
|
|
||||||
# data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder
|
|
||||||
# )
|
|
||||||
|
|
||||||
# test_auto_generate_instructions()
|
|
||||||
|
|
||||||
output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/"
|
|
||||||
output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/"
|
|
||||||
|
|
||||||
# batch_extract_data(
|
|
||||||
# pdf_folder,
|
|
||||||
# page_filter_ground_truth_file,
|
|
||||||
# output_extract_data_child_folder,
|
|
||||||
# output_extract_data_total_folder,
|
|
||||||
# special_doc_id_list,
|
|
||||||
# re_run,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# doc_id = "476492237"
|
|
||||||
# extract_way = "image"
|
|
||||||
# extract_data(doc_id,
|
|
||||||
# pdf_folder,
|
|
||||||
# output_extract_data_child_folder,
|
|
||||||
# extract_way,
|
|
||||||
# re_run_extract_data)
|
|
||||||
|
|
||||||
# special_doc_id_list = ["505174428", "510326848", "349679479"]
|
# special_doc_id_list = ["505174428", "510326848", "349679479"]
|
||||||
# check_mapping_doc_id_list = [
|
# check_mapping_doc_id_list = [
|
||||||
# "327956364",
|
# "327956364",
|
||||||
|
|
@ -1197,7 +1151,13 @@ if __name__ == "__main__":
|
||||||
"534535767"
|
"534535767"
|
||||||
]
|
]
|
||||||
special_doc_id_list = check_db_mapping_doc_id_list
|
special_doc_id_list = check_db_mapping_doc_id_list
|
||||||
special_doc_id_list = ["532998065"]
|
special_doc_id_list = ["534535767"]
|
||||||
|
pdf_folder = r"/data/emea_ar/pdf/"
|
||||||
|
page_filter_ground_truth_file = (
|
||||||
|
r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx"
|
||||||
|
)
|
||||||
|
output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/"
|
||||||
|
output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/"
|
||||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||||
re_run_extract_data = True
|
re_run_extract_data = True
|
||||||
|
|
@ -1206,8 +1166,6 @@ if __name__ == "__main__":
|
||||||
calculate_metrics = False
|
calculate_metrics = False
|
||||||
|
|
||||||
extract_ways = ["text"]
|
extract_ways = ["text"]
|
||||||
# pdf_folder = r"/data/emea_ar/small_pdf/"
|
|
||||||
pdf_folder = r"/data/emea_ar/pdf/"
|
|
||||||
for extract_way in extract_ways:
|
for extract_way in extract_ways:
|
||||||
batch_start_job(
|
batch_start_job(
|
||||||
pdf_folder,
|
pdf_folder,
|
||||||
|
|
@ -1224,5 +1182,49 @@ if __name__ == "__main__":
|
||||||
calculate_metrics=calculate_metrics,
|
calculate_metrics=calculate_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_run_documents()
|
||||||
|
|
||||||
|
# new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
|
||||||
|
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
||||||
|
# replace_rerun_data(new_data_file, original_data_file)
|
||||||
|
# test_calculate_metrics()
|
||||||
|
# test_replace_abbrevation()
|
||||||
|
# test_translate_pdf()
|
||||||
|
# test_mapping_raw_name()
|
||||||
# test_data_extraction_metrics()
|
# test_data_extraction_metrics()
|
||||||
|
|
||||||
|
# batch_filter_pdf_files(
|
||||||
|
# pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list
|
||||||
|
# )
|
||||||
|
|
||||||
|
# data_type = "page_filter"
|
||||||
|
# prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx"
|
||||||
|
# missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||||
|
# data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder
|
||||||
|
# )
|
||||||
|
|
||||||
|
# test_auto_generate_instructions()
|
||||||
|
|
||||||
|
# batch_extract_data(
|
||||||
|
# pdf_folder,
|
||||||
|
# page_filter_ground_truth_file,
|
||||||
|
# output_extract_data_child_folder,
|
||||||
|
# output_extract_data_total_folder,
|
||||||
|
# special_doc_id_list,
|
||||||
|
# re_run,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# doc_id = "476492237"
|
||||||
|
# extract_way = "image"
|
||||||
|
# extract_data(doc_id,
|
||||||
|
# pdf_folder,
|
||||||
|
# output_extract_data_child_folder,
|
||||||
|
# extract_way,
|
||||||
|
# re_run_extract_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
from tqdm import tqdm
|
||||||
|
from glob import glob
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
from traceback import print_exc
|
||||||
|
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
||||||
|
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []):
|
||||||
|
data_df = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping")
|
||||||
|
# convert doc_id column to string
|
||||||
|
data_df["doc_id"] = data_df["doc_id"].astype(str)
|
||||||
|
data_df = data_df[data_df["raw_check"].isin([0, 1])]
|
||||||
|
|
||||||
|
if document_list is not None and len(document_list) > 0:
|
||||||
|
data_df = data_df[data_df["doc_id"].isin(document_list)]
|
||||||
|
|
||||||
|
data_df.fillna("", inplace=True)
|
||||||
|
data_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# tor data
|
||||||
|
tor_data_df = data_df[data_df["datapoint"] == "tor"]
|
||||||
|
tor_metrics = get_sub_metrics(tor_data_df, "tor")
|
||||||
|
logger.info(f"TOR metrics: {tor_metrics}")
|
||||||
|
|
||||||
|
# ter data
|
||||||
|
ter_data_df = data_df[data_df["datapoint"] == "ter"]
|
||||||
|
ter_metrics = get_sub_metrics(ter_data_df, "ter")
|
||||||
|
logger.info(f"TER metrics: {ter_metrics}")
|
||||||
|
|
||||||
|
# ogc data
|
||||||
|
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
|
||||||
|
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
|
||||||
|
logger.info(f"OGC metrics: {ogc_metrics}")
|
||||||
|
|
||||||
|
# performance_fee data
|
||||||
|
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
|
||||||
|
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
|
||||||
|
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
|
||||||
|
|
||||||
|
metrics_df = pd.DataFrame([tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics])
|
||||||
|
# add average metrics
|
||||||
|
avg_metrics = {
|
||||||
|
"DataPoint": "average",
|
||||||
|
"F1": metrics_df["F1"].mean(),
|
||||||
|
"Precision": metrics_df["Precision"].mean(),
|
||||||
|
"Recall": metrics_df["Recall"].mean(),
|
||||||
|
"Accuracy": metrics_df["Accuracy"].mean(),
|
||||||
|
"Support": metrics_df["Support"].sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics_df = pd.DataFrame([tor_metrics, ter_metrics,
|
||||||
|
ogc_metrics, performance_fee_metrics,
|
||||||
|
avg_metrics])
|
||||||
|
metrics_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
||||||
|
|
||||||
|
document_count = len(document_list) \
|
||||||
|
if document_list is not None and len(document_list) > 0 \
|
||||||
|
else len(data_df["doc_id"].unique())
|
||||||
|
|
||||||
|
output_metrics_file = os.path.join(output_folder,
|
||||||
|
f"complex_document_{document_count}_metrics.xlsx")
|
||||||
|
with pd.ExcelWriter(output_metrics_file) as writer:
|
||||||
|
metrics_df.to_excel(writer, index=False, sheet_name="metrics")
|
||||||
|
|
||||||
|
|
||||||
|
def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
|
||||||
|
data_df_raw_check_1 = data_df[data_df["raw_check"] == 1]
|
||||||
|
gt_list = [1] * len(data_df_raw_check_1)
|
||||||
|
pre_list = [1] * len(data_df_raw_check_1)
|
||||||
|
|
||||||
|
data_df_raw_check_0 = data_df[data_df["raw_check"] == 0]
|
||||||
|
for index, row in data_df_raw_check_0.iterrows():
|
||||||
|
if row["raw_check_comment"] == "modify":
|
||||||
|
gt_list.append(0)
|
||||||
|
pre_list.append(1)
|
||||||
|
|
||||||
|
gt_list.append(1)
|
||||||
|
pre_list.append(0)
|
||||||
|
elif row["raw_check_comment"] == "incorrect":
|
||||||
|
gt_list.append(0)
|
||||||
|
pre_list.append(1)
|
||||||
|
elif row["raw_check_comment"] == "supplement":
|
||||||
|
gt_list.append(1)
|
||||||
|
pre_list.append(0)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# calculate metrics
|
||||||
|
accuracy = accuracy_score(gt_list, pre_list)
|
||||||
|
precision = precision_score(gt_list, pre_list)
|
||||||
|
recall = recall_score(gt_list, pre_list)
|
||||||
|
f1 = f1_score(gt_list, pre_list)
|
||||||
|
support = sum(gt_list)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"DataPoint": data_point,
|
||||||
|
"F1": f1,
|
||||||
|
"Precision": precision,
|
||||||
|
"Recall": recall,
|
||||||
|
"Accuracy": accuracy,
|
||||||
|
"Support": support
|
||||||
|
}
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
||||||
|
verify_file = "mapping_data_info_31_documents_by_text_first_round.xlsx"
|
||||||
|
verify_file_path = os.path.join(file_folder, verify_file)
|
||||||
|
document_list = [
|
||||||
|
"334584772",
|
||||||
|
"337293427",
|
||||||
|
"337937633",
|
||||||
|
"404712928",
|
||||||
|
"406913630",
|
||||||
|
"407275419",
|
||||||
|
"422686965",
|
||||||
|
"422760148",
|
||||||
|
"422760156",
|
||||||
|
"422761666",
|
||||||
|
"423364758",
|
||||||
|
"423365707",
|
||||||
|
"423395975",
|
||||||
|
"423418395",
|
||||||
|
"423418540",
|
||||||
|
"425595958",
|
||||||
|
"451063582",
|
||||||
|
"451878128",
|
||||||
|
"466580448",
|
||||||
|
"481482392",
|
||||||
|
"508704368",
|
||||||
|
"532998065",
|
||||||
|
"536344026",
|
||||||
|
"540307575"
|
||||||
|
]
|
||||||
|
calculate_complex_document_metrics(verify_file_path=verify_file_path,
|
||||||
|
document_list=document_list)
|
||||||
Loading…
Reference in New Issue