import os import json import json_repair import re import fitz import pandas as pd from utils.gpt_utils import chat from utils.pdf_util import PDFUtil from utils.sql_query_util import query_document_fund_mapping from utils.logger import logger from utils.biz_utils import add_slash_to_text_as_regex, clean_text class DataExtraction: def __init__( self, doc_id: str, pdf_file: str, output_data_folder: str, page_text_dict: dict, datapoint_page_info: dict, document_mapping_info_df: pd.DataFrame, ) -> None: self.doc_id = doc_id self.pdf_file = pdf_file if output_data_folder is None or len(output_data_folder) == 0: output_data_folder = r"/data/emea_ar/output/extract_data/docs/" os.makedirs(output_data_folder, exist_ok=True) self.output_data_json_folder = os.path.join(output_data_folder, "json/") os.makedirs(self.output_data_json_folder, exist_ok=True) self.output_data_excel_folder = os.path.join(output_data_folder, "excel/") os.makedirs(self.output_data_excel_folder, exist_ok=True) if page_text_dict is None or len(page_text_dict.keys()) == 0: self.page_text_dict = self.get_pdf_page_text_dict() else: self.page_text_dict = page_text_dict if document_mapping_info_df is None or len(document_mapping_info_df) == 0: self.document_mapping_info_df = query_document_fund_mapping(doc_id) else: self.document_mapping_info_df = document_mapping_info_df self.datapoint_page_info = datapoint_page_info self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info() self.datapoints = self.get_datapoints_from_datapoint_page_info() self.instructions_config = self.get_instructions_config() self.datapoint_level_config = self.get_datapoint_level() self.datapoint_name_config = self.get_datapoint_name() def get_instructions_config(self) -> dict: instructions_config_file = r"./instructions/data_extraction_prompts_config.json" with open(instructions_config_file, "r", encoding="utf-8") as f: instructions_config = json.load(f) return instructions_config def get_datapoint_level(self) -> dict: datapoint_level_file = r"./configuration/datapoint_level.json" with open(datapoint_level_file, "r", encoding="utf-8") as f: datapoint_level = json.load(f) return datapoint_level def get_datapoint_name(self) -> dict: datapoint_name_file = r"./configuration/datapoint_name.json" with open(datapoint_name_file, "r", encoding="utf-8") as f: datapoint_name = json.load(f) return datapoint_name def get_pdf_page_text_dict(self) -> dict: pdf_util = PDFUtil(self.pdf_file) success, text, page_text_dict = pdf_util.extract_text() return page_text_dict def get_datapoints_from_datapoint_page_info(self) -> list: datapoints = list(self.datapoint_page_info.keys()) if "doc_id" in datapoints: datapoints.remove("doc_id") return datapoints def get_page_nums_from_datapoint_page_info(self) -> list: page_nums_with_datapoints = [] for datapoint, page_nums in self.datapoint_page_info.items(): if datapoint == "doc_id": continue page_nums_with_datapoints.extend(page_nums) page_nums_with_datapoints = list(set(page_nums_with_datapoints)) # sort the page numbers page_nums_with_datapoints.sort() return page_nums_with_datapoints def extract_data(self) -> dict: """ keys are doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name """ data_list = [] pdf_page_count = len(self.page_text_dict.keys()) handled_page_num_list = [] for page_num, page_text in self.page_text_dict.items(): if page_num in handled_page_num_list: continue page_datapoints = self.get_datapoints_by_page_num(page_num) if len(page_datapoints) == 0: continue extract_data = self.extract_data_by_page( page_num, page_text, page_datapoints, need_exclude=False, exclude_data=None, ) data_list.append(extract_data) page_data_list = extract_data.get("extract_data", {}).get("data", []) current_page_data_count = len(page_data_list) if current_page_data_count > 0: count = 1 # some pdf documents have multiple pages for the same data # and the next page may without table header with data point keywords. # the purpose is try to get data from the next page current_text = page_text while count < 3: try: next_page_num = page_num + count if next_page_num >= pdf_page_count: break next_datapoints = page_datapoints if next_page_num in self.page_nums_with_datapoints: should_continue = False next_datapoints = self.get_datapoints_by_page_num(next_page_num) if len(next_datapoints) == 0: should_continue = True else: for next_datapoint in next_datapoints: if next_datapoint not in page_datapoints: should_continue = True break next_datapoints.extend(page_datapoints) # remove duplicate datapoints next_datapoints = list(set(next_datapoints)) if not should_continue: break next_page_text = self.page_text_dict.get(next_page_num, "") target_text = current_text + next_page_text # try to get data by current page_datapoints next_page_extract_data = self.extract_data_by_page( next_page_num, target_text, next_datapoints, need_exclude=True, exclude_data=page_data_list, ) next_page_data_list = next_page_extract_data.get( "extract_data", {} ).get("data", []) if next_page_data_list is not None and len(next_page_data_list) > 0: for current_page_data in page_data_list: if current_page_data in next_page_data_list: next_page_data_list.remove(current_page_data) next_page_extract_data["extract_data"][ "data" ] = next_page_data_list data_list.append(next_page_extract_data) handled_page_num_list.append(next_page_num) else: break count += 1 except Exception as e: logger.error(f"Error in extracting data from next page: {e}") break json_data_file = os.path.join( self.output_data_json_folder, f"{self.doc_id}.json" ) with open(json_data_file, "w", encoding="utf-8") as f: json.dump(data_list, f, ensure_ascii=False, indent=4) data_df = pd.DataFrame(data_list) data_df.reset_index(drop=True, inplace=True) excel_data_file = os.path.join( self.output_data_excel_folder, f"{self.doc_id}.xlsx" ) with pd.ExcelWriter(excel_data_file) as writer: data_df.to_excel(writer, sheet_name="extract_data", index=False) return data_list def extract_data_by_page( self, page_num: int, page_text: str, page_datapoints: list, need_exclude: bool = False, exclude_data: list = None, ) -> dict: """ keys are doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name """ logger.info(f"Extracting data from page {page_num}") instructions = self.get_instructions_by_datapoints( page_text, page_datapoints, need_exclude, exclude_data ) response, with_error = chat( instructions, response_format={"type": "json_object"} ) if with_error: logger.error(f"Error in extracting tables from page") return "" try: data = json.loads(response) except: try: data = json_repair.loads(response) except: data = {"data": []} data_dict = {"doc_id": self.doc_id} data_dict["page_index"] = page_num data_dict["datapoints"] = ", ".join(page_datapoints) data_dict["page_text"] = page_text data_dict["instructions"] = instructions data_dict["raw_answer"] = response data_dict["extract_data"] = data return data_dict def get_datapoints_by_page_num(self, page_num: int) -> list: datapoints = [] for datapoint in self.datapoints: if page_num in self.datapoint_page_info[datapoint]: datapoints.append(datapoint) return datapoints def get_instructions_by_datapoints( self, page_text: str, datapoints: list, need_exclude: bool = False, exclude_data: list = None, ) -> str: """ Get instructions to extract data from the page by the datapoints Below is the instructions sections: summary: string reported_name by datapoints: dict data_business_features: dict common: list investment_level by datapoints: dict data_value_range by datapoints: dict special_rule by datapoints: dict special_cases: dict common: list title contents special_case by datapoints: list title contents output_requirement common: list fund_level: list share_level: dict fund_name: list share_name: list ogc_value: list ter_value: list performance_fee_value: list end """ instructions = [f"Context:\n{page_text}\n\nInstructions:\n"] datapoint_name_list = [] for datapoint in datapoints: datapoint_name = self.datapoint_name_config.get(datapoint, "") datapoint_name_list.append(datapoint_name) summary = self.instructions_config.get("summary", "\n") instructions.append(summary.format(", ".join(datapoint_name_list))) instructions.append("\n") instructions.append("Datapoints Reported name:\n") reported_name_info = self.instructions_config.get("reported_name", {}) for datapoint in datapoints: reported_name = reported_name_info.get(datapoint, "") instructions.append(reported_name) instructions.append("\n") instructions.append("\n") instructions.append("Data business features:\n") data_business_features = self.instructions_config.get( "data_business_features", {} ) common = "\n".join(data_business_features.get("common", [])) instructions.append(common) instructions.append("\n") instructions.append("Datapoints investment level:\n") investment_level_info = data_business_features.get("investment_level", {}) for datapoint in datapoints: investment_level = investment_level_info.get(datapoint, "") instructions.append(investment_level) instructions.append("\n") instructions.append("\n") instructions.append("Datapoints value range:\n") data_value_range_info = data_business_features.get("data_value_range", {}) for datapoint in datapoints: data_value_range = data_value_range_info.get(datapoint, "") instructions.append(data_value_range) instructions.append("\n") instructions.append("\n") special_rule_info = data_business_features.get("special_rule", {}) with_special_rule_title = False for datapoint in datapoints: special_rule_list = special_rule_info.get(datapoint, []) if len(special_rule_list) > 0: if not with_special_rule_title: instructions.append("Special rule:\n") with_special_rule_title = True special_rule = "\n".join(special_rule_list) instructions.append(special_rule) instructions.append("\n\n") instructions.append("\n") instructions.append("Special cases:\n") special_cases = self.instructions_config.get("special_cases", {}) special_cases_common_list = special_cases.get("common", []) for special_cases_common in special_cases_common_list: title = special_cases_common.get("title", "") instructions.append(title) instructions.append("\n") contents_list = special_cases_common.get("contents", []) contents = "\n".join(contents_list) instructions.append(contents) instructions.append("\n\n") for datapoint in datapoints: special_case_list = special_cases.get(datapoint, []) for special_case in special_case_list: title = special_case.get("title", "") instructions.append(title) instructions.append("\n") contents_list = special_case.get("contents", []) contents = "\n".join(contents_list) instructions.append(contents) instructions.append("\n\n") instructions.append("\n") instructions.append("Output requirement:\n") output_requirement = self.instructions_config.get("output_requirement", {}) output_requirement_common_list = output_requirement.get("common", []) instructions.append("\n".join(output_requirement_common_list)) instructions.append("\n") share_datapoint_value_example = {} share_level_config = output_requirement.get("share_level", {}) example_list = [] for datapoint in datapoints: investment_level = self.datapoint_level_config.get(datapoint, "") if investment_level == "fund_level": fund_level_example_list = output_requirement.get("fund_level", []) for example in fund_level_example_list: try: sub_example_list = json.loads(example) except: sub_example_list = json_repair.loads(example) example_list.extend(sub_example_list) elif investment_level == "share_level": share_datapoint_value_example[datapoint] = share_level_config.get( f"{datapoint}_value", [] ) share_datapoint_list = list(share_datapoint_value_example.keys()) instructions.append(f"Example:\n") if len(share_datapoint_list) > 0: fund_name_example_list = share_level_config.get("fund_name", []) share_name_example_list = share_level_config.get("share_name", []) for index in range(len(fund_name_example_list)): example_dict = { "fund name": fund_name_example_list[index], "share name": share_name_example_list[index], } for share_datapoint in share_datapoint_list: share_datapoint_values = share_datapoint_value_example[ share_datapoint ] if index < len(share_datapoint_values): example_dict[share_datapoint] = share_datapoint_values[index] example_list.append(example_dict) example_data = {"data": example_list} instructions.append(json.dumps(example_data, ensure_ascii=False, indent=4)) instructions.append("\n") instructions.append("\n") end_list = self.instructions_config.get("end", []) instructions.append("\n".join(end_list)) instructions.append("\n") if need_exclude and exclude_data is not None and isinstance(exclude_data, list): instructions.append("Please exclude below data from output:\n") instructions.append(json.dumps(exclude_data, ensure_ascii=False, indent=4)) instructions.append("\n") instructions.append("\n") instructions.append("Answer:\n") instructions_text = "".join(instructions) return instructions_text