dc-ml-emea-ar/core/data_extraction.py

276 lines
12 KiB
Python
Raw Normal View History

import os
import json
import json_repair
import re
import fitz
import pandas as pd
from utils.gpt_utils import chat
from utils.pdf_util import PDFUtil
from utils.sql_query_util import query_document_fund_mapping
from utils.logger import logger
from utils.biz_utils import add_slash_to_text_as_regex, clean_text
class DataExtraction:
def __init__(
self,
doc_id: str,
pdf_file: str,
output_data_folder: str,
page_text_dict: dict,
datapoint_page_info: dict,
document_mapping_info_df: pd.DataFrame
) -> None:
self.doc_id = doc_id
self.pdf_file = pdf_file
if output_data_folder is None or len(output_data_folder) == 0:
output_data_folder = r"/data/emea_ar/output/extract_data/docs/"
os.makedirs(output_data_folder, exist_ok=True)
self.output_data_json_folder = os.path.join(output_data_folder, "json/")
os.makedirs(self.output_data_json_folder, exist_ok=True)
self.output_data_excel_folder = os.path.join(output_data_folder, "excel/")
os.makedirs(self.output_data_excel_folder, exist_ok=True)
if page_text_dict is None or len(page_text_dict.keys()) == 0:
self.page_text_dict = self.get_pdf_page_text_dict()
else:
self.page_text_dict = page_text_dict
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(doc_id)
else:
self.document_mapping_info_df = document_mapping_info_df
self.datapoint_page_info = datapoint_page_info
self.datapoints = self.get_datapoints_from_datapoint_page_info()
self.instructions_config = self.get_instructions_config()
self.datapoint_level_config = self.get_datapoint_level()
self.datapoint_name_config = self.get_datapoint_name()
def get_instructions_config(self) -> dict:
instructions_config_file = r"./instructions/data_extraction_prompts_config.json"
with open(instructions_config_file, "r", encoding="utf-8") as f:
instructions_config = json.load(f)
return instructions_config
def get_datapoint_level(self) -> dict:
datapoint_level_file = r"./configuration/datapoint_level.json"
with open(datapoint_level_file, "r", encoding="utf-8") as f:
datapoint_level = json.load(f)
return datapoint_level
def get_datapoint_name(self) -> dict:
datapoint_name_file = r"./configuration/datapoint_name.json"
with open(datapoint_name_file, "r", encoding="utf-8") as f:
datapoint_name = json.load(f)
return datapoint_name
def get_pdf_page_text_dict(self) -> dict:
pdf_util = PDFUtil(self.pdf_file)
success, text, page_text_dict = pdf_util.extract_text()
return page_text_dict
def get_datapoints_from_datapoint_page_info(self) -> list:
datapoints = list(self.datapoint_page_info.keys())
if "doc_id" in datapoints:
datapoints.remove("doc_id")
return datapoints
def extract_data(self) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
data_list = []
for page_num, page_text in self.page_text_dict.items():
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
instructions = self.get_instructions_by_datapoints(page_text, page_datapoints)
response, with_error = chat(instructions)
if with_error:
logger.error(f"Error in extracting tables from page")
return ""
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
except:
data = {}
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["data"] = data
data_list.append(data_dict)
json_data_file = os.path.join(self.output_data_json_folder, f"{self.doc_id}.json")
with open(json_data_file, "w", encoding="utf-8") as f:
json.dump(data_list, f, ensure_ascii=False, indent=4)
data_df = pd.DataFrame(data_list)
data_df.reset_index(drop=True, inplace=True)
excel_data_file = os.path.join(self.output_data_excel_folder, f"{self.doc_id}.xlsx")
with pd.ExcelWriter(excel_data_file) as writer:
data_df.to_excel(writer, sheet_name="extract_data", index=False)
return data_list
def get_datapoints_by_page_num(self, page_num: int) -> list:
datapoints = []
for datapoint in self.datapoints:
if page_num in self.datapoint_page_info[datapoint]:
datapoints.append(datapoint)
return datapoints
def get_instructions_by_datapoints(self, page_text: str, datapoints: list) -> str:
"""
Get instructions to extract data from the page by the datapoints
Below is the instructions sections:
summary: string
reported_name by datapoints: dict
data_business_features: dict
common: list
investment_level by datapoints: dict
data_value_range by datapoints: dict
special_rule by datapoints: dict
special_cases: dict
common: list
title
contents
special_case by datapoints: list
title
contents
output_requirement
common: list
fund_level: list
share_level: dict
fund_name: list
share_name: list
ogc_value: list
ter_value: list
performance_fee_value: list
end
"""
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
datapoint_name_list = []
for datapoint in datapoints:
datapoint_name = self.datapoint_name_config.get(datapoint, "")
datapoint_name_list.append(datapoint_name)
summary = self.instructions_config.get("summary", "\n")
instructions.append(summary.format(', '.join(datapoint_name_list)))
instructions.append("\n")
instructions.append("Datapoints Reported name:\n")
reported_name_info = self.instructions_config.get("reported_name", {})
for datapoint in datapoints:
reported_name = reported_name_info.get(datapoint, "")
instructions.append(reported_name)
instructions.append("\n")
instructions.append("\n")
instructions.append("Data business features:\n")
data_business_features = self.instructions_config.get("data_business_features", {})
common = '\n'.join(data_business_features.get("common", []))
instructions.append(common)
instructions.append("\n")
instructions.append("Datapoints investment level:\n")
investment_level_info = data_business_features.get("investment_level", {})
for datapoint in datapoints:
investment_level = investment_level_info.get(datapoint, "")
instructions.append(investment_level)
instructions.append("\n")
instructions.append("\n")
instructions.append("Datapoints value range:\n")
data_value_range_info = data_business_features.get("data_value_range", {})
for datapoint in datapoints:
data_value_range = data_value_range_info.get(datapoint, "")
instructions.append(data_value_range)
instructions.append("\n")
instructions.append("\n")
special_rule_info = data_business_features.get("special_rule", {})
with_special_rule_title = False
for datapoint in datapoints:
special_rule_list = special_rule_info.get(datapoint, [])
if len(special_rule_list) > 0:
if not with_special_rule_title:
instructions.append("Special rule:\n")
with_special_rule_title = True
special_rule = '\n'.join(special_rule_list)
instructions.append(special_rule)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Special cases:\n")
special_cases = self.instructions_config.get("special_cases", {})
special_cases_common_list = special_cases.get("common", [])
for special_cases_common in special_cases_common_list:
title = special_cases_common.get("title", "")
instructions.append(title)
instructions.append("\n")
contents_list = special_cases_common.get("contents", [])
contents = '\n'.join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
for datapoint in datapoints:
special_case_list = special_cases.get(datapoint, [])
for special_case in special_case_list:
title = special_case.get("title", "")
instructions.append(title)
instructions.append("\n")
contents_list = special_case.get("contents", [])
contents = '\n'.join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Output requirement:\n")
output_requirement = self.instructions_config.get("output_requirement", {})
output_requirement_common_list = output_requirement.get("common", [])
instructions.append("\n".join(output_requirement_common_list))
instructions.append("\n")
share_datapoint_value_example = {}
share_level_config = output_requirement.get("share_level", {})
for datapoint in datapoints:
investment_level = self.datapoint_level_config.get(datapoint, "")
if investment_level == "fund_level":
fund_level_example_list = output_requirement.get("fund_level", [])
for example in fund_level_example_list:
instructions.append(example)
instructions.append("\n")
instructions.append("\n")
elif investment_level == "share_level":
share_datapoint_value_example[datapoint] = share_level_config.get(f"{datapoint}_value", [])
share_datapoint_list = list(share_datapoint_value_example.keys())
if len(share_datapoint_list) > 0:
fund_name_example_list = share_level_config.get("fund_name", [])
share_name_example_list = share_level_config.get("share_name", [])
for index in range(len(fund_name_example_list)):
example_dict = {"fund name": fund_name_example_list[index],
"share name": share_name_example_list[index]}
for share_datapoint in share_datapoint_list:
share_datapoint_values = share_datapoint_value_example[share_datapoint]
if index < len(share_datapoint_values):
example_dict[share_datapoint] = share_datapoint_values[index]
instructions.append(f"Example {index + 1}:\n")
instructions.append(json.dumps(example_dict, ensure_ascii=False))
instructions.append("\n")
instructions.append("\n")
end_list = self.instructions_config.get("end", [])
instructions.append('\n'.join(end_list))
instructions.append("\n")
instructions.append("Answer:\n")
instructions_text = ''.join(instructions)
return instructions_text