diff --git a/core/data_extraction.py b/core/data_extraction.py index 2f060e6..5abef3e 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -13,26 +13,26 @@ from utils.biz_utils import add_slash_to_text_as_regex, clean_text class DataExtraction: def __init__( - self, - doc_id: str, + self, + doc_id: str, pdf_file: str, - output_data_folder: str, - page_text_dict: dict, - datapoint_page_info: dict, - document_mapping_info_df: pd.DataFrame + output_data_folder: str, + page_text_dict: dict, + datapoint_page_info: dict, + document_mapping_info_df: pd.DataFrame, ) -> None: self.doc_id = doc_id self.pdf_file = pdf_file if output_data_folder is None or len(output_data_folder) == 0: output_data_folder = r"/data/emea_ar/output/extract_data/docs/" os.makedirs(output_data_folder, exist_ok=True) - + self.output_data_json_folder = os.path.join(output_data_folder, "json/") os.makedirs(self.output_data_json_folder, exist_ok=True) - + self.output_data_excel_folder = os.path.join(output_data_folder, "excel/") os.makedirs(self.output_data_excel_folder, exist_ok=True) - + if page_text_dict is None or len(page_text_dict.keys()) == 0: self.page_text_dict = self.get_pdf_page_text_dict() else: @@ -42,17 +42,19 @@ class DataExtraction: else: self.document_mapping_info_df = document_mapping_info_df self.datapoint_page_info = datapoint_page_info + self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info() + self.datapoints = self.get_datapoints_from_datapoint_page_info() self.instructions_config = self.get_instructions_config() self.datapoint_level_config = self.get_datapoint_level() self.datapoint_name_config = self.get_datapoint_name() - + def get_instructions_config(self) -> dict: instructions_config_file = r"./instructions/data_extraction_prompts_config.json" with open(instructions_config_file, "r", encoding="utf-8") as f: instructions_config = json.load(f) return instructions_config - + def get_datapoint_level(self) -> dict: datapoint_level_file = r"./configuration/datapoint_level.json" with open(datapoint_level_file, "r", encoding="utf-8") as f: @@ -64,68 +66,181 @@ class DataExtraction: with open(datapoint_name_file, "r", encoding="utf-8") as f: datapoint_name = json.load(f) return datapoint_name - + def get_pdf_page_text_dict(self) -> dict: pdf_util = PDFUtil(self.pdf_file) success, text, page_text_dict = pdf_util.extract_text() return page_text_dict - + def get_datapoints_from_datapoint_page_info(self) -> list: datapoints = list(self.datapoint_page_info.keys()) if "doc_id" in datapoints: datapoints.remove("doc_id") return datapoints - + + def get_page_nums_from_datapoint_page_info(self) -> list: + page_nums_with_datapoints = [] + for datapoint, page_nums in self.datapoint_page_info.items(): + if datapoint == "doc_id": + continue + page_nums_with_datapoints.extend(page_nums) + page_nums_with_datapoints = list(set(page_nums_with_datapoints)) + # sort the page numbers + page_nums_with_datapoints.sort() + return page_nums_with_datapoints + def extract_data(self) -> dict: """ keys are doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name """ data_list = [] + pdf_page_count = len(self.page_text_dict.keys()) + handled_page_num_list = [] for page_num, page_text in self.page_text_dict.items(): + if page_num in handled_page_num_list: + continue page_datapoints = self.get_datapoints_by_page_num(page_num) if len(page_datapoints) == 0: continue - instructions = self.get_instructions_by_datapoints(page_text, page_datapoints) - response, with_error = chat(instructions) - if with_error: - logger.error(f"Error in extracting tables from page") - return "" - try: - data = json.loads(response) - except: - try: - data = json_repair.loads(response) - except: - data = {} - data_dict = {"doc_id": self.doc_id} - data_dict["page_index"] = page_num - data_dict["datapoints"] = ", ".join(page_datapoints) - data_dict["page_text"] = page_text - data_dict["instructions"] = instructions - data_dict["raw_answer"] = response - data_dict["data"] = data - data_list.append(data_dict) - json_data_file = os.path.join(self.output_data_json_folder, f"{self.doc_id}.json") + extract_data = self.extract_data_by_page( + page_num, + page_text, + page_datapoints, + need_exclude=False, + exclude_data=None, + ) + data_list.append(extract_data) + + page_data_list = extract_data.get("extract_data", {}).get("data", []) + + current_page_data_count = len(page_data_list) + if current_page_data_count > 0: + count = 1 + # some pdf documents have multiple pages for the same data + # and the next page may without table header with data point keywords. + # the purpose is try to get data from the next page + current_text = page_text + + while count < 3: + try: + next_page_num = page_num + count + if next_page_num >= pdf_page_count: + break + next_datapoints = page_datapoints + if next_page_num in self.page_nums_with_datapoints: + should_continue = False + next_datapoints = self.get_datapoints_by_page_num(next_page_num) + if len(next_datapoints) == 0: + should_continue = True + else: + for next_datapoint in next_datapoints: + if next_datapoint not in page_datapoints: + should_continue = True + break + next_datapoints.extend(page_datapoints) + # remove duplicate datapoints + next_datapoints = list(set(next_datapoints)) + if not should_continue: + break + next_page_text = self.page_text_dict.get(next_page_num, "") + target_text = current_text + next_page_text + # try to get data by current page_datapoints + next_page_extract_data = self.extract_data_by_page( + next_page_num, + target_text, + next_datapoints, + need_exclude=True, + exclude_data=page_data_list, + ) + next_page_data_list = next_page_extract_data.get( + "extract_data", {} + ).get("data", []) + + if next_page_data_list is not None and len(next_page_data_list) > 0: + for current_page_data in page_data_list: + if current_page_data in next_page_data_list: + next_page_data_list.remove(current_page_data) + next_page_extract_data["extract_data"][ + "data" + ] = next_page_data_list + data_list.append(next_page_extract_data) + handled_page_num_list.append(next_page_num) + else: + break + count += 1 + except Exception as e: + logger.error(f"Error in extracting data from next page: {e}") + break + + json_data_file = os.path.join( + self.output_data_json_folder, f"{self.doc_id}.json" + ) with open(json_data_file, "w", encoding="utf-8") as f: json.dump(data_list, f, ensure_ascii=False, indent=4) - + data_df = pd.DataFrame(data_list) data_df.reset_index(drop=True, inplace=True) - excel_data_file = os.path.join(self.output_data_excel_folder, f"{self.doc_id}.xlsx") + excel_data_file = os.path.join( + self.output_data_excel_folder, f"{self.doc_id}.xlsx" + ) with pd.ExcelWriter(excel_data_file) as writer: data_df.to_excel(writer, sheet_name="extract_data", index=False) - + return data_list - + + def extract_data_by_page( + self, + page_num: int, + page_text: str, + page_datapoints: list, + need_exclude: bool = False, + exclude_data: list = None, + ) -> dict: + """ + keys are + doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name + """ + logger.info(f"Extracting data from page {page_num}") + instructions = self.get_instructions_by_datapoints( + page_text, page_datapoints, need_exclude, exclude_data + ) + response, with_error = chat( + instructions, response_format={"type": "json_object"} + ) + if with_error: + logger.error(f"Error in extracting tables from page") + return "" + try: + data = json.loads(response) + except: + try: + data = json_repair.loads(response) + except: + data = {"data": []} + data_dict = {"doc_id": self.doc_id} + data_dict["page_index"] = page_num + data_dict["datapoints"] = ", ".join(page_datapoints) + data_dict["page_text"] = page_text + data_dict["instructions"] = instructions + data_dict["raw_answer"] = response + data_dict["extract_data"] = data + return data_dict + def get_datapoints_by_page_num(self, page_num: int) -> list: datapoints = [] for datapoint in self.datapoints: if page_num in self.datapoint_page_info[datapoint]: datapoints.append(datapoint) return datapoints - - def get_instructions_by_datapoints(self, page_text: str, datapoints: list) -> str: + + def get_instructions_by_datapoints( + self, + page_text: str, + datapoints: list, + need_exclude: bool = False, + exclude_data: list = None, + ) -> str: """ Get instructions to extract data from the page by the datapoints Below is the instructions sections: @@ -159,11 +274,11 @@ class DataExtraction: for datapoint in datapoints: datapoint_name = self.datapoint_name_config.get(datapoint, "") datapoint_name_list.append(datapoint_name) - + summary = self.instructions_config.get("summary", "\n") - instructions.append(summary.format(', '.join(datapoint_name_list))) + instructions.append(summary.format(", ".join(datapoint_name_list))) instructions.append("\n") - + instructions.append("Datapoints Reported name:\n") reported_name_info = self.instructions_config.get("reported_name", {}) for datapoint in datapoints: @@ -171,13 +286,15 @@ class DataExtraction: instructions.append(reported_name) instructions.append("\n") instructions.append("\n") - + instructions.append("Data business features:\n") - data_business_features = self.instructions_config.get("data_business_features", {}) - common = '\n'.join(data_business_features.get("common", [])) + data_business_features = self.instructions_config.get( + "data_business_features", {} + ) + common = "\n".join(data_business_features.get("common", [])) instructions.append(common) instructions.append("\n") - + instructions.append("Datapoints investment level:\n") investment_level_info = data_business_features.get("investment_level", {}) for datapoint in datapoints: @@ -185,7 +302,7 @@ class DataExtraction: instructions.append(investment_level) instructions.append("\n") instructions.append("\n") - + instructions.append("Datapoints value range:\n") data_value_range_info = data_business_features.get("data_value_range", {}) for datapoint in datapoints: @@ -193,7 +310,7 @@ class DataExtraction: instructions.append(data_value_range) instructions.append("\n") instructions.append("\n") - + special_rule_info = data_business_features.get("special_rule", {}) with_special_rule_title = False for datapoint in datapoints: @@ -202,11 +319,11 @@ class DataExtraction: if not with_special_rule_title: instructions.append("Special rule:\n") with_special_rule_title = True - special_rule = '\n'.join(special_rule_list) + special_rule = "\n".join(special_rule_list) instructions.append(special_rule) instructions.append("\n\n") instructions.append("\n") - + instructions.append("Special cases:\n") special_cases = self.instructions_config.get("special_cases", {}) special_cases_common_list = special_cases.get("common", []) @@ -215,10 +332,10 @@ class DataExtraction: instructions.append(title) instructions.append("\n") contents_list = special_cases_common.get("contents", []) - contents = '\n'.join(contents_list) + contents = "\n".join(contents_list) instructions.append(contents) instructions.append("\n\n") - + for datapoint in datapoints: special_case_list = special_cases.get(datapoint, []) for special_case in special_case_list: @@ -226,51 +343,69 @@ class DataExtraction: instructions.append(title) instructions.append("\n") contents_list = special_case.get("contents", []) - contents = '\n'.join(contents_list) + contents = "\n".join(contents_list) instructions.append(contents) instructions.append("\n\n") instructions.append("\n") - - + instructions.append("Output requirement:\n") output_requirement = self.instructions_config.get("output_requirement", {}) output_requirement_common_list = output_requirement.get("common", []) instructions.append("\n".join(output_requirement_common_list)) instructions.append("\n") - + share_datapoint_value_example = {} share_level_config = output_requirement.get("share_level", {}) + + example_list = [] for datapoint in datapoints: investment_level = self.datapoint_level_config.get(datapoint, "") if investment_level == "fund_level": fund_level_example_list = output_requirement.get("fund_level", []) for example in fund_level_example_list: - instructions.append(example) - instructions.append("\n") - instructions.append("\n") + try: + sub_example_list = json.loads(example) + except: + sub_example_list = json_repair.loads(example) + example_list.extend(sub_example_list) elif investment_level == "share_level": - share_datapoint_value_example[datapoint] = share_level_config.get(f"{datapoint}_value", []) - + share_datapoint_value_example[datapoint] = share_level_config.get( + f"{datapoint}_value", [] + ) + share_datapoint_list = list(share_datapoint_value_example.keys()) + instructions.append(f"Example:\n") if len(share_datapoint_list) > 0: fund_name_example_list = share_level_config.get("fund_name", []) share_name_example_list = share_level_config.get("share_name", []) + for index in range(len(fund_name_example_list)): - example_dict = {"fund name": fund_name_example_list[index], - "share name": share_name_example_list[index]} + example_dict = { + "fund name": fund_name_example_list[index], + "share name": share_name_example_list[index], + } for share_datapoint in share_datapoint_list: - share_datapoint_values = share_datapoint_value_example[share_datapoint] + share_datapoint_values = share_datapoint_value_example[ + share_datapoint + ] if index < len(share_datapoint_values): example_dict[share_datapoint] = share_datapoint_values[index] - instructions.append(f"Example {index + 1}:\n") - instructions.append(json.dumps(example_dict, ensure_ascii=False)) - instructions.append("\n") - instructions.append("\n") - - end_list = self.instructions_config.get("end", []) - instructions.append('\n'.join(end_list)) + example_list.append(example_dict) + example_data = {"data": example_list} + instructions.append(json.dumps(example_data, ensure_ascii=False, indent=4)) instructions.append("\n") + instructions.append("\n") + + end_list = self.instructions_config.get("end", []) + instructions.append("\n".join(end_list)) + instructions.append("\n") + + if need_exclude and exclude_data is not None and isinstance(exclude_data, list): + instructions.append("Please exclude below data from output:\n") + instructions.append(json.dumps(exclude_data, ensure_ascii=False, indent=4)) + instructions.append("\n") + instructions.append("\n") instructions.append("Answer:\n") - - instructions_text = ''.join(instructions) - return instructions_text \ No newline at end of file + + instructions_text = "".join(instructions) + return instructions_text diff --git a/instructions/data_extraction_prompts_config.json b/instructions/data_extraction_prompts_config.json index e361c84..5cbe99c 100644 --- a/instructions/data_extraction_prompts_config.json +++ b/instructions/data_extraction_prompts_config.json @@ -135,7 +135,7 @@ }, "end": [ "Only output JSON data.", - "Don't output the value which not exist in context, especiall for fund level datapoint: TOR.", - "If can't find share class name in context, please output empty JSON data: []" + "Don't output the value which not exist in context, especially for fund level datapoint: TOR.", + "If can't find share class name in context, please output empty JSON data: {\"data\": []}" ] } \ No newline at end of file diff --git a/main.py b/main.py index 7b2a038..96d436f 100644 --- a/main.py +++ b/main.py @@ -335,15 +335,16 @@ if __name__ == "__main__": # ) # test_auto_generate_instructions() - # doc_id = "294132333" - # extract_data(doc_id, pdf_folder) output_child_folder = r"/data/emea_ar/output/extract_data/docs/" output_total_folder = r"/data/emea_ar/output/extract_data/total/" - re_run = False + re_run = True batch_extract_data(pdf_folder, page_filter_ground_truth_file, output_child_folder, output_total_folder, special_doc_id_list, re_run) + + # doc_id = "476492237" + # extract_data(doc_id, pdf_folder, output_child_folder, re_run) diff --git a/prepare_data.py b/prepare_data.py index 14d8c5e..bb45188 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -726,13 +726,31 @@ def pickup_document_from_top_100_providers(): top_100_provider_document_file, sheet_name="all_data" ) + top_100_provider_document_fund_count = pd.read_excel( + top_100_provider_document_file, sheet_name="doc_fund_count" + ) + top_100_provider_document_fund_count.reset_index(drop=True, inplace=True) + top_100_provider_document_share_count = pd.read_excel( top_100_provider_document_file, sheet_name="doc_share_class_count" ) top_100_provider_document_share_count = \ top_100_provider_document_share_count[top_100_provider_document_share_count["with_ar_data"] == True] top_100_provider_document_share_count.reset_index(drop=True, inplace=True) - + + top_100_provider_document_share_count = pd.merge( + top_100_provider_document_share_count, + top_100_provider_document_fund_count, + on=["DocumentId"], + how="left", + ) + top_100_provider_document_share_count = top_100_provider_document_share_count[ + ["DocumentId", "CompanyId_x", "CompanyName_x", "fund_count", "share_class_count"] + ] + top_100_provider_document_share_count.rename( + columns={"CompanyId_x": "CompanyId"}, inplace=True + ) + # add a new column with name share_count_rank to top_100_provider_document_share_count by merge with provider_share_count top_100_provider_document_share_count = pd.merge( top_100_provider_document_share_count, @@ -742,12 +760,11 @@ def pickup_document_from_top_100_providers(): ) # Keep columns: DocumentId, CompanyId, CompanyName, share_class_count_x, share_count_rank top_100_provider_document_share_count = top_100_provider_document_share_count[ - ["DocumentId", "CompanyId", "CompanyName_x", "share_class_count_x", "share_count_rank"] + ["DocumentId", "CompanyId", "CompanyName", "fund_count", "share_class_count_x", "share_count_rank"] ] # rename column share_class_count_x to share_class_count top_100_provider_document_share_count.rename( columns={"share_class_count_x": "share_class_count", - "CompanyName_x": "Company_Name", "share_count_rank": "provider_share_count_rank"}, inplace=True ) top_100_provider_document_share_count = top_100_provider_document_share_count.sort_values( @@ -833,8 +850,8 @@ if __name__ == "__main__": random_small_document_data_file = ( r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx" ) - download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder) - output_pdf_page_text(pdf_folder, output_folder) + # download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder) + # output_pdf_page_text(pdf_folder, output_folder) # extract_pdf_table(pdf_folder, output_folder) # analyze_json_error() @@ -846,4 +863,4 @@ if __name__ == "__main__": # output_folder=basic_info_folder, # ) # statistics_document_fund_share_count(doc_mapping_from_top_100_provider_file) - # pickup_document_from_top_100_providers() + pickup_document_from_top_100_providers() diff --git a/utils/gpt_utils.py b/utils/gpt_utils.py index 628cd50..1f22ecd 100644 --- a/utils/gpt_utils.py +++ b/utils/gpt_utils.py @@ -69,6 +69,7 @@ def chat( api_key=os.getenv("OPENAI_API_KEY_GPT4o"), api_version=os.getenv("OPENAI_API_VERSION_GPT4o"), temperature: float = 0.0, + response_format: dict = None, image_file: str = None, image_base64: str = None, ): @@ -108,18 +109,32 @@ def chat( try: if count > 0: print(f"retrying the {count} time...") - response = client.chat.completions.create( - model=engine, - temperature=temperature, - max_tokens=max_tokens, - top_p=0.95, - frequency_penalty=0, - presence_penalty=0, - timeout=request_timeout, - stop=None, - messages=messages, - response_format={"type": "json_object"}, - ) + if response_format is None: + response = client.chat.completions.create( + model=engine, + temperature=temperature, + max_tokens=max_tokens, + top_p=0.95, + frequency_penalty=0, + presence_penalty=0, + timeout=request_timeout, + stop=None, + messages=messages, + ) + else: + # response_format={"type": "json_object"} + response = client.chat.completions.create( + model=engine, + temperature=temperature, + max_tokens=max_tokens, + top_p=0.95, + frequency_penalty=0, + presence_penalty=0, + timeout=request_timeout, + stop=None, + messages=messages, + response_format=response_format, + ) return response.choices[0].message.content, False except Exception as e: error = str(e)