optimize data extraction algorithm: if can't find cost numeric value from PDF page text, then extract data by Vision ChatGPT

This commit is contained in:
Blade He 2024-10-15 15:57:54 -05:00
parent 8b651f374c
commit f166e73362
3 changed files with 177 additions and 141 deletions

View File

@ -136,7 +136,7 @@ class DataExtraction:
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
extract_data = self.extract_data_by_page_text(
extract_data = self.extract_data_by_page(
page_num,
page_text,
page_datapoints,
@ -179,7 +179,7 @@ class DataExtraction:
next_page_text = self.page_text_dict.get(next_page_num, "")
target_text = current_text + next_page_text
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page_text(
next_page_extract_data = self.extract_data_by_page(
next_page_num,
target_text,
next_datapoints,
@ -314,6 +314,25 @@ class DataExtraction:
with pd.ExcelWriter(excel_data_file) as writer:
data_df.to_excel(writer, sheet_name="extract_data", index=False)
def extract_data_by_page(
self,
page_num: int,
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,) -> dict:
# If can't find numberic value, e.g. 1.25 or 3,88
# apply Vision ChatGPT to extract data
numeric_regex = r"\d+(\.|\,)\d+"
if not re.search(numeric_regex, page_text):
logger.info(f"Can't find numberic value in page {page_num}, apply Vision ChatGPT to extract data")
return self.extract_data_by_page_image(
page_num, page_datapoints, need_exclude, exclude_data)
else:
return self.extract_data_by_page_text(
page_num, page_text, page_datapoints, need_exclude, exclude_data
)
def extract_data_by_page_text(
self,
page_num: int,
@ -328,7 +347,11 @@ class DataExtraction:
"""
logger.info(f"Extracting data from page {page_num}")
instructions = self.get_instructions_by_datapoints(
page_text, page_datapoints, need_exclude, exclude_data
page_text,
page_datapoints,
need_exclude,
exclude_data,
extract_way="text"
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}
@ -342,6 +365,7 @@ class DataExtraction:
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "text"
return data_dict
try:
data = json.loads(response)
@ -367,12 +391,15 @@ class DataExtraction:
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "text"
return data_dict
def extract_data_by_page_image(
self,
page_num: int,
page_datapoints: list
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
) -> dict:
"""
keys are
@ -381,7 +408,11 @@ class DataExtraction:
logger.info(f"Extracting data from page {page_num}")
image_base64 = self.get_pdf_image_base64(page_num)
instructions = self.get_instructions_by_datapoints(
"", page_datapoints, need_exclude=False, exclude_data=None
"",
page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
extract_way="image"
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}, image_base64=image_base64
@ -391,9 +422,11 @@ class DataExtraction:
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "image"
return data_dict
try:
data = json.loads(response)
@ -407,138 +440,13 @@ class DataExtraction:
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "image"
return data_dict
def chat_by_split_context(self,
page_text: str,
page_datapoints: list,
need_exclude: bool,
exclude_data: list) -> list:
"""
If occur error, split the context to two parts and try to get data from the two parts
Relevant document: 503194284, page index 147
"""
try:
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
split_context = re.split(r"\n", page_text)
split_context = [text.strip() for text in split_context
if len(text.strip()) > 0]
if len(split_context) < 10:
return {"data": []}
split_context_len = len(split_context)
top_10_context = split_context[:10]
rest_context = split_context[10:]
header = "\n".join(top_10_context)
half_len = split_context_len // 2
# the member of half_len should not start with number
# reverse iterate the list by half_len
half_len_list = [i for i in range(half_len)]
fund_name_line = ""
half_line = rest_context[half_len].strip()
max_similarity_fund_name, max_similarity = get_most_similar_name(
half_line, self.provider_fund_name_list, matching_type="fund"
)
if max_similarity < 0.2:
# get the fund name line text from the first half
for index in reversed(half_len_list):
line_text = rest_context[index].strip()
if len(line_text) == 0:
continue
line_text_split = line_text.split()
if len(line_text_split) < 3:
continue
first_word = line_text_split[0]
if first_word.lower() == "class":
continue
max_similarity_fund_name, max_similarity = get_most_similar_name(
line_text, self.provider_fund_name_list, matching_type="fund"
)
if max_similarity >= 0.2:
fund_name_line = line_text
break
else:
fund_name_line = half_line
half_len += 1
if fund_name_line == "":
return {"data": []}
logger.info(f"Split first part from 0 to {half_len}")
split_first_part = "\n".join(split_context[:half_len])
first_part = '\n'.join(split_first_part)
first_instructions = self.get_instructions_by_datapoints(
first_part, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
first_instructions, response_format={"type": "json_object"}
)
first_part_data = {"data": []}
if not with_error:
try:
first_part_data = json.loads(response)
except:
first_part_data = json_repair.loads(response)
logger.info(f"Split second part from {half_len} to {split_context_len}")
split_second_part = "\n".join(split_context[half_len:])
second_part = header + "\n" + fund_name_line + "\n" + split_second_part
second_instructions = self.get_instructions_by_datapoints(
second_part, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
second_instructions, response_format={"type": "json_object"}
)
second_part_data = {"data": []}
if not with_error:
try:
second_part_data = json.loads(response)
except:
second_part_data = json_repair.loads(response)
first_part_data_list = first_part_data.get("data", [])
logger.info(f"First part data count: {len(first_part_data_list)}")
second_part_data_list = second_part_data.get("data", [])
logger.info(f"Second part data count: {len(second_part_data_list)}")
for first_data in first_part_data_list:
if first_data in second_part_data_list:
second_part_data_list.remove(first_data)
else:
# if the first part data is with same fund name and share name,
# remove the second part data
first_data_dp = [key for key in list(first_data.keys())
if key not in ["fund name", "share name"]]
# order the data points
first_data_dp.sort()
first_fund_name = first_data.get("fund name", "")
first_share_name = first_data.get("share name", "")
if len(first_fund_name) > 0 and len(first_share_name) > 0:
remove_second_list = []
for second_data in second_part_data_list:
second_fund_name = second_data.get("fund name", "")
second_share_name = second_data.get("share name", "")
if first_fund_name == second_fund_name and \
first_share_name == second_share_name:
second_data_dp = [key for key in list(second_data.keys())
if key not in ["fund name", "share name"]]
second_data_dp.sort()
if first_data_dp == second_data_dp:
remove_second_list.append(second_data)
for remove_second in remove_second_list:
if remove_second in second_part_data_list:
second_part_data_list.remove(remove_second)
data_list = first_part_data_list + second_part_data_list
extract_data = {"data": data_list}
return extract_data
except Exception as e:
logger.error(f"Error in split context: {e}")
return {"data": []}
def validate_data(self, extract_data_info: dict) -> dict:
"""
Validate data by the rules
@ -634,7 +542,6 @@ class DataExtraction:
return True
return False
def get_datapoints_by_page_num(self, page_num: int) -> list:
datapoints = []
for datapoint in self.datapoints:
@ -648,6 +555,7 @@ class DataExtraction:
datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
extract_way: str = "text",
) -> str:
"""
Get instructions to extract data from the page by the datapoints
@ -678,7 +586,7 @@ class DataExtraction:
end
"""
instructions = []
if self.extract_way == "text":
if extract_way == "text":
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
datapoint_name_list = []
@ -686,9 +594,9 @@ class DataExtraction:
datapoint_name = self.datapoint_name_config.get(datapoint, "")
datapoint_name_list.append(datapoint_name)
if self.extract_way == "text":
if extract_way == "text":
summary = self.instructions_config.get("summary", "\n")
elif self.extract_way == "image":
elif extract_way == "image":
summary = self.instructions_config.get("summary_image", "\n")
else:
summary = self.instructions_config.get("summary", "\n")
@ -696,7 +604,7 @@ class DataExtraction:
instructions.append(summary.format(", ".join(datapoint_name_list)))
instructions.append("\n")
if self.extract_way == "image":
if extract_way == "image":
image_features = self.instructions_config.get("image_features", [])
instructions.extend(image_features)
instructions.append("\n")
@ -831,3 +739,130 @@ class DataExtraction:
instructions_text = "".join(instructions)
return instructions_text
# def chat_by_split_context(self,
# page_text: str,
# page_datapoints: list,
# need_exclude: bool,
# exclude_data: list) -> list:
# """
# If occur error, split the context to two parts and try to get data from the two parts
# Relevant document: 503194284, page index 147
# """
# try:
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
# split_context = re.split(r"\n", page_text)
# split_context = [text.strip() for text in split_context
# if len(text.strip()) > 0]
# if len(split_context) < 10:
# return {"data": []}
# split_context_len = len(split_context)
# top_10_context = split_context[:10]
# rest_context = split_context[10:]
# header = "\n".join(top_10_context)
# half_len = split_context_len // 2
# # the member of half_len should not start with number
# # reverse iterate the list by half_len
# half_len_list = [i for i in range(half_len)]
# fund_name_line = ""
# half_line = rest_context[half_len].strip()
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# half_line, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity < 0.2:
# # get the fund name line text from the first half
# for index in reversed(half_len_list):
# line_text = rest_context[index].strip()
# if len(line_text) == 0:
# continue
# line_text_split = line_text.split()
# if len(line_text_split) < 3:
# continue
# first_word = line_text_split[0]
# if first_word.lower() == "class":
# continue
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# line_text, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity >= 0.2:
# fund_name_line = line_text
# break
# else:
# fund_name_line = half_line
# half_len += 1
# if fund_name_line == "":
# return {"data": []}
# logger.info(f"Split first part from 0 to {half_len}")
# split_first_part = "\n".join(split_context[:half_len])
# first_part = '\n'.join(split_first_part)
# first_instructions = self.get_instructions_by_datapoints(
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# first_instructions, response_format={"type": "json_object"}
# )
# first_part_data = {"data": []}
# if not with_error:
# try:
# first_part_data = json.loads(response)
# except:
# first_part_data = json_repair.loads(response)
# logger.info(f"Split second part from {half_len} to {split_context_len}")
# split_second_part = "\n".join(split_context[half_len:])
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
# second_instructions = self.get_instructions_by_datapoints(
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# second_instructions, response_format={"type": "json_object"}
# )
# second_part_data = {"data": []}
# if not with_error:
# try:
# second_part_data = json.loads(response)
# except:
# second_part_data = json_repair.loads(response)
# first_part_data_list = first_part_data.get("data", [])
# logger.info(f"First part data count: {len(first_part_data_list)}")
# second_part_data_list = second_part_data.get("data", [])
# logger.info(f"Second part data count: {len(second_part_data_list)}")
# for first_data in first_part_data_list:
# if first_data in second_part_data_list:
# second_part_data_list.remove(first_data)
# else:
# # if the first part data is with same fund name and share name,
# # remove the second part data
# first_data_dp = [key for key in list(first_data.keys())
# if key not in ["fund name", "share name"]]
# # order the data points
# first_data_dp.sort()
# first_fund_name = first_data.get("fund name", "")
# first_share_name = first_data.get("share name", "")
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
# remove_second_list = []
# for second_data in second_part_data_list:
# second_fund_name = second_data.get("fund name", "")
# second_share_name = second_data.get("share name", "")
# if first_fund_name == second_fund_name and \
# first_share_name == second_share_name:
# second_data_dp = [key for key in list(second_data.keys())
# if key not in ["fund name", "share name"]]
# second_data_dp.sort()
# if first_data_dp == second_data_dp:
# remove_second_list.append(second_data)
# for remove_second in remove_second_list:
# if remove_second in second_part_data_list:
# second_part_data_list.remove(remove_second)
# data_list = first_part_data_list + second_part_data_list
# extract_data = {"data": data_list}
# return extract_data
# except Exception as e:
# logger.error(f"Error in split context: {e}")
# return {"data": []}

View File

@ -809,10 +809,10 @@ if __name__ == "__main__":
]
# special_doc_id_list = check_mapping_doc_id_list
special_doc_id_list = check_db_mapping_doc_id_list
special_doc_id_list = ["423395975"]
special_doc_id_list = ["514213638"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_extract_data = True
re_run_extract_data = False
re_run_mapping_data = True
force_save_total_data = False
calculate_metrics = False

View File

@ -74,7 +74,7 @@ def chat(
image_file: str = None,
image_base64: str = None,
):
if engine != "gpt-4o-2024-08-06-research":
if not engine.startswith("gpt-4o"):
max_tokens = 4096
client = AzureOpenAI(
@ -138,6 +138,7 @@ def chat(
messages=messages,
response_format=response_format,
)
sleep(1)
return response.choices[0].message.content, False
except Exception as e:
error = str(e)
@ -145,7 +146,7 @@ def chat(
if "maximum context length" in error:
return error, True
count += 1
sleep(3)
sleep(2)
return error, True