support split text for this case: outputs over 4K tokens.

This commit is contained in:
Blade He 2024-09-16 12:03:13 -05:00
parent 0f6dbd27eb
commit 932870f406
4 changed files with 103 additions and 5 deletions

View File

@ -160,6 +160,16 @@ class DataExtraction:
] = next_page_data_list ] = next_page_data_list
data_list.append(next_page_extract_data) data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num) handled_page_num_list.append(next_page_num)
exist_current_page_datapoint = False
for next_page_data in next_page_data_list:
for page_datapoint in page_datapoints:
if page_datapoint in list(next_page_data.keys()):
exist_current_page_datapoint = True
break
if exist_current_page_datapoint:
break
if not exist_current_page_datapoint:
break
else: else:
break break
count += 1 count += 1
@ -204,12 +214,25 @@ class DataExtraction:
) )
if with_error: if with_error:
logger.error(f"Error in extracting tables from page") logger.error(f"Error in extracting tables from page")
return "" data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
return data_dict
try: try:
data = json.loads(response) data = json.loads(response)
except: except:
try: try:
data = json_repair.loads(response) # if occur error, perhaps the output length is over 4K tokens
# split the context to two parts and try to get data from the two parts
data = self.chat_by_split_context(
page_text, page_datapoints, need_exclude, exclude_data
)
if len(data.get("data", [])) == 0:
data = json_repair.loads(response)
except: except:
data = {"data": []} data = {"data": []}
data = self.validate_data(data) data = self.validate_data(data)
@ -223,6 +246,80 @@ class DataExtraction:
data_dict["extract_data"] = data data_dict["extract_data"] = data
return data_dict return data_dict
def chat_by_split_context(self,
page_text: str,
page_datapoints: list,
need_exclude: bool,
exclude_data: list) -> list:
"""
If occur error, split the context to two parts and try to get data from the two parts
"""
try:
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
split_context = re.split(r"\n", page_text)
split_context = [text.strip() for text in split_context
if len(text.strip()) > 0]
split_context_len = len(split_context)
top_10_context = split_context[:10]
rest_context = split_context[10:]
header = "\n".join(top_10_context)
half_len = split_context_len // 2
# the member of half_len should not start with number
# reverse iterate the list by half_len
half_len_list = [i for i in range(half_len)]
for index in reversed(half_len_list):
first_letter = rest_context[index].strip()[0]
if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]:
half_len = index
break
logger.info(f"Split first part from 0 to {half_len}")
split_first_part = "\n".join(split_context[:half_len])
first_part = '\n'.join(split_first_part)
first_instructions = self.get_instructions_by_datapoints(
first_part, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
first_instructions, response_format={"type": "json_object"}
)
first_part_data = {"data": []}
if not with_error:
try:
first_part_data = json.loads(response)
except:
first_part_data = json_repair.loads(response)
logger.info(f"Split second part from {half_len} to {split_context_len}")
split_second_part = "\n".join(split_context[half_len:])
second_part = header + '\n' + split_second_part
second_instructions = self.get_instructions_by_datapoints(
second_part, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
second_instructions, response_format={"type": "json_object"}
)
second_part_data = {"data": []}
if not with_error:
try:
second_part_data = json.loads(response)
except:
second_part_data = json_repair.loads(response)
first_part_data_list = first_part_data.get("data", [])
logger.info(f"First part data count: {len(first_part_data_list)}")
second_part_data_list = second_part_data.get("data", [])
logger.info(f"Second part data count: {len(second_part_data_list)}")
for first_data in first_part_data_list:
if first_data in second_part_data_list:
second_part_data_list.remove(first_data)
data_list = first_part_data_list + second_part_data_list
extract_data = {"data": data_list}
return extract_data
except Exception as e:
logger.error(f"Error in split context: {e}")
return {"data": []}
def validate_data(self, extract_data_info: dict) -> dict: def validate_data(self, extract_data_info: dict) -> dict:
""" """
Validate data by the rules Validate data by the rules

View File

@ -102,6 +102,7 @@
"common": [ "common": [
"If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.", "If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.",
"If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".", "If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".",
"If find fund name, and exist sub fund name, please output fund name + sub fund name, e.g. fund name is \"Black Rock European\", sub fund name is \"Growth\", the output fund name should be: \"Black Rock European Growth\".",
"Only output the data point which with relevant value.", "Only output the data point which with relevant value.",
"Don't ignore the data point which with negative value, e.g. -0.12, -1.13", "Don't ignore the data point which with negative value, e.g. -0.12, -1.13",
"Don't ignore the data point which with explicit zero value, e.g. 0, 0.00", "Don't ignore the data point which with explicit zero value, e.g. 0, 0.00",
@ -110,7 +111,7 @@
"The output should be JSON format, the format is like below example(s):" "The output should be JSON format, the format is like below example(s):"
], ],
"fund_level": [ "fund_level": [
"[{\"fund name\": \"fund 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]" "[{\"fund name\": \"fund 1 - sub fund name 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2 - sub fund name 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
], ],
"share_level": { "share_level": {
"fund_name": [ "fund_name": [

View File

@ -505,7 +505,7 @@ if __name__ == "__main__":
# doc_id = "476492237" # doc_id = "476492237"
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run) # extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
special_doc_id_list = ["491593469"] special_doc_id_list = ["503194284"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_mapping_data = True re_run_mapping_data = True

View File

@ -103,7 +103,7 @@ def chat(
count = 0 count = 0
error = "" error = ""
max_tokens = 4000 max_tokens = 4096
request_timeout = 120 request_timeout = 120
while count < 8: while count < 8:
try: try: