support split text for this case: outputs over 4K tokens.

This commit is contained in:
Blade He 2024-09-16 12:03:13 -05:00
parent 0f6dbd27eb
commit 932870f406
4 changed files with 103 additions and 5 deletions

View File

@ -160,6 +160,16 @@ class DataExtraction:
] = next_page_data_list
data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num)
exist_current_page_datapoint = False
for next_page_data in next_page_data_list:
for page_datapoint in page_datapoints:
if page_datapoint in list(next_page_data.keys()):
exist_current_page_datapoint = True
break
if exist_current_page_datapoint:
break
if not exist_current_page_datapoint:
break
else:
break
count += 1
@ -204,12 +214,25 @@ class DataExtraction:
)
if with_error:
logger.error(f"Error in extracting tables from page")
return ""
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
return data_dict
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
# if occur error, perhaps the output length is over 4K tokens
# split the context to two parts and try to get data from the two parts
data = self.chat_by_split_context(
page_text, page_datapoints, need_exclude, exclude_data
)
if len(data.get("data", [])) == 0:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data)
@ -223,6 +246,80 @@ class DataExtraction:
data_dict["extract_data"] = data
return data_dict
def chat_by_split_context(self,
page_text: str,
page_datapoints: list,
need_exclude: bool,
exclude_data: list) -> list:
"""
If occur error, split the context to two parts and try to get data from the two parts
"""
try:
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
split_context = re.split(r"\n", page_text)
split_context = [text.strip() for text in split_context
if len(text.strip()) > 0]
split_context_len = len(split_context)
top_10_context = split_context[:10]
rest_context = split_context[10:]
header = "\n".join(top_10_context)
half_len = split_context_len // 2
# the member of half_len should not start with number
# reverse iterate the list by half_len
half_len_list = [i for i in range(half_len)]
for index in reversed(half_len_list):
first_letter = rest_context[index].strip()[0]
if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]:
half_len = index
break
logger.info(f"Split first part from 0 to {half_len}")
split_first_part = "\n".join(split_context[:half_len])
first_part = '\n'.join(split_first_part)
first_instructions = self.get_instructions_by_datapoints(
first_part, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
first_instructions, response_format={"type": "json_object"}
)
first_part_data = {"data": []}
if not with_error:
try:
first_part_data = json.loads(response)
except:
first_part_data = json_repair.loads(response)
logger.info(f"Split second part from {half_len} to {split_context_len}")
split_second_part = "\n".join(split_context[half_len:])
second_part = header + '\n' + split_second_part
second_instructions = self.get_instructions_by_datapoints(
second_part, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
second_instructions, response_format={"type": "json_object"}
)
second_part_data = {"data": []}
if not with_error:
try:
second_part_data = json.loads(response)
except:
second_part_data = json_repair.loads(response)
first_part_data_list = first_part_data.get("data", [])
logger.info(f"First part data count: {len(first_part_data_list)}")
second_part_data_list = second_part_data.get("data", [])
logger.info(f"Second part data count: {len(second_part_data_list)}")
for first_data in first_part_data_list:
if first_data in second_part_data_list:
second_part_data_list.remove(first_data)
data_list = first_part_data_list + second_part_data_list
extract_data = {"data": data_list}
return extract_data
except Exception as e:
logger.error(f"Error in split context: {e}")
return {"data": []}
def validate_data(self, extract_data_info: dict) -> dict:
"""
Validate data by the rules

View File

@ -102,6 +102,7 @@
"common": [
"If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.",
"If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".",
"If find fund name, and exist sub fund name, please output fund name + sub fund name, e.g. fund name is \"Black Rock European\", sub fund name is \"Growth\", the output fund name should be: \"Black Rock European Growth\".",
"Only output the data point which with relevant value.",
"Don't ignore the data point which with negative value, e.g. -0.12, -1.13",
"Don't ignore the data point which with explicit zero value, e.g. 0, 0.00",
@ -110,7 +111,7 @@
"The output should be JSON format, the format is like below example(s):"
],
"fund_level": [
"[{\"fund name\": \"fund 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
"[{\"fund name\": \"fund 1 - sub fund name 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2 - sub fund name 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
],
"share_level": {
"fund_name": [

View File

@ -505,7 +505,7 @@ if __name__ == "__main__":
# doc_id = "476492237"
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
special_doc_id_list = ["491593469"]
special_doc_id_list = ["503194284"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_mapping_data = True

View File

@ -103,7 +103,7 @@ def chat(
count = 0
error = ""
max_tokens = 4000
max_tokens = 4096
request_timeout = 120
while count < 8:
try: