optimize data extraction algorithm: if can't find cost numeric value from PDF page text, then extract data by Vision ChatGPT
This commit is contained in:
parent
8b651f374c
commit
f166e73362
|
|
@ -136,7 +136,7 @@ class DataExtraction:
|
|||
page_datapoints = self.get_datapoints_by_page_num(page_num)
|
||||
if len(page_datapoints) == 0:
|
||||
continue
|
||||
extract_data = self.extract_data_by_page_text(
|
||||
extract_data = self.extract_data_by_page(
|
||||
page_num,
|
||||
page_text,
|
||||
page_datapoints,
|
||||
|
|
@ -179,7 +179,7 @@ class DataExtraction:
|
|||
next_page_text = self.page_text_dict.get(next_page_num, "")
|
||||
target_text = current_text + next_page_text
|
||||
# try to get data by current page_datapoints
|
||||
next_page_extract_data = self.extract_data_by_page_text(
|
||||
next_page_extract_data = self.extract_data_by_page(
|
||||
next_page_num,
|
||||
target_text,
|
||||
next_datapoints,
|
||||
|
|
@ -313,6 +313,25 @@ class DataExtraction:
|
|||
)
|
||||
with pd.ExcelWriter(excel_data_file) as writer:
|
||||
data_df.to_excel(writer, sheet_name="extract_data", index=False)
|
||||
|
||||
def extract_data_by_page(
|
||||
self,
|
||||
page_num: int,
|
||||
page_text: str,
|
||||
page_datapoints: list,
|
||||
need_exclude: bool = False,
|
||||
exclude_data: list = None,) -> dict:
|
||||
# If can't find numberic value, e.g. 1.25 or 3,88
|
||||
# apply Vision ChatGPT to extract data
|
||||
numeric_regex = r"\d+(\.|\,)\d+"
|
||||
if not re.search(numeric_regex, page_text):
|
||||
logger.info(f"Can't find numberic value in page {page_num}, apply Vision ChatGPT to extract data")
|
||||
return self.extract_data_by_page_image(
|
||||
page_num, page_datapoints, need_exclude, exclude_data)
|
||||
else:
|
||||
return self.extract_data_by_page_text(
|
||||
page_num, page_text, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
|
||||
def extract_data_by_page_text(
|
||||
self,
|
||||
|
|
@ -328,7 +347,11 @@ class DataExtraction:
|
|||
"""
|
||||
logger.info(f"Extracting data from page {page_num}")
|
||||
instructions = self.get_instructions_by_datapoints(
|
||||
page_text, page_datapoints, need_exclude, exclude_data
|
||||
page_text,
|
||||
page_datapoints,
|
||||
need_exclude,
|
||||
exclude_data,
|
||||
extract_way="text"
|
||||
)
|
||||
response, with_error = chat(
|
||||
instructions, response_format={"type": "json_object"}
|
||||
|
|
@ -342,6 +365,7 @@ class DataExtraction:
|
|||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["extract_data"] = {"data": []}
|
||||
data_dict["extract_way"] = "text"
|
||||
return data_dict
|
||||
try:
|
||||
data = json.loads(response)
|
||||
|
|
@ -367,12 +391,15 @@ class DataExtraction:
|
|||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["extract_data"] = data
|
||||
data_dict["extract_way"] = "text"
|
||||
return data_dict
|
||||
|
||||
def extract_data_by_page_image(
|
||||
self,
|
||||
page_num: int,
|
||||
page_datapoints: list
|
||||
page_datapoints: list,
|
||||
need_exclude: bool = False,
|
||||
exclude_data: list = None,
|
||||
) -> dict:
|
||||
"""
|
||||
keys are
|
||||
|
|
@ -381,7 +408,11 @@ class DataExtraction:
|
|||
logger.info(f"Extracting data from page {page_num}")
|
||||
image_base64 = self.get_pdf_image_base64(page_num)
|
||||
instructions = self.get_instructions_by_datapoints(
|
||||
"", page_datapoints, need_exclude=False, exclude_data=None
|
||||
"",
|
||||
page_datapoints,
|
||||
need_exclude=need_exclude,
|
||||
exclude_data=exclude_data,
|
||||
extract_way="image"
|
||||
)
|
||||
response, with_error = chat(
|
||||
instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
||||
|
|
@ -391,9 +422,11 @@ class DataExtraction:
|
|||
data_dict = {"doc_id": self.doc_id}
|
||||
data_dict["page_index"] = page_num
|
||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||
data_dict["page_text"] = ""
|
||||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["extract_data"] = {"data": []}
|
||||
data_dict["extract_way"] = "image"
|
||||
return data_dict
|
||||
try:
|
||||
data = json.loads(response)
|
||||
|
|
@ -407,137 +440,12 @@ class DataExtraction:
|
|||
data_dict = {"doc_id": self.doc_id}
|
||||
data_dict["page_index"] = page_num
|
||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||
data_dict["page_text"] = ""
|
||||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["extract_data"] = data
|
||||
data_dict["extract_way"] = "image"
|
||||
return data_dict
|
||||
|
||||
def chat_by_split_context(self,
|
||||
page_text: str,
|
||||
page_datapoints: list,
|
||||
need_exclude: bool,
|
||||
exclude_data: list) -> list:
|
||||
"""
|
||||
If occur error, split the context to two parts and try to get data from the two parts
|
||||
Relevant document: 503194284, page index 147
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||
split_context = re.split(r"\n", page_text)
|
||||
split_context = [text.strip() for text in split_context
|
||||
if len(text.strip()) > 0]
|
||||
if len(split_context) < 10:
|
||||
return {"data": []}
|
||||
|
||||
split_context_len = len(split_context)
|
||||
top_10_context = split_context[:10]
|
||||
rest_context = split_context[10:]
|
||||
header = "\n".join(top_10_context)
|
||||
half_len = split_context_len // 2
|
||||
# the member of half_len should not start with number
|
||||
# reverse iterate the list by half_len
|
||||
half_len_list = [i for i in range(half_len)]
|
||||
|
||||
fund_name_line = ""
|
||||
half_line = rest_context[half_len].strip()
|
||||
max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
half_line, self.provider_fund_name_list, matching_type="fund"
|
||||
)
|
||||
if max_similarity < 0.2:
|
||||
# get the fund name line text from the first half
|
||||
for index in reversed(half_len_list):
|
||||
line_text = rest_context[index].strip()
|
||||
if len(line_text) == 0:
|
||||
continue
|
||||
line_text_split = line_text.split()
|
||||
if len(line_text_split) < 3:
|
||||
continue
|
||||
first_word = line_text_split[0]
|
||||
if first_word.lower() == "class":
|
||||
continue
|
||||
|
||||
max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
line_text, self.provider_fund_name_list, matching_type="fund"
|
||||
)
|
||||
if max_similarity >= 0.2:
|
||||
fund_name_line = line_text
|
||||
break
|
||||
else:
|
||||
fund_name_line = half_line
|
||||
half_len += 1
|
||||
if fund_name_line == "":
|
||||
return {"data": []}
|
||||
|
||||
logger.info(f"Split first part from 0 to {half_len}")
|
||||
split_first_part = "\n".join(split_context[:half_len])
|
||||
first_part = '\n'.join(split_first_part)
|
||||
first_instructions = self.get_instructions_by_datapoints(
|
||||
first_part, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
response, with_error = chat(
|
||||
first_instructions, response_format={"type": "json_object"}
|
||||
)
|
||||
first_part_data = {"data": []}
|
||||
if not with_error:
|
||||
try:
|
||||
first_part_data = json.loads(response)
|
||||
except:
|
||||
first_part_data = json_repair.loads(response)
|
||||
|
||||
logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||
split_second_part = "\n".join(split_context[half_len:])
|
||||
second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
||||
second_instructions = self.get_instructions_by_datapoints(
|
||||
second_part, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
response, with_error = chat(
|
||||
second_instructions, response_format={"type": "json_object"}
|
||||
)
|
||||
second_part_data = {"data": []}
|
||||
if not with_error:
|
||||
try:
|
||||
second_part_data = json.loads(response)
|
||||
except:
|
||||
second_part_data = json_repair.loads(response)
|
||||
|
||||
first_part_data_list = first_part_data.get("data", [])
|
||||
logger.info(f"First part data count: {len(first_part_data_list)}")
|
||||
second_part_data_list = second_part_data.get("data", [])
|
||||
logger.info(f"Second part data count: {len(second_part_data_list)}")
|
||||
for first_data in first_part_data_list:
|
||||
if first_data in second_part_data_list:
|
||||
second_part_data_list.remove(first_data)
|
||||
else:
|
||||
# if the first part data is with same fund name and share name,
|
||||
# remove the second part data
|
||||
first_data_dp = [key for key in list(first_data.keys())
|
||||
if key not in ["fund name", "share name"]]
|
||||
# order the data points
|
||||
first_data_dp.sort()
|
||||
first_fund_name = first_data.get("fund name", "")
|
||||
first_share_name = first_data.get("share name", "")
|
||||
if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
||||
remove_second_list = []
|
||||
for second_data in second_part_data_list:
|
||||
second_fund_name = second_data.get("fund name", "")
|
||||
second_share_name = second_data.get("share name", "")
|
||||
if first_fund_name == second_fund_name and \
|
||||
first_share_name == second_share_name:
|
||||
second_data_dp = [key for key in list(second_data.keys())
|
||||
if key not in ["fund name", "share name"]]
|
||||
second_data_dp.sort()
|
||||
if first_data_dp == second_data_dp:
|
||||
remove_second_list.append(second_data)
|
||||
for remove_second in remove_second_list:
|
||||
if remove_second in second_part_data_list:
|
||||
second_part_data_list.remove(remove_second)
|
||||
|
||||
data_list = first_part_data_list + second_part_data_list
|
||||
extract_data = {"data": data_list}
|
||||
return extract_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in split context: {e}")
|
||||
return {"data": []}
|
||||
|
||||
def validate_data(self, extract_data_info: dict) -> dict:
|
||||
"""
|
||||
|
|
@ -634,7 +542,6 @@ class DataExtraction:
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
||||
datapoints = []
|
||||
for datapoint in self.datapoints:
|
||||
|
|
@ -648,6 +555,7 @@ class DataExtraction:
|
|||
datapoints: list,
|
||||
need_exclude: bool = False,
|
||||
exclude_data: list = None,
|
||||
extract_way: str = "text",
|
||||
) -> str:
|
||||
"""
|
||||
Get instructions to extract data from the page by the datapoints
|
||||
|
|
@ -678,7 +586,7 @@ class DataExtraction:
|
|||
end
|
||||
"""
|
||||
instructions = []
|
||||
if self.extract_way == "text":
|
||||
if extract_way == "text":
|
||||
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
|
||||
|
||||
datapoint_name_list = []
|
||||
|
|
@ -686,9 +594,9 @@ class DataExtraction:
|
|||
datapoint_name = self.datapoint_name_config.get(datapoint, "")
|
||||
datapoint_name_list.append(datapoint_name)
|
||||
|
||||
if self.extract_way == "text":
|
||||
if extract_way == "text":
|
||||
summary = self.instructions_config.get("summary", "\n")
|
||||
elif self.extract_way == "image":
|
||||
elif extract_way == "image":
|
||||
summary = self.instructions_config.get("summary_image", "\n")
|
||||
else:
|
||||
summary = self.instructions_config.get("summary", "\n")
|
||||
|
|
@ -696,7 +604,7 @@ class DataExtraction:
|
|||
instructions.append(summary.format(", ".join(datapoint_name_list)))
|
||||
instructions.append("\n")
|
||||
|
||||
if self.extract_way == "image":
|
||||
if extract_way == "image":
|
||||
image_features = self.instructions_config.get("image_features", [])
|
||||
instructions.extend(image_features)
|
||||
instructions.append("\n")
|
||||
|
|
@ -831,3 +739,130 @@ class DataExtraction:
|
|||
|
||||
instructions_text = "".join(instructions)
|
||||
return instructions_text
|
||||
|
||||
# def chat_by_split_context(self,
|
||||
# page_text: str,
|
||||
# page_datapoints: list,
|
||||
# need_exclude: bool,
|
||||
# exclude_data: list) -> list:
|
||||
# """
|
||||
# If occur error, split the context to two parts and try to get data from the two parts
|
||||
# Relevant document: 503194284, page index 147
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||
# split_context = re.split(r"\n", page_text)
|
||||
# split_context = [text.strip() for text in split_context
|
||||
# if len(text.strip()) > 0]
|
||||
# if len(split_context) < 10:
|
||||
# return {"data": []}
|
||||
|
||||
# split_context_len = len(split_context)
|
||||
# top_10_context = split_context[:10]
|
||||
# rest_context = split_context[10:]
|
||||
# header = "\n".join(top_10_context)
|
||||
# half_len = split_context_len // 2
|
||||
# # the member of half_len should not start with number
|
||||
# # reverse iterate the list by half_len
|
||||
# half_len_list = [i for i in range(half_len)]
|
||||
|
||||
# fund_name_line = ""
|
||||
# half_line = rest_context[half_len].strip()
|
||||
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
# half_line, self.provider_fund_name_list, matching_type="fund"
|
||||
# )
|
||||
# if max_similarity < 0.2:
|
||||
# # get the fund name line text from the first half
|
||||
# for index in reversed(half_len_list):
|
||||
# line_text = rest_context[index].strip()
|
||||
# if len(line_text) == 0:
|
||||
# continue
|
||||
# line_text_split = line_text.split()
|
||||
# if len(line_text_split) < 3:
|
||||
# continue
|
||||
# first_word = line_text_split[0]
|
||||
# if first_word.lower() == "class":
|
||||
# continue
|
||||
|
||||
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
# line_text, self.provider_fund_name_list, matching_type="fund"
|
||||
# )
|
||||
# if max_similarity >= 0.2:
|
||||
# fund_name_line = line_text
|
||||
# break
|
||||
# else:
|
||||
# fund_name_line = half_line
|
||||
# half_len += 1
|
||||
# if fund_name_line == "":
|
||||
# return {"data": []}
|
||||
|
||||
# logger.info(f"Split first part from 0 to {half_len}")
|
||||
# split_first_part = "\n".join(split_context[:half_len])
|
||||
# first_part = '\n'.join(split_first_part)
|
||||
# first_instructions = self.get_instructions_by_datapoints(
|
||||
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
||||
# )
|
||||
# response, with_error = chat(
|
||||
# first_instructions, response_format={"type": "json_object"}
|
||||
# )
|
||||
# first_part_data = {"data": []}
|
||||
# if not with_error:
|
||||
# try:
|
||||
# first_part_data = json.loads(response)
|
||||
# except:
|
||||
# first_part_data = json_repair.loads(response)
|
||||
|
||||
# logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||
# split_second_part = "\n".join(split_context[half_len:])
|
||||
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
||||
# second_instructions = self.get_instructions_by_datapoints(
|
||||
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
||||
# )
|
||||
# response, with_error = chat(
|
||||
# second_instructions, response_format={"type": "json_object"}
|
||||
# )
|
||||
# second_part_data = {"data": []}
|
||||
# if not with_error:
|
||||
# try:
|
||||
# second_part_data = json.loads(response)
|
||||
# except:
|
||||
# second_part_data = json_repair.loads(response)
|
||||
|
||||
# first_part_data_list = first_part_data.get("data", [])
|
||||
# logger.info(f"First part data count: {len(first_part_data_list)}")
|
||||
# second_part_data_list = second_part_data.get("data", [])
|
||||
# logger.info(f"Second part data count: {len(second_part_data_list)}")
|
||||
# for first_data in first_part_data_list:
|
||||
# if first_data in second_part_data_list:
|
||||
# second_part_data_list.remove(first_data)
|
||||
# else:
|
||||
# # if the first part data is with same fund name and share name,
|
||||
# # remove the second part data
|
||||
# first_data_dp = [key for key in list(first_data.keys())
|
||||
# if key not in ["fund name", "share name"]]
|
||||
# # order the data points
|
||||
# first_data_dp.sort()
|
||||
# first_fund_name = first_data.get("fund name", "")
|
||||
# first_share_name = first_data.get("share name", "")
|
||||
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
||||
# remove_second_list = []
|
||||
# for second_data in second_part_data_list:
|
||||
# second_fund_name = second_data.get("fund name", "")
|
||||
# second_share_name = second_data.get("share name", "")
|
||||
# if first_fund_name == second_fund_name and \
|
||||
# first_share_name == second_share_name:
|
||||
# second_data_dp = [key for key in list(second_data.keys())
|
||||
# if key not in ["fund name", "share name"]]
|
||||
# second_data_dp.sort()
|
||||
# if first_data_dp == second_data_dp:
|
||||
# remove_second_list.append(second_data)
|
||||
# for remove_second in remove_second_list:
|
||||
# if remove_second in second_part_data_list:
|
||||
# second_part_data_list.remove(remove_second)
|
||||
|
||||
# data_list = first_part_data_list + second_part_data_list
|
||||
# extract_data = {"data": data_list}
|
||||
# return extract_data
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in split context: {e}")
|
||||
# return {"data": []}
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -809,10 +809,10 @@ if __name__ == "__main__":
|
|||
]
|
||||
# special_doc_id_list = check_mapping_doc_id_list
|
||||
special_doc_id_list = check_db_mapping_doc_id_list
|
||||
special_doc_id_list = ["423395975"]
|
||||
special_doc_id_list = ["514213638"]
|
||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||
re_run_extract_data = True
|
||||
re_run_extract_data = False
|
||||
re_run_mapping_data = True
|
||||
force_save_total_data = False
|
||||
calculate_metrics = False
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def chat(
|
|||
image_file: str = None,
|
||||
image_base64: str = None,
|
||||
):
|
||||
if engine != "gpt-4o-2024-08-06-research":
|
||||
if not engine.startswith("gpt-4o"):
|
||||
max_tokens = 4096
|
||||
|
||||
client = AzureOpenAI(
|
||||
|
|
@ -138,6 +138,7 @@ def chat(
|
|||
messages=messages,
|
||||
response_format=response_format,
|
||||
)
|
||||
sleep(1)
|
||||
return response.choices[0].message.content, False
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
|
|
@ -145,7 +146,7 @@ def chat(
|
|||
if "maximum context length" in error:
|
||||
return error, True
|
||||
count += 1
|
||||
sleep(3)
|
||||
sleep(2)
|
||||
return error, True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue