diff --git a/core/data_extraction.py b/core/data_extraction.py index e3412ce..a73e5af 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -160,6 +160,16 @@ class DataExtraction: ] = next_page_data_list data_list.append(next_page_extract_data) handled_page_num_list.append(next_page_num) + exist_current_page_datapoint = False + for next_page_data in next_page_data_list: + for page_datapoint in page_datapoints: + if page_datapoint in list(next_page_data.keys()): + exist_current_page_datapoint = True + break + if exist_current_page_datapoint: + break + if not exist_current_page_datapoint: + break else: break count += 1 @@ -204,12 +214,25 @@ class DataExtraction: ) if with_error: logger.error(f"Error in extracting tables from page") - return "" + data_dict = {"doc_id": self.doc_id} + data_dict["page_index"] = page_num + data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["page_text"] = page_text + data_dict["instructions"] = instructions + data_dict["raw_answer"] = response + data_dict["extract_data"] = {"data": []} + return data_dict try: data = json.loads(response) except: try: - data = json_repair.loads(response) + # if occur error, perhaps the output length is over 4K tokens + # split the context to two parts and try to get data from the two parts + data = self.chat_by_split_context( + page_text, page_datapoints, need_exclude, exclude_data + ) + if len(data.get("data", [])) == 0: + data = json_repair.loads(response) except: data = {"data": []} data = self.validate_data(data) @@ -223,6 +246,80 @@ class DataExtraction: data_dict["extract_data"] = data return data_dict + def chat_by_split_context(self, + page_text: str, + page_datapoints: list, + need_exclude: bool, + exclude_data: list) -> list: + """ + If occur error, split the context to two parts and try to get data from the two parts + """ + try: + logger.info(f"Split context to get data to fix issue which output length is over 4K tokens") + split_context = re.split(r"\n", page_text) + split_context = [text.strip() for text in split_context + if len(text.strip()) > 0] + split_context_len = len(split_context) + top_10_context = split_context[:10] + rest_context = split_context[10:] + header = "\n".join(top_10_context) + half_len = split_context_len // 2 + # the member of half_len should not start with number + # reverse iterate the list by half_len + half_len_list = [i for i in range(half_len)] + for index in reversed(half_len_list): + first_letter = rest_context[index].strip()[0] + if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]: + half_len = index + break + + logger.info(f"Split first part from 0 to {half_len}") + split_first_part = "\n".join(split_context[:half_len]) + first_part = '\n'.join(split_first_part) + first_instructions = self.get_instructions_by_datapoints( + first_part, page_datapoints, need_exclude, exclude_data + ) + response, with_error = chat( + first_instructions, response_format={"type": "json_object"} + ) + first_part_data = {"data": []} + if not with_error: + try: + first_part_data = json.loads(response) + except: + first_part_data = json_repair.loads(response) + + logger.info(f"Split second part from {half_len} to {split_context_len}") + split_second_part = "\n".join(split_context[half_len:]) + second_part = header + '\n' + split_second_part + second_instructions = self.get_instructions_by_datapoints( + second_part, page_datapoints, need_exclude, exclude_data + ) + response, with_error = chat( + second_instructions, response_format={"type": "json_object"} + ) + second_part_data = {"data": []} + if not with_error: + try: + second_part_data = json.loads(response) + except: + second_part_data = json_repair.loads(response) + + first_part_data_list = first_part_data.get("data", []) + logger.info(f"First part data count: {len(first_part_data_list)}") + second_part_data_list = second_part_data.get("data", []) + logger.info(f"Second part data count: {len(second_part_data_list)}") + for first_data in first_part_data_list: + if first_data in second_part_data_list: + second_part_data_list.remove(first_data) + + data_list = first_part_data_list + second_part_data_list + extract_data = {"data": data_list} + return extract_data + except Exception as e: + logger.error(f"Error in split context: {e}") + return {"data": []} + def validate_data(self, extract_data_info: dict) -> dict: """ Validate data by the rules diff --git a/instructions/data_extraction_prompts_config.json b/instructions/data_extraction_prompts_config.json index 7ad6713..d5ce07d 100644 --- a/instructions/data_extraction_prompts_config.json +++ b/instructions/data_extraction_prompts_config.json @@ -102,6 +102,7 @@ "common": [ "If possible, please extract fund name, share name, TOR, TER, performance fees, OGC values as the output.", "If find share name, and exist relevant currency, please output share name + currency, e.g. share name is \"Class A\", currency is \"USD\", the output share name should be: \"Class A USD\".", + "If find fund name, and exist sub fund name, please output fund name + sub fund name, e.g. fund name is \"Black Rock European\", sub fund name is \"Growth\", the output fund name should be: \"Black Rock European Growth\".", "Only output the data point which with relevant value.", "Don't ignore the data point which with negative value, e.g. -0.12, -1.13", "Don't ignore the data point which with explicit zero value, e.g. 0, 0.00", @@ -110,7 +111,7 @@ "The output should be JSON format, the format is like below example(s):" ], "fund_level": [ - "[{\"fund name\": \"fund 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]" + "[{\"fund name\": \"fund 1 - sub fund name 1\",\"tor\": 35.26}, {\"fund name\": \"fund 2 - sub fund name 2\",\"tor\": -28.26}, {\"fund name\": \"fund 3\",\"tor\": 115.52,}]" ], "share_level": { "fund_name": [ diff --git a/main.py b/main.py index 016e570..4159e8c 100644 --- a/main.py +++ b/main.py @@ -505,7 +505,7 @@ if __name__ == "__main__": # doc_id = "476492237" # extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run) - special_doc_id_list = ["491593469"] + special_doc_id_list = ["503194284"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_mapping_data = True diff --git a/utils/gpt_utils.py b/utils/gpt_utils.py index 1f22ecd..b8d0e9f 100644 --- a/utils/gpt_utils.py +++ b/utils/gpt_utils.py @@ -103,7 +103,7 @@ def chat( count = 0 error = "" - max_tokens = 4000 + max_tokens = 4096 request_timeout = 120 while count < 8: try: