From d673a99e2185e009e6de2a285457aa291febe141 Mon Sep 17 00:00:00 2001 From: Blade He Date: Tue, 10 Dec 2024 16:17:47 -0600 Subject: [PATCH] switch back to extract data from image stream directly, instead of getting text from image stream as the first step, then extract data from extracted text. The reason is: the quality of getting text from image steam is not good enough. --- core/data_extraction.py | 112 +++++-- drilldown_practice.py | 2 +- .../data_extraction_prompts_config.json | 2 +- main.py | 5 +- specific_calc_metrics.py | 280 +++++++++++++----- 5 files changed, 302 insertions(+), 99 deletions(-) diff --git a/core/data_extraction.py b/core/data_extraction.py index 6a992df..a8a9cd0 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -171,6 +171,8 @@ class DataExtraction: previous_page_datapoints = [] previous_page_fund_name = None for page_num, page_text in self.page_text_dict.items(): + # if page_num > 640 or page_num < 610: + # continue if page_num in handled_page_num_list: continue page_datapoints = self.get_datapoints_by_page_num(page_num) @@ -278,6 +280,7 @@ class DataExtraction: if not exist_current_page_datapoint: break else: + data_list.append(next_page_extract_data) break count += 1 except Exception as e: @@ -336,7 +339,8 @@ class DataExtraction: # try to get data by current page_datapoints next_page_extract_data = self.extract_data_by_page_image( page_num=next_page_num, - page_datapoints=next_datapoints + page_datapoints=next_datapoints, + need_extract_text=False ) next_page_data_list = next_page_extract_data.get( "extract_data", {} @@ -403,7 +407,8 @@ class DataExtraction: page_datapoints=page_datapoints, need_exclude=False, exclude_data=None, - previous_page_last_fund=previous_page_last_fund) + previous_page_last_fund=previous_page_last_fund, + need_extract_text=False) else: return self.extract_data_by_page_text( page_num=page_num, @@ -480,6 +485,55 @@ class DataExtraction: return data_dict def extract_data_by_page_image( + self, + page_num: int, + page_datapoints: list, + need_exclude: bool = False, + exclude_data: list = None, + previous_page_last_fund: str = None, + need_extract_text: bool = False + ) -> dict: + """ + keys are + doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name + """ + if need_extract_text: + logger.info(f"Extracting data from page {page_num} with extracting text as single step.") + page_text = self.get_image_text(page_num) + if page_text is None or len(page_text) == 0: + data_dict = {"doc_id": self.doc_id} + data_dict["page_index"] = page_num + data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["page_text"] = "" + data_dict["instructions"] = "" + data_dict["raw_answer"] = "" + data_dict["extract_data"] = {"data": []} + data_dict["extract_way"] = "image" + return data_dict + else: + if previous_page_last_fund is not None and len(previous_page_last_fund) > 0: + logger.info(f"Transfer previous page fund name: {previous_page_last_fund} to be the pre-fix of page text") + page_text = f"\nThe last fund name of previous PDF page: {previous_page_last_fund}\n{page_text}" + return self.extract_data_by_page_text( + page_num=page_num, + page_text=page_text, + page_datapoints=page_datapoints, + need_exclude=need_exclude, + exclude_data=exclude_data, + previous_page_last_fund=previous_page_last_fund, + original_way="image" + ) + else: + logger.info(f"Extracting data from page {page_num} without extracting text as single step.") + return self.extract_data_by_pure_image( + page_num=page_num, + page_datapoints=page_datapoints, + need_exclude=need_exclude, + exclude_data=exclude_data, + previous_page_last_fund=previous_page_last_fund + ) + + def extract_data_by_pure_image( self, page_num: int, page_datapoints: list, @@ -491,32 +545,46 @@ class DataExtraction: keys are doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name """ - logger.info(f"Extracting data from page {page_num}") - # image_base64 = self.get_pdf_image_base64(page_num) - page_text = self.get_image_text(page_num) - if page_text is None or len(page_text) == 0: + image_base64 = self.get_pdf_image_base64(page_num) + instructions = self.get_instructions_by_datapoints( + previous_page_last_fund, + page_datapoints, + need_exclude=need_exclude, + exclude_data=exclude_data, + extract_way="image" + ) + response, with_error = chat( + instructions, response_format={"type": "json_object"}, image_base64=image_base64 + ) + if with_error: + logger.error(f"Error in extracting tables from page") data_dict = {"doc_id": self.doc_id} data_dict["page_index"] = page_num data_dict["datapoints"] = ", ".join(page_datapoints) data_dict["page_text"] = "" - data_dict["instructions"] = "" - data_dict["raw_answer"] = "" + data_dict["instructions"] = instructions + data_dict["raw_answer"] = response data_dict["extract_data"] = {"data": []} data_dict["extract_way"] = "image" return data_dict - else: - if previous_page_last_fund is not None and len(previous_page_last_fund) > 0: - logger.info(f"Transfer previous page fund name: {previous_page_last_fund} to be the pre-fix of page text") - page_text = f"\nThe last fund name of previous PDF page: {previous_page_last_fund}\n{page_text}" - return self.extract_data_by_page_text( - page_num=page_num, - page_text=page_text, - page_datapoints=page_datapoints, - need_exclude=need_exclude, - exclude_data=exclude_data, - previous_page_last_fund=previous_page_last_fund, - original_way="image" - ) + try: + data = json.loads(response) + except: + try: + data = json_repair.loads(response) + except: + data = {"data": []} + data = self.validate_data(data, None, previous_page_last_fund) + + data_dict = {"doc_id": self.doc_id} + data_dict["page_index"] = page_num + data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["page_text"] = "" + data_dict["instructions"] = instructions + data_dict["raw_answer"] = response + data_dict["extract_data"] = data + data_dict["extract_way"] = "image" + return data_dict def get_image_text(self, page_num: int) -> str: image_base64 = self.get_pdf_image_base64(page_num) @@ -536,6 +604,7 @@ class DataExtraction: except: pass text = data.get("text", "") + # print(text) return text def validate_data(self, @@ -790,6 +859,7 @@ class DataExtraction: elif extract_way == "image": summary = self.instructions_config.get("summary_image", "\n") if page_text is not None and len(page_text) > 0: + logger.info(f"Transfer previous page fund name: {page_text} to be the pre-fix of page text") summary += f"\nThe last fund name of previous PDF page: {page_text}\n" else: summary = self.instructions_config.get("summary", "\n") diff --git a/drilldown_practice.py b/drilldown_practice.py index 3869179..a536a32 100644 --- a/drilldown_practice.py +++ b/drilldown_practice.py @@ -157,4 +157,4 @@ def calculate_metrics(): if __name__ == "__main__": drilldown_documents() - calculate_metrics() \ No newline at end of file + # calculate_metrics() \ No newline at end of file diff --git a/instructions/data_extraction_prompts_config.json b/instructions/data_extraction_prompts_config.json index 2e4eaa7..f56caab 100644 --- a/instructions/data_extraction_prompts_config.json +++ b/instructions/data_extraction_prompts_config.json @@ -1,7 +1,7 @@ { "summary": "Read the context carefully.\nMaybe exists {} data in the context.\n", "summary_image": "Read the image carefully.\nMaybe exists {} data in the image.\n", - "get_image_text": "Instructions: Please extract the text from the image. output the result as a JSON, the JSON format is like below example(s): {\"text\": \"Text from image\"} \n\nAnswer:\n", + "get_image_text": "Instructions:\nYou are given an image of a page from a PDF document. Extract **all visible text** from the image while preserving the original order, structure, and any associated context as closely as possible. Ensure that:\n\n1. **All textual elements are included**, such as headings, body text, tables, and labels.\n2. **Numerical data, symbols, and special characters** are preserved accurately.\n3. Text in structured formats (e.g., tables, lists) is retained in a logical and readable format.\n4. Any text embedded in graphical elements, if clearly readable, is also included.\n5. The text is clean, readable, and free of formatting artifacts or errors.\n\nDo not include non-textual elements such as images or graphics unless they contain text that can be meaningfully extracted.\n\n### Output Format:\nOutput the result as JSON format, here is the example: \n{\"text\": \"Text from image\"}\n\nAnswer: \n[Extracted Text Here, retaining logical structure and all content]", "image_features": [ "1. Identify the text in the PDF page image.", diff --git a/main.py b/main.py index 7fd97fd..f3862cb 100644 --- a/main.py +++ b/main.py @@ -887,10 +887,11 @@ def batch_run_documents(): calculate_metrics = False extract_way = "text" - special_doc_id_list = [] + special_doc_id_list = ["435128656"] if len(special_doc_id_list) == 0: force_save_total_data = True - file_base_name_candidates = ["sample_document_complex", "emea_case_from_word_complex"] + # file_base_name_candidates = ["sample_document_complex", "emea_case_from_word_complex"] + file_base_name_candidates = ["sample_document_complex"] for document_list_file in document_list_files: file_base_name = os.path.basename(document_list_file).replace(".txt", "") if (file_base_name_candidates is not None and diff --git a/specific_calc_metrics.py b/specific_calc_metrics.py index 620de34..8254fd8 100644 --- a/specific_calc_metrics.py +++ b/specific_calc_metrics.py @@ -10,68 +10,117 @@ from utils.logger import logger def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []): - data_df = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping") + data_df_1 = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping") # convert doc_id column to string - data_df["doc_id"] = data_df["doc_id"].astype(str) - data_df = data_df[data_df["raw_check"].isin([0, 1])] + data_df_1["doc_id"] = data_df_1["doc_id"].astype(str) + data_df_1 = data_df_1[data_df_1["raw_check"].isin([0, 1])] + + exclude_documents = ["532422548"] + # remove data by doc_id not in exclude_documents + data_df_1 = data_df_1[~data_df_1["doc_id"].isin(exclude_documents)] if document_list is not None and len(document_list) > 0: - data_df = data_df[data_df["doc_id"].isin(document_list)] + data_df_1 = data_df_1[data_df_1["doc_id"].isin(document_list)] + + data_df_2 = pd.read_excel(verify_file_path, sheet_name="total_mapping_data") + data_df_2["doc_id"] = data_df_2["doc_id"].astype(str) + data_df_2 = data_df_2[data_df_2["raw_check"].isin([0, 1])] + + data_df = pd.concat([data_df_1, data_df_2], ignore_index=True) data_df.fillna("", inplace=True) data_df.reset_index(drop=True, inplace=True) + metrics_df_list = [] + doc_id_list = data_df["doc_id"].unique().tolist() + for doc_id in tqdm(doc_id_list): + try: + document_data_df = data_df[data_df["doc_id"] == doc_id] + document_metrics_df = calc_metrics(document_data_df, doc_id) + metrics_df_list.append(document_metrics_df) + except Exception as e: + logger.error(f"Error when calculating metrics for document {doc_id}") + print_exc() + + total_metrics_df = calc_metrics(data_df, doc_id=None) + metrics_df_list.append(total_metrics_df) + + all_metrics_df = pd.concat(metrics_df_list, ignore_index=True) + all_metrics_df.reset_index(drop=True, inplace=True) + + output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/" + verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "") + output_metrics_file = os.path.join(output_folder, + f"complex_{verify_file_name}_metrics_all.xlsx") + with pd.ExcelWriter(output_metrics_file) as writer: + all_metrics_df.to_excel(writer, index=False, sheet_name="metrics") + + +def calc_metrics(data_df: pd.DataFrame, doc_id: str = None): # tor data tor_data_df = data_df[data_df["datapoint"] == "tor"] - tor_metrics = get_sub_metrics(tor_data_df, "tor") - logger.info(f"TOR metrics: {tor_metrics}") + if len(tor_data_df) > 0: + tor_metrics = get_sub_metrics(tor_data_df, "tor", doc_id) + logger.info(f"TOR metrics: {tor_metrics}") + else: + tor_metrics = None # ter data ter_data_df = data_df[data_df["datapoint"] == "ter"] - ter_metrics = get_sub_metrics(ter_data_df, "ter") - logger.info(f"TER metrics: {ter_metrics}") + if len(ter_data_df) > 0: + ter_metrics = get_sub_metrics(ter_data_df, "ter", doc_id) + logger.info(f"TER metrics: {ter_metrics}") + else: + ter_metrics = None # ogc data ogc_data_df = data_df[data_df["datapoint"] == "ogc"] - ogc_metrics = get_sub_metrics(ogc_data_df, "ogc") - logger.info(f"OGC metrics: {ogc_metrics}") + if len(ogc_data_df) > 0: + ogc_metrics = get_sub_metrics(ogc_data_df, "ogc", doc_id) + logger.info(f"OGC metrics: {ogc_metrics}") + else: + ogc_metrics = None # performance_fee data performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"] - performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee") - logger.info(f"Performance fee metrics: {performance_fee_metrics}") + if len(performance_fee_data_df) > 0: + performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee", doc_id) + logger.info(f"Performance fee metrics: {performance_fee_metrics}") + else: + performance_fee_metrics = None - metrics_df = pd.DataFrame([tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics]) + metrics_candidates = [tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics] + metrics_list = [metrics for metrics in metrics_candidates if metrics is not None] + metrics_df = pd.DataFrame(metrics_list) # add average metrics - avg_metrics = { - "DataPoint": "average", - "F1": metrics_df["F1"].mean(), - "Precision": metrics_df["Precision"].mean(), - "Recall": metrics_df["Recall"].mean(), - "Accuracy": metrics_df["Accuracy"].mean(), - "Support": metrics_df["Support"].sum() - } + if doc_id is not None and len(doc_id) > 0: + avg_metrics = { + "DocumentId": doc_id, + "DataPoint": "average", + "F1": metrics_df["F1"].mean(), + "Precision": metrics_df["Precision"].mean(), + "Recall": metrics_df["Recall"].mean(), + "Accuracy": metrics_df["Accuracy"].mean(), + "Support": metrics_df["Support"].sum() + } + else: + avg_metrics = { + "DocumentId": "All", + "DataPoint": "average", + "F1": metrics_df["F1"].mean(), + "Precision": metrics_df["Precision"].mean(), + "Recall": metrics_df["Recall"].mean(), + "Accuracy": metrics_df["Accuracy"].mean(), + "Support": metrics_df["Support"].sum() + } - metrics_df = pd.DataFrame([tor_metrics, ter_metrics, - ogc_metrics, performance_fee_metrics, - avg_metrics]) + metrics_list.append(avg_metrics) + metrics_df = pd.DataFrame(metrics_list) metrics_df.reset_index(drop=True, inplace=True) - - output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/" - - document_count = len(document_list) \ - if document_list is not None and len(document_list) > 0 \ - else len(data_df["doc_id"].unique()) - - verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "") - - output_metrics_file = os.path.join(output_folder, - f"complex_{verify_file_name}_metrics.xlsx") - with pd.ExcelWriter(output_metrics_file) as writer: - metrics_df.to_excel(writer, index=False, sheet_name="metrics") + return metrics_df -def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict: +def get_sub_metrics(data_df: pd.DataFrame, data_point: str, doc_id: str = None) -> dict: data_df_raw_check_1 = data_df[data_df["raw_check"] == 1] gt_list = [1] * len(data_df_raw_check_1) pre_list = [1] * len(data_df_raw_check_1) @@ -99,47 +148,130 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict: recall = recall_score(gt_list, pre_list) f1 = f1_score(gt_list, pre_list) support = sum(gt_list) - - metrics = { - "DataPoint": data_point, - "F1": f1, - "Precision": precision, - "Recall": recall, - "Accuracy": accuracy, - "Support": support - } + if doc_id is not None and len(doc_id) > 0: + metrics = { + "DocumentId": doc_id, + "DataPoint": data_point, + "F1": f1, + "Precision": precision, + "Recall": recall, + "Accuracy": accuracy, + "Support": support + } + else: + metrics = { + "DocumentId": "All", + "DataPoint": data_point, + "F1": f1, + "Precision": precision, + "Recall": recall, + "Accuracy": accuracy, + "Support": support + } return metrics +def get_metrics_based_documents(metrics_file: str, document_list: list): + metrics_df = pd.read_excel(metrics_file, sheet_name="metrics") + metrics_df_list = [] + for doc_id in tqdm(document_list): + try: + document_metrics_df = metrics_df[metrics_df["DocumentId"] == doc_id] + metrics_df_list.append(document_metrics_df) + except Exception as e: + logger.error(f"Error when calculating metrics for document {doc_id}") + print_exc() + metrics_document_df = pd.concat(metrics_df_list, ignore_index=True) + + stats_metrics_list = [] + tor_df = metrics_document_df[metrics_document_df["DataPoint"] == "tor"] + if len(tor_df) > 0: + tor_metrics = { + "DocumentId": "All", + "DataPoint": "tor", + "F1": tor_df["F1"].mean(), + "Precision": tor_df["Precision"].mean(), + "Recall": tor_df["Recall"].mean(), + "Accuracy": tor_df["Accuracy"].mean(), + "Support": tor_df["Support"].sum() + } + stats_metrics_list.append(tor_metrics) + ter_df = metrics_document_df[metrics_document_df["DataPoint"] == "ter"] + if len(ter_df) > 0: + ter_metrics = { + "DocumentId": "All", + "DataPoint": "ter", + "F1": ter_df["F1"].mean(), + "Precision": ter_df["Precision"].mean(), + "Recall": ter_df["Recall"].mean(), + "Accuracy": ter_df["Accuracy"].mean(), + "Support": ter_df["Support"].sum() + } + stats_metrics_list.append(ter_metrics) + ogc_df = metrics_document_df[metrics_document_df["DataPoint"] == "ogc"] + if len(ogc_df) > 0: + ogc_metrics = { + "DocumentId": "All", + "DataPoint": "ogc", + "F1": ogc_df["F1"].mean(), + "Precision": ogc_df["Precision"].mean(), + "Recall": ogc_df["Recall"].mean(), + "Accuracy": ogc_df["Accuracy"].mean(), + "Support": ogc_df["Support"].sum() + } + stats_metrics_list.append(ogc_metrics) + performance_fee_df = metrics_document_df[metrics_document_df["DataPoint"] == "performance_fee"] + if len(performance_fee_df) > 0: + performance_fee_metrics = { + "DocumentId": "All", + "DataPoint": "performance_fee", + "F1": performance_fee_df["F1"].mean(), + "Precision": performance_fee_df["Precision"].mean(), + "Recall": performance_fee_df["Recall"].mean(), + "Accuracy": performance_fee_df["Accuracy"].mean(), + "Support": performance_fee_df["Support"].sum() + } + stats_metrics_list.append(performance_fee_metrics) + average_df = metrics_document_df[metrics_document_df["DataPoint"] == "average"] + if len(average_df) > 0: + avg_metrics = { + "DocumentId": "All", + "DataPoint": "average", + "F1": average_df["F1"].mean(), + "Precision": average_df["Precision"].mean(), + "Recall": average_df["Recall"].mean(), + "Accuracy": average_df["Accuracy"].mean(), + "Support": average_df["Support"].sum() + } + stats_metrics_list.append(avg_metrics) + + stats_metrics_df = pd.DataFrame(stats_metrics_list) + metrics_df_list.append(stats_metrics_df) + all_metrics_df = pd.concat(metrics_df_list, ignore_index=True) + all_metrics_df.reset_index(drop=True, inplace=True) + + output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/" + verify_file_name = "complex_mapping_data_info_31_documents_by_text_second_round_metrics_remain_7.xlsx" + output_metrics_file = os.path.join(output_folder, verify_file_name) + with pd.ExcelWriter(output_metrics_file) as writer: + all_metrics_df.to_excel(writer, index=False, sheet_name="metrics") + + return all_metrics_df + + if __name__ == "__main__": file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/" verify_file = "mapping_data_info_31_documents_by_text_second_round.xlsx" verify_file_path = os.path.join(file_folder, verify_file) - document_list = [ - "334584772", - "337293427", - "337937633", - "404712928", - "406913630", - "407275419", - "422686965", - "422760148", - "422760156", - "422761666", - "423364758", - "423365707", - "423395975", - "423418395", - "423418540", - "425595958", - "451063582", - "451878128", - "466580448", - "481482392", - "508704368", - "532998065", - "536344026", - "540307575" - ] calculate_complex_document_metrics(verify_file_path=verify_file_path, - document_list=document_list) \ No newline at end of file + document_list=None) + document_list = ["492029971", + "510300817", + "512745032", + "514213638", + "527525440", + "534535767"] + metrics_file = "complex_mapping_data_info_31_documents_by_text_second_round_metrics_all.xlsx" + metrics_file_path = os.path.join(file_folder, metrics_file) + # get_metrics_based_documents(metrics_file=metrics_file_path, + # document_list=document_list) \ No newline at end of file