support split text for this case: outputs over 4K tokens.
This commit is contained in:
parent
0f6dbd27eb
commit
932870f406
|
|
@ -160,6 +160,16 @@ class DataExtraction:
|
|||
] = next_page_data_list
|
||||
data_list.append(next_page_extract_data)
|
||||
handled_page_num_list.append(next_page_num)
|
||||
exist_current_page_datapoint = False
|
||||
for next_page_data in next_page_data_list:
|
||||
for page_datapoint in page_datapoints:
|
||||
if page_datapoint in list(next_page_data.keys()):
|
||||
exist_current_page_datapoint = True
|
||||
break
|
||||
if exist_current_page_datapoint:
|
||||
break
|
||||
if not exist_current_page_datapoint:
|
||||
break
|
||||
else:
|
||||
break
|
||||
count += 1
|
||||
|
|
@ -204,12 +214,25 @@ class DataExtraction:
|
|||
)
|
||||
if with_error:
|
||||
logger.error(f"Error in extracting tables from page")
|
||||
return ""
|
||||
data_dict = {"doc_id": self.doc_id}
|
||||
data_dict["page_index"] = page_num
|
||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||
data_dict["page_text"] = page_text
|
||||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["extract_data"] = {"data": []}
|
||||
return data_dict
|
||||
try:
|
||||
data = json.loads(response)
|
||||
except:
|
||||
try:
|
||||
data = json_repair.loads(response)
|
||||
# if occur error, perhaps the output length is over 4K tokens
|
||||
# split the context to two parts and try to get data from the two parts
|
||||
data = self.chat_by_split_context(
|
||||
page_text, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
if len(data.get("data", [])) == 0:
|
||||
data = json_repair.loads(response)
|
||||
except:
|
||||
data = {"data": []}
|
||||
data = self.validate_data(data)
|
||||
|
|
@ -223,6 +246,80 @@ class DataExtraction:
|
|||
data_dict["extract_data"] = data
|
||||
return data_dict
|
||||
|
||||
def chat_by_split_context(self,
|
||||
page_text: str,
|
||||
page_datapoints: list,
|
||||
need_exclude: bool,
|
||||
exclude_data: list) -> list:
|
||||
"""
|
||||
If occur error, split the context to two parts and try to get data from the two parts
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||
split_context = re.split(r"\n", page_text)
|
||||
split_context = [text.strip() for text in split_context
|
||||
if len(text.strip()) > 0]
|
||||
split_context_len = len(split_context)
|
||||
top_10_context = split_context[:10]
|
||||
rest_context = split_context[10:]
|
||||
header = "\n".join(top_10_context)
|
||||
half_len = split_context_len // 2
|
||||
# the member of half_len should not start with number
|
||||
# reverse iterate the list by half_len
|
||||
half_len_list = [i for i in range(half_len)]
|
||||
for index in reversed(half_len_list):
|
||||
first_letter = rest_context[index].strip()[0]
|
||||
if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]:
|
||||
half_len = index
|
||||
break
|
||||
|
||||
logger.info(f"Split first part from 0 to {half_len}")
|
||||
split_first_part = "\n".join(split_context[:half_len])
|
||||
first_part = '\n'.join(split_first_part)
|
||||
first_instructions = self.get_instructions_by_datapoints(
|
||||
first_part, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
response, with_error = chat(
|
||||
first_instructions, response_format={"type": "json_object"}
|
||||
)
|
||||
first_part_data = {"data": []}
|
||||
if not with_error:
|
||||
try:
|
||||
first_part_data = json.loads(response)
|
||||
except:
|
||||
first_part_data = json_repair.loads(response)
|
||||
|
||||
logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||
split_second_part = "\n".join(split_context[half_len:])
|
||||
second_part = header + '\n' + split_second_part
|
||||
second_instructions = self.get_instructions_by_datapoints(
|
||||
second_part, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
response, with_error = chat(
|
||||
second_instructions, response_format={"type": "json_object"}
|
||||
)
|
||||
second_part_data = {"data": []}
|
||||
if not with_error:
|
||||
try:
|
||||
second_part_data = json.loads(response)
|
||||
except:
|
||||
second_part_data = json_repair.loads(response)
|
||||
|
||||
first_part_data_list = first_part_data.get("data", [])
|
||||
logger.info(f"First part data count: {len(first_part_data_list)}")
|
||||
second_part_data_list = second_part_data.get("data", [])
|
||||
logger.info(f"Second part data count: {len(second_part_data_list)}")
|
||||
for first_data in first_part_data_list:
|
||||
if first_data in second_part_data_list:
|
||||
second_part_data_list.remove(first_data)
|
||||
|
||||
data_list = first_part_data_list + second_part_data_list
|
||||
extract_data = {"data": data_list}
|
||||
return extract_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in split context: {e}")
|
||||
return {"data": []}
|
||||
|
||||
def validate_data(self, extract_data_info: dict) -> dict:
|
||||
"""
|
||||
Validate data by the rules
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@
|
|||
"common": [
|
||||
"If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.",
|
||||
"If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".",
|
||||
"If find fund name, and exist sub fund name, please output fund name + sub fund name, e.g. fund name is \"Black Rock European\", sub fund name is \"Growth\", the output fund name should be: \"Black Rock European Growth\".",
|
||||
"Only output the data point which with relevant value.",
|
||||
"Don't ignore the data point which with negative value, e.g. -0.12, -1.13",
|
||||
"Don't ignore the data point which with explicit zero value, e.g. 0, 0.00",
|
||||
|
|
@ -110,7 +111,7 @@
|
|||
"The output should be JSON format, the format is like below example(s):"
|
||||
],
|
||||
"fund_level": [
|
||||
"[{\"fund name\": \"fund 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
|
||||
"[{\"fund name\": \"fund 1 - sub fund name 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2 - sub fund name 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
|
||||
],
|
||||
"share_level": {
|
||||
"fund_name": [
|
||||
|
|
|
|||
2
main.py
2
main.py
|
|
@ -505,7 +505,7 @@ if __name__ == "__main__":
|
|||
|
||||
# doc_id = "476492237"
|
||||
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
|
||||
special_doc_id_list = ["491593469"]
|
||||
special_doc_id_list = ["503194284"]
|
||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||
re_run_mapping_data = True
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ def chat(
|
|||
|
||||
count = 0
|
||||
error = ""
|
||||
max_tokens = 4000
|
||||
max_tokens = 4096
|
||||
request_timeout = 120
|
||||
while count < 8:
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue