diff --git a/core/data_extraction.py b/core/data_extraction.py index 3bbc0e6..fa230a8 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -21,6 +21,8 @@ class DataExtraction: datapoint_page_info: dict, datapoints: list, document_mapping_info_df: pd.DataFrame, + extract_way: str = "text", + output_image_folder: str = None, ) -> None: self.doc_id = doc_id self.pdf_file = pdf_file @@ -48,7 +50,14 @@ class DataExtraction: self.instructions_config = self.get_instructions_config() self.datapoint_level_config = self.get_datapoint_level() self.datapoint_name_config = self.get_datapoint_name() - + self.extract_way = extract_way + self.output_image_folder = output_image_folder + + def get_pdf_image_base64(self, page_index: int) -> dict: + pdf_util = PDFUtil(self.pdf_file) + return pdf_util.extract_image_from_page(page_index=page_index, + output_folder=self.output_image_folder) + def get_instructions_config(self) -> dict: instructions_config_file = r"./instructions/data_extraction_prompts_config.json" with open(instructions_config_file, "r", encoding="utf-8") as f: @@ -82,8 +91,17 @@ class DataExtraction: # sort the page numbers page_nums_with_datapoints.sort() return page_nums_with_datapoints - + def extract_data(self) -> dict: + logger.info(f"Extracting data from document {self.doc_id}, extract way: {self.extract_way}") + if self.extract_way == "text": + return self.extract_data_by_text() + elif self.extract_way == "image": + return self.extract_data_by_image() + else: + return self.extract_data_by_text() + + def extract_data_by_text(self) -> dict: """ keys are doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name @@ -97,7 +115,7 @@ class DataExtraction: page_datapoints = self.get_datapoints_by_page_num(page_num) if len(page_datapoints) == 0: continue - extract_data = self.extract_data_by_page( + extract_data = self.extract_data_by_page_text( page_num, page_text, page_datapoints, @@ -140,7 +158,7 @@ class DataExtraction: next_page_text = self.page_text_dict.get(next_page_num, "") target_text = current_text + next_page_text # try to get data by current page_datapoints - next_page_extract_data = self.extract_data_by_page( + next_page_extract_data = self.extract_data_by_page_text( next_page_num, target_text, next_datapoints, @@ -177,6 +195,90 @@ class DataExtraction: logger.error(f"Error in extracting data from next page: {e}") break + self.output_data_to_file(data_list) + + return data_list + + def extract_data_by_image(self) -> dict: + """ + keys are + doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name + """ + data_list = [] + pdf_page_count = len(self.page_text_dict.keys()) + handled_page_num_list = [] + for page_num, page_text in self.page_text_dict.items(): + if page_num in handled_page_num_list: + continue + page_datapoints = self.get_datapoints_by_page_num(page_num) + if len(page_datapoints) == 0: + continue + + extract_data = self.extract_data_by_page_image(page_num=page_num, + page_datapoints=page_datapoints) + data_list.append(extract_data) + + page_data_list = extract_data.get("extract_data", {}).get("data", []) + + current_page_data_count = len(page_data_list) + if current_page_data_count > 0: + count = 1 + + while count < 3: + try: + next_page_num = page_num + count + if next_page_num >= pdf_page_count: + break + next_datapoints = page_datapoints + if next_page_num in self.page_nums_with_datapoints: + should_continue = False + next_datapoints = self.get_datapoints_by_page_num(next_page_num) + if len(next_datapoints) == 0: + should_continue = True + else: + for next_datapoint in next_datapoints: + if next_datapoint not in page_datapoints: + should_continue = True + break + next_datapoints.extend(page_datapoints) + # remove duplicate datapoints + next_datapoints = list(set(next_datapoints)) + if not should_continue: + break + # try to get data by current page_datapoints + next_page_extract_data = self.extract_data_by_page_image( + page_num=next_page_num, + page_datapoints=next_datapoints + ) + next_page_data_list = next_page_extract_data.get( + "extract_data", {} + ).get("data", []) + + if next_page_data_list is not None and len(next_page_data_list) > 0: + data_list.append(next_page_extract_data) + handled_page_num_list.append(next_page_num) + exist_current_page_datapoint = False + for next_page_data in next_page_data_list: + for page_datapoint in page_datapoints: + if page_datapoint in list(next_page_data.keys()): + exist_current_page_datapoint = True + break + if exist_current_page_datapoint: + break + if not exist_current_page_datapoint: + break + else: + break + count += 1 + except Exception as e: + logger.error(f"Error in extracting data from next page: {e}") + break + + self.output_data_to_file(data_list) + + return data_list + + def output_data_to_file(self, data_list: list) -> None: json_data_file = os.path.join( self.output_data_json_folder, f"{self.doc_id}.json" ) @@ -190,10 +292,8 @@ class DataExtraction: ) with pd.ExcelWriter(excel_data_file) as writer: data_df.to_excel(writer, sheet_name="extract_data", index=False) - - return data_list - - def extract_data_by_page( + + def extract_data_by_page_text( self, page_num: int, page_text: str, @@ -246,6 +346,49 @@ class DataExtraction: data_dict["extract_data"] = data return data_dict + def extract_data_by_page_image( + self, + page_num: int, + page_datapoints: list + ) -> dict: + """ + keys are + doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name + """ + logger.info(f"Extracting data from page {page_num}") + image_base64 = self.get_pdf_image_base64(page_num) + instructions = self.get_instructions_by_datapoints( + "", page_datapoints, need_exclude=False, exclude_data=None + ) + response, with_error = chat( + instructions, response_format={"type": "json_object"}, image_base64=image_base64 + ) + if with_error: + logger.error(f"Error in extracting tables from page") + data_dict = {"doc_id": self.doc_id} + data_dict["page_index"] = page_num + data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["instructions"] = instructions + data_dict["raw_answer"] = response + data_dict["extract_data"] = {"data": []} + return data_dict + try: + data = json.loads(response) + except: + try: + data = json_repair.loads(response) + except: + data = {"data": []} + data = self.validate_data(data) + + data_dict = {"doc_id": self.doc_id} + data_dict["page_index"] = page_num + data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["instructions"] = instructions + data_dict["raw_answer"] = response + data_dict["extract_data"] = data + return data_dict + def chat_by_split_context(self, page_text: str, page_datapoints: list, @@ -412,15 +555,29 @@ class DataExtraction: performance_fee_value: list end """ - instructions = [f"Context:\n{page_text}\n\nInstructions:\n"] + instructions = [] + if self.extract_way == "text": + instructions = [f"Context:\n{page_text}\n\nInstructions:\n"] + datapoint_name_list = [] for datapoint in datapoints: datapoint_name = self.datapoint_name_config.get(datapoint, "") datapoint_name_list.append(datapoint_name) - summary = self.instructions_config.get("summary", "\n") + if self.extract_way == "text": + summary = self.instructions_config.get("summary", "\n") + elif self.extract_way == "image": + summary = self.instructions_config.get("summary_image", "\n") + else: + summary = self.instructions_config.get("summary", "\n") + instructions.append(summary.format(", ".join(datapoint_name_list))) instructions.append("\n") + + if self.extract_way == "image": + image_features = self.instructions_config.get("image_features", []) + instructions.extend(image_features) + instructions.append("\n") instructions.append("Datapoints Reported name:\n") reported_name_info = self.instructions_config.get("reported_name", {}) diff --git a/core/data_mapping.py b/core/data_mapping.py index 6845ce1..155a9a0 100644 --- a/core/data_mapping.py +++ b/core/data_mapping.py @@ -104,6 +104,7 @@ class DataMapping: raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name """ + logger.info(f"Mapping raw data for document {self.doc_id}") mapped_data_list = [] mapped_fund_cache = {} mapped_share_cache = {} diff --git a/core/metrics.py b/core/metrics.py index 895f587..3fac16d 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -302,13 +302,13 @@ class Metrics: dp_ground_truth["unique_words"] = dp_ground_truth["raw_name"].apply( get_unique_words_text ) - ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist() + ground_truth_unique_words_list = dp_ground_truth["unique_words"].unique().tolist() ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist() # add new column to store unique words for dp_prediction dp_prediction["unique_words"] = dp_prediction["raw_name"].apply( get_unique_words_text ) - pred_unique_words = dp_prediction["unique_words"].unique().tolist() + pred_unique_words_list = dp_prediction["unique_words"].unique().tolist() pred_raw_names = dp_prediction["raw_name"].unique().tolist() true_data = [] @@ -330,9 +330,9 @@ class Metrics: find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name] - if pred_unique_words in ground_truth_unique_words or len(find_raw_name_in_gt) > 0: + if pred_unique_words in ground_truth_unique_words_list or len(find_raw_name_in_gt) > 0: # get the ground truth data with the same unique words - if pred_unique_words in ground_truth_unique_words: + if pred_unique_words in ground_truth_unique_words_list: gt_data = dp_ground_truth[ dp_ground_truth["unique_words"] == pred_unique_words ].iloc[0] @@ -383,7 +383,7 @@ class Metrics: find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name] - if gt_unique_words not in pred_unique_words and \ + if gt_unique_words not in pred_unique_words_list and \ len(find_raw_name_in_pred) == 0: true_data.append(1) pred_data.append(0) @@ -394,7 +394,7 @@ class Metrics: "pred_raw_name": "", "investment_type": gt_investment_type, "error_type": "raw name missing", - "error_value": pred_data_point_value, + "error_value": "", "correct_value": gt_raw_name, } missing_error_data.append(error_data) diff --git a/instructions/data_extraction_prompts_config.json b/instructions/data_extraction_prompts_config.json index 094d9b4..e012faf 100644 --- a/instructions/data_extraction_prompts_config.json +++ b/instructions/data_extraction_prompts_config.json @@ -1,5 +1,18 @@ { "summary": "Read the context carefully.\nMaybe exists {} data in the context.\n", + "summary_image": "Read the image carefully.\nMaybe exists {} data in the image.\n", + "image_features": + [ + "1. Identify the text in the PDF page image.", + "2. Identify and format the all of tables in the PDF page image.", + "Table contents should be as markdown format,", + "ensuring the table structure and contents are exactly as in the PDF page image.", + "The format should be: |Column1|Column2|\n|---|---|\n|Row1Col1|Row1Col2|", + "Each cell in the table(s) should be in the proper position of relevant row and column.", + " 3. Extract data from upon parsed text and table(s) contents.", + "3.1 The upon parsed text and table(s) contents as context.", + "3.2 Please extract data from the context." + ], "reported_name": { "tor": "The TOR reported name could be:\nTOR, Turnover Ratio, Portfolio Turnover, Portfolio turnover ratio, PTR, etc.", "ogc": "The OGC reported name could be:\nOGC, OGF, Ongoing Charge, Operation Charge, Ongoing charges in per cent, Ongoing charges in percent, Ongoing charges as a percentage, On Going Charges, Operating Charge, Ongoing Fund Charge, etc.", diff --git a/main.py b/main.py index 125cd57..a566652 100644 --- a/main.py +++ b/main.py @@ -20,20 +20,37 @@ class EMEA_AR_Parsing: pdf_folder: str = r"/data/emea_ar/pdf/", output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/", output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/", + extract_way: str = "text", ) -> None: self.doc_id = doc_id self.pdf_folder = pdf_folder os.makedirs(self.pdf_folder, exist_ok=True) self.pdf_file = self.download_pdf() self.document_mapping_info_df = query_document_fund_mapping(doc_id) + + if extract_way is None or len(extract_way) == 0: + extract_way = "text" + self.extract_way = extract_way + self.output_extract_image_folder = None + if self.extract_way == "image": + self.output_extract_image_folder = r"/data/emea_ar/output/extract_data/images/" + os.makedirs(self.output_extract_image_folder, exist_ok=True) if output_extract_data_folder is None or len(output_extract_data_folder) == 0: output_extract_data_folder = r"/data/emea_ar/output/extract_data/docs/" + if not output_extract_data_folder.endswith("/"): + output_extract_data_folder = f"{output_extract_data_folder}/" + if extract_way is not None and len(extract_way) > 0: + output_extract_data_folder = f"{output_extract_data_folder}by_{extract_way}/" self.output_extract_data_folder = output_extract_data_folder os.makedirs(self.output_extract_data_folder, exist_ok=True) if output_mapping_data_folder is None or len(output_mapping_data_folder) == 0: output_mapping_data_folder = r"/data/emea_ar/output/mapping_data/docs/" + if not output_mapping_data_folder.endswith("/"): + output_mapping_data_folder = f"{output_mapping_data_folder}/" + if extract_way is not None and len(extract_way) > 0: + output_mapping_data_folder = f"{output_mapping_data_folder}by_{extract_way}/" self.output_mapping_data_folder = output_mapping_data_folder os.makedirs(self.output_mapping_data_folder, exist_ok=True) @@ -58,7 +75,8 @@ class EMEA_AR_Parsing: datapoints.remove("doc_id") return datapoints - def extract_data(self, re_run: bool = False) -> list: + def extract_data(self, + re_run: bool = False,) -> list: if not re_run: output_data_json_folder = os.path.join( self.output_extract_data_folder, "json/" @@ -81,6 +99,8 @@ class EMEA_AR_Parsing: self.datapoint_page_info, self.datapoints, self.document_mapping_info_df, + extract_way=self.extract_way, + output_image_folder=self.output_extract_image_folder ) data_from_gpt = data_extraction.extract_data() return data_from_gpt @@ -124,11 +144,18 @@ def filter_pages(doc_id: str, pdf_folder: str) -> None: def extract_data( - doc_id: str, pdf_folder: str, output_data_folder: str, re_run: bool = False + doc_id: str, + pdf_folder: str, + output_data_folder: str, + extract_way: str = "text", + re_run: bool = False ) -> None: logger.info(f"Extract EMEA AR data for doc_id: {doc_id}") emea_ar_parsing = EMEA_AR_Parsing( - doc_id, pdf_folder, output_extract_data_folder=output_data_folder + doc_id, + pdf_folder, + output_extract_data_folder=output_data_folder, + extract_way=extract_way ) data_from_gpt = emea_ar_parsing.extract_data(re_run) return data_from_gpt @@ -139,6 +166,7 @@ def mapping_data( pdf_folder: str, output_extract_data_folder: str, output_mapping_folder: str, + extract_way: str = "text", re_run_extract_data: bool = False, re_run_mapping_data: bool = False, ) -> None: @@ -148,6 +176,7 @@ def mapping_data( pdf_folder, output_extract_data_folder=output_extract_data_folder, output_mapping_data_folder=output_mapping_folder, + extract_way=extract_way, ) doc_data_from_gpt = emea_ar_parsing.extract_data(re_run=re_run_extract_data) doc_mapping_data = emea_ar_parsing.mapping_data( @@ -161,6 +190,7 @@ def batch_extract_data( doc_data_excel_file: str = None, output_child_folder: str = r"/data/emea_ar/output/extract_data/docs/", output_total_folder: str = r"/data/emea_ar/output/extract_data/total/", + extract_way: str = "text", special_doc_id_list: list = None, re_run: bool = False, ) -> None: @@ -188,6 +218,7 @@ def batch_extract_data( doc_id=doc_id, pdf_folder=pdf_folder, output_data_folder=output_child_folder, + extract_way=extract_way, re_run=re_run, ) result_list.extend(data_from_gpt) @@ -214,6 +245,7 @@ def batch_start_job( output_mapping_child_folder: str = r"/data/emea_ar/output/mapping_data/docs/", output_extract_data_total_folder: str = r"/data/emea_ar/output/extract_data/total/", output_mapping_total_folder: str = r"/data/emea_ar/output/mapping_data/total/", + extract_way: str = "text", special_doc_id_list: list = None, re_run_extract_data: bool = False, re_run_mapping_data: bool = False, @@ -245,6 +277,7 @@ def batch_start_job( pdf_folder=pdf_folder, output_extract_data_folder=output_extract_data_child_folder, output_mapping_folder=output_mapping_child_folder, + extract_way=extract_way, re_run_extract_data=re_run_extract_data, re_run_mapping_data=re_run_mapping_data, ) @@ -263,7 +296,7 @@ def batch_start_job( time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) output_file = os.path.join( output_extract_data_total_folder, - f"extract_data_info_{len(pdf_files)}_documents_{time_stamp}.xlsx", + f"extract_data_info_{len(pdf_files)}_documents_by_{extract_way}_{time_stamp}.xlsx", ) with pd.ExcelWriter(output_file) as writer: result_extract_data_df.to_excel( @@ -275,7 +308,7 @@ def batch_start_job( time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) output_file = os.path.join( output_mapping_total_folder, - f"mapping_data_info_{len(pdf_files)}_documents_{time_stamp}.xlsx", + f"mapping_data_info_{len(pdf_files)}_documents_by_{extract_way}_{time_stamp}.xlsx", ) with pd.ExcelWriter(output_file) as writer: result_mappingdata_df.to_excel( @@ -489,7 +522,8 @@ def test_auto_generate_instructions(): def test_data_extraction_metrics(): data_type = "data_extraction" - prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240917121708.xlsx" + prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240919120502.xlsx" + # prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/321733631.xlsx" prediction_sheet_name = "mapping_data" ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" ground_truth_sheet_name = "mapping_data" @@ -536,24 +570,33 @@ if __name__ == "__main__": # ) # doc_id = "476492237" - # extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run) - special_doc_id_list = [] + # extract_way = "image" + # extract_data(doc_id, + # pdf_folder, + # output_extract_data_child_folder, + # extract_way, + # re_run_extract_data) + + special_doc_id_list = ["476492237"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_mapping_data = True + force_save_total_data = False - force_save_total_data = True - # batch_start_job( - # pdf_folder, - # page_filter_ground_truth_file, - # output_extract_data_child_folder, - # output_mapping_child_folder, - # output_extract_data_total_folder, - # output_mapping_total_folder, - # special_doc_id_list, - # re_run_extract_data, - # re_run_mapping_data, - # force_save_total_data=force_save_total_data, - # ) + extract_ways = ["text"] + # for extract_way in extract_ways: + # batch_start_job( + # pdf_folder, + # page_filter_ground_truth_file, + # output_extract_data_child_folder, + # output_mapping_child_folder, + # output_extract_data_total_folder, + # output_mapping_total_folder, + # extract_way, + # special_doc_id_list, + # re_run_extract_data, + # re_run_mapping_data, + # force_save_total_data=force_save_total_data, + # ) test_data_extraction_metrics() diff --git a/utils/pdf_util.py b/utils/pdf_util.py index 4d0af21..1221b29 100644 --- a/utils/pdf_util.py +++ b/utils/pdf_util.py @@ -146,6 +146,36 @@ class PDFUtil: logger.error(f"Error extracting images: {e}") print_exc() return {} + + def extract_image_from_page(self, + page_index: int, + zoom:float = 2.0, + output_folder: str = None): + try: + pdf_doc = fitz.open(self.pdf_file) + try: + pdf_encrypted = pdf_doc.isEncrypted + except: + pdf_encrypted = pdf_doc.is_encrypted + if pdf_encrypted: + pdf_doc.authenticate("") + pdf_base_name = os.path.basename(self.pdf_file).replace(".pdf", "") + mat = fitz.Matrix(zoom, zoom) + page = pdf_doc[page_index] + pix = page.get_pixmap(matrix=mat) + img_buffer = pix.tobytes(output='png') + img_base64 = base64.b64encode(img_buffer).decode('utf-8') + if output_folder and len(output_folder) > 0: + os.makedirs(output_folder, exist_ok=True) + image_file = os.path.join(output_folder, f"{pdf_base_name}_{page_index}.png") + pix.save(image_file) + pdf_doc.close() + return img_base64 + except Exception as e: + logger.error(f"Error extracting image from page: {e}") + print_exc() + return None + def parse_blocks_page(self, page: fitz.Page): blocks = page.get_text("blocks")