support split text for this case: outputs over 4K tokens.
This commit is contained in:
parent
0f6dbd27eb
commit
932870f406
|
|
@ -160,6 +160,16 @@ class DataExtraction:
|
||||||
] = next_page_data_list
|
] = next_page_data_list
|
||||||
data_list.append(next_page_extract_data)
|
data_list.append(next_page_extract_data)
|
||||||
handled_page_num_list.append(next_page_num)
|
handled_page_num_list.append(next_page_num)
|
||||||
|
exist_current_page_datapoint = False
|
||||||
|
for next_page_data in next_page_data_list:
|
||||||
|
for page_datapoint in page_datapoints:
|
||||||
|
if page_datapoint in list(next_page_data.keys()):
|
||||||
|
exist_current_page_datapoint = True
|
||||||
|
break
|
||||||
|
if exist_current_page_datapoint:
|
||||||
|
break
|
||||||
|
if not exist_current_page_datapoint:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
count += 1
|
count += 1
|
||||||
|
|
@ -204,12 +214,25 @@ class DataExtraction:
|
||||||
)
|
)
|
||||||
if with_error:
|
if with_error:
|
||||||
logger.error(f"Error in extracting tables from page")
|
logger.error(f"Error in extracting tables from page")
|
||||||
return ""
|
data_dict = {"doc_id": self.doc_id}
|
||||||
|
data_dict["page_index"] = page_num
|
||||||
|
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||||
|
data_dict["page_text"] = page_text
|
||||||
|
data_dict["instructions"] = instructions
|
||||||
|
data_dict["raw_answer"] = response
|
||||||
|
data_dict["extract_data"] = {"data": []}
|
||||||
|
return data_dict
|
||||||
try:
|
try:
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
data = json_repair.loads(response)
|
# if occur error, perhaps the output length is over 4K tokens
|
||||||
|
# split the context to two parts and try to get data from the two parts
|
||||||
|
data = self.chat_by_split_context(
|
||||||
|
page_text, page_datapoints, need_exclude, exclude_data
|
||||||
|
)
|
||||||
|
if len(data.get("data", [])) == 0:
|
||||||
|
data = json_repair.loads(response)
|
||||||
except:
|
except:
|
||||||
data = {"data": []}
|
data = {"data": []}
|
||||||
data = self.validate_data(data)
|
data = self.validate_data(data)
|
||||||
|
|
@ -223,6 +246,80 @@ class DataExtraction:
|
||||||
data_dict["extract_data"] = data
|
data_dict["extract_data"] = data
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
def chat_by_split_context(self,
|
||||||
|
page_text: str,
|
||||||
|
page_datapoints: list,
|
||||||
|
need_exclude: bool,
|
||||||
|
exclude_data: list) -> list:
|
||||||
|
"""
|
||||||
|
If occur error, split the context to two parts and try to get data from the two parts
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||||
|
split_context = re.split(r"\n", page_text)
|
||||||
|
split_context = [text.strip() for text in split_context
|
||||||
|
if len(text.strip()) > 0]
|
||||||
|
split_context_len = len(split_context)
|
||||||
|
top_10_context = split_context[:10]
|
||||||
|
rest_context = split_context[10:]
|
||||||
|
header = "\n".join(top_10_context)
|
||||||
|
half_len = split_context_len // 2
|
||||||
|
# the member of half_len should not start with number
|
||||||
|
# reverse iterate the list by half_len
|
||||||
|
half_len_list = [i for i in range(half_len)]
|
||||||
|
for index in reversed(half_len_list):
|
||||||
|
first_letter = rest_context[index].strip()[0]
|
||||||
|
if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]:
|
||||||
|
half_len = index
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Split first part from 0 to {half_len}")
|
||||||
|
split_first_part = "\n".join(split_context[:half_len])
|
||||||
|
first_part = '\n'.join(split_first_part)
|
||||||
|
first_instructions = self.get_instructions_by_datapoints(
|
||||||
|
first_part, page_datapoints, need_exclude, exclude_data
|
||||||
|
)
|
||||||
|
response, with_error = chat(
|
||||||
|
first_instructions, response_format={"type": "json_object"}
|
||||||
|
)
|
||||||
|
first_part_data = {"data": []}
|
||||||
|
if not with_error:
|
||||||
|
try:
|
||||||
|
first_part_data = json.loads(response)
|
||||||
|
except:
|
||||||
|
first_part_data = json_repair.loads(response)
|
||||||
|
|
||||||
|
logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||||
|
split_second_part = "\n".join(split_context[half_len:])
|
||||||
|
second_part = header + '\n' + split_second_part
|
||||||
|
second_instructions = self.get_instructions_by_datapoints(
|
||||||
|
second_part, page_datapoints, need_exclude, exclude_data
|
||||||
|
)
|
||||||
|
response, with_error = chat(
|
||||||
|
second_instructions, response_format={"type": "json_object"}
|
||||||
|
)
|
||||||
|
second_part_data = {"data": []}
|
||||||
|
if not with_error:
|
||||||
|
try:
|
||||||
|
second_part_data = json.loads(response)
|
||||||
|
except:
|
||||||
|
second_part_data = json_repair.loads(response)
|
||||||
|
|
||||||
|
first_part_data_list = first_part_data.get("data", [])
|
||||||
|
logger.info(f"First part data count: {len(first_part_data_list)}")
|
||||||
|
second_part_data_list = second_part_data.get("data", [])
|
||||||
|
logger.info(f"Second part data count: {len(second_part_data_list)}")
|
||||||
|
for first_data in first_part_data_list:
|
||||||
|
if first_data in second_part_data_list:
|
||||||
|
second_part_data_list.remove(first_data)
|
||||||
|
|
||||||
|
data_list = first_part_data_list + second_part_data_list
|
||||||
|
extract_data = {"data": data_list}
|
||||||
|
return extract_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in split context: {e}")
|
||||||
|
return {"data": []}
|
||||||
|
|
||||||
def validate_data(self, extract_data_info: dict) -> dict:
|
def validate_data(self, extract_data_info: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
Validate data by the rules
|
Validate data by the rules
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,7 @@
|
||||||
"common": [
|
"common": [
|
||||||
"If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.",
|
"If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.",
|
||||||
"If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".",
|
"If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".",
|
||||||
|
"If find fund name, and exist sub fund name, please output fund name + sub fund name, e.g. fund name is \"Black Rock European\", sub fund name is \"Growth\", the output fund name should be: \"Black Rock European Growth\".",
|
||||||
"Only output the data point which with relevant value.",
|
"Only output the data point which with relevant value.",
|
||||||
"Don't ignore the data point which with negative value, e.g. -0.12, -1.13",
|
"Don't ignore the data point which with negative value, e.g. -0.12, -1.13",
|
||||||
"Don't ignore the data point which with explicit zero value, e.g. 0, 0.00",
|
"Don't ignore the data point which with explicit zero value, e.g. 0, 0.00",
|
||||||
|
|
@ -110,7 +111,7 @@
|
||||||
"The output should be JSON format, the format is like below example(s):"
|
"The output should be JSON format, the format is like below example(s):"
|
||||||
],
|
],
|
||||||
"fund_level": [
|
"fund_level": [
|
||||||
"[{\"fund name\": \"fund 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
|
"[{\"fund name\": \"fund 1 - sub fund name 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2 - sub fund name 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]"
|
||||||
],
|
],
|
||||||
"share_level": {
|
"share_level": {
|
||||||
"fund_name": [
|
"fund_name": [
|
||||||
|
|
|
||||||
2
main.py
2
main.py
|
|
@ -505,7 +505,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# doc_id = "476492237"
|
# doc_id = "476492237"
|
||||||
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
|
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
|
||||||
special_doc_id_list = ["491593469"]
|
special_doc_id_list = ["503194284"]
|
||||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||||
re_run_mapping_data = True
|
re_run_mapping_data = True
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ def chat(
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
error = ""
|
error = ""
|
||||||
max_tokens = 4000
|
max_tokens = 4096
|
||||||
request_timeout = 120
|
request_timeout = 120
|
||||||
while count < 8:
|
while count < 8:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue