dc-ml-emea-ar/core/data_extraction.py

1126 lines
52 KiB
Python
Raw Normal View History

import os
import json
import json_repair
import re
import fitz
import pandas as pd
from utils.gpt_utils import chat
from utils.pdf_util import PDFUtil
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
from utils.logger import logger
2024-12-02 23:16:56 +00:00
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, get_most_similar_name, remove_abundant_data
class DataExtraction:
def __init__(
self,
doc_id: str,
pdf_file: str,
output_data_folder: str,
page_text_dict: dict,
datapoint_page_info: dict,
datapoints: list,
document_mapping_info_df: pd.DataFrame,
2024-09-19 21:29:26 +00:00
extract_way: str = "text",
output_image_folder: str = None,
) -> None:
self.doc_id = doc_id
self.pdf_file = pdf_file
if output_data_folder is None or len(output_data_folder) == 0:
output_data_folder = r"/data/emea_ar/output/extract_data/docs/"
os.makedirs(output_data_folder, exist_ok=True)
self.output_data_json_folder = os.path.join(output_data_folder, "json/")
os.makedirs(self.output_data_json_folder, exist_ok=True)
self.output_data_excel_folder = os.path.join(output_data_folder, "excel/")
os.makedirs(self.output_data_excel_folder, exist_ok=True)
if page_text_dict is None or len(page_text_dict.keys()) == 0:
self.page_text_dict = self.get_pdf_page_text_dict()
else:
self.page_text_dict = page_text_dict
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
else:
self.document_mapping_info_df = document_mapping_info_df
self.provider_mapping_df = self.get_provider_mapping()
if len(self.provider_mapping_df) == 0:
self.provider_fund_name_list = []
else:
self.provider_fund_name_list = (
self.provider_mapping_df["FundName"].unique().tolist()
)
self.datapoint_page_info = datapoint_page_info
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
self.datapoints = datapoints
self.instructions_config = self.get_instructions_config()
self.datapoint_level_config = self.get_datapoint_level()
self.datapoint_name_config = self.get_datapoint_name()
self.datapoint_reported_name_config, self.non_english_reported_name_config = \
self.get_datapoint_reported_name()
2024-09-19 21:29:26 +00:00
self.extract_way = extract_way
self.output_image_folder = output_image_folder
def get_datapoint_reported_name(self):
language_config_file = r"./configuration/language.json"
self.language_config = {}
with open(language_config_file, "r", encoding="utf-8") as file:
self.language_config = json.load(file)
self.language_id = self.document_mapping_info_df["Language"].iloc[0]
self.language = self.language_config.get(self.language_id, None)
datapoint_reported_name_config_file = r"./configuration/datapoint_reported_name.json"
all_datapoint_reported_name = {}
with open(datapoint_reported_name_config_file, "r", encoding="utf-8") as file:
all_datapoint_reported_name = json.load(file)
non_english_reported_name_config = {}
datapoint_reported_name_config = {}
common_language = "english"
for datapoint, language_reported_name in all_datapoint_reported_name.items():
reported_name_list = language_reported_name.get(common_language, [])
if self.language != "english":
reported_name_list.extend(language_reported_name.get(self.language, []))
non_english_reported_name_config[datapoint] = language_reported_name.get(self.language, [])
# remove duplicate reported name
reported_name_list = list(set(reported_name_list))
# sort the reported name
reported_name_list.sort()
datapoint_reported_name_config[datapoint] = reported_name_list
return datapoint_reported_name_config, non_english_reported_name_config
2024-09-19 21:29:26 +00:00
def get_provider_mapping(self):
if len(self.document_mapping_info_df) == 0:
return pd.DataFrame()
provider_id_list = (
self.document_mapping_info_df["ProviderId"].unique().tolist()
)
provider_mapping_list = []
for provider_id in provider_id_list:
provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False))
provider_mapping_df = pd.concat(provider_mapping_list)
provider_mapping_df = provider_mapping_df.drop_duplicates()
provider_mapping_df.reset_index(drop=True, inplace=True)
return provider_mapping_df
2024-09-19 21:29:26 +00:00
def get_pdf_image_base64(self, page_index: int) -> dict:
pdf_util = PDFUtil(self.pdf_file)
return pdf_util.extract_image_from_page(page_index=page_index,
output_folder=self.output_image_folder)
def get_instructions_config(self) -> dict:
instructions_config_file = r"./instructions/data_extraction_prompts_config.json"
with open(instructions_config_file, "r", encoding="utf-8") as f:
instructions_config = json.load(f)
return instructions_config
def get_datapoint_level(self) -> dict:
datapoint_level_file = r"./configuration/datapoint_level.json"
with open(datapoint_level_file, "r", encoding="utf-8") as f:
datapoint_level = json.load(f)
return datapoint_level
def get_datapoint_name(self) -> dict:
datapoint_name_file = r"./configuration/datapoint_name.json"
with open(datapoint_name_file, "r", encoding="utf-8") as f:
datapoint_name = json.load(f)
return datapoint_name
def get_pdf_page_text_dict(self) -> dict:
pdf_util = PDFUtil(self.pdf_file)
success, text, page_text_dict = pdf_util.extract_text()
return page_text_dict
def get_page_nums_from_datapoint_page_info(self) -> list:
page_nums_with_datapoints = []
for datapoint, page_nums in self.datapoint_page_info.items():
if datapoint == "doc_id":
continue
page_nums_with_datapoints.extend(page_nums)
page_nums_with_datapoints = list(set(page_nums_with_datapoints))
# sort the page numbers
page_nums_with_datapoints.sort()
return page_nums_with_datapoints
2024-09-19 21:29:26 +00:00
def extract_data(self) -> dict:
2024-09-19 21:29:26 +00:00
logger.info(f"Extracting data from document {self.doc_id}, extract way: {self.extract_way}")
if self.extract_way == "text":
2024-12-02 23:16:56 +00:00
data_list = self.extract_data_by_text()
2024-09-19 21:29:26 +00:00
elif self.extract_way == "image":
2024-12-02 23:16:56 +00:00
data_list = self.extract_data_by_image()
2024-09-19 21:29:26 +00:00
else:
2024-12-02 23:16:56 +00:00
data_list = self.extract_data_by_text()
data_list = remove_abundant_data(data_list)
self.output_data_to_file(data_list)
return data_list
2024-09-19 21:29:26 +00:00
def extract_data_by_text(self) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
data_list = []
pdf_page_count = len(self.page_text_dict.keys())
handled_page_num_list = []
previous_page_num = -1
previous_page_datapoints = []
previous_page_fund_name = None
for page_num, page_text in self.page_text_dict.items():
if page_num in handled_page_num_list:
continue
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
if previous_page_num == page_num - 1 and \
previous_page_datapoints == page_datapoints and \
previous_page_fund_name is not None:
# Transfer previous page fund name to be the pre-fix of page text
# The purpose is to get fund name if the first records without fund name
# example document: 431073795, page index 1727 to 1728
logger.info(f"Transfer previous page fund name: {previous_page_fund_name} to be the pre-fix of page text")
page_text = f"\nThe last fund name of previous PDF page: {previous_page_fund_name}\n{page_text}"
extract_data = self.extract_data_by_page(
page_num,
page_text,
page_datapoints,
need_exclude=False,
exclude_data=None,
previous_page_last_fund=previous_page_fund_name
)
data_list.append(extract_data)
page_data_list = extract_data.get("extract_data", {}).get("data", [])
if len(page_data_list) > 0:
previous_page_num = page_num
previous_page_fund_name = page_data_list[-1].get("fund_name", "")
previous_page_datapoints = page_datapoints
current_page_data_count = len(page_data_list)
if current_page_data_count > 0:
count = 1
# some pdf documents have multiple pages for the same data
# and the next page may without table header with data point keywords.
# the purpose is try to get data from the next page
current_text = page_text
while count < 3:
try:
next_page_num = page_num + count
if next_page_num >= pdf_page_count:
break
next_datapoints = page_datapoints
if next_page_num in self.page_nums_with_datapoints:
should_continue = False
next_datapoints = self.get_datapoints_by_page_num(next_page_num)
if len(next_datapoints) == 0:
should_continue = True
else:
for next_datapoint in next_datapoints:
if next_datapoint not in page_datapoints:
should_continue = True
break
next_datapoints.extend(page_datapoints)
# remove duplicate datapoints
next_datapoints = list(set(next_datapoints))
if not should_continue:
break
next_page_text = self.page_text_dict.get(next_page_num, "")
target_text = current_text + next_page_text
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page(
next_page_num,
target_text,
next_datapoints,
need_exclude=True,
exclude_data=page_data_list,
previous_page_last_fund=previous_page_fund_name
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
).get("data", [])
if next_page_data_list is not None and len(next_page_data_list) > 0:
for current_page_data in page_data_list:
if current_page_data in next_page_data_list:
next_page_data_list.remove(current_page_data)
next_page_extract_data["extract_data"][
"data"
] = next_page_data_list
data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num)
exist_current_page_datapoint = False
for next_page_data in next_page_data_list:
for page_datapoint in page_datapoints:
if page_datapoint in list(next_page_data.keys()):
exist_current_page_datapoint = True
break
if exist_current_page_datapoint:
break
if not exist_current_page_datapoint:
break
else:
break
count += 1
except Exception as e:
logger.error(f"Error in extracting data from next page: {e}")
break
2024-12-02 23:16:56 +00:00
# self.output_data_to_file(data_list)
2024-09-19 21:29:26 +00:00
return data_list
def extract_data_by_image(self) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
data_list = []
pdf_page_count = len(self.page_text_dict.keys())
handled_page_num_list = []
for page_num, page_text in self.page_text_dict.items():
if page_num in handled_page_num_list:
continue
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
extract_data = self.extract_data_by_page_image(page_num=page_num,
page_datapoints=page_datapoints)
data_list.append(extract_data)
page_data_list = extract_data.get("extract_data", {}).get("data", [])
current_page_data_count = len(page_data_list)
if current_page_data_count > 0:
count = 1
while count < 3:
try:
next_page_num = page_num + count
if next_page_num >= pdf_page_count:
break
next_datapoints = page_datapoints
if next_page_num in self.page_nums_with_datapoints:
should_continue = False
next_datapoints = self.get_datapoints_by_page_num(next_page_num)
if len(next_datapoints) == 0:
should_continue = True
else:
for next_datapoint in next_datapoints:
if next_datapoint not in page_datapoints:
should_continue = True
break
next_datapoints.extend(page_datapoints)
# remove duplicate datapoints
next_datapoints = list(set(next_datapoints))
if not should_continue:
break
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page_image(
page_num=next_page_num,
page_datapoints=next_datapoints
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
).get("data", [])
if next_page_data_list is not None and len(next_page_data_list) > 0:
data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num)
exist_current_page_datapoint = False
for next_page_data in next_page_data_list:
for page_datapoint in page_datapoints:
if page_datapoint in list(next_page_data.keys()):
exist_current_page_datapoint = True
break
if exist_current_page_datapoint:
break
if not exist_current_page_datapoint:
break
else:
break
count += 1
except Exception as e:
logger.error(f"Error in extracting data from next page: {e}")
break
2024-12-02 23:16:56 +00:00
# self.output_data_to_file(data_list)
2024-09-19 21:29:26 +00:00
return data_list
def output_data_to_file(self, data_list: list) -> None:
json_data_file = os.path.join(
self.output_data_json_folder, f"{self.doc_id}.json"
)
with open(json_data_file, "w", encoding="utf-8") as f:
json.dump(data_list, f, ensure_ascii=False, indent=4)
data_df = pd.DataFrame(data_list)
data_df.reset_index(drop=True, inplace=True)
excel_data_file = os.path.join(
self.output_data_excel_folder, f"{self.doc_id}.xlsx"
)
with pd.ExcelWriter(excel_data_file) as writer:
data_df.to_excel(writer, sheet_name="extract_data", index=False)
def extract_data_by_page(
self,
page_num: int,
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
previous_page_last_fund: str = None) -> dict:
# If can't find numberic value, e.g. 1.25 or 3,88
# apply Vision ChatGPT to extract data
numeric_regex = r"\d+(\.|\,)\d+"
if not re.search(numeric_regex, page_text):
logger.info(f"Can't find numberic value in page {page_num}, apply Vision ChatGPT to extract data")
return self.extract_data_by_page_image(
page_num, page_datapoints, need_exclude, exclude_data)
else:
return self.extract_data_by_page_text(
page_num=page_num,
page_text=page_text,
page_datapoints=page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
previous_page_last_fund=previous_page_last_fund
)
2024-09-19 21:29:26 +00:00
def extract_data_by_page_text(
self,
page_num: int,
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
previous_page_last_fund: str = None
) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
logger.info(f"Extracting data from page {page_num}")
instructions = self.get_instructions_by_datapoints(
page_text,
page_datapoints,
need_exclude,
exclude_data,
extract_way="text"
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}
)
if with_error:
logger.error(f"Error in extracting tables from page")
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "text"
return data_dict
try:
data = json.loads(response)
except:
try:
# if occur error, perhaps the output length is over 4K tokens
# split the context to two parts and try to get data from the two parts
# Attention: after deploy ChatGPT4o 2024-08-16 version, the max token length is 16K,
# need not to split the context.
# data = self.chat_by_split_context(
# page_text, page_datapoints, need_exclude, exclude_data
# )
if len(data.get("data", [])) == 0:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data, previous_page_last_fund)
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "text"
return data_dict
2024-09-19 21:29:26 +00:00
def extract_data_by_page_image(
self,
page_num: int,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
2024-09-19 21:29:26 +00:00
) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
logger.info(f"Extracting data from page {page_num}")
image_base64 = self.get_pdf_image_base64(page_num)
instructions = self.get_instructions_by_datapoints(
"",
page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
extract_way="image"
2024-09-19 21:29:26 +00:00
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}, image_base64=image_base64
)
if with_error:
logger.error(f"Error in extracting tables from page")
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
2024-09-19 21:29:26 +00:00
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "image"
2024-09-19 21:29:26 +00:00
return data_dict
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data, None)
2024-09-19 21:29:26 +00:00
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
2024-09-19 21:29:26 +00:00
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "image"
2024-09-19 21:29:26 +00:00
return data_dict
def validate_data(self, extract_data_info: dict, previous_page_last_fund: str=None) -> dict:
"""
Validate data by the rules
1. Each data should be with fund name
2. For share level data, it should be with share name
"""
data_list = extract_data_info.get("data", [])
if len(data_list) == 0:
return extract_data_info
remove_list = []
for data in data_list:
fund_name = data.get("fund name", "").strip()
if fund_name == "":
remove_list.append(data)
# Clean fund name start
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
previous_page_last_fund = previous_page_last_fund.strip()
if fund_name.startswith(previous_page_last_fund) and fund_name != previous_page_last_fund:
modified_fund_name = fund_name.replace(previous_page_last_fund, "").strip()
if len(modified_fund_name.split()) > 1:
fund_name = modified_fund_name
fund_name = self.get_fund_name(fund_name, "Fund")
fund_name = self.get_fund_name(fund_name, "Bond")
remove_list = ["Market Specific Equity Sub-Funds",
"International and Regional Equity Sub-Funds",
"Equity Sub-Funds"]
for remove_item in remove_list:
if fund_name.startswith(remove_item):
fund_name = fund_name.replace(remove_item, "").strip()
2024-12-02 23:16:56 +00:00
data["fund name"] = fund_name
# Clean fund name end
keys = list(data.keys())
for key in keys:
if self.datapoint_level_config.get(key, "") == "share_level":
if data.get("share name", "") == "":
is_share_name = self.check_fund_name_as_share(fund_name)
if not is_share_name:
remove_list.append(data)
break
else:
data["share name"] = fund_name
if data.get(key, "") == "":
data.pop(key)
for remove_data in remove_list:
if remove_data in data_list:
data_list.remove(remove_data)
# check performance_fee
for data in data_list:
performance_fee = data.get("performance_fee", None)
if performance_fee is not None:
try:
performance_fee = float(performance_fee)
if (performance_fee > 3 and performance_fee % 2.5 == 0) or \
performance_fee > 10:
data.pop("performance_fee")
except:
data.pop("performance_fee")
remove_list = []
for data in data_list:
keys = [key for key in list(data.keys())
if key not in ["fund name", "share name"]]
if len(keys) == 0:
remove_list.append(data)
for remove_data in remove_list:
if remove_data in data_list:
data_list.remove(remove_data)
# update "fund name" to be "fund_name"
# update "share name" to be "share_name"
new_data_list = []
multi_over_3_share_regex = r"([A-Z]{1,}\,\s){3,}"
exist_multi_over_3_share = False
for data in data_list:
fund_name = data.get("fund name", "").strip()
if len(fund_name) == 0:
continue
raw_share_name = data.get("share name", "")
if not exist_multi_over_3_share:
multi_over_3_share_search = re.search(multi_over_3_share_regex, raw_share_name)
if multi_over_3_share_search is not None:
exist_multi_over_3_share = True
if exist_multi_over_3_share:
share_name_list = self.split_multi_share_name(raw_share_name)
else:
share_name_list = [raw_share_name]
if len(share_name_list) > 0:
for share_name in share_name_list:
new_data = {}
new_data["fund_name"] = fund_name
if share_name != "":
new_data["share_name"] = share_name
ter = data.get("ter", None)
if ter is not None:
new_data["ter"] = ter
performance_fee = data.get("performance fees", None)
if performance_fee is not None:
new_data["performance_fee"] = performance_fee
for key, value in data.items():
if key not in ["fund name", "share name", "ter", "performance fees"]:
new_data[key] = value
new_data_list.append(new_data)
extract_data_info["data"] = new_data_list
return extract_data_info
def split_multi_share_name(self, raw_share_name: str) -> list:
"""
Some document, e.g. 481482392
Exist multi share name as table header, e.g. "Class A, B, E, M, N, P, R, U"
For this case, need split the share name to be ["Class A", "Class B", "Class E",
"Class M", "Class N", "Class P", "Class R", "Class U"]
"""
multi_over_2_share_regex = r"([A-Z]{1,}\,\s){2,}"
multi_over_2_share_search = re.search(multi_over_2_share_regex, raw_share_name)
share_name_list = [raw_share_name]
if multi_over_2_share_search is not None:
multi_share_splits = [share_name.strip() for share_name in raw_share_name.split(",")
if len(share_name.strip()) > 0]
first_share_name = multi_share_splits[0]
first_share_name_split = first_share_name.split()
share_name_prefix = None
if len(first_share_name_split) == 2:
share_name_prefix = first_share_name_split[0]
if share_name_prefix is not None and len(share_name_prefix) > 0:
new_share_name_list = []
for split in multi_share_splits:
if split == first_share_name:
new_share_name_list.append(split)
else:
new_share_name_list.append(f"{share_name_prefix} {split}")
share_name_list = new_share_name_list
else:
share_name_list = multi_share_splits
else:
share_name_list = multi_share_splits
return share_name_list
2024-12-02 23:16:56 +00:00
def get_fund_name(self, fund_name: str, fund_feature: str):
if not fund_name.endswith(fund_feature):
return fund_name
# to avoid split funds to fund s
fund_feature = fund_feature + " "
2024-12-02 23:16:56 +00:00
fund_name_split = fund_name.split(fund_feature)
if len(fund_name_split) > 1:
last_fund = fund_name_split[-1].strip()
if len(last_fund) == 0:
last_fund = fund_name_split[-2].strip()
fund_name = f"{last_fund} {fund_feature}"
return fund_name
def check_fund_name_as_share(self, fund_name: str) -> bool:
"""
Check if the fund name is the same as share name
"""
if len(fund_name) == 0 == 0:
return False
share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
if len(share_name_list) == 0:
return False
max_similarity_name, max_similarity = get_most_similar_name(
text=fund_name,
name_list=share_name_list,
share_name=None,
fund_name=None,
matching_type="share",
process_cache=None)
if max_similarity >= 0.8:
return True
return False
def get_datapoints_by_page_num(self, page_num: int) -> list:
datapoints = []
for datapoint in self.datapoints:
if page_num in self.datapoint_page_info[datapoint]:
datapoints.append(datapoint)
return datapoints
def get_instructions_by_datapoints(
self,
page_text: str,
datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
extract_way: str = "text",
) -> str:
"""
Get instructions to extract data from the page by the datapoints
Below is the instructions sections:
summary: string
reported_name by datapoints: dict
data_business_features: dict
common: list
investment_level by datapoints: dict
data_value_range by datapoints: dict
special_rule by datapoints: dict
special_cases: dict
common: list
title
contents
special_case by datapoints: list
title
contents
output_requirement
common: list
fund_level: list
share_level: dict
fund_name: list
share_name: list
ogc_value: list
ter_value: list
performance_fee_value: list
end
"""
2024-09-19 21:29:26 +00:00
instructions = []
if extract_way == "text":
2024-09-19 21:29:26 +00:00
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
datapoint_name_list = []
for datapoint in datapoints:
datapoint_name = self.datapoint_name_config.get(datapoint, "")
datapoint_name_list.append(datapoint_name)
if extract_way == "text":
2024-09-19 21:29:26 +00:00
summary = self.instructions_config.get("summary", "\n")
elif extract_way == "image":
2024-09-19 21:29:26 +00:00
summary = self.instructions_config.get("summary_image", "\n")
else:
summary = self.instructions_config.get("summary", "\n")
instructions.append(summary.format(", ".join(datapoint_name_list)))
instructions.append("\n")
2024-09-19 21:29:26 +00:00
if extract_way == "image":
2024-09-19 21:29:26 +00:00
image_features = self.instructions_config.get("image_features", [])
instructions.extend(image_features)
instructions.append("\n")
instructions.append("Datapoints Reported name:\n")
instructions.append("Please look for relevant reported names and similar variations in the context.\n")
reported_name_info_in_instructions = self.instructions_config.get("reported_name", {})
for datapoint in datapoints:
reported_name_list = self.datapoint_reported_name_config.get(datapoint, [])
if len(reported_name_list) == 0:
reported_name = reported_name_info_in_instructions.get(datapoint, "")
else:
joined_reported_name = ", ".join(reported_name_list)
datapoint_name = datapoint
if datapoint_name == "performance_fee":
datapoint_name = "performance fees"
else:
datapoint_name = datapoint_name.upper()
reported_name = f"The {datapoint_name} reported name could be:\n{joined_reported_name}"
instructions.append(reported_name)
instructions.append("\n")
instructions.append("\n")
if self.language != "english":
"""
"multilingual_reported_name": {
"describe": "Please be careful to extract relevant data by different reported names from multilingual Context.",
"regular_example_template": "{datapoint} Example {number}:\nLanguage: {language}\n---Context Start-----\n{fund_name}\n{share_name}\n{reported_name}\n{value}\n---Context End-----\nAnswer: {answer}",
"special_example_template_none": "{datapoint} Example {number}:\nLanguage: {language}\nValue is belong to \"-, *, **, N/A, N/A%, N/A %, NONE\", ignore it\n---Context Start-----\n{fund_name}\n{share_name}\n{reported_name} 2)\n-\n---Context End-----\nAnswer: {answer}",
"value_examples": ["1,98", "3.25", "2.16", "1,73", "4,53"]
"fund_example": "Fund 1",
"share_example": "Share 1"
}
"""
multilingual_reported_name_config = self.instructions_config.get("multilingual_reported_name", {})
describe = multilingual_reported_name_config.get("describe", "")
regular_example_template = multilingual_reported_name_config.get("regular_example_template", "")
special_example_template_none = multilingual_reported_name_config.get("special_example_template_none", "")
value_examples = multilingual_reported_name_config.get("value_examples", [])
fund_example = multilingual_reported_name_config.get("fund_example", "")
share_example = multilingual_reported_name_config.get("share_example", "")
instructions.append("Multilingual reported name:\n")
instructions.append(f"{describe}\n")
# set language the first char to be upper
language = self.language[0].upper() + self.language[1:]
for datapoint in datapoints:
mul_reported_name_list = self.non_english_reported_name_config.get(datapoint, [])
# shuffle the reported name list
mul_reported_name_list = list(set(mul_reported_name_list))
if len(mul_reported_name_list) == 0:
continue
datapoint_name = datapoint
if datapoint_name == "performance_fee":
datapoint_name = "performance fees"
else:
datapoint_name = datapoint_name.upper()
example_count = 1
none_value_example_count = 0
for mul_reported_name in mul_reported_name_list:
if datapoint in ["ter", "performance_fee"] and example_count == 3:
break
value = value_examples[example_count % len(value_examples)]
answer = {"fund name": fund_example,
"share name": share_example,
datapoint: float(value.replace(",", "."))}
# transfer answer to string
answer = json.dumps(answer, ensure_ascii=False)
example = regular_example_template.format(
datapoint=datapoint_name,
number=example_count,
language=language,
fund_name=fund_example,
share_name=share_example,
reported_name=mul_reported_name,
value=value,
answer=answer,
)
instructions.append(example)
instructions.append("\n")
instructions.append("\n")
example_count += 1
if len(mul_reported_name.split()) > 1:
if none_value_example_count != 2:
none_value_example = special_example_template_none.format(
datapoint=datapoint_name,
number=example_count,
language=language,
fund_name=fund_example,
share_name=share_example,
reported_name=mul_reported_name,
answer = json.dumps({}, ensure_ascii=False)
)
instructions.append(none_value_example)
instructions.append("\n")
instructions.append("\n")
example_count += 1
none_value_example_count += 1
instructions.append("\n")
instructions.append("Data business features:\n")
data_business_features = self.instructions_config.get(
"data_business_features", {}
)
common = "\n".join(data_business_features.get("common", []))
instructions.append(common)
instructions.append("\n")
instructions.append("Datapoints investment level:\n")
investment_level_info = data_business_features.get("investment_level", {})
for datapoint in datapoints:
investment_level = investment_level_info.get(datapoint, "")
instructions.append(investment_level)
instructions.append("\n")
instructions.append("\n")
instructions.append("Datapoints value range:\n")
data_value_range_info = data_business_features.get("data_value_range", {})
for datapoint in datapoints:
data_value_range = data_value_range_info.get(datapoint, "")
instructions.append(data_value_range)
instructions.append("\n")
instructions.append("\n")
special_rule_info = data_business_features.get("special_rule", {})
with_special_rule_title = False
for datapoint in datapoints:
special_rule_list = special_rule_info.get(datapoint, [])
if len(special_rule_list) > 0:
if not with_special_rule_title:
instructions.append("Special rule:\n")
with_special_rule_title = True
special_rule = "\n".join(special_rule_list)
instructions.append(special_rule)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Special cases:\n")
special_cases = self.instructions_config.get("special_cases", {})
special_cases_common_list = special_cases.get("common", [])
special_cases_number = 1
for special_cases_common in special_cases_common_list:
title = special_cases_common.get("title", "")
title = f"{special_cases_number}. {title} "
special_cases_number += 1
instructions.append(title)
instructions.append("\n")
contents_list = special_cases_common.get("contents", [])
contents = "\n".join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
for datapoint in datapoints:
special_case_list = special_cases.get(datapoint, [])
for special_case in special_case_list:
title = special_case.get("title", "")
title = f"{special_cases_number}. {title} "
special_cases_number += 1
instructions.append(title)
instructions.append("\n")
contents_list = special_case.get("contents", [])
contents = "\n".join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
instructions.append("\n")
# extreme_complex_config_list = special_cases.get("extreme_complex", [])
# if len(extreme_complex_config_list) > 0:
# for extreme_complex_config in extreme_complex_config_list:
# regex = extreme_complex_config.get("regex", "")
# if len(regex) == 0:
# continue
# search = re.search(regex, page_text)
# if search is not None:
# title = extreme_complex_config.get("title", "")
# title = f"{special_cases_number}. {title} "
# special_cases_number += 1
# instructions.append(title)
# instructions.append("\n")
# contents_list = extreme_complex_config.get("contents", [])
# contents = "\n".join(contents_list)
# instructions.append(contents)
# instructions.append("\n\n")
instructions.append("Output requirement:\n")
output_requirement = self.instructions_config.get("output_requirement", {})
output_requirement_common_list = output_requirement.get("common", [])
instructions.append("\n".join(output_requirement_common_list))
instructions.append("\n")
share_datapoint_value_example = {}
share_level_config = output_requirement.get("share_level", {})
example_list = []
dp_reported_name_config = output_requirement.get("dp_reported_name", {})
dp_reported_name = {}
for datapoint in datapoints:
investment_level = self.datapoint_level_config.get(datapoint, "")
if investment_level == "fund_level":
fund_level_example_list = output_requirement.get("fund_level", [])
for example in fund_level_example_list:
try:
sub_example_list = json.loads(example)
except:
sub_example_list = json_repair.loads(example)
example_list.extend(sub_example_list)
elif investment_level == "share_level":
share_datapoint_value_example[datapoint] = share_level_config.get(
f"{datapoint}_value", []
)
dp_reported_name[datapoint] = dp_reported_name_config.get(datapoint, "")
share_datapoint_list = list(share_datapoint_value_example.keys())
instructions.append(f"Example:\n")
if len(share_datapoint_list) > 0:
fund_name_example_list = share_level_config.get("fund_name", [])
share_name_example_list = share_level_config.get("share_name", [])
for index in range(len(fund_name_example_list)):
example_dict = {
"fund name": fund_name_example_list[index],
"share name": share_name_example_list[index],
}
for share_datapoint in share_datapoint_list:
share_datapoint_values = share_datapoint_value_example[
share_datapoint
]
if index < len(share_datapoint_values):
example_dict[share_datapoint] = share_datapoint_values[index]
example_list.append(example_dict)
example_data = {"data": example_list, "dp_reported_name": dp_reported_name}
instructions.append(json.dumps(example_data, ensure_ascii=False, indent=4))
instructions.append("\n")
instructions.append("\n")
end_list = self.instructions_config.get("end", [])
instructions.append("\n".join(end_list))
instructions.append("\n")
if need_exclude and exclude_data is not None and isinstance(exclude_data, list):
instructions.append("Please exclude below data from output:\n")
instructions.append(json.dumps(exclude_data, ensure_ascii=False, indent=4))
instructions.append("\n")
instructions.append("\n")
instructions.append("Answer:\n")
instructions_text = "".join(instructions)
return instructions_text
# def chat_by_split_context(self,
# page_text: str,
# page_datapoints: list,
# need_exclude: bool,
# exclude_data: list) -> list:
# """
# If occur error, split the context to two parts and try to get data from the two parts
# Relevant document: 503194284, page index 147
# """
# try:
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
# split_context = re.split(r"\n", page_text)
# split_context = [text.strip() for text in split_context
# if len(text.strip()) > 0]
# if len(split_context) < 10:
# return {"data": []}
# split_context_len = len(split_context)
# top_10_context = split_context[:10]
# rest_context = split_context[10:]
# header = "\n".join(top_10_context)
# half_len = split_context_len // 2
# # the member of half_len should not start with number
# # reverse iterate the list by half_len
# half_len_list = [i for i in range(half_len)]
# fund_name_line = ""
# half_line = rest_context[half_len].strip()
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# half_line, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity < 0.2:
# # get the fund name line text from the first half
# for index in reversed(half_len_list):
# line_text = rest_context[index].strip()
# if len(line_text) == 0:
# continue
# line_text_split = line_text.split()
# if len(line_text_split) < 3:
# continue
# first_word = line_text_split[0]
# if first_word.lower() == "class":
# continue
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# line_text, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity >= 0.2:
# fund_name_line = line_text
# break
# else:
# fund_name_line = half_line
# half_len += 1
# if fund_name_line == "":
# return {"data": []}
# logger.info(f"Split first part from 0 to {half_len}")
# split_first_part = "\n".join(split_context[:half_len])
# first_part = '\n'.join(split_first_part)
# first_instructions = self.get_instructions_by_datapoints(
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# first_instructions, response_format={"type": "json_object"}
# )
# first_part_data = {"data": []}
# if not with_error:
# try:
# first_part_data = json.loads(response)
# except:
# first_part_data = json_repair.loads(response)
# logger.info(f"Split second part from {half_len} to {split_context_len}")
# split_second_part = "\n".join(split_context[half_len:])
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
# second_instructions = self.get_instructions_by_datapoints(
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# second_instructions, response_format={"type": "json_object"}
# )
# second_part_data = {"data": []}
# if not with_error:
# try:
# second_part_data = json.loads(response)
# except:
# second_part_data = json_repair.loads(response)
# first_part_data_list = first_part_data.get("data", [])
# logger.info(f"First part data count: {len(first_part_data_list)}")
# second_part_data_list = second_part_data.get("data", [])
# logger.info(f"Second part data count: {len(second_part_data_list)}")
# for first_data in first_part_data_list:
# if first_data in second_part_data_list:
# second_part_data_list.remove(first_data)
# else:
# # if the first part data is with same fund name and share name,
# # remove the second part data
# first_data_dp = [key for key in list(first_data.keys())
# if key not in ["fund name", "share name"]]
# # order the data points
# first_data_dp.sort()
# first_fund_name = first_data.get("fund name", "")
# first_share_name = first_data.get("share name", "")
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
# remove_second_list = []
# for second_data in second_part_data_list:
# second_fund_name = second_data.get("fund name", "")
# second_share_name = second_data.get("share name", "")
# if first_fund_name == second_fund_name and \
# first_share_name == second_share_name:
# second_data_dp = [key for key in list(second_data.keys())
# if key not in ["fund name", "share name"]]
# second_data_dp.sort()
# if first_data_dp == second_data_dp:
# remove_second_list.append(second_data)
# for remove_second in remove_second_list:
# if remove_second in second_part_data_list:
# second_part_data_list.remove(remove_second)
# data_list = first_part_data_list + second_part_data_list
# extract_data = {"data": data_list}
# return extract_data
# except Exception as e:
# logger.error(f"Error in split context: {e}")
# return {"data": []}