diff --git a/configuration/aus_prospectus/misc_config.json b/configuration/aus_prospectus/misc_config.json index b576cc0..ab1750e 100644 --- a/configuration/aus_prospectus/misc_config.json +++ b/configuration/aus_prospectus/misc_config.json @@ -1,4 +1,4 @@ { - "apply_pdf2html": true, + "apply_pdf2html": false, "apply_drilldown": false } \ No newline at end of file diff --git a/configuration/emea_ar/misc_config.json b/configuration/emea_ar/misc_config.json index 3948e5e..ab1750e 100644 --- a/configuration/emea_ar/misc_config.json +++ b/configuration/emea_ar/misc_config.json @@ -1,4 +1,4 @@ { "apply_pdf2html": false, - "apply_drilldown": true + "apply_drilldown": false } \ No newline at end of file diff --git a/core/data_extraction.py b/core/data_extraction.py index 989c1e9..685ff15 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -5,9 +5,8 @@ import re import fitz import pandas as pd from traceback import print_exc -from utils.gpt_utils import chat +from utils.qwen_utils import chat from utils.pdf_util import PDFUtil -from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider from utils.logger import logger from utils.biz_utils import add_slash_to_text_as_regex, clean_text, \ get_most_similar_name, remove_abundant_data, replace_special_table_header @@ -23,11 +22,20 @@ class DataExtraction: page_text_dict: dict, datapoint_page_info: dict, datapoints: list, - document_mapping_info_df: pd.DataFrame, extract_way: str = "text", output_image_folder: str = None, + text_model: str = "qwen-plus", + image_model: str = "qwen-vl-plus", ) -> None: self.doc_source = doc_source + if self.doc_source == "aus_prospectus": + self.document_type = 1 + elif self.doc_source == "emea_ar": + self.document_type = 2 + else: + raise ValueError(f"Invalid document source: {self.doc_source}") + self.text_model = text_model + self.image_model = image_model self.doc_id = doc_id self.pdf_file = pdf_file self.configuration_folder = f"./configuration/{doc_source}/" @@ -46,26 +54,7 @@ class DataExtraction: self.page_text_dict = self.get_pdf_page_text_dict() else: self.page_text_dict = page_text_dict - if document_mapping_info_df is None or len(document_mapping_info_df) == 0: - self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False) - else: - self.document_mapping_info_df = document_mapping_info_df - self.fund_name_list = self.document_mapping_info_df["FundName"].unique().tolist() - - # get document type by DocumentType in self.document_mapping_info_df - self.document_type = int(self.document_mapping_info_df["DocumentType"].iloc[0]) - self.investment_objective_pages = [] - if self.document_type == 1: - self.investment_objective_pages = self.get_investment_objective_pages() - - self.provider_mapping_df = self.get_provider_mapping() - if len(self.provider_mapping_df) == 0: - self.provider_fund_name_list = [] - else: - self.provider_fund_name_list = ( - self.provider_mapping_df["FundName"].unique().tolist() - ) self.document_category, self.document_production = self.get_document_category_production() self.datapoint_page_info = self.get_datapoint_page_info(datapoint_page_info) self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info() @@ -77,11 +66,14 @@ class DataExtraction: self.replace_table_header_config = self.get_replace_table_header_config() self.special_datapoint_feature_config = self.get_special_datapoint_feature_config() self.special_datapoint_feature = self.init_special_datapoint_feature() + self.investment_objective_pages = self.get_investment_objective_pages() self.datapoint_reported_name_config, self.non_english_reported_name_config = \ self.get_datapoint_reported_name() self.extract_way = extract_way self.output_image_folder = output_image_folder + + def get_special_datapoint_feature_config(self) -> dict: special_datapoint_feature_config_file = os.path.join(self.configuration_folder, "special_datapoint_feature.json") @@ -118,17 +110,27 @@ class DataExtraction: if len(document_category_prompt) > 0: prompts = f"Context: \n{first_4_page_text}\n\Instructions: \n{document_category_prompt}" result, with_error = chat( - prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000 + prompt=prompts, text_model=self.text_model, image_model=self.image_model ) response = result.get("response", "") if not with_error: try: + if response.startswith("```json"): + response = response.replace("```json", "").replace("```", "").strip() + if response.startswith("```JSON"): + response = response.replace("```JSON", "").replace("```", "").strip() + if response.startswith("```"): + response = response.replace("```", "").strip() data = json.loads(response) document_category = data.get("document_category", None) document_production = data.get("document_production", None) except: pass - + if document_category is None or len(document_category) == 0: + print(f"Document category is None or empty, use default value: Super") + document_category = "Super" + if document_production is None or len(document_production) == 0: + document_production = "AUS" return document_category, document_production def get_objective_fund_name(self, page_text: str) -> str: @@ -142,11 +144,17 @@ class DataExtraction: if len(objective_fund_name_prompt) > 0: prompts = f"Context: \n{page_text}\n\Instructions: \n{objective_fund_name_prompt}" result, with_error = chat( - prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000 + prompt=prompts, text_model=self.text_model, image_model=self.image_model ) response = result.get("response", "") if not with_error: try: + if response.startswith("```json"): + response = response.replace("```json", "").replace("```", "").strip() + if response.startswith("```JSON"): + response = response.replace("```JSON", "").replace("```", "").strip() + if response.startswith("```"): + response = response.replace("```", "").strip() data = json.loads(response) fund_name = data.get("fund_name", "") except: @@ -187,8 +195,8 @@ class DataExtraction: with open(language_config_file, "r", encoding="utf-8") as file: self.language_config = json.load(file) - self.language_id = self.document_mapping_info_df["Language"].iloc[0] - self.language = self.language_config.get(self.language_id, None) + self.language_id = "0L00000122" + self.language = "english" datapoint_reported_name_config_file = os.path.join(self.configuration_folder, "datapoint_reported_name.json") all_datapoint_reported_name = {} @@ -210,20 +218,6 @@ class DataExtraction: reported_name_list.sort() datapoint_reported_name_config[datapoint] = reported_name_list return datapoint_reported_name_config, non_english_reported_name_config - - def get_provider_mapping(self): - if len(self.document_mapping_info_df) == 0: - return pd.DataFrame() - provider_id_list = ( - self.document_mapping_info_df["ProviderId"].unique().tolist() - ) - provider_mapping_list = [] - for provider_id in provider_id_list: - provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False)) - provider_mapping_df = pd.concat(provider_mapping_list) - provider_mapping_df = provider_mapping_df.drop_duplicates() - provider_mapping_df.reset_index(drop=True, inplace=True) - return provider_mapping_df def get_pdf_image_base64(self, page_index: int) -> dict: pdf_util = PDFUtil(self.pdf_file) @@ -557,8 +551,6 @@ class DataExtraction: """ If some datapoint with production name, then each fund/ share class in the same document for the datapoint should be with same value. """ - if len(self.fund_name_list) < 3: - return data_list, [] raw_name_dict = self.get_raw_name_dict(data_list) raw_name_list = list(raw_name_dict.keys()) if len(raw_name_list) < 3: @@ -1125,11 +1117,17 @@ class DataExtraction: if len(compare_table_structure_prompts) > 0: prompts = f"Context: \ncurrent page contents:\n{current_page_text}\nnext page contents:\n{next_page_text}\nInstructions:\n{compare_table_structure_prompts}\n" result, with_error = chat( - prompt=prompts, response_format={"type": "json_object"}, max_tokens=100 + prompt=prompts, text_model="qwen-plus", image_model="qwen-vl-plus" ) response = result.get("response", "") if not with_error: try: + if response.startswith("```json"): + response = response.replace("```json", "").replace("```", "").strip() + if response.startswith("```JSON"): + response = response.replace("```JSON", "").replace("```", "").strip() + if response.startswith("```"): + response = response.replace("```", "").strip() data = json.loads(response) answer = data.get("answer", "No") if answer.lower() == "yes": @@ -1300,9 +1298,6 @@ class DataExtraction: """ logger.info(f"Extracting data from page {page_num}") if self.document_type == 1: - # pre_context = f"The document type is prospectus. \nThe fund names in this document are {', '.join(self.fund_name_list)}." - # if pre_context in page_text: - # page_text = page_text.replace(pre_context, "\n").strip() pre_context = "" if len(self.investment_objective_pages) > 0: # Get the page number of the most recent investment objective at the top of the current page. @@ -1330,8 +1325,9 @@ class DataExtraction: extract_way="text" ) result, with_error = chat( - prompt=instructions, response_format={"type": "json_object"} + prompt=instructions, text_model=self.text_model, image_model=self.image_model ) + response = result.get("response", "") if with_error: logger.error(f"Error in extracting tables from page") @@ -1346,8 +1342,15 @@ class DataExtraction: data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["completion_token"] = result.get("completion_token", 0) data_dict["total_token"] = result.get("total_token", 0) + data_dict["model"] = result.get("model", "") return data_dict try: + if response.startswith("```json"): + response = response.replace("```json", "").replace("```", "").strip() + if response.startswith("```JSON"): + response = response.replace("```JSON", "").replace("```", "").strip() + if response.startswith("```"): + response = response.replace("```", "").strip() data = json.loads(response) except: try: @@ -1388,6 +1391,7 @@ class DataExtraction: data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["completion_token"] = result.get("completion_token", 0) data_dict["total_token"] = result.get("total_token", 0) + data_dict["model"] = result.get("model", "") return data_dict def extract_data_by_page_image( @@ -1418,6 +1422,7 @@ class DataExtraction: data_dict["prompt_token"] = 0 data_dict["completion_token"] = 0 data_dict["total_token"] = 0 + data_dict["model"] = self.image_model return data_dict else: if previous_page_last_fund is not None and len(previous_page_last_fund) > 0: @@ -1463,7 +1468,7 @@ class DataExtraction: extract_way="image" ) result, with_error = chat( - prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64 + prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64 ) response = result.get("response", "") if with_error: @@ -1479,8 +1484,15 @@ class DataExtraction: data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["completion_token"] = result.get("completion_token", 0) data_dict["total_token"] = result.get("total_token", 0) + data_dict["model"] = result.get("model", "") return data_dict try: + if response.startswith("```json"): + response = response.replace("```json", "").replace("```", "").strip() + if response.startswith("```JSON"): + response = response.replace("```JSON", "").replace("```", "").strip() + if response.startswith("```"): + response = response.replace("```", "").strip() data = json.loads(response) except: try: @@ -1508,6 +1520,7 @@ class DataExtraction: data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["completion_token"] = result.get("completion_token", 0) data_dict["total_token"] = result.get("total_token", 0) + data_dict["model"] = result.get("model", "") return data_dict def get_image_text(self, page_num: int) -> str: @@ -1515,13 +1528,19 @@ class DataExtraction: instructions = self.instructions_config.get("get_image_text", "\n") logger.info(f"Get text from image of page {page_num}") result, with_error = chat( - prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64 + prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64 ) response = result.get("response", "") text = "" if with_error: logger.error(f"Can't get text from current image") try: + if response.startswith("```json"): + response = response.replace("```json", "").replace("```", "").strip() + if response.startswith("```JSON"): + response = response.replace("```JSON", "").replace("```", "").strip() + if response.startswith("```"): + response = response.replace("```", "").strip() data = json.loads(response) except: try: @@ -1599,11 +1618,11 @@ class DataExtraction: ter_search = re.search(ter_regex, page_text) if ter_search is not None: include_key_words = True - if not include_key_words: - is_share_name = self.check_fund_name_as_share(raw_fund_name) - if not is_share_name: - remove_list.append(data) - break + # if not include_key_words: + # is_share_name = self.check_fund_name_as_share(raw_fund_name) + # if not is_share_name: + # remove_list.append(data) + # break data["share name"] = raw_fund_name if data.get(key, "") == "": data.pop(key) @@ -1723,73 +1742,12 @@ class DataExtraction: new_data[key] = value new_data_list.append(new_data) extract_data_info["data"] = new_data_list - if page_text is not None and len(page_text) > 0: - try: - self.set_datapoint_feature_properties(new_data_list, page_text, page_num) - except Exception as e: - logger.error(f"Error in setting datapoint feature properties: {e}") + # if page_text is not None and len(page_text) > 0: + # try: + # self.set_datapoint_feature_properties(new_data_list, page_text, page_num) + # except Exception as e: + # logger.error(f"Error in setting datapoint feature properties: {e}") return extract_data_info - - def set_datapoint_feature_properties(self, data_list: list, page_text: str, page_num: int) -> None: - for feature, properties in self.special_datapoint_feature_config.items(): - if self.special_datapoint_feature.get(feature, {}).get("page_index", None) is not None: - continue - provider_ids = properties.get("provider_ids", []) - if len(provider_ids) > 0: - is_current_provider = False - doc_provider_list = self.document_mapping_info_df["ProviderId"].unique().tolist() - if len(doc_provider_list) > 0: - for provider in provider_ids: - if provider in doc_provider_list: - is_current_provider = True - break - if not is_current_provider: - continue - detail_list = properties.get("details", []) - if len(detail_list) == 0: - continue - set_feature_property = False - for detail in detail_list: - regex_text_list = detail.get("regex_text", []) - if len(regex_text_list) == 0: - continue - effective_datapoints = detail.get("effective_datapoints", []) - if len(effective_datapoints) == 0: - continue - exclude_datapoints = detail.get("exclude_datapoints", []) - - exist_effective_datapoints = False - exist_exclude_datapoints = False - for data_item in data_list: - datapoints = [datapoint for datapoint in list(data_item.keys()) - if datapoint in effective_datapoints] - if len(datapoints) > 0: - exist_effective_datapoints = True - datapoints = [datapoint for datapoint in list(data_item.keys()) - if datapoint in exclude_datapoints] - if len(datapoints) > 0: - exist_exclude_datapoints = True - if exist_effective_datapoints and exist_exclude_datapoints: - break - - if not exist_effective_datapoints: - continue - if exist_exclude_datapoints: - continue - found_regex_text = False - for regex_text in regex_text_list: - regex_search = re.search(regex_text, page_text, re.IGNORECASE) - if regex_search is not None: - found_regex_text = True - break - if found_regex_text: - if self.special_datapoint_feature[feature].get("page_index", None) is None: - self.special_datapoint_feature[feature]["page_index"] = [] - self.special_datapoint_feature[feature]["datapoint"] = effective_datapoints[0] - self.special_datapoint_feature[feature]["page_index"].append(page_num) - set_feature_property = True - if set_feature_property: - break def split_multi_share_name(self, raw_share_name: str) -> list: """ @@ -1836,25 +1794,25 @@ class DataExtraction: fund_name = f"{last_fund} {fund_feature}" return fund_name - def check_fund_name_as_share(self, fund_name: str) -> bool: - """ - Check if the fund name is the same as share name - """ - if len(fund_name) == 0 == 0: - return False - share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist() - if len(share_name_list) == 0: - return False - max_similarity_name, max_similarity = get_most_similar_name( - text=fund_name, - name_list=share_name_list, - share_name=None, - fund_name=None, - matching_type="share", - process_cache=None) - if max_similarity >= 0.8: - return True - return False + # def check_fund_name_as_share(self, fund_name: str) -> bool: + # """ + # Check if the fund name is the same as share name + # """ + # if len(fund_name) == 0 == 0: + # return False + # share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist() + # if len(share_name_list) == 0: + # return False + # max_similarity_name, max_similarity = get_most_similar_name( + # text=fund_name, + # name_list=share_name_list, + # share_name=None, + # fund_name=None, + # matching_type="share", + # process_cache=None) + # if max_similarity >= 0.8: + # return True + # return False def get_datapoints_by_page_num(self, page_num: int) -> list: datapoints = [] @@ -2165,13 +2123,6 @@ class DataExtraction: for datapoint in datapoints: investment_level = self.datapoint_level_config.get(datapoint, "") if investment_level == "fund_level": - # fund_level_example_list = output_requirement.get("fund_level", []) - # for example in fund_level_example_list: - # try: - # sub_example_list = json.loads(example) - # except: - # sub_example_list = json_repair.loads(example) - # example_list.extend(sub_example_list) fund_datapoint_value_example[datapoint] = fund_level_config.get( f"{datapoint}_value", [] ) @@ -2228,131 +2179,4 @@ class DataExtraction: instructions.append("Answer:\n") instructions_text = "".join(instructions) - return instructions_text - - # def chat_by_split_context(self, - # page_text: str, - # page_datapoints: list, - # need_exclude: bool, - # exclude_data: list) -> list: - # """ - # If occur error, split the context to two parts and try to get data from the two parts - # Relevant document: 503194284, page index 147 - # """ - # try: - # logger.info(f"Split context to get data to fix issue which output length is over 4K tokens") - # split_context = re.split(r"\n", page_text) - # split_context = [text.strip() for text in split_context - # if len(text.strip()) > 0] - # if len(split_context) < 10: - # return {"data": []} - - # split_context_len = len(split_context) - # top_10_context = split_context[:10] - # rest_context = split_context[10:] - # header = "\n".join(top_10_context) - # half_len = split_context_len // 2 - # # the member of half_len should not start with number - # # reverse iterate the list by half_len - # half_len_list = [i for i in range(half_len)] - - # fund_name_line = "" - # half_line = rest_context[half_len].strip() - # max_similarity_fund_name, max_similarity = get_most_similar_name( - # half_line, self.provider_fund_name_list, matching_type="fund" - # ) - # if max_similarity < 0.2: - # # get the fund name line text from the first half - # for index in reversed(half_len_list): - # line_text = rest_context[index].strip() - # if len(line_text) == 0: - # continue - # line_text_split = line_text.split() - # if len(line_text_split) < 3: - # continue - # first_word = line_text_split[0] - # if first_word.lower() == "class": - # continue - - # max_similarity_fund_name, max_similarity = get_most_similar_name( - # line_text, self.provider_fund_name_list, matching_type="fund" - # ) - # if max_similarity >= 0.2: - # fund_name_line = line_text - # break - # else: - # fund_name_line = half_line - # half_len += 1 - # if fund_name_line == "": - # return {"data": []} - - # logger.info(f"Split first part from 0 to {half_len}") - # split_first_part = "\n".join(split_context[:half_len]) - # first_part = '\n'.join(split_first_part) - # first_instructions = self.get_instructions_by_datapoints( - # first_part, page_datapoints, need_exclude, exclude_data, extract_way="text" - # ) - # response, with_error = chat( - # first_instructions, response_format={"type": "json_object"} - # ) - # first_part_data = {"data": []} - # if not with_error: - # try: - # first_part_data = json.loads(response) - # except: - # first_part_data = json_repair.loads(response) - - # logger.info(f"Split second part from {half_len} to {split_context_len}") - # split_second_part = "\n".join(split_context[half_len:]) - # second_part = header + "\n" + fund_name_line + "\n" + split_second_part - # second_instructions = self.get_instructions_by_datapoints( - # second_part, page_datapoints, need_exclude, exclude_data, extract_way="text" - # ) - # response, with_error = chat( - # second_instructions, response_format={"type": "json_object"} - # ) - # second_part_data = {"data": []} - # if not with_error: - # try: - # second_part_data = json.loads(response) - # except: - # second_part_data = json_repair.loads(response) - - # first_part_data_list = first_part_data.get("data", []) - # logger.info(f"First part data count: {len(first_part_data_list)}") - # second_part_data_list = second_part_data.get("data", []) - # logger.info(f"Second part data count: {len(second_part_data_list)}") - # for first_data in first_part_data_list: - # if first_data in second_part_data_list: - # second_part_data_list.remove(first_data) - # else: - # # if the first part data is with same fund name and share name, - # # remove the second part data - # first_data_dp = [key for key in list(first_data.keys()) - # if key not in ["fund name", "share name"]] - # # order the data points - # first_data_dp.sort() - # first_fund_name = first_data.get("fund name", "") - # first_share_name = first_data.get("share name", "") - # if len(first_fund_name) > 0 and len(first_share_name) > 0: - # remove_second_list = [] - # for second_data in second_part_data_list: - # second_fund_name = second_data.get("fund name", "") - # second_share_name = second_data.get("share name", "") - # if first_fund_name == second_fund_name and \ - # first_share_name == second_share_name: - # second_data_dp = [key for key in list(second_data.keys()) - # if key not in ["fund name", "share name"]] - # second_data_dp.sort() - # if first_data_dp == second_data_dp: - # remove_second_list.append(second_data) - # for remove_second in remove_second_list: - # if remove_second in second_part_data_list: - # second_part_data_list.remove(remove_second) - - # data_list = first_part_data_list + second_part_data_list - # extract_data = {"data": data_list} - # return extract_data - # except Exception as e: - # logger.error(f"Error in split context: {e}") - # return {"data": []} + return instructions_text \ No newline at end of file diff --git a/core/page_filter.py b/core/page_filter.py index ecffccc..5814a18 100644 --- a/core/page_filter.py +++ b/core/page_filter.py @@ -15,7 +15,6 @@ class FilterPages: self, doc_id: str, pdf_file: str, - document_mapping_info_df: pd.DataFrame, doc_source: str = "emea_ar", output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/", ) -> None: @@ -32,10 +31,7 @@ class FilterPages: else: self.apply_pdf2html = False self.page_text_dict = self.get_pdf_page_text_dict() - if document_mapping_info_df is None or len(document_mapping_info_df) == 0: - self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False) - else: - self.document_mapping_info_df = document_mapping_info_df + self.get_configuration_from_file() self.doc_info = self.get_doc_info() self.datapoint_config, self.datapoint_exclude_config = ( @@ -138,7 +134,7 @@ class FilterPages: self.datapoint_type_config = json.load(file) def get_doc_info(self) -> dict: - if len(self.document_mapping_info_df) == 0: + if self.doc_source == "emea_ar": return { "effective_date": None, "document_type": "ar", @@ -146,22 +142,16 @@ class FilterPages: "language": "english", "domicile": "LUX", } - effective_date = self.document_mapping_info_df["EffectiveDate"].iloc[0] - document_type = self.document_mapping_info_df["DocumentType"].iloc[0] - if document_type in [4, 5] or self.doc_source == "emea_ar": - document_type = "ar" - elif document_type == 1 or self.doc_source == "aus_prospectus": - document_type = "prospectus" - language_id = self.document_mapping_info_df["Language"].iloc[0] - language = self.language_config.get(language_id, None) - domicile = self.document_mapping_info_df["Domicile"].iloc[0] - return { - "effective_date": effective_date, - "document_type": document_type, - "language_id": language_id, - "language": language, - "domicile": domicile, - } + elif self.doc_source == "aus_prospectus": + return { + "effective_date": None, + "document_type": "prospectus", + "language_id": "0L00000122", + "language": "english", + "domicile": "AUS", + } + else: + raise ValueError(f"Invalid doc_source: {self.doc_source}") def get_datapoint_config(self) -> dict: domicile = self.doc_info.get("domicile", None) diff --git a/mini_main.py b/mini_main.py index a078573..ee81adb 100644 --- a/mini_main.py +++ b/mini_main.py @@ -29,19 +29,17 @@ class EMEA_AR_Parsing: pdf_folder: str = r"/data/emea_ar/pdf/", output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/", output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/", - output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/", extract_way: str = "text", drilldown_folder: str = r"/data/emea_ar/output/drilldown/", - compare_with_provider: bool = True + text_model: str = "qwen-plus", + image_model: str = "qwen-vl-plus", ) -> None: self.doc_id = doc_id self.doc_source = doc_source self.pdf_folder = pdf_folder os.makedirs(self.pdf_folder, exist_ok=True) - self.compare_with_provider = compare_with_provider - + self.pdf_file = self.download_pdf() - self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False) if extract_way is None or len(extract_way) == 0: extract_way = "text" @@ -64,21 +62,9 @@ class EMEA_AR_Parsing: self.output_extract_data_folder = output_extract_data_folder os.makedirs(self.output_extract_data_folder, exist_ok=True) - if output_mapping_data_folder is None or len(output_mapping_data_folder) == 0: - output_mapping_data_folder = r"/data/emea_ar/output/mapping_data/docs/" - if not output_mapping_data_folder.endswith("/"): - output_mapping_data_folder = f"{output_mapping_data_folder}/" - if extract_way is not None and len(extract_way) > 0: - output_mapping_data_folder = ( - f"{output_mapping_data_folder}by_{extract_way}/" - ) - self.output_mapping_data_folder = output_mapping_data_folder - os.makedirs(self.output_mapping_data_folder, exist_ok=True) - self.filter_pages = FilterPages( self.doc_id, self.pdf_file, - self.document_mapping_info_df, self.doc_source, output_pdf_text_folder, ) @@ -100,6 +86,8 @@ class EMEA_AR_Parsing: self.apply_drilldown = misc_config.get("apply_drilldown", False) else: self.apply_drilldown = False + self.text_model = text_model + self.image_model = image_model def download_pdf(self) -> str: pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id) @@ -144,9 +132,10 @@ class EMEA_AR_Parsing: self.page_text_dict, self.datapoint_page_info, self.datapoints, - self.document_mapping_info_df, extract_way=self.extract_way, output_image_folder=self.output_extract_image_folder, + text_model=self.text_model, + image_model=self.image_model, ) data_from_gpt = data_extraction.extract_data() except Exception as e: @@ -266,70 +255,6 @@ class EMEA_AR_Parsing: logger.error(f"Error: {e}") return annotation_list - def mapping_data(self, data_from_gpt: list, re_run: bool = False) -> list: - if not re_run: - output_data_json_folder = os.path.join( - self.output_mapping_data_folder, "json/" - ) - os.makedirs(output_data_json_folder, exist_ok=True) - json_file = os.path.join(output_data_json_folder, f"{self.doc_id}.json") - if os.path.exists(json_file): - logger.info( - f"The fund/ share of this document: {self.doc_id} has been mapped, loading data from {json_file}" - ) - with open(json_file, "r", encoding="utf-8") as f: - doc_mapping_data = json.load(f) - if self.doc_source == "aus_prospectus": - output_data_folder_splits = output_data_json_folder.split("output") - if len(output_data_folder_splits) == 2: - merged_data_folder = f'{output_data_folder_splits[0]}output/merged_data/docs/' - os.makedirs(merged_data_folder, exist_ok=True) - - merged_data_json_folder = os.path.join(merged_data_folder, "json/") - os.makedirs(merged_data_json_folder, exist_ok=True) - - merged_data_excel_folder = os.path.join(merged_data_folder, "excel/") - os.makedirs(merged_data_excel_folder, exist_ok=True) - - merged_data_file = os.path.join(merged_data_json_folder, f"merged_{self.doc_id}.json") - if os.path.exists(merged_data_file): - with open(merged_data_file, "r", encoding="utf-8") as f: - merged_data_list = json.load(f) - return merged_data_list - else: - data_mapping = DataMapping( - self.doc_id, - self.datapoints, - data_from_gpt, - self.document_mapping_info_df, - self.output_mapping_data_folder, - self.doc_source, - compare_with_provider=self.compare_with_provider - ) - merged_data_list = data_mapping.merge_output_data_aus_prospectus(doc_mapping_data, - merged_data_json_folder, - merged_data_excel_folder) - return merged_data_list - else: - return doc_mapping_data - """ - doc_id, - datapoints: list, - raw_document_data_list: list, - document_mapping_info_df: pd.DataFrame, - output_data_folder: str, - """ - data_mapping = DataMapping( - self.doc_id, - self.datapoints, - data_from_gpt, - self.document_mapping_info_df, - self.output_mapping_data_folder, - self.doc_source, - compare_with_provider=self.compare_with_provider - ) - return data_mapping.mapping_raw_data_entrance() - def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None: logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}") @@ -347,6 +272,8 @@ def extract_data( output_data_folder: str, extract_way: str = "text", re_run: bool = False, + text_model: str = "qwen-plus", + image_model: str = "qwen-vl-plus", ) -> None: logger.info(f"Extract EMEA AR data for doc_id: {doc_id}") emea_ar_parsing = EMEA_AR_Parsing( @@ -355,6 +282,8 @@ def extract_data( pdf_folder=pdf_folder, output_extract_data_folder=output_data_folder, extract_way=extract_way, + text_model=text_model, + image_model=image_model, ) data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run) return data_from_gpt, annotation_list @@ -368,6 +297,8 @@ def batch_extract_data( extract_way: str = "text", special_doc_id_list: list = None, re_run: bool = False, + text_model: str = "qwen-plus", + image_model: str = "qwen-vl-plus", ) -> None: pdf_files = glob(pdf_folder + "*.pdf") doc_list = [] @@ -391,6 +322,8 @@ def batch_extract_data( output_data_folder=output_child_folder, extract_way=extract_way, re_run=re_run, + text_model=text_model, + image_model=image_model, ) result_list.extend(data_from_gpt) @@ -421,31 +354,35 @@ def test_translate_pdf(): if __name__ == "__main__": os.environ["SSL_CERT_FILE"] = certifi.where() - doc_source = "aus_prospectus" + # doc_source = "aus_prospectus" + doc_source = "emea_ar" re_run = True extract_way = "text" if doc_source == "aus_prospectus": - special_doc_id_list = ["539266874"] - pdf_folder: str = r"/data/aus_prospectus/pdf/" - output_pdf_text_folder: str = r"/data/aus_prospectus/output/pdf_text/" + special_doc_id_list = ["412778803", "539266874"] + pdf_folder: str = r"./data/aus_prospectus/pdf/" + output_pdf_text_folder: str = r"./data/aus_prospectus/output/pdf_text/" output_child_folder: str = ( - r"/data/aus_prospectus/output/extract_data/docs/" + r"./data/aus_prospectus/output/extract_data/docs/" ) output_total_folder: str = ( - r"/data/aus_prospectus/output/extract_data/total/" + r"./data/aus_prospectus/output/extract_data/total/" ) elif doc_source == "emea_ar": special_doc_id_list = ["514636993"] - pdf_folder: str = r"/data/emea_ar/pdf/" + pdf_folder: str = r"./data/emea_ar/pdf/" output_child_folder: str = ( - r"/data/emea_ar/output/extract_data/docs/" + r"./data/emea_ar/output/extract_data/docs/" ) output_total_folder: str = ( - r"/data/emea_ar/output/extract_data/total/" + r"./data/emea_ar/output/extract_data/total/" ) else: raise ValueError(f"Invalid doc_source: {doc_source}") + # text_model = "qwen-plus" + text_model = "qwen-max" + image_model = "qwen-vl-plus" batch_extract_data( pdf_folder=pdf_folder, doc_source=doc_source, @@ -454,6 +391,8 @@ if __name__ == "__main__": extract_way=extract_way, special_doc_id_list=special_doc_id_list, re_run=re_run, + text_model=text_model, + image_model=image_model, ) diff --git a/utils/qwen_utils.py b/utils/qwen_utils.py new file mode 100644 index 0000000..610a89f --- /dev/null +++ b/utils/qwen_utils.py @@ -0,0 +1,148 @@ +import requests +import json +import os +from bs4 import BeautifulSoup +import time +from time import sleep +from datetime import datetime +import pytz +import pandas as pd +import dashscope +import dotenv +import base64 +dotenv.load_dotenv() + + +ali_api_key = os.getenv("ALI_API_KEY_QWEN") + + +def chat( + prompt: str, + text_model: str = "qwen-plus", + image_model: str = "qwen-vl-plus", + image_file: str = None, + image_base64: str = None, + enable_search: bool = False, +): + try: + token = 0 + if ( + image_base64 is None + and image_file is not None + and len(image_file) > 0 + and os.path.exists(image_file) + ): + image_base64 = encode_image(image_file) + + use_image_model = False + if image_base64 is not None and len(image_base64) > 0: + use_image_model = True + messages = [ + { + "role": "user", + "content": [ + {"text": prompt}, + { + "image": f"data:image/png;base64,{image_base64}", + }, + ], + } + ] + count = 0 + while count < 3: + try: + print(f"调用阿里云Qwen模型, 次数: {count + 1}") + response = dashscope.MultiModalConversation.call( + api_key=ali_api_key, + model=image_model, + messages=messages, + ) + if response.status_code == 200: + break + else: + print(f"调用阿里云Qwen模型失败: {response.code} {response.message}") + count += 1 + sleep(2) + except Exception as e: + print(f"调用阿里云Qwen模型失败: {e}") + count += 1 + sleep(2) + if response.status_code == 200: + image_text = ( + response.get("output", {}) + .get("choices", [])[0] + .get("message", {}) + .get("content", "") + ) + temp_image_text = "" + if isinstance(image_text, list): + for item in image_text: + if isinstance(item, dict): + temp_image_text += item.get("text", "") + "\n\n" + elif isinstance(item, str): + temp_image_text += item + "\n\n" + else: + pass + response_contents = temp_image_text.strip() + token = response.get("usage", {}).get("total_tokens", 0) + else: + response_contents = f"{response.code} {response.message} 无法分析图片" + token = 0 + else: + messages = [{"role": "user", "content": prompt}] + count = 0 + while count < 3: + try: + print(f"调用阿里云Qwen模型, 次数: {count + 1}") + response = dashscope.Generation.call( + api_key=ali_api_key, + model=text_model, + messages=messages, + enable_search=enable_search, + search_options={"forced_search": enable_search}, # 强制联网搜索 + result_format="message", + ) + if response.status_code == 200: + break + else: + print(f"调用阿里云Qwen模型失败: {response.code} {response.message}") + count += 1 + sleep(2) + except Exception as e: + print(f"调用阿里云Qwen模型失败: {e}") + count += 1 + sleep(2) + + # 获取response的token + if response.status_code == 200: + response_contents = ( + response.get("output", {}) + .get("choices", [])[0] + .get("message", {}) + .get("content", "") + ) + token = response.get("usage", {}).get("total_tokens", 0) + else: + response_contents = f"{response.code} {response.message}" + token = 0 + result = {} + if use_image_model: + result["model"] = image_model + else: + result["model"] = text_model + result["response"] = response_contents + result["prompt_token"] = response.get("usage", {}).get("input_tokens", 0) + result["completion_token"] = response.get("usage", {}).get("output_tokens", 0) + result["total_token"] = token + sleep(2) + return result, False + except Exception as e: + print(f"调用阿里云Qwen模型失败: {e}") + return {}, True + + +def encode_image(image_path: str): + if image_path is None or len(image_path) == 0 or not os.path.exists(image_path): + return None + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") \ No newline at end of file