support extract the continuous page(s) for not missing next page data which without table header.

This commit is contained in:
Blade He 2024-09-06 16:29:35 -05:00
parent 1caf552065
commit 878383a72c
5 changed files with 267 additions and 99 deletions

View File

@ -13,26 +13,26 @@ from utils.biz_utils import add_slash_to_text_as_regex, clean_text
class DataExtraction:
def __init__(
self,
doc_id: str,
self,
doc_id: str,
pdf_file: str,
output_data_folder: str,
page_text_dict: dict,
datapoint_page_info: dict,
document_mapping_info_df: pd.DataFrame
output_data_folder: str,
page_text_dict: dict,
datapoint_page_info: dict,
document_mapping_info_df: pd.DataFrame,
) -> None:
self.doc_id = doc_id
self.pdf_file = pdf_file
if output_data_folder is None or len(output_data_folder) == 0:
output_data_folder = r"/data/emea_ar/output/extract_data/docs/"
os.makedirs(output_data_folder, exist_ok=True)
self.output_data_json_folder = os.path.join(output_data_folder, "json/")
os.makedirs(self.output_data_json_folder, exist_ok=True)
self.output_data_excel_folder = os.path.join(output_data_folder, "excel/")
os.makedirs(self.output_data_excel_folder, exist_ok=True)
if page_text_dict is None or len(page_text_dict.keys()) == 0:
self.page_text_dict = self.get_pdf_page_text_dict()
else:
@ -42,17 +42,19 @@ class DataExtraction:
else:
self.document_mapping_info_df = document_mapping_info_df
self.datapoint_page_info = datapoint_page_info
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
self.datapoints = self.get_datapoints_from_datapoint_page_info()
self.instructions_config = self.get_instructions_config()
self.datapoint_level_config = self.get_datapoint_level()
self.datapoint_name_config = self.get_datapoint_name()
def get_instructions_config(self) -> dict:
instructions_config_file = r"./instructions/data_extraction_prompts_config.json"
with open(instructions_config_file, "r", encoding="utf-8") as f:
instructions_config = json.load(f)
return instructions_config
def get_datapoint_level(self) -> dict:
datapoint_level_file = r"./configuration/datapoint_level.json"
with open(datapoint_level_file, "r", encoding="utf-8") as f:
@ -64,68 +66,181 @@ class DataExtraction:
with open(datapoint_name_file, "r", encoding="utf-8") as f:
datapoint_name = json.load(f)
return datapoint_name
def get_pdf_page_text_dict(self) -> dict:
pdf_util = PDFUtil(self.pdf_file)
success, text, page_text_dict = pdf_util.extract_text()
return page_text_dict
def get_datapoints_from_datapoint_page_info(self) -> list:
datapoints = list(self.datapoint_page_info.keys())
if "doc_id" in datapoints:
datapoints.remove("doc_id")
return datapoints
def get_page_nums_from_datapoint_page_info(self) -> list:
page_nums_with_datapoints = []
for datapoint, page_nums in self.datapoint_page_info.items():
if datapoint == "doc_id":
continue
page_nums_with_datapoints.extend(page_nums)
page_nums_with_datapoints = list(set(page_nums_with_datapoints))
# sort the page numbers
page_nums_with_datapoints.sort()
return page_nums_with_datapoints
def extract_data(self) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
data_list = []
pdf_page_count = len(self.page_text_dict.keys())
handled_page_num_list = []
for page_num, page_text in self.page_text_dict.items():
if page_num in handled_page_num_list:
continue
page_datapoints = self.get_datapoints_by_page_num(page_num)
if len(page_datapoints) == 0:
continue
instructions = self.get_instructions_by_datapoints(page_text, page_datapoints)
response, with_error = chat(instructions)
if with_error:
logger.error(f"Error in extracting tables from page")
return ""
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
except:
data = {}
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["data"] = data
data_list.append(data_dict)
json_data_file = os.path.join(self.output_data_json_folder, f"{self.doc_id}.json")
extract_data = self.extract_data_by_page(
page_num,
page_text,
page_datapoints,
need_exclude=False,
exclude_data=None,
)
data_list.append(extract_data)
page_data_list = extract_data.get("extract_data", {}).get("data", [])
current_page_data_count = len(page_data_list)
if current_page_data_count > 0:
count = 1
# some pdf documents have multiple pages for the same data
# and the next page may without table header with data point keywords.
# the purpose is try to get data from the next page
current_text = page_text
while count < 3:
try:
next_page_num = page_num + count
if next_page_num >= pdf_page_count:
break
next_datapoints = page_datapoints
if next_page_num in self.page_nums_with_datapoints:
should_continue = False
next_datapoints = self.get_datapoints_by_page_num(next_page_num)
if len(next_datapoints) == 0:
should_continue = True
else:
for next_datapoint in next_datapoints:
if next_datapoint not in page_datapoints:
should_continue = True
break
next_datapoints.extend(page_datapoints)
# remove duplicate datapoints
next_datapoints = list(set(next_datapoints))
if not should_continue:
break
next_page_text = self.page_text_dict.get(next_page_num, "")
target_text = current_text + next_page_text
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page(
next_page_num,
target_text,
next_datapoints,
need_exclude=True,
exclude_data=page_data_list,
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
).get("data", [])
if next_page_data_list is not None and len(next_page_data_list) > 0:
for current_page_data in page_data_list:
if current_page_data in next_page_data_list:
next_page_data_list.remove(current_page_data)
next_page_extract_data["extract_data"][
"data"
] = next_page_data_list
data_list.append(next_page_extract_data)
handled_page_num_list.append(next_page_num)
else:
break
count += 1
except Exception as e:
logger.error(f"Error in extracting data from next page: {e}")
break
json_data_file = os.path.join(
self.output_data_json_folder, f"{self.doc_id}.json"
)
with open(json_data_file, "w", encoding="utf-8") as f:
json.dump(data_list, f, ensure_ascii=False, indent=4)
data_df = pd.DataFrame(data_list)
data_df.reset_index(drop=True, inplace=True)
excel_data_file = os.path.join(self.output_data_excel_folder, f"{self.doc_id}.xlsx")
excel_data_file = os.path.join(
self.output_data_excel_folder, f"{self.doc_id}.xlsx"
)
with pd.ExcelWriter(excel_data_file) as writer:
data_df.to_excel(writer, sheet_name="extract_data", index=False)
return data_list
def extract_data_by_page(
self,
page_num: int,
page_text: str,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
logger.info(f"Extracting data from page {page_num}")
instructions = self.get_instructions_by_datapoints(
page_text, page_datapoints, need_exclude, exclude_data
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}
)
if with_error:
logger.error(f"Error in extracting tables from page")
return ""
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
except:
data = {"data": []}
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = page_text
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
return data_dict
def get_datapoints_by_page_num(self, page_num: int) -> list:
datapoints = []
for datapoint in self.datapoints:
if page_num in self.datapoint_page_info[datapoint]:
datapoints.append(datapoint)
return datapoints
def get_instructions_by_datapoints(self, page_text: str, datapoints: list) -> str:
def get_instructions_by_datapoints(
self,
page_text: str,
datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
) -> str:
"""
Get instructions to extract data from the page by the datapoints
Below is the instructions sections:
@ -159,11 +274,11 @@ class DataExtraction:
for datapoint in datapoints:
datapoint_name = self.datapoint_name_config.get(datapoint, "")
datapoint_name_list.append(datapoint_name)
summary = self.instructions_config.get("summary", "\n")
instructions.append(summary.format(', '.join(datapoint_name_list)))
instructions.append(summary.format(", ".join(datapoint_name_list)))
instructions.append("\n")
instructions.append("Datapoints Reported name:\n")
reported_name_info = self.instructions_config.get("reported_name", {})
for datapoint in datapoints:
@ -171,13 +286,15 @@ class DataExtraction:
instructions.append(reported_name)
instructions.append("\n")
instructions.append("\n")
instructions.append("Data business features:\n")
data_business_features = self.instructions_config.get("data_business_features", {})
common = '\n'.join(data_business_features.get("common", []))
data_business_features = self.instructions_config.get(
"data_business_features", {}
)
common = "\n".join(data_business_features.get("common", []))
instructions.append(common)
instructions.append("\n")
instructions.append("Datapoints investment level:\n")
investment_level_info = data_business_features.get("investment_level", {})
for datapoint in datapoints:
@ -185,7 +302,7 @@ class DataExtraction:
instructions.append(investment_level)
instructions.append("\n")
instructions.append("\n")
instructions.append("Datapoints value range:\n")
data_value_range_info = data_business_features.get("data_value_range", {})
for datapoint in datapoints:
@ -193,7 +310,7 @@ class DataExtraction:
instructions.append(data_value_range)
instructions.append("\n")
instructions.append("\n")
special_rule_info = data_business_features.get("special_rule", {})
with_special_rule_title = False
for datapoint in datapoints:
@ -202,11 +319,11 @@ class DataExtraction:
if not with_special_rule_title:
instructions.append("Special rule:\n")
with_special_rule_title = True
special_rule = '\n'.join(special_rule_list)
special_rule = "\n".join(special_rule_list)
instructions.append(special_rule)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Special cases:\n")
special_cases = self.instructions_config.get("special_cases", {})
special_cases_common_list = special_cases.get("common", [])
@ -215,10 +332,10 @@ class DataExtraction:
instructions.append(title)
instructions.append("\n")
contents_list = special_cases_common.get("contents", [])
contents = '\n'.join(contents_list)
contents = "\n".join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
for datapoint in datapoints:
special_case_list = special_cases.get(datapoint, [])
for special_case in special_case_list:
@ -226,51 +343,69 @@ class DataExtraction:
instructions.append(title)
instructions.append("\n")
contents_list = special_case.get("contents", [])
contents = '\n'.join(contents_list)
contents = "\n".join(contents_list)
instructions.append(contents)
instructions.append("\n\n")
instructions.append("\n")
instructions.append("Output requirement:\n")
output_requirement = self.instructions_config.get("output_requirement", {})
output_requirement_common_list = output_requirement.get("common", [])
instructions.append("\n".join(output_requirement_common_list))
instructions.append("\n")
share_datapoint_value_example = {}
share_level_config = output_requirement.get("share_level", {})
example_list = []
for datapoint in datapoints:
investment_level = self.datapoint_level_config.get(datapoint, "")
if investment_level == "fund_level":
fund_level_example_list = output_requirement.get("fund_level", [])
for example in fund_level_example_list:
instructions.append(example)
instructions.append("\n")
instructions.append("\n")
try:
sub_example_list = json.loads(example)
except:
sub_example_list = json_repair.loads(example)
example_list.extend(sub_example_list)
elif investment_level == "share_level":
share_datapoint_value_example[datapoint] = share_level_config.get(f"{datapoint}_value", [])
share_datapoint_value_example[datapoint] = share_level_config.get(
f"{datapoint}_value", []
)
share_datapoint_list = list(share_datapoint_value_example.keys())
instructions.append(f"Example:\n")
if len(share_datapoint_list) > 0:
fund_name_example_list = share_level_config.get("fund_name", [])
share_name_example_list = share_level_config.get("share_name", [])
for index in range(len(fund_name_example_list)):
example_dict = {"fund name": fund_name_example_list[index],
"share name": share_name_example_list[index]}
example_dict = {
"fund name": fund_name_example_list[index],
"share name": share_name_example_list[index],
}
for share_datapoint in share_datapoint_list:
share_datapoint_values = share_datapoint_value_example[share_datapoint]
share_datapoint_values = share_datapoint_value_example[
share_datapoint
]
if index < len(share_datapoint_values):
example_dict[share_datapoint] = share_datapoint_values[index]
instructions.append(f"Example {index + 1}:\n")
instructions.append(json.dumps(example_dict, ensure_ascii=False))
instructions.append("\n")
instructions.append("\n")
end_list = self.instructions_config.get("end", [])
instructions.append('\n'.join(end_list))
example_list.append(example_dict)
example_data = {"data": example_list}
instructions.append(json.dumps(example_data, ensure_ascii=False, indent=4))
instructions.append("\n")
instructions.append("\n")
end_list = self.instructions_config.get("end", [])
instructions.append("\n".join(end_list))
instructions.append("\n")
if need_exclude and exclude_data is not None and isinstance(exclude_data, list):
instructions.append("Please exclude below data from output:\n")
instructions.append(json.dumps(exclude_data, ensure_ascii=False, indent=4))
instructions.append("\n")
instructions.append("\n")
instructions.append("Answer:\n")
instructions_text = ''.join(instructions)
return instructions_text
instructions_text = "".join(instructions)
return instructions_text

View File

@ -135,7 +135,7 @@
},
"end": [
"Only output JSON data.",
"Don't output the value which not exist in context, especiall for fund level datapoint: TOR.",
"If can't find share class name in context, please output empty JSON data: []"
"Don't output the value which not exist in context, especially for fund level datapoint: TOR.",
"If can't find share class name in context, please output empty JSON data: {\"data\": []}"
]
}

View File

@ -335,15 +335,16 @@ if __name__ == "__main__":
# )
# test_auto_generate_instructions()
# doc_id = "294132333"
# extract_data(doc_id, pdf_folder)
output_child_folder = r"/data/emea_ar/output/extract_data/docs/"
output_total_folder = r"/data/emea_ar/output/extract_data/total/"
re_run = False
re_run = True
batch_extract_data(pdf_folder,
page_filter_ground_truth_file,
output_child_folder,
output_total_folder,
special_doc_id_list,
re_run)
# doc_id = "476492237"
# extract_data(doc_id, pdf_folder, output_child_folder, re_run)

View File

@ -726,13 +726,31 @@ def pickup_document_from_top_100_providers():
top_100_provider_document_file, sheet_name="all_data"
)
top_100_provider_document_fund_count = pd.read_excel(
top_100_provider_document_file, sheet_name="doc_fund_count"
)
top_100_provider_document_fund_count.reset_index(drop=True, inplace=True)
top_100_provider_document_share_count = pd.read_excel(
top_100_provider_document_file, sheet_name="doc_share_class_count"
)
top_100_provider_document_share_count = \
top_100_provider_document_share_count[top_100_provider_document_share_count["with_ar_data"] == True]
top_100_provider_document_share_count.reset_index(drop=True, inplace=True)
top_100_provider_document_share_count = pd.merge(
top_100_provider_document_share_count,
top_100_provider_document_fund_count,
on=["DocumentId"],
how="left",
)
top_100_provider_document_share_count = top_100_provider_document_share_count[
["DocumentId", "CompanyId_x", "CompanyName_x", "fund_count", "share_class_count"]
]
top_100_provider_document_share_count.rename(
columns={"CompanyId_x": "CompanyId"}, inplace=True
)
# add a new column with name share_count_rank to top_100_provider_document_share_count by merge with provider_share_count
top_100_provider_document_share_count = pd.merge(
top_100_provider_document_share_count,
@ -742,12 +760,11 @@ def pickup_document_from_top_100_providers():
)
# Keep columns: DocumentId, CompanyId, CompanyName, share_class_count_x, share_count_rank
top_100_provider_document_share_count = top_100_provider_document_share_count[
["DocumentId", "CompanyId", "CompanyName_x", "share_class_count_x", "share_count_rank"]
["DocumentId", "CompanyId", "CompanyName", "fund_count", "share_class_count_x", "share_count_rank"]
]
# rename column share_class_count_x to share_class_count
top_100_provider_document_share_count.rename(
columns={"share_class_count_x": "share_class_count",
"CompanyName_x": "Company_Name",
"share_count_rank": "provider_share_count_rank"}, inplace=True
)
top_100_provider_document_share_count = top_100_provider_document_share_count.sort_values(
@ -833,8 +850,8 @@ if __name__ == "__main__":
random_small_document_data_file = (
r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx"
)
download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder)
output_pdf_page_text(pdf_folder, output_folder)
# download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder)
# output_pdf_page_text(pdf_folder, output_folder)
# extract_pdf_table(pdf_folder, output_folder)
# analyze_json_error()
@ -846,4 +863,4 @@ if __name__ == "__main__":
# output_folder=basic_info_folder,
# )
# statistics_document_fund_share_count(doc_mapping_from_top_100_provider_file)
# pickup_document_from_top_100_providers()
pickup_document_from_top_100_providers()

View File

@ -69,6 +69,7 @@ def chat(
api_key=os.getenv("OPENAI_API_KEY_GPT4o"),
api_version=os.getenv("OPENAI_API_VERSION_GPT4o"),
temperature: float = 0.0,
response_format: dict = None,
image_file: str = None,
image_base64: str = None,
):
@ -108,18 +109,32 @@ def chat(
try:
if count > 0:
print(f"retrying the {count} time...")
response = client.chat.completions.create(
model=engine,
temperature=temperature,
max_tokens=max_tokens,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
timeout=request_timeout,
stop=None,
messages=messages,
response_format={"type": "json_object"},
)
if response_format is None:
response = client.chat.completions.create(
model=engine,
temperature=temperature,
max_tokens=max_tokens,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
timeout=request_timeout,
stop=None,
messages=messages,
)
else:
# response_format={"type": "json_object"}
response = client.chat.completions.create(
model=engine,
temperature=temperature,
max_tokens=max_tokens,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
timeout=request_timeout,
stop=None,
messages=messages,
response_format=response_format,
)
return response.choices[0].message.content, False
except Exception as e:
error = str(e)