From f166e73362c233d7cbf475afb0656800376f3863 Mon Sep 17 00:00:00 2001 From: Blade He Date: Tue, 15 Oct 2024 15:57:54 -0500 Subject: [PATCH] optimize data extraction algorithm: if can't find cost numeric value from PDF page text, then extract data by Vision ChatGPT --- core/data_extraction.py | 309 ++++++++++++++++++++++------------------ main.py | 4 +- utils/gpt_utils.py | 5 +- 3 files changed, 177 insertions(+), 141 deletions(-) diff --git a/core/data_extraction.py b/core/data_extraction.py index 9fa0541..77fd8a1 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -136,7 +136,7 @@ class DataExtraction: page_datapoints = self.get_datapoints_by_page_num(page_num) if len(page_datapoints) == 0: continue - extract_data = self.extract_data_by_page_text( + extract_data = self.extract_data_by_page( page_num, page_text, page_datapoints, @@ -179,7 +179,7 @@ class DataExtraction: next_page_text = self.page_text_dict.get(next_page_num, "") target_text = current_text + next_page_text # try to get data by current page_datapoints - next_page_extract_data = self.extract_data_by_page_text( + next_page_extract_data = self.extract_data_by_page( next_page_num, target_text, next_datapoints, @@ -313,6 +313,25 @@ class DataExtraction: ) with pd.ExcelWriter(excel_data_file) as writer: data_df.to_excel(writer, sheet_name="extract_data", index=False) + + def extract_data_by_page( + self, + page_num: int, + page_text: str, + page_datapoints: list, + need_exclude: bool = False, + exclude_data: list = None,) -> dict: + # If can't find numberic value, e.g. 1.25 or 3,88 + # apply Vision ChatGPT to extract data + numeric_regex = r"\d+(\.|\,)\d+" + if not re.search(numeric_regex, page_text): + logger.info(f"Can't find numberic value in page {page_num}, apply Vision ChatGPT to extract data") + return self.extract_data_by_page_image( + page_num, page_datapoints, need_exclude, exclude_data) + else: + return self.extract_data_by_page_text( + page_num, page_text, page_datapoints, need_exclude, exclude_data + ) def extract_data_by_page_text( self, @@ -328,7 +347,11 @@ class DataExtraction: """ logger.info(f"Extracting data from page {page_num}") instructions = self.get_instructions_by_datapoints( - page_text, page_datapoints, need_exclude, exclude_data + page_text, + page_datapoints, + need_exclude, + exclude_data, + extract_way="text" ) response, with_error = chat( instructions, response_format={"type": "json_object"} @@ -342,6 +365,7 @@ class DataExtraction: data_dict["instructions"] = instructions data_dict["raw_answer"] = response data_dict["extract_data"] = {"data": []} + data_dict["extract_way"] = "text" return data_dict try: data = json.loads(response) @@ -367,12 +391,15 @@ class DataExtraction: data_dict["instructions"] = instructions data_dict["raw_answer"] = response data_dict["extract_data"] = data + data_dict["extract_way"] = "text" return data_dict def extract_data_by_page_image( self, page_num: int, - page_datapoints: list + page_datapoints: list, + need_exclude: bool = False, + exclude_data: list = None, ) -> dict: """ keys are @@ -381,7 +408,11 @@ class DataExtraction: logger.info(f"Extracting data from page {page_num}") image_base64 = self.get_pdf_image_base64(page_num) instructions = self.get_instructions_by_datapoints( - "", page_datapoints, need_exclude=False, exclude_data=None + "", + page_datapoints, + need_exclude=need_exclude, + exclude_data=exclude_data, + extract_way="image" ) response, with_error = chat( instructions, response_format={"type": "json_object"}, image_base64=image_base64 @@ -391,9 +422,11 @@ class DataExtraction: data_dict = {"doc_id": self.doc_id} data_dict["page_index"] = page_num data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["page_text"] = "" data_dict["instructions"] = instructions data_dict["raw_answer"] = response data_dict["extract_data"] = {"data": []} + data_dict["extract_way"] = "image" return data_dict try: data = json.loads(response) @@ -407,137 +440,12 @@ class DataExtraction: data_dict = {"doc_id": self.doc_id} data_dict["page_index"] = page_num data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["page_text"] = "" data_dict["instructions"] = instructions data_dict["raw_answer"] = response data_dict["extract_data"] = data + data_dict["extract_way"] = "image" return data_dict - - def chat_by_split_context(self, - page_text: str, - page_datapoints: list, - need_exclude: bool, - exclude_data: list) -> list: - """ - If occur error, split the context to two parts and try to get data from the two parts - Relevant document: 503194284, page index 147 - """ - try: - logger.info(f"Split context to get data to fix issue which output length is over 4K tokens") - split_context = re.split(r"\n", page_text) - split_context = [text.strip() for text in split_context - if len(text.strip()) > 0] - if len(split_context) < 10: - return {"data": []} - - split_context_len = len(split_context) - top_10_context = split_context[:10] - rest_context = split_context[10:] - header = "\n".join(top_10_context) - half_len = split_context_len // 2 - # the member of half_len should not start with number - # reverse iterate the list by half_len - half_len_list = [i for i in range(half_len)] - - fund_name_line = "" - half_line = rest_context[half_len].strip() - max_similarity_fund_name, max_similarity = get_most_similar_name( - half_line, self.provider_fund_name_list, matching_type="fund" - ) - if max_similarity < 0.2: - # get the fund name line text from the first half - for index in reversed(half_len_list): - line_text = rest_context[index].strip() - if len(line_text) == 0: - continue - line_text_split = line_text.split() - if len(line_text_split) < 3: - continue - first_word = line_text_split[0] - if first_word.lower() == "class": - continue - - max_similarity_fund_name, max_similarity = get_most_similar_name( - line_text, self.provider_fund_name_list, matching_type="fund" - ) - if max_similarity >= 0.2: - fund_name_line = line_text - break - else: - fund_name_line = half_line - half_len += 1 - if fund_name_line == "": - return {"data": []} - - logger.info(f"Split first part from 0 to {half_len}") - split_first_part = "\n".join(split_context[:half_len]) - first_part = '\n'.join(split_first_part) - first_instructions = self.get_instructions_by_datapoints( - first_part, page_datapoints, need_exclude, exclude_data - ) - response, with_error = chat( - first_instructions, response_format={"type": "json_object"} - ) - first_part_data = {"data": []} - if not with_error: - try: - first_part_data = json.loads(response) - except: - first_part_data = json_repair.loads(response) - - logger.info(f"Split second part from {half_len} to {split_context_len}") - split_second_part = "\n".join(split_context[half_len:]) - second_part = header + "\n" + fund_name_line + "\n" + split_second_part - second_instructions = self.get_instructions_by_datapoints( - second_part, page_datapoints, need_exclude, exclude_data - ) - response, with_error = chat( - second_instructions, response_format={"type": "json_object"} - ) - second_part_data = {"data": []} - if not with_error: - try: - second_part_data = json.loads(response) - except: - second_part_data = json_repair.loads(response) - - first_part_data_list = first_part_data.get("data", []) - logger.info(f"First part data count: {len(first_part_data_list)}") - second_part_data_list = second_part_data.get("data", []) - logger.info(f"Second part data count: {len(second_part_data_list)}") - for first_data in first_part_data_list: - if first_data in second_part_data_list: - second_part_data_list.remove(first_data) - else: - # if the first part data is with same fund name and share name, - # remove the second part data - first_data_dp = [key for key in list(first_data.keys()) - if key not in ["fund name", "share name"]] - # order the data points - first_data_dp.sort() - first_fund_name = first_data.get("fund name", "") - first_share_name = first_data.get("share name", "") - if len(first_fund_name) > 0 and len(first_share_name) > 0: - remove_second_list = [] - for second_data in second_part_data_list: - second_fund_name = second_data.get("fund name", "") - second_share_name = second_data.get("share name", "") - if first_fund_name == second_fund_name and \ - first_share_name == second_share_name: - second_data_dp = [key for key in list(second_data.keys()) - if key not in ["fund name", "share name"]] - second_data_dp.sort() - if first_data_dp == second_data_dp: - remove_second_list.append(second_data) - for remove_second in remove_second_list: - if remove_second in second_part_data_list: - second_part_data_list.remove(remove_second) - - data_list = first_part_data_list + second_part_data_list - extract_data = {"data": data_list} - return extract_data - except Exception as e: - logger.error(f"Error in split context: {e}") - return {"data": []} def validate_data(self, extract_data_info: dict) -> dict: """ @@ -634,7 +542,6 @@ class DataExtraction: return True return False - def get_datapoints_by_page_num(self, page_num: int) -> list: datapoints = [] for datapoint in self.datapoints: @@ -648,6 +555,7 @@ class DataExtraction: datapoints: list, need_exclude: bool = False, exclude_data: list = None, + extract_way: str = "text", ) -> str: """ Get instructions to extract data from the page by the datapoints @@ -678,7 +586,7 @@ class DataExtraction: end """ instructions = [] - if self.extract_way == "text": + if extract_way == "text": instructions = [f"Context:\n{page_text}\n\nInstructions:\n"] datapoint_name_list = [] @@ -686,9 +594,9 @@ class DataExtraction: datapoint_name = self.datapoint_name_config.get(datapoint, "") datapoint_name_list.append(datapoint_name) - if self.extract_way == "text": + if extract_way == "text": summary = self.instructions_config.get("summary", "\n") - elif self.extract_way == "image": + elif extract_way == "image": summary = self.instructions_config.get("summary_image", "\n") else: summary = self.instructions_config.get("summary", "\n") @@ -696,7 +604,7 @@ class DataExtraction: instructions.append(summary.format(", ".join(datapoint_name_list))) instructions.append("\n") - if self.extract_way == "image": + if extract_way == "image": image_features = self.instructions_config.get("image_features", []) instructions.extend(image_features) instructions.append("\n") @@ -831,3 +739,130 @@ class DataExtraction: instructions_text = "".join(instructions) return instructions_text + + # def chat_by_split_context(self, + # page_text: str, + # page_datapoints: list, + # need_exclude: bool, + # exclude_data: list) -> list: + # """ + # If occur error, split the context to two parts and try to get data from the two parts + # Relevant document: 503194284, page index 147 + # """ + # try: + # logger.info(f"Split context to get data to fix issue which output length is over 4K tokens") + # split_context = re.split(r"\n", page_text) + # split_context = [text.strip() for text in split_context + # if len(text.strip()) > 0] + # if len(split_context) < 10: + # return {"data": []} + + # split_context_len = len(split_context) + # top_10_context = split_context[:10] + # rest_context = split_context[10:] + # header = "\n".join(top_10_context) + # half_len = split_context_len // 2 + # # the member of half_len should not start with number + # # reverse iterate the list by half_len + # half_len_list = [i for i in range(half_len)] + + # fund_name_line = "" + # half_line = rest_context[half_len].strip() + # max_similarity_fund_name, max_similarity = get_most_similar_name( + # half_line, self.provider_fund_name_list, matching_type="fund" + # ) + # if max_similarity < 0.2: + # # get the fund name line text from the first half + # for index in reversed(half_len_list): + # line_text = rest_context[index].strip() + # if len(line_text) == 0: + # continue + # line_text_split = line_text.split() + # if len(line_text_split) < 3: + # continue + # first_word = line_text_split[0] + # if first_word.lower() == "class": + # continue + + # max_similarity_fund_name, max_similarity = get_most_similar_name( + # line_text, self.provider_fund_name_list, matching_type="fund" + # ) + # if max_similarity >= 0.2: + # fund_name_line = line_text + # break + # else: + # fund_name_line = half_line + # half_len += 1 + # if fund_name_line == "": + # return {"data": []} + + # logger.info(f"Split first part from 0 to {half_len}") + # split_first_part = "\n".join(split_context[:half_len]) + # first_part = '\n'.join(split_first_part) + # first_instructions = self.get_instructions_by_datapoints( + # first_part, page_datapoints, need_exclude, exclude_data, extract_way="text" + # ) + # response, with_error = chat( + # first_instructions, response_format={"type": "json_object"} + # ) + # first_part_data = {"data": []} + # if not with_error: + # try: + # first_part_data = json.loads(response) + # except: + # first_part_data = json_repair.loads(response) + + # logger.info(f"Split second part from {half_len} to {split_context_len}") + # split_second_part = "\n".join(split_context[half_len:]) + # second_part = header + "\n" + fund_name_line + "\n" + split_second_part + # second_instructions = self.get_instructions_by_datapoints( + # second_part, page_datapoints, need_exclude, exclude_data, extract_way="text" + # ) + # response, with_error = chat( + # second_instructions, response_format={"type": "json_object"} + # ) + # second_part_data = {"data": []} + # if not with_error: + # try: + # second_part_data = json.loads(response) + # except: + # second_part_data = json_repair.loads(response) + + # first_part_data_list = first_part_data.get("data", []) + # logger.info(f"First part data count: {len(first_part_data_list)}") + # second_part_data_list = second_part_data.get("data", []) + # logger.info(f"Second part data count: {len(second_part_data_list)}") + # for first_data in first_part_data_list: + # if first_data in second_part_data_list: + # second_part_data_list.remove(first_data) + # else: + # # if the first part data is with same fund name and share name, + # # remove the second part data + # first_data_dp = [key for key in list(first_data.keys()) + # if key not in ["fund name", "share name"]] + # # order the data points + # first_data_dp.sort() + # first_fund_name = first_data.get("fund name", "") + # first_share_name = first_data.get("share name", "") + # if len(first_fund_name) > 0 and len(first_share_name) > 0: + # remove_second_list = [] + # for second_data in second_part_data_list: + # second_fund_name = second_data.get("fund name", "") + # second_share_name = second_data.get("share name", "") + # if first_fund_name == second_fund_name and \ + # first_share_name == second_share_name: + # second_data_dp = [key for key in list(second_data.keys()) + # if key not in ["fund name", "share name"]] + # second_data_dp.sort() + # if first_data_dp == second_data_dp: + # remove_second_list.append(second_data) + # for remove_second in remove_second_list: + # if remove_second in second_part_data_list: + # second_part_data_list.remove(remove_second) + + # data_list = first_part_data_list + second_part_data_list + # extract_data = {"data": data_list} + # return extract_data + # except Exception as e: + # logger.error(f"Error in split context: {e}") + # return {"data": []} diff --git a/main.py b/main.py index a042ed4..1e34bbb 100644 --- a/main.py +++ b/main.py @@ -809,10 +809,10 @@ if __name__ == "__main__": ] # special_doc_id_list = check_mapping_doc_id_list special_doc_id_list = check_db_mapping_doc_id_list - special_doc_id_list = ["423395975"] + special_doc_id_list = ["514213638"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" - re_run_extract_data = True + re_run_extract_data = False re_run_mapping_data = True force_save_total_data = False calculate_metrics = False diff --git a/utils/gpt_utils.py b/utils/gpt_utils.py index 253a179..70cf148 100644 --- a/utils/gpt_utils.py +++ b/utils/gpt_utils.py @@ -74,7 +74,7 @@ def chat( image_file: str = None, image_base64: str = None, ): - if engine != "gpt-4o-2024-08-06-research": + if not engine.startswith("gpt-4o"): max_tokens = 4096 client = AzureOpenAI( @@ -138,6 +138,7 @@ def chat( messages=messages, response_format=response_format, ) + sleep(1) return response.choices[0].message.content, False except Exception as e: error = str(e) @@ -145,7 +146,7 @@ def chat( if "maximum context length" in error: return error, True count += 1 - sleep(3) + sleep(2) return error, True