switch back to extract data from image stream directly, instead of getting text from image stream as the first step, then extract data from extracted text.

The reason is: the quality of getting text from image steam is not good enough.
This commit is contained in:
Blade He 2024-12-10 16:17:47 -06:00
parent f71e2968cc
commit d673a99e21
5 changed files with 302 additions and 99 deletions

View File

@ -171,6 +171,8 @@ class DataExtraction:
previous_page_datapoints = []
previous_page_fund_name = None
for page_num, page_text in self.page_text_dict.items():
# if page_num > 640 or page_num < 610:
# continue
if page_num in handled_page_num_list:
continue
page_datapoints = self.get_datapoints_by_page_num(page_num)
@ -278,6 +280,7 @@ class DataExtraction:
if not exist_current_page_datapoint:
break
else:
data_list.append(next_page_extract_data)
break
count += 1
except Exception as e:
@ -336,7 +339,8 @@ class DataExtraction:
# try to get data by current page_datapoints
next_page_extract_data = self.extract_data_by_page_image(
page_num=next_page_num,
page_datapoints=next_datapoints
page_datapoints=next_datapoints,
need_extract_text=False
)
next_page_data_list = next_page_extract_data.get(
"extract_data", {}
@ -403,7 +407,8 @@ class DataExtraction:
page_datapoints=page_datapoints,
need_exclude=False,
exclude_data=None,
previous_page_last_fund=previous_page_last_fund)
previous_page_last_fund=previous_page_last_fund,
need_extract_text=False)
else:
return self.extract_data_by_page_text(
page_num=page_num,
@ -480,6 +485,55 @@ class DataExtraction:
return data_dict
def extract_data_by_page_image(
self,
page_num: int,
page_datapoints: list,
need_exclude: bool = False,
exclude_data: list = None,
previous_page_last_fund: str = None,
need_extract_text: bool = False
) -> dict:
"""
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
if need_extract_text:
logger.info(f"Extracting data from page {page_num} with extracting text as single step.")
page_text = self.get_image_text(page_num)
if page_text is None or len(page_text) == 0:
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = ""
data_dict["raw_answer"] = ""
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "image"
return data_dict
else:
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
logger.info(f"Transfer previous page fund name: {previous_page_last_fund} to be the pre-fix of page text")
page_text = f"\nThe last fund name of previous PDF page: {previous_page_last_fund}\n{page_text}"
return self.extract_data_by_page_text(
page_num=page_num,
page_text=page_text,
page_datapoints=page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
previous_page_last_fund=previous_page_last_fund,
original_way="image"
)
else:
logger.info(f"Extracting data from page {page_num} without extracting text as single step.")
return self.extract_data_by_pure_image(
page_num=page_num,
page_datapoints=page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
previous_page_last_fund=previous_page_last_fund
)
def extract_data_by_pure_image(
self,
page_num: int,
page_datapoints: list,
@ -491,32 +545,46 @@ class DataExtraction:
keys are
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
"""
logger.info(f"Extracting data from page {page_num}")
# image_base64 = self.get_pdf_image_base64(page_num)
page_text = self.get_image_text(page_num)
if page_text is None or len(page_text) == 0:
image_base64 = self.get_pdf_image_base64(page_num)
instructions = self.get_instructions_by_datapoints(
previous_page_last_fund,
page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
extract_way="image"
)
response, with_error = chat(
instructions, response_format={"type": "json_object"}, image_base64=image_base64
)
if with_error:
logger.error(f"Error in extracting tables from page")
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = ""
data_dict["raw_answer"] = ""
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = {"data": []}
data_dict["extract_way"] = "image"
return data_dict
else:
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
logger.info(f"Transfer previous page fund name: {previous_page_last_fund} to be the pre-fix of page text")
page_text = f"\nThe last fund name of previous PDF page: {previous_page_last_fund}\n{page_text}"
return self.extract_data_by_page_text(
page_num=page_num,
page_text=page_text,
page_datapoints=page_datapoints,
need_exclude=need_exclude,
exclude_data=exclude_data,
previous_page_last_fund=previous_page_last_fund,
original_way="image"
)
try:
data = json.loads(response)
except:
try:
data = json_repair.loads(response)
except:
data = {"data": []}
data = self.validate_data(data, None, previous_page_last_fund)
data_dict = {"doc_id": self.doc_id}
data_dict["page_index"] = page_num
data_dict["datapoints"] = ", ".join(page_datapoints)
data_dict["page_text"] = ""
data_dict["instructions"] = instructions
data_dict["raw_answer"] = response
data_dict["extract_data"] = data
data_dict["extract_way"] = "image"
return data_dict
def get_image_text(self, page_num: int) -> str:
image_base64 = self.get_pdf_image_base64(page_num)
@ -536,6 +604,7 @@ class DataExtraction:
except:
pass
text = data.get("text", "")
# print(text)
return text
def validate_data(self,
@ -790,6 +859,7 @@ class DataExtraction:
elif extract_way == "image":
summary = self.instructions_config.get("summary_image", "\n")
if page_text is not None and len(page_text) > 0:
logger.info(f"Transfer previous page fund name: {page_text} to be the pre-fix of page text")
summary += f"\nThe last fund name of previous PDF page: {page_text}\n"
else:
summary = self.instructions_config.get("summary", "\n")

View File

@ -157,4 +157,4 @@ def calculate_metrics():
if __name__ == "__main__":
drilldown_documents()
calculate_metrics()
# calculate_metrics()

View File

@ -1,7 +1,7 @@
{
"summary": "Read the context carefully.\nMaybe exists {} data in the context.\n",
"summary_image": "Read the image carefully.\nMaybe exists {} data in the image.\n",
"get_image_text": "Instructions: Please extract the text from the image. output the result as a JSON, the JSON format is like below example(s): {\"text\": \"Text from image\"} \n\nAnswer:\n",
"get_image_text": "Instructions:\nYou are given an image of a page from a PDF document. Extract **all visible text** from the image while preserving the original order, structure, and any associated context as closely as possible. Ensure that:\n\n1. **All textual elements are included**, such as headings, body text, tables, and labels.\n2. **Numerical data, symbols, and special characters** are preserved accurately.\n3. Text in structured formats (e.g., tables, lists) is retained in a logical and readable format.\n4. Any text embedded in graphical elements, if clearly readable, is also included.\n5. The text is clean, readable, and free of formatting artifacts or errors.\n\nDo not include non-textual elements such as images or graphics unless they contain text that can be meaningfully extracted.\n\n### Output Format:\nOutput the result as JSON format, here is the example: \n{\"text\": \"Text from image\"}\n\nAnswer: \n[Extracted Text Here, retaining logical structure and all content]",
"image_features":
[
"1. Identify the text in the PDF page image.",

View File

@ -887,10 +887,11 @@ def batch_run_documents():
calculate_metrics = False
extract_way = "text"
special_doc_id_list = []
special_doc_id_list = ["435128656"]
if len(special_doc_id_list) == 0:
force_save_total_data = True
file_base_name_candidates = ["sample_document_complex", "emea_case_from_word_complex"]
# file_base_name_candidates = ["sample_document_complex", "emea_case_from_word_complex"]
file_base_name_candidates = ["sample_document_complex"]
for document_list_file in document_list_files:
file_base_name = os.path.basename(document_list_file).replace(".txt", "")
if (file_base_name_candidates is not None and

View File

@ -10,68 +10,117 @@ from utils.logger import logger
def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []):
data_df = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping")
data_df_1 = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping")
# convert doc_id column to string
data_df["doc_id"] = data_df["doc_id"].astype(str)
data_df = data_df[data_df["raw_check"].isin([0, 1])]
data_df_1["doc_id"] = data_df_1["doc_id"].astype(str)
data_df_1 = data_df_1[data_df_1["raw_check"].isin([0, 1])]
exclude_documents = ["532422548"]
# remove data by doc_id not in exclude_documents
data_df_1 = data_df_1[~data_df_1["doc_id"].isin(exclude_documents)]
if document_list is not None and len(document_list) > 0:
data_df = data_df[data_df["doc_id"].isin(document_list)]
data_df_1 = data_df_1[data_df_1["doc_id"].isin(document_list)]
data_df_2 = pd.read_excel(verify_file_path, sheet_name="total_mapping_data")
data_df_2["doc_id"] = data_df_2["doc_id"].astype(str)
data_df_2 = data_df_2[data_df_2["raw_check"].isin([0, 1])]
data_df = pd.concat([data_df_1, data_df_2], ignore_index=True)
data_df.fillna("", inplace=True)
data_df.reset_index(drop=True, inplace=True)
metrics_df_list = []
doc_id_list = data_df["doc_id"].unique().tolist()
for doc_id in tqdm(doc_id_list):
try:
document_data_df = data_df[data_df["doc_id"] == doc_id]
document_metrics_df = calc_metrics(document_data_df, doc_id)
metrics_df_list.append(document_metrics_df)
except Exception as e:
logger.error(f"Error when calculating metrics for document {doc_id}")
print_exc()
total_metrics_df = calc_metrics(data_df, doc_id=None)
metrics_df_list.append(total_metrics_df)
all_metrics_df = pd.concat(metrics_df_list, ignore_index=True)
all_metrics_df.reset_index(drop=True, inplace=True)
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
output_metrics_file = os.path.join(output_folder,
f"complex_{verify_file_name}_metrics_all.xlsx")
with pd.ExcelWriter(output_metrics_file) as writer:
all_metrics_df.to_excel(writer, index=False, sheet_name="metrics")
def calc_metrics(data_df: pd.DataFrame, doc_id: str = None):
# tor data
tor_data_df = data_df[data_df["datapoint"] == "tor"]
tor_metrics = get_sub_metrics(tor_data_df, "tor")
logger.info(f"TOR metrics: {tor_metrics}")
if len(tor_data_df) > 0:
tor_metrics = get_sub_metrics(tor_data_df, "tor", doc_id)
logger.info(f"TOR metrics: {tor_metrics}")
else:
tor_metrics = None
# ter data
ter_data_df = data_df[data_df["datapoint"] == "ter"]
ter_metrics = get_sub_metrics(ter_data_df, "ter")
logger.info(f"TER metrics: {ter_metrics}")
if len(ter_data_df) > 0:
ter_metrics = get_sub_metrics(ter_data_df, "ter", doc_id)
logger.info(f"TER metrics: {ter_metrics}")
else:
ter_metrics = None
# ogc data
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
logger.info(f"OGC metrics: {ogc_metrics}")
if len(ogc_data_df) > 0:
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc", doc_id)
logger.info(f"OGC metrics: {ogc_metrics}")
else:
ogc_metrics = None
# performance_fee data
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
if len(performance_fee_data_df) > 0:
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee", doc_id)
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
else:
performance_fee_metrics = None
metrics_df = pd.DataFrame([tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics])
metrics_candidates = [tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics]
metrics_list = [metrics for metrics in metrics_candidates if metrics is not None]
metrics_df = pd.DataFrame(metrics_list)
# add average metrics
avg_metrics = {
"DataPoint": "average",
"F1": metrics_df["F1"].mean(),
"Precision": metrics_df["Precision"].mean(),
"Recall": metrics_df["Recall"].mean(),
"Accuracy": metrics_df["Accuracy"].mean(),
"Support": metrics_df["Support"].sum()
}
if doc_id is not None and len(doc_id) > 0:
avg_metrics = {
"DocumentId": doc_id,
"DataPoint": "average",
"F1": metrics_df["F1"].mean(),
"Precision": metrics_df["Precision"].mean(),
"Recall": metrics_df["Recall"].mean(),
"Accuracy": metrics_df["Accuracy"].mean(),
"Support": metrics_df["Support"].sum()
}
else:
avg_metrics = {
"DocumentId": "All",
"DataPoint": "average",
"F1": metrics_df["F1"].mean(),
"Precision": metrics_df["Precision"].mean(),
"Recall": metrics_df["Recall"].mean(),
"Accuracy": metrics_df["Accuracy"].mean(),
"Support": metrics_df["Support"].sum()
}
metrics_df = pd.DataFrame([tor_metrics, ter_metrics,
ogc_metrics, performance_fee_metrics,
avg_metrics])
metrics_list.append(avg_metrics)
metrics_df = pd.DataFrame(metrics_list)
metrics_df.reset_index(drop=True, inplace=True)
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
document_count = len(document_list) \
if document_list is not None and len(document_list) > 0 \
else len(data_df["doc_id"].unique())
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
output_metrics_file = os.path.join(output_folder,
f"complex_{verify_file_name}_metrics.xlsx")
with pd.ExcelWriter(output_metrics_file) as writer:
metrics_df.to_excel(writer, index=False, sheet_name="metrics")
return metrics_df
def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
def get_sub_metrics(data_df: pd.DataFrame, data_point: str, doc_id: str = None) -> dict:
data_df_raw_check_1 = data_df[data_df["raw_check"] == 1]
gt_list = [1] * len(data_df_raw_check_1)
pre_list = [1] * len(data_df_raw_check_1)
@ -99,47 +148,130 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
recall = recall_score(gt_list, pre_list)
f1 = f1_score(gt_list, pre_list)
support = sum(gt_list)
metrics = {
"DataPoint": data_point,
"F1": f1,
"Precision": precision,
"Recall": recall,
"Accuracy": accuracy,
"Support": support
}
if doc_id is not None and len(doc_id) > 0:
metrics = {
"DocumentId": doc_id,
"DataPoint": data_point,
"F1": f1,
"Precision": precision,
"Recall": recall,
"Accuracy": accuracy,
"Support": support
}
else:
metrics = {
"DocumentId": "All",
"DataPoint": data_point,
"F1": f1,
"Precision": precision,
"Recall": recall,
"Accuracy": accuracy,
"Support": support
}
return metrics
def get_metrics_based_documents(metrics_file: str, document_list: list):
metrics_df = pd.read_excel(metrics_file, sheet_name="metrics")
metrics_df_list = []
for doc_id in tqdm(document_list):
try:
document_metrics_df = metrics_df[metrics_df["DocumentId"] == doc_id]
metrics_df_list.append(document_metrics_df)
except Exception as e:
logger.error(f"Error when calculating metrics for document {doc_id}")
print_exc()
metrics_document_df = pd.concat(metrics_df_list, ignore_index=True)
stats_metrics_list = []
tor_df = metrics_document_df[metrics_document_df["DataPoint"] == "tor"]
if len(tor_df) > 0:
tor_metrics = {
"DocumentId": "All",
"DataPoint": "tor",
"F1": tor_df["F1"].mean(),
"Precision": tor_df["Precision"].mean(),
"Recall": tor_df["Recall"].mean(),
"Accuracy": tor_df["Accuracy"].mean(),
"Support": tor_df["Support"].sum()
}
stats_metrics_list.append(tor_metrics)
ter_df = metrics_document_df[metrics_document_df["DataPoint"] == "ter"]
if len(ter_df) > 0:
ter_metrics = {
"DocumentId": "All",
"DataPoint": "ter",
"F1": ter_df["F1"].mean(),
"Precision": ter_df["Precision"].mean(),
"Recall": ter_df["Recall"].mean(),
"Accuracy": ter_df["Accuracy"].mean(),
"Support": ter_df["Support"].sum()
}
stats_metrics_list.append(ter_metrics)
ogc_df = metrics_document_df[metrics_document_df["DataPoint"] == "ogc"]
if len(ogc_df) > 0:
ogc_metrics = {
"DocumentId": "All",
"DataPoint": "ogc",
"F1": ogc_df["F1"].mean(),
"Precision": ogc_df["Precision"].mean(),
"Recall": ogc_df["Recall"].mean(),
"Accuracy": ogc_df["Accuracy"].mean(),
"Support": ogc_df["Support"].sum()
}
stats_metrics_list.append(ogc_metrics)
performance_fee_df = metrics_document_df[metrics_document_df["DataPoint"] == "performance_fee"]
if len(performance_fee_df) > 0:
performance_fee_metrics = {
"DocumentId": "All",
"DataPoint": "performance_fee",
"F1": performance_fee_df["F1"].mean(),
"Precision": performance_fee_df["Precision"].mean(),
"Recall": performance_fee_df["Recall"].mean(),
"Accuracy": performance_fee_df["Accuracy"].mean(),
"Support": performance_fee_df["Support"].sum()
}
stats_metrics_list.append(performance_fee_metrics)
average_df = metrics_document_df[metrics_document_df["DataPoint"] == "average"]
if len(average_df) > 0:
avg_metrics = {
"DocumentId": "All",
"DataPoint": "average",
"F1": average_df["F1"].mean(),
"Precision": average_df["Precision"].mean(),
"Recall": average_df["Recall"].mean(),
"Accuracy": average_df["Accuracy"].mean(),
"Support": average_df["Support"].sum()
}
stats_metrics_list.append(avg_metrics)
stats_metrics_df = pd.DataFrame(stats_metrics_list)
metrics_df_list.append(stats_metrics_df)
all_metrics_df = pd.concat(metrics_df_list, ignore_index=True)
all_metrics_df.reset_index(drop=True, inplace=True)
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
verify_file_name = "complex_mapping_data_info_31_documents_by_text_second_round_metrics_remain_7.xlsx"
output_metrics_file = os.path.join(output_folder, verify_file_name)
with pd.ExcelWriter(output_metrics_file) as writer:
all_metrics_df.to_excel(writer, index=False, sheet_name="metrics")
return all_metrics_df
if __name__ == "__main__":
file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
verify_file = "mapping_data_info_31_documents_by_text_second_round.xlsx"
verify_file_path = os.path.join(file_folder, verify_file)
document_list = [
"334584772",
"337293427",
"337937633",
"404712928",
"406913630",
"407275419",
"422686965",
"422760148",
"422760156",
"422761666",
"423364758",
"423365707",
"423395975",
"423418395",
"423418540",
"425595958",
"451063582",
"451878128",
"466580448",
"481482392",
"508704368",
"532998065",
"536344026",
"540307575"
]
calculate_complex_document_metrics(verify_file_path=verify_file_path,
document_list=document_list)
document_list=None)
document_list = ["492029971",
"510300817",
"512745032",
"514213638",
"527525440",
"534535767"]
metrics_file = "complex_mapping_data_info_31_documents_by_text_second_round_metrics_all.xlsx"
metrics_file_path = os.path.join(file_folder, metrics_file)
# get_metrics_based_documents(metrics_file=metrics_file_path,
# document_list=document_list)