From d96f77fe0084ff201cdf1636fec2f1a01daddd06 Mon Sep 17 00:00:00 2001 From: Blade He Date: Fri, 6 Dec 2024 16:31:42 -0600 Subject: [PATCH] Split share class names which with multiple share classes in same line --- core/data_extraction.py | 81 +++++++++++++++++++++++++++++--------- main.py | 8 ++-- test_specific_biz_logic.py | 69 ++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 22 deletions(-) create mode 100644 test_specific_biz_logic.py diff --git a/core/data_extraction.py b/core/data_extraction.py index e711f02..b8c9dba 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -582,28 +582,73 @@ class DataExtraction: # update "fund name" to be "fund_name" # update "share name" to be "share_name" new_data_list = [] + multi_over_3_share_regex = r"([A-Z]{1,}\,\s){3,}" + exist_multi_over_3_share = False for data in data_list: - new_data = {} - fund_name = data.get("fund name", "") - if fund_name != "": - new_data["fund_name"] = fund_name - share_name = data.get("share name", "") - if share_name != "": - new_data["share_name"] = share_name - ter = data.get("ter", None) - if ter is not None: - new_data["ter"] = ter - performance_fee = data.get("performance fees", None) - if performance_fee is not None: - new_data["performance_fee"] = performance_fee - - for key, value in data.items(): - if key not in ["fund name", "share name", "ter", "performance fees"]: - new_data[key] = value - new_data_list.append(new_data) + fund_name = data.get("fund name", "").strip() + if len(fund_name) == 0: + continue + raw_share_name = data.get("share name", "") + if not exist_multi_over_3_share: + multi_over_3_share_search = re.search(multi_over_3_share_regex, raw_share_name) + if multi_over_3_share_search is not None: + exist_multi_over_3_share = True + if exist_multi_over_3_share: + share_name_list = self.split_multi_share_name(raw_share_name) + else: + share_name_list = [raw_share_name] + if len(share_name_list) > 0: + for share_name in share_name_list: + new_data = {} + new_data["fund_name"] = fund_name + if share_name != "": + new_data["share_name"] = share_name + ter = data.get("ter", None) + if ter is not None: + new_data["ter"] = ter + performance_fee = data.get("performance fees", None) + if performance_fee is not None: + new_data["performance_fee"] = performance_fee + + for key, value in data.items(): + if key not in ["fund name", "share name", "ter", "performance fees"]: + new_data[key] = value + new_data_list.append(new_data) extract_data_info["data"] = new_data_list return extract_data_info + + def split_multi_share_name(self, raw_share_name: str) -> list: + """ + Some document, e.g. 481482392 + Exist multi share name as table header, e.g. "Class A, B, E, M, N, P, R, U" + For this case, need split the share name to be ["Class A", "Class B", "Class E", + "Class M", "Class N", "Class P", "Class R", "Class U"] + """ + multi_over_2_share_regex = r"([A-Z]{1,}\,\s){2,}" + multi_over_2_share_search = re.search(multi_over_2_share_regex, raw_share_name) + share_name_list = [raw_share_name] + if multi_over_2_share_search is not None: + multi_share_splits = [share_name.strip() for share_name in raw_share_name.split(",") + if len(share_name.strip()) > 0] + first_share_name = multi_share_splits[0] + first_share_name_split = first_share_name.split() + share_name_prefix = None + if len(first_share_name_split) == 2: + share_name_prefix = first_share_name_split[0] + if share_name_prefix is not None and len(share_name_prefix) > 0: + new_share_name_list = [] + for split in multi_share_splits: + if split == first_share_name: + new_share_name_list.append(split) + else: + new_share_name_list.append(f"{share_name_prefix} {split}") + share_name_list = new_share_name_list + else: + share_name_list = multi_share_splits + else: + share_name_list = multi_share_splits + return share_name_list def get_fund_name(self, fund_name: str, fund_feature: str): if not fund_name.endswith(fund_feature): diff --git a/main.py b/main.py index d965cfb..8b152bc 100644 --- a/main.py +++ b/main.py @@ -1150,7 +1150,7 @@ def batch_run_documents(): "534535767" ] special_doc_id_list = check_db_mapping_doc_id_list - # special_doc_id_list = ["407275419", "425595958", "451063582", "451878128"] + special_doc_id_list = ["481482392"] pdf_folder = r"/data/emea_ar/pdf/" page_filter_ground_truth_file = ( r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx" @@ -1159,9 +1159,9 @@ def batch_run_documents(): output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" - re_run_extract_data = False - re_run_mapping_data = False - force_save_total_data = True + re_run_extract_data = True + re_run_mapping_data = True + force_save_total_data = False calculate_metrics = False extract_ways = ["text"] diff --git a/test_specific_biz_logic.py b/test_specific_biz_logic.py new file mode 100644 index 0000000..32f5d2d --- /dev/null +++ b/test_specific_biz_logic.py @@ -0,0 +1,69 @@ +import os +import json +import pandas as pd +from glob import glob +from tqdm import tqdm +from utils.logger import logger +from utils.sql_query_util import query_document_fund_mapping +from core.page_filter import FilterPages +from core.data_extraction import DataExtraction + + +def test_validate_extraction_data(): + document_id = "481482392" + pdf_file = f"/data/emea_ar/pdf/481482392.pdf" + output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/" + output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/" + document_mapping_info_df = query_document_fund_mapping(document_id, rerun=False) + filter_pages = FilterPages( + document_id, pdf_file, document_mapping_info_df + ) + page_text_dict = filter_pages.page_text_dict + datapoint_page_info, result_details = get_datapoint_page_info(filter_pages) + datapoints = get_datapoints_from_datapoint_page_info(datapoint_page_info) + data_extraction = DataExtraction( + doc_id=document_id, + pdf_file=pdf_file, + output_data_folder=output_extract_data_child_folder, + page_text_dict=page_text_dict, + datapoint_page_info=datapoint_page_info, + datapoints=datapoints, + document_mapping_info_df=document_mapping_info_df, + extract_way="text", + output_image_folder=None + ) + output_data_json_folder = os.path.join( + r"/data/emea_ar/output/extract_data/docs/by_text/", "json/" + ) + os.makedirs(output_data_json_folder, exist_ok=True) + json_file = os.path.join(output_data_json_folder, f"{document_id}.json") + data_from_gpt = None + if os.path.exists(json_file): + logger.info( + f"The document: {document_id} has been parsed, loading data from {json_file}" + ) + with open(json_file, "r", encoding="utf-8") as f: + data_from_gpt = json.load(f) + for extract_data in data_from_gpt: + page_index = extract_data["page_index"] + if page_index == 451: + logger.info(f"Page index: {page_index}") + raw_answer = extract_data["raw_answer"] + raw_answer_json = json.loads(raw_answer) + extract_data_info = data_extraction.validate_data(raw_answer_json) + print(extract_data_info) + +def get_datapoint_page_info(filter_pages) -> tuple: + datapoint_page_info, result_details = filter_pages.start_job() + return datapoint_page_info, result_details + + +def get_datapoints_from_datapoint_page_info(datapoint_page_info) -> list: + datapoints = list(datapoint_page_info.keys()) + if "doc_id" in datapoints: + datapoints.remove("doc_id") + return datapoints + + +if __name__ == "__main__": + test_validate_extraction_data() \ No newline at end of file