dc-ml-emea-ar/core/data_extraction.py

1013 lines
46 KiB
Python

import os
import json
import json_repair
import re
import fitz
import pandas as pd
from utils.gpt_utils import chat
from utils.pdf_util import PDFUtil
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
from utils.logger import logger
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, get_most_similar_name
class DataExtraction:
def __init__(
self,
doc_id: str,
pdf_file: str,
output_data_folder: str,
page_text_dict: dict,
datapoint_page_info: dict,
datapoints: list,
document_mapping_info_df: pd.DataFrame,
extract_way: str = "text",
output_image_folder: str = None,
) -> None:
self.doc_id = doc_id
self.pdf_file = pdf_file
if output_data_folder is None or len(output_data_folder) == 0:
output_data_folder = r"/data/emea_ar/output/extract_data/docs/"
os.makedirs(output_data_folder, exist_ok=True)
self.output_data_json_folder = os.path.join(output_data_folder, "json/")
os.makedirs(self.output_data_json_folder, exist_ok=True)
self.output_data_excel_folder = os.path.join(output_data_folder, "excel/")
os.makedirs(self.output_data_excel_folder, exist_ok=True)
if page_text_dict is None or len(page_text_dict.keys()) == 0:
self.page_text_dict = self.get_pdf_page_text_dict()
else:
self.page_text_dict = page_text_dict
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
else:
self.document_mapping_info_df = document_mapping_info_df
self.provider_mapping_df = self.get_provider_mapping()
if len(self.provider_mapping_df) == 0:
self.provider_fund_name_list = []
else:
self.provider_fund_name_list = (
self.provider_mapping_df["FundName"].unique().tolist()
)
self.datapoint_page_info = datapoint_page_info
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
self.datapoints = datapoints
self.instructions_config = self.get_instructions_config()
self.datapoint_level_config = self.get_datapoint_level()
self.datapoint_name_config = self.get_datapoint_name()
self.datapoint_reported_name_config, self.non_english_reported_name_config = \
self.get_datapoint_reported_name()
self.extract_way = extract_way
self.output_image_folder = output_image_folder
def get_datapoint_reported_name(self):
language_config_file = r"./configuration/language.json"
self.language_config = {}
with open(language_config_file, "r", encoding="utf-8") as file:
self.language_config = json.load(file)
self.language_id = self.document_mapping_info_df["Language"].iloc[0]
self.language = self.language_config.get(self.language_id, None)
datapoint_reported_name_config_file = r"./configuration/datapoint_reported_name.json"
all_datapoint_reported_name = {}
with open(datapoint_reported_name_config_file, "r", encoding="utf-8") as file:
all_datapoint_reported_name = json.load(file)
non_english_reported_name_config = {}
datapoint_reported_name_config = {}
common_language = "english"
for datapoint, language_reported_name in all_datapoint_reported_name.items():
reported_name_list = language_reported_name.get(common_language, [])
if self.language != "english":
reported_name_list.extend(language_reported_name.get(self.language, []))
non_english_reported_name_config[datapoint] = language_reported_name.get(self.language, [])
# remove duplicate reported name
reported_name_list = list(set(reported_name_list))
# sort the reported name
reported_name_list.sort()
datapoint_reported_name_config[datapoint] = reported_name_list
return datapoint_reported_name_config, non_english_reported_name_config
def get_provider_mapping(self):
if len(self.document_mapping_info_df) == 0:
return pd.DataFrame()
provider_id_list = (
self.document_mapping_info_df["ProviderId"].unique().tolist()
)
provider_mapping_list = []
for provider_id in provider_id_list:
provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False))
provider_mapping_df = pd.concat(provider_mapping_list)
provider_mapping_df = provider_mapping_df.drop_duplicates()
provider_mapping_df.reset_index(drop=True, inplace=True)
return provider_mapping_df
def get_pdf_image_base64(self, page_index: int) -> dict:
pdf_util = PDFUtil(self.pdf_file)
return pdf_util.extract_image_from_page(page_index=page_index,
output_folder=self.output_image_folder)
def get_instructions_config(self) -> dict:
instructions_config_file = r"./instructions/data_extraction_prompts_config.json"
with open(instructions_config_file, "r", encoding="utf-8") as f:
instructions_config = json.load(f)
return instructions_config
def get_datapoint_level(self) -> dict:
datapoint_level_file = r"./configuration/datapoint_level.json"
with open(datapoint_level_file, "r", encoding="utf-8") as f:
datapoint_level = json.load(f)
return datapoint_level
def get_datapoint_name(self) -> dict:
datapoint_name_file = r"./configuration/datapoint_name.json"
with open(datapoint_name_file, "r", encoding="utf-8") as f:
datapoint_name = json.load(f)
return datapoint_name
def get_pdf_page_text_dict(self) -> dict:
pdf_util = PDFUtil(self.pdf_file)
success, text, page_text_dict = pdf_util.extract_text()
return page_text_dict
def get_page_nums_from_datapoint_page_info(self) -> list:
page_nums_with_datapoints = []
for datapoint, page_nums in self.datapoint_page_info.items():
if datapoint == "doc_id":
continue
page_nums_with_datapoints.extend(page_nums)
page_nums_with_datapoints = list(set(page_nums_with_datapoints))
# sort the page numbers
page_nums_with_datapoints.sort()
return page_nums_with_datapoints
def extract_data(self) -> dict:
logger.info(f"Extracting data from document {self.doc_id}, extract way: {self.extract_way}")
if self.extract_way == "text":
return self.extract_data_by_text()
elif self.extract_way == "image":
return self.extract_data_by_image()
else:
return self.extract_data_by_text()
def extract_data_by_text(self) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
data_list = []
pdf_page_count = len(self.page_text_dict.keys())
handled_page_num_list = []
previous_page_num = -1
previous_page_datapoints = []
previous_page_fund_name = None
for page_num, page_text in self.page_text_dict.items():
if page_num in handled_page_num_list:
continue
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
if previous_page_num == page_num - 1 and \
previous_page_datapoints == page_datapoints and \
previous_page_fund_name is not None:
# Transfer previous page fund name to be the pre-fix of page text
# The purpose is to get fund name if the first records without fund name
# example document: 431073795, page index 1727 to 1728
logger.info(f"Transfer previous page fund name: {previous_page_fund_name} to be the pre-fix of page text")
page_text = f"\n{previous_page_fund_name}\n{page_text}"
extract_data = self.extract_data_by_page(
page_num,
page_text,
page_datapoints,
need_exclude=False,
exclude_data=None,
)
data_list.append(extract_data)
page_data_list = extract_data.get("extract_data", {}).get("data", [])
if len(page_data_list) > 0:
previous_page_num = page_num
previous_page_fund_name = page_data_list[-1].get("fund_name", "")
previous_page_datapoints = page_datapoints
current_page_data_count = len(page_data_list)
if current_page_data_count > 0:
count = 1
# some pdf documents have multiple pages for the same data
# and the next page may without table header with data point keywords.
# the purpose is try to get data from the next page
current_text = page_text
while count < 3:
try:
next_page_num = page_num + count
if next_page_num >= pdf_page_count:
break
next_datapoints = page_datapoints
if next_page_num in self.page_nums_with_datapoints:
should_continue = False
next_datapoints = self.get_datapoints_by_page_num(next_page_num)
if len(next_datapoints) == 0:
should_continue = True
else:
for next_datapoint in next_datapoints:
if next_datapoint not in page_datapoints:
should_continue = True
break
next_datapoints.extend(page_datapoints)
# remove duplicate datapoints
next_datapoints = list(set(next_datapoints))
if not should_continue:
break
next_page_text = self.page_text_dict.get(next_page_num, "")
target_text = current_text + next_page_text
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page(
next_page_num,
target_text,
next_datapoints,
need_exclude=True,
exclude_data=page_data_list,
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
).get("data", [])
if next_page_data_list is not None and len(next_page_data_list) > 0:
for current_page_data in page_data_list:
if current_page_data in next_page_data_list:
next_page_data_list.remove(current_page_data)
next_page_extract_data["extract_data"][
"data"
] = next_page_data_list
data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num)
exist_current_page_datapoint = False
for next_page_data in next_page_data_list:
for page_datapoint in page_datapoints:
if page_datapoint in list(next_page_data.keys()):
exist_current_page_datapoint = True
break
if exist_current_page_datapoint:
break
if not exist_current_page_datapoint:
break
else:
break
count += 1
except Exception as e:
logger.error(f"Error in extracting data from next page: {e}")
break
self.output_data_to_file(data_list)
return data_list
def extract_data_by_image(self) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
data_list = []
pdf_page_count = len(self.page_text_dict.keys())
handled_page_num_list = []
for page_num, page_text in self.page_text_dict.items():
if page_num in handled_page_num_list:
continue
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
extract_data = self.extract_data_by_page_image(page_num=page_num,
page_datapoints=page_datapoints)
data_list.append(extract_data)
page_data_list = extract_data.get("extract_data", {}).get("data", [])
current_page_data_count = len(page_data_list)
if current_page_data_count > 0:
count = 1
while count < 3:
try:
next_page_num = page_num + count
if next_page_num >= pdf_page_count:
break
next_datapoints = page_datapoints
if next_page_num in self.page_nums_with_datapoints:
should_continue = False
next_datapoints = self.get_datapoints_by_page_num(next_page_num)
if len(next_datapoints) == 0:
should_continue = True
else:
for next_datapoint in next_datapoints:
if next_datapoint not in page_datapoints:
should_continue = True
break
next_datapoints.extend(page_datapoints)
# remove duplicate datapoints
next_datapoints = list(set(next_datapoints))
if not should_continue:
break
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page_image(
page_num=next_page_num,
page_datapoints=next_datapoints
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
).get("data", [])
if next_page_data_list is not None and len(next_page_data_list) > 0:
data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num)
exist_current_page_datapoint = False
for next_page_data in next_page_data_list:
for page_datapoint in page_datapoints:
if page_datapoint in list(next_page_data.keys()):
exist_current_page_datapoint = True
break
if exist_current_page_datapoint:
break
if not exist_current_page_datapoint:
break
else:
break
count += 1
except Exception as e:
logger.error(f"Error in extracting data from next page: {e}")
break
self.output_data_to_file(data_list)
return data_list
def output_data_to_file(self, data_list: list) -> None:
json_data_file = os.path.join(
self.output_data_json_folder, f"{self.doc_id}.json"
)
with open(json_data_file, "w", encoding="utf-8") as f:
json.dump(data_list, f, ensure_ascii=False, indent=4)
data_df = pd.DataFrame(data_list)
data_df.reset_index(drop=True, inplace=True)
excel_data_file = os.path.join(
self.output_data_excel_folder, f"{self.doc_id}.xlsx"
)
with pd.ExcelWriter(excel_data_file) as writer:
data_df.to_excel(writer, sheet_name="extract_data", index=False)
def extract_data_by_page(
self,
page_num: int,
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,) -> dict:
# If can't find numberic value, e.g. 1.25 or 3,88
# apply Vision ChatGPT to extract data
numeric_regex = r"\d+(\.|\,)\d+"
if not re.search(numeric_regex, page_text):
logger.info(f"Can't find numberic value in page {page_num}, apply Vision ChatGPT to extract data")
return self.extract_data_by_page_image(
page_num, page_datapoints, need_exclude, exclude_data)
else:
return self.extract_data_by_page_text(
page_num, page_text, page_datapoints, need_exclude, exclude_data
)
def extract_data_by_page_text(
self,
page_num: int,
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
logger.info(f"Extracting data from page {page_num}")
instructions = self.get_instructions_by_datapoints(
page_text,
page_datapoints,
need_exclude,
exclude_data,
extract_way="text"
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}
)
if with_error:
logger.error(f"Error in extracting tables from page")
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "text"
return data_dict
try:
data = json.loads(response)
except:
try:
# if occur error, perhaps the output length is over 4K tokens
# split the context to two parts and try to get data from the two parts
# Attention: after deploy ChatGPT4o 2024-08-16 version, the max token length is 16K,
# need not to split the context.
# data = self.chat_by_split_context(
# page_text, page_datapoints, need_exclude, exclude_data
# )
if len(data.get("data", [])) == 0:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data)
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "text"
return data_dict
def extract_data_by_page_image(
self,
page_num: int,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
logger.info(f"Extracting data from page {page_num}")
image_base64 = self.get_pdf_image_base64(page_num)
instructions = self.get_instructions_by_datapoints(
"",
page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
extract_way="image"
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}, image_base64=image_base64
)
if with_error:
logger.error(f"Error in extracting tables from page")
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "image"
return data_dict
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data)
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "image"
return data_dict
def validate_data(self, extract_data_info: dict) -> dict:
"""
Validate data by the rules
1. Each data should be with fund name
2. For share level data, it should be with share name
"""
data_list = extract_data_info.get("data", [])
if len(data_list) == 0:
return extract_data_info
remove_list = []
for data in data_list:
fund_name = data.get("fund name", "")
if fund_name == "":
remove_list.append(data)
keys = list(data.keys())
for key in keys:
if self.datapoint_level_config.get(key, "") == "share_level":
if data.get("share name", "") == "":
is_share_name = self.check_fund_name_as_share(fund_name)
if not is_share_name:
remove_list.append(data)
break
else:
data["share name"] = fund_name
if data.get(key, "") == "":
data.pop(key)
for remove_data in remove_list:
if remove_data in data_list:
data_list.remove(remove_data)
# check performance_fee
for data in data_list:
performance_fee = data.get("performance_fee", None)
if performance_fee is not None:
try:
performance_fee = float(performance_fee)
if (performance_fee > 3 and performance_fee % 2.5 == 0) or \
performance_fee > 10:
data.pop("performance_fee")
except:
data.pop("performance_fee")
remove_list = []
for data in data_list:
keys = [key for key in list(data.keys())
if key not in ["fund name", "share name"]]
if len(keys) == 0:
remove_list.append(data)
for remove_data in remove_list:
if remove_data in data_list:
data_list.remove(remove_data)
# update "fund name" to be "fund_name"
# update "share name" to be "share_name"
new_data_list = []
for data in data_list:
new_data = {}
fund_name = data.get("fund name", "")
if fund_name != "":
new_data["fund_name"] = fund_name
share_name = data.get("share name", "")
if share_name != "":
new_data["share_name"] = share_name
ter = data.get("ter", None)
if ter is not None:
new_data["ter"] = ter
performance_fee = data.get("performance fees", None)
if performance_fee is not None:
new_data["performance_fee"] = performance_fee
for key, value in data.items():
if key not in ["fund name", "share name", "ter", "performance fees"]:
new_data[key] = value
new_data_list.append(new_data)
extract_data_info["data"] = new_data_list
return extract_data_info
def check_fund_name_as_share(self, fund_name: str) -> bool:
"""
Check if the fund name is the same as share name
"""
if len(fund_name) == 0 == 0:
return False
share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
if len(share_name_list) == 0:
return False
max_similarity_name, max_similarity = get_most_similar_name(
text=fund_name,
name_list=share_name_list,
share_name=None,
fund_name=None,
matching_type="share",
process_cache=None)
if max_similarity >= 0.8:
return True
return False
def get_datapoints_by_page_num(self, page_num: int) -> list:
datapoints = []
for datapoint in self.datapoints:
if page_num in self.datapoint_page_info[datapoint]:
datapoints.append(datapoint)
return datapoints
def get_instructions_by_datapoints(
self,
page_text: str,
datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
extract_way: str = "text",
) -> str:
"""
Get instructions to extract data from the page by the datapoints
Below is the instructions sections:
summary: string
reported_name by datapoints: dict
data_business_features: dict
common: list
investment_level by datapoints: dict
data_value_range by datapoints: dict
special_rule by datapoints: dict
special_cases: dict
common: list
title
contents
special_case by datapoints: list
title
contents
output_requirement
common: list
fund_level: list
share_level: dict
fund_name: list
share_name: list
ogc_value: list
ter_value: list
performance_fee_value: list
end
"""
instructions = []
if extract_way == "text":
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
datapoint_name_list = []
for datapoint in datapoints:
datapoint_name = self.datapoint_name_config.get(datapoint, "")
datapoint_name_list.append(datapoint_name)
if extract_way == "text":
summary = self.instructions_config.get("summary", "\n")
elif extract_way == "image":
summary = self.instructions_config.get("summary_image", "\n")
else:
summary = self.instructions_config.get("summary", "\n")
instructions.append(summary.format(", ".join(datapoint_name_list)))
instructions.append("\n")
if extract_way == "image":
image_features = self.instructions_config.get("image_features", [])
instructions.extend(image_features)
instructions.append("\n")
instructions.append("Datapoints Reported name:\n")
instructions.append("Please look for relevant reported names and similar variations in the context.\n")
reported_name_info_in_instructions = self.instructions_config.get("reported_name", {})
for datapoint in datapoints:
reported_name_list = self.datapoint_reported_name_config.get(datapoint, [])
if len(reported_name_list) == 0:
reported_name = reported_name_info_in_instructions.get(datapoint, "")
else:
joined_reported_name = ", ".join(reported_name_list)
datapoint_name = datapoint
if datapoint_name == "performance_fee":
datapoint_name = "performance fees"
else:
datapoint_name = datapoint_name.upper()
reported_name = f"The {datapoint_name} reported name could be:\n{joined_reported_name}"
instructions.append(reported_name)
instructions.append("\n")
instructions.append("\n")
if self.language != "english":
"""
"multilingual_reported_name": {
"describe": "Please be careful to extract relevant data by different reported names from multilingual Context.",
"regular_example_template": "{datapoint} Example {number}:\nLanguage: {language}\n---Context Start-----\n{fund_name}\n{share_name}\n{reported_name}\n{value}\n---Context End-----\nAnswer: {answer}",
"special_example_template_none": "{datapoint} Example {number}:\nLanguage: {language}\nValue is belong to \"-, *, **, N/A, N/A%, N/A %, NONE\", ignore it\n---Context Start-----\n{fund_name}\n{share_name}\n{reported_name} 2)\n-\n---Context End-----\nAnswer: {answer}",
"value_examples": ["1,98", "3.25", "2.16", "1,73", "4,53"]
"fund_example": "Fund 1",
"share_example": "Share 1"
}
"""
multilingual_reported_name_config = self.instructions_config.get("multilingual_reported_name", {})
describe = multilingual_reported_name_config.get("describe", "")
regular_example_template = multilingual_reported_name_config.get("regular_example_template", "")
special_example_template_none = multilingual_reported_name_config.get("special_example_template_none", "")
value_examples = multilingual_reported_name_config.get("value_examples", [])
fund_example = multilingual_reported_name_config.get("fund_example", "")
share_example = multilingual_reported_name_config.get("share_example", "")
instructions.append("Multilingual reported name:\n")
instructions.append(f"{describe}\n")
# set language the first char to be upper
language = self.language[0].upper() + self.language[1:]
for datapoint in datapoints:
mul_reported_name_list = self.non_english_reported_name_config.get(datapoint, [])
# shuffle the reported name list
mul_reported_name_list = list(set(mul_reported_name_list))
if len(mul_reported_name_list) == 0:
continue
datapoint_name = datapoint
if datapoint_name == "performance_fee":
datapoint_name = "performance fees"
else:
datapoint_name = datapoint_name.upper()
example_count = 1
none_value_example_count = 0
for mul_reported_name in mul_reported_name_list:
if datapoint in ["ter", "performance_fee"] and example_count == 3:
break
value = value_examples[example_count % len(value_examples)]
answer = {"fund name": fund_example,
"share name": share_example,
datapoint: float(value.replace(",", "."))}
# transfer answer to string
answer = json.dumps(answer, ensure_ascii=False)
example = regular_example_template.format(
datapoint=datapoint_name,
number=example_count,
language=language,
fund_name=fund_example,
share_name=share_example,
reported_name=mul_reported_name,
value=value,
answer=answer,
)
instructions.append(example)
instructions.append("\n")
instructions.append("\n")
example_count += 1
if len(mul_reported_name.split()) > 1:
if none_value_example_count != 2:
none_value_example = special_example_template_none.format(
datapoint=datapoint_name,
number=example_count,
language=language,
fund_name=fund_example,
share_name=share_example,
reported_name=mul_reported_name,
answer = json.dumps({}, ensure_ascii=False)
)
instructions.append(none_value_example)
instructions.append("\n")
instructions.append("\n")
example_count += 1
none_value_example_count += 1
instructions.append("\n")
instructions.append("Data business features:\n")
data_business_features = self.instructions_config.get(
"data_business_features", {}
)
common = "\n".join(data_business_features.get("common", []))
instructions.append(common)
instructions.append("\n")
instructions.append("Datapoints investment level:\n")
investment_level_info = data_business_features.get("investment_level", {})
for datapoint in datapoints:
investment_level = investment_level_info.get(datapoint, "")
instructions.append(investment_level)
instructions.append("\n")
instructions.append("\n")
instructions.append("Datapoints value range:\n")
data_value_range_info = data_business_features.get("data_value_range", {})
for datapoint in datapoints:
data_value_range = data_value_range_info.get(datapoint, "")
instructions.append(data_value_range)
instructions.append("\n")
instructions.append("\n")
special_rule_info = data_business_features.get("special_rule", {})
with_special_rule_title = False
for datapoint in datapoints:
special_rule_list = special_rule_info.get(datapoint, [])
if len(special_rule_list) > 0:
if not with_special_rule_title:
instructions.append("Special rule:\n")
with_special_rule_title = True
special_rule = "\n".join(special_rule_list)
instructions.append(special_rule)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Special cases:\n")
special_cases = self.instructions_config.get("special_cases", {})
special_cases_common_list = special_cases.get("common", [])
for special_cases_common in special_cases_common_list:
title = special_cases_common.get("title", "")
instructions.append(title)
instructions.append("\n")
contents_list = special_cases_common.get("contents", [])
contents = "\n".join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
for datapoint in datapoints:
special_case_list = special_cases.get(datapoint, [])
for special_case in special_case_list:
title = special_case.get("title", "")
instructions.append(title)
instructions.append("\n")
contents_list = special_case.get("contents", [])
contents = "\n".join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Output requirement:\n")
output_requirement = self.instructions_config.get("output_requirement", {})
output_requirement_common_list = output_requirement.get("common", [])
instructions.append("\n".join(output_requirement_common_list))
instructions.append("\n")
share_datapoint_value_example = {}
share_level_config = output_requirement.get("share_level", {})
example_list = []
dp_reported_name_config = output_requirement.get("dp_reported_name", {})
dp_reported_name = {}
for datapoint in datapoints:
investment_level = self.datapoint_level_config.get(datapoint, "")
if investment_level == "fund_level":
fund_level_example_list = output_requirement.get("fund_level", [])
for example in fund_level_example_list:
try:
sub_example_list = json.loads(example)
except:
sub_example_list = json_repair.loads(example)
example_list.extend(sub_example_list)
elif investment_level == "share_level":
share_datapoint_value_example[datapoint] = share_level_config.get(
f"{datapoint}_value", []
)
dp_reported_name[datapoint] = dp_reported_name_config.get(datapoint, "")
share_datapoint_list = list(share_datapoint_value_example.keys())
instructions.append(f"Example:\n")
if len(share_datapoint_list) > 0:
fund_name_example_list = share_level_config.get("fund_name", [])
share_name_example_list = share_level_config.get("share_name", [])
for index in range(len(fund_name_example_list)):
example_dict = {
"fund name": fund_name_example_list[index],
"share name": share_name_example_list[index],
}
for share_datapoint in share_datapoint_list:
share_datapoint_values = share_datapoint_value_example[
share_datapoint
]
if index < len(share_datapoint_values):
example_dict[share_datapoint] = share_datapoint_values[index]
example_list.append(example_dict)
example_data = {"data": example_list, "dp_reported_name": dp_reported_name}
instructions.append(json.dumps(example_data, ensure_ascii=False, indent=4))
instructions.append("\n")
instructions.append("\n")
end_list = self.instructions_config.get("end", [])
instructions.append("\n".join(end_list))
instructions.append("\n")
if need_exclude and exclude_data is not None and isinstance(exclude_data, list):
instructions.append("Please exclude below data from output:\n")
instructions.append(json.dumps(exclude_data, ensure_ascii=False, indent=4))
instructions.append("\n")
instructions.append("\n")
instructions.append("Answer:\n")
instructions_text = "".join(instructions)
return instructions_text
# def chat_by_split_context(self,
# page_text: str,
# page_datapoints: list,
# need_exclude: bool,
# exclude_data: list) -> list:
# """
# If occur error, split the context to two parts and try to get data from the two parts
# Relevant document: 503194284, page index 147
# """
# try:
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
# split_context = re.split(r"\n", page_text)
# split_context = [text.strip() for text in split_context
# if len(text.strip()) > 0]
# if len(split_context) < 10:
# return {"data": []}
# split_context_len = len(split_context)
# top_10_context = split_context[:10]
# rest_context = split_context[10:]
# header = "\n".join(top_10_context)
# half_len = split_context_len // 2
# # the member of half_len should not start with number
# # reverse iterate the list by half_len
# half_len_list = [i for i in range(half_len)]
# fund_name_line = ""
# half_line = rest_context[half_len].strip()
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# half_line, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity < 0.2:
# # get the fund name line text from the first half
# for index in reversed(half_len_list):
# line_text = rest_context[index].strip()
# if len(line_text) == 0:
# continue
# line_text_split = line_text.split()
# if len(line_text_split) < 3:
# continue
# first_word = line_text_split[0]
# if first_word.lower() == "class":
# continue
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# line_text, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity >= 0.2:
# fund_name_line = line_text
# break
# else:
# fund_name_line = half_line
# half_len += 1
# if fund_name_line == "":
# return {"data": []}
# logger.info(f"Split first part from 0 to {half_len}")
# split_first_part = "\n".join(split_context[:half_len])
# first_part = '\n'.join(split_first_part)
# first_instructions = self.get_instructions_by_datapoints(
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# first_instructions, response_format={"type": "json_object"}
# )
# first_part_data = {"data": []}
# if not with_error:
# try:
# first_part_data = json.loads(response)
# except:
# first_part_data = json_repair.loads(response)
# logger.info(f"Split second part from {half_len} to {split_context_len}")
# split_second_part = "\n".join(split_context[half_len:])
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
# second_instructions = self.get_instructions_by_datapoints(
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# second_instructions, response_format={"type": "json_object"}
# )
# second_part_data = {"data": []}
# if not with_error:
# try:
# second_part_data = json.loads(response)
# except:
# second_part_data = json_repair.loads(response)
# first_part_data_list = first_part_data.get("data", [])
# logger.info(f"First part data count: {len(first_part_data_list)}")
# second_part_data_list = second_part_data.get("data", [])
# logger.info(f"Second part data count: {len(second_part_data_list)}")
# for first_data in first_part_data_list:
# if first_data in second_part_data_list:
# second_part_data_list.remove(first_data)
# else:
# # if the first part data is with same fund name and share name,
# # remove the second part data
# first_data_dp = [key for key in list(first_data.keys())
# if key not in ["fund name", "share name"]]
# # order the data points
# first_data_dp.sort()
# first_fund_name = first_data.get("fund name", "")
# first_share_name = first_data.get("share name", "")
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
# remove_second_list = []
# for second_data in second_part_data_list:
# second_fund_name = second_data.get("fund name", "")
# second_share_name = second_data.get("share name", "")
# if first_fund_name == second_fund_name and \
# first_share_name == second_share_name:
# second_data_dp = [key for key in list(second_data.keys())
# if key not in ["fund name", "share name"]]
# second_data_dp.sort()
# if first_data_dp == second_data_dp:
# remove_second_list.append(second_data)
# for remove_second in remove_second_list:
# if remove_second in second_part_data_list:
# second_part_data_list.remove(remove_second)
# data_list = first_part_data_list + second_part_data_list
# extract_data = {"data": data_list}
# return extract_data
# except Exception as e:
# logger.error(f"Error in split context: {e}")
# return {"data": []}