update for apply ALI QWEN as Demo

This commit is contained in:
blade 2025-11-11 13:33:57 +08:00
parent 255752c848
commit ea81197bcd
6 changed files with 291 additions and 390 deletions

View File

@ -1,4 +1,4 @@
{ {
"apply_pdf2html": true, "apply_pdf2html": false,
"apply_drilldown": false "apply_drilldown": false
} }

View File

@ -1,4 +1,4 @@
{ {
"apply_pdf2html": false, "apply_pdf2html": false,
"apply_drilldown": true "apply_drilldown": false
} }

View File

@ -5,9 +5,8 @@ import re
import fitz import fitz
import pandas as pd import pandas as pd
from traceback import print_exc from traceback import print_exc
from utils.gpt_utils import chat from utils.qwen_utils import chat
from utils.pdf_util import PDFUtil from utils.pdf_util import PDFUtil
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
from utils.logger import logger from utils.logger import logger
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, \ from utils.biz_utils import add_slash_to_text_as_regex, clean_text, \
get_most_similar_name, remove_abundant_data, replace_special_table_header get_most_similar_name, remove_abundant_data, replace_special_table_header
@ -23,11 +22,20 @@ class DataExtraction:
page_text_dict: dict, page_text_dict: dict,
datapoint_page_info: dict, datapoint_page_info: dict,
datapoints: list, datapoints: list,
document_mapping_info_df: pd.DataFrame,
extract_way: str = "text", extract_way: str = "text",
output_image_folder: str = None, output_image_folder: str = None,
text_model: str = "qwen-plus",
image_model: str = "qwen-vl-plus",
) -> None: ) -> None:
self.doc_source = doc_source self.doc_source = doc_source
if self.doc_source == "aus_prospectus":
self.document_type = 1
elif self.doc_source == "emea_ar":
self.document_type = 2
else:
raise ValueError(f"Invalid document source: {self.doc_source}")
self.text_model = text_model
self.image_model = image_model
self.doc_id = doc_id self.doc_id = doc_id
self.pdf_file = pdf_file self.pdf_file = pdf_file
self.configuration_folder = f"./configuration/{doc_source}/" self.configuration_folder = f"./configuration/{doc_source}/"
@ -46,26 +54,7 @@ class DataExtraction:
self.page_text_dict = self.get_pdf_page_text_dict() self.page_text_dict = self.get_pdf_page_text_dict()
else: else:
self.page_text_dict = page_text_dict self.page_text_dict = page_text_dict
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
else:
self.document_mapping_info_df = document_mapping_info_df
self.fund_name_list = self.document_mapping_info_df["FundName"].unique().tolist()
# get document type by DocumentType in self.document_mapping_info_df
self.document_type = int(self.document_mapping_info_df["DocumentType"].iloc[0])
self.investment_objective_pages = []
if self.document_type == 1:
self.investment_objective_pages = self.get_investment_objective_pages()
self.provider_mapping_df = self.get_provider_mapping()
if len(self.provider_mapping_df) == 0:
self.provider_fund_name_list = []
else:
self.provider_fund_name_list = (
self.provider_mapping_df["FundName"].unique().tolist()
)
self.document_category, self.document_production = self.get_document_category_production() self.document_category, self.document_production = self.get_document_category_production()
self.datapoint_page_info = self.get_datapoint_page_info(datapoint_page_info) self.datapoint_page_info = self.get_datapoint_page_info(datapoint_page_info)
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info() self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
@ -77,11 +66,14 @@ class DataExtraction:
self.replace_table_header_config = self.get_replace_table_header_config() self.replace_table_header_config = self.get_replace_table_header_config()
self.special_datapoint_feature_config = self.get_special_datapoint_feature_config() self.special_datapoint_feature_config = self.get_special_datapoint_feature_config()
self.special_datapoint_feature = self.init_special_datapoint_feature() self.special_datapoint_feature = self.init_special_datapoint_feature()
self.investment_objective_pages = self.get_investment_objective_pages()
self.datapoint_reported_name_config, self.non_english_reported_name_config = \ self.datapoint_reported_name_config, self.non_english_reported_name_config = \
self.get_datapoint_reported_name() self.get_datapoint_reported_name()
self.extract_way = extract_way self.extract_way = extract_way
self.output_image_folder = output_image_folder self.output_image_folder = output_image_folder
def get_special_datapoint_feature_config(self) -> dict: def get_special_datapoint_feature_config(self) -> dict:
special_datapoint_feature_config_file = os.path.join(self.configuration_folder, "special_datapoint_feature.json") special_datapoint_feature_config_file = os.path.join(self.configuration_folder, "special_datapoint_feature.json")
@ -118,17 +110,27 @@ class DataExtraction:
if len(document_category_prompt) > 0: if len(document_category_prompt) > 0:
prompts = f"Context: \n{first_4_page_text}\n\Instructions: \n{document_category_prompt}" prompts = f"Context: \n{first_4_page_text}\n\Instructions: \n{document_category_prompt}"
result, with_error = chat( result, with_error = chat(
prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000 prompt=prompts, text_model=self.text_model, image_model=self.image_model
) )
response = result.get("response", "") response = result.get("response", "")
if not with_error: if not with_error:
try: try:
if response.startswith("```json"):
response = response.replace("```json", "").replace("```", "").strip()
if response.startswith("```JSON"):
response = response.replace("```JSON", "").replace("```", "").strip()
if response.startswith("```"):
response = response.replace("```", "").strip()
data = json.loads(response) data = json.loads(response)
document_category = data.get("document_category", None) document_category = data.get("document_category", None)
document_production = data.get("document_production", None) document_production = data.get("document_production", None)
except: except:
pass pass
if document_category is None or len(document_category) == 0:
print(f"Document category is None or empty, use default value: Super")
document_category = "Super"
if document_production is None or len(document_production) == 0:
document_production = "AUS"
return document_category, document_production return document_category, document_production
def get_objective_fund_name(self, page_text: str) -> str: def get_objective_fund_name(self, page_text: str) -> str:
@ -142,11 +144,17 @@ class DataExtraction:
if len(objective_fund_name_prompt) > 0: if len(objective_fund_name_prompt) > 0:
prompts = f"Context: \n{page_text}\n\Instructions: \n{objective_fund_name_prompt}" prompts = f"Context: \n{page_text}\n\Instructions: \n{objective_fund_name_prompt}"
result, with_error = chat( result, with_error = chat(
prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000 prompt=prompts, text_model=self.text_model, image_model=self.image_model
) )
response = result.get("response", "") response = result.get("response", "")
if not with_error: if not with_error:
try: try:
if response.startswith("```json"):
response = response.replace("```json", "").replace("```", "").strip()
if response.startswith("```JSON"):
response = response.replace("```JSON", "").replace("```", "").strip()
if response.startswith("```"):
response = response.replace("```", "").strip()
data = json.loads(response) data = json.loads(response)
fund_name = data.get("fund_name", "") fund_name = data.get("fund_name", "")
except: except:
@ -187,8 +195,8 @@ class DataExtraction:
with open(language_config_file, "r", encoding="utf-8") as file: with open(language_config_file, "r", encoding="utf-8") as file:
self.language_config = json.load(file) self.language_config = json.load(file)
self.language_id = self.document_mapping_info_df["Language"].iloc[0] self.language_id = "0L00000122"
self.language = self.language_config.get(self.language_id, None) self.language = "english"
datapoint_reported_name_config_file = os.path.join(self.configuration_folder, "datapoint_reported_name.json") datapoint_reported_name_config_file = os.path.join(self.configuration_folder, "datapoint_reported_name.json")
all_datapoint_reported_name = {} all_datapoint_reported_name = {}
@ -210,20 +218,6 @@ class DataExtraction:
reported_name_list.sort() reported_name_list.sort()
datapoint_reported_name_config[datapoint] = reported_name_list datapoint_reported_name_config[datapoint] = reported_name_list
return datapoint_reported_name_config, non_english_reported_name_config return datapoint_reported_name_config, non_english_reported_name_config
def get_provider_mapping(self):
if len(self.document_mapping_info_df) == 0:
return pd.DataFrame()
provider_id_list = (
self.document_mapping_info_df["ProviderId"].unique().tolist()
)
provider_mapping_list = []
for provider_id in provider_id_list:
provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False))
provider_mapping_df = pd.concat(provider_mapping_list)
provider_mapping_df = provider_mapping_df.drop_duplicates()
provider_mapping_df.reset_index(drop=True, inplace=True)
return provider_mapping_df
def get_pdf_image_base64(self, page_index: int) -> dict: def get_pdf_image_base64(self, page_index: int) -> dict:
pdf_util = PDFUtil(self.pdf_file) pdf_util = PDFUtil(self.pdf_file)
@ -557,8 +551,6 @@ class DataExtraction:
""" """
If some datapoint with production name, then each fund/ share class in the same document for the datapoint should be with same value. If some datapoint with production name, then each fund/ share class in the same document for the datapoint should be with same value.
""" """
if len(self.fund_name_list) < 3:
return data_list, []
raw_name_dict = self.get_raw_name_dict(data_list) raw_name_dict = self.get_raw_name_dict(data_list)
raw_name_list = list(raw_name_dict.keys()) raw_name_list = list(raw_name_dict.keys())
if len(raw_name_list) < 3: if len(raw_name_list) < 3:
@ -1125,11 +1117,17 @@ class DataExtraction:
if len(compare_table_structure_prompts) > 0: if len(compare_table_structure_prompts) > 0:
prompts = f"Context: \ncurrent page contents:\n{current_page_text}\nnext page contents:\n{next_page_text}\nInstructions:\n{compare_table_structure_prompts}\n" prompts = f"Context: \ncurrent page contents:\n{current_page_text}\nnext page contents:\n{next_page_text}\nInstructions:\n{compare_table_structure_prompts}\n"
result, with_error = chat( result, with_error = chat(
prompt=prompts, response_format={"type": "json_object"}, max_tokens=100 prompt=prompts, text_model="qwen-plus", image_model="qwen-vl-plus"
) )
response = result.get("response", "") response = result.get("response", "")
if not with_error: if not with_error:
try: try:
if response.startswith("```json"):
response = response.replace("```json", "").replace("```", "").strip()
if response.startswith("```JSON"):
response = response.replace("```JSON", "").replace("```", "").strip()
if response.startswith("```"):
response = response.replace("```", "").strip()
data = json.loads(response) data = json.loads(response)
answer = data.get("answer", "No") answer = data.get("answer", "No")
if answer.lower() == "yes": if answer.lower() == "yes":
@ -1300,9 +1298,6 @@ class DataExtraction:
""" """
logger.info(f"Extracting data from page {page_num}") logger.info(f"Extracting data from page {page_num}")
if self.document_type == 1: if self.document_type == 1:
# pre_context = f"The document type is prospectus. \nThe fund names in this document are {', '.join(self.fund_name_list)}."
# if pre_context in page_text:
# page_text = page_text.replace(pre_context, "\n").strip()
pre_context = "" pre_context = ""
if len(self.investment_objective_pages) > 0: if len(self.investment_objective_pages) > 0:
# Get the page number of the most recent investment objective at the top of the current page. # Get the page number of the most recent investment objective at the top of the current page.
@ -1330,8 +1325,9 @@ class DataExtraction:
extract_way="text" extract_way="text"
) )
result, with_error = chat( result, with_error = chat(
prompt=instructions, response_format={"type": "json_object"} prompt=instructions, text_model=self.text_model, image_model=self.image_model
) )
response = result.get("response", "") response = result.get("response", "")
if with_error: if with_error:
logger.error(f"Error in extracting tables from page") logger.error(f"Error in extracting tables from page")
@ -1346,8 +1342,15 @@ class DataExtraction:
data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["prompt_token"] = result.get("prompt_token", 0)
data_dict["completion_token"] = result.get("completion_token", 0) data_dict["completion_token"] = result.get("completion_token", 0)
data_dict["total_token"] = result.get("total_token", 0) data_dict["total_token"] = result.get("total_token", 0)
data_dict["model"] = result.get("model", "")
return data_dict return data_dict
try: try:
if response.startswith("```json"):
response = response.replace("```json", "").replace("```", "").strip()
if response.startswith("```JSON"):
response = response.replace("```JSON", "").replace("```", "").strip()
if response.startswith("```"):
response = response.replace("```", "").strip()
data = json.loads(response) data = json.loads(response)
except: except:
try: try:
@ -1388,6 +1391,7 @@ class DataExtraction:
data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["prompt_token"] = result.get("prompt_token", 0)
data_dict["completion_token"] = result.get("completion_token", 0) data_dict["completion_token"] = result.get("completion_token", 0)
data_dict["total_token"] = result.get("total_token", 0) data_dict["total_token"] = result.get("total_token", 0)
data_dict["model"] = result.get("model", "")
return data_dict return data_dict
def extract_data_by_page_image( def extract_data_by_page_image(
@ -1418,6 +1422,7 @@ class DataExtraction:
data_dict["prompt_token"] = 0 data_dict["prompt_token"] = 0
data_dict["completion_token"] = 0 data_dict["completion_token"] = 0
data_dict["total_token"] = 0 data_dict["total_token"] = 0
data_dict["model"] = self.image_model
return data_dict return data_dict
else: else:
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0: if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
@ -1463,7 +1468,7 @@ class DataExtraction:
extract_way="image" extract_way="image"
) )
result, with_error = chat( result, with_error = chat(
prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64 prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64
) )
response = result.get("response", "") response = result.get("response", "")
if with_error: if with_error:
@ -1479,8 +1484,15 @@ class DataExtraction:
data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["prompt_token"] = result.get("prompt_token", 0)
data_dict["completion_token"] = result.get("completion_token", 0) data_dict["completion_token"] = result.get("completion_token", 0)
data_dict["total_token"] = result.get("total_token", 0) data_dict["total_token"] = result.get("total_token", 0)
data_dict["model"] = result.get("model", "")
return data_dict return data_dict
try: try:
if response.startswith("```json"):
response = response.replace("```json", "").replace("```", "").strip()
if response.startswith("```JSON"):
response = response.replace("```JSON", "").replace("```", "").strip()
if response.startswith("```"):
response = response.replace("```", "").strip()
data = json.loads(response) data = json.loads(response)
except: except:
try: try:
@ -1508,6 +1520,7 @@ class DataExtraction:
data_dict["prompt_token"] = result.get("prompt_token", 0) data_dict["prompt_token"] = result.get("prompt_token", 0)
data_dict["completion_token"] = result.get("completion_token", 0) data_dict["completion_token"] = result.get("completion_token", 0)
data_dict["total_token"] = result.get("total_token", 0) data_dict["total_token"] = result.get("total_token", 0)
data_dict["model"] = result.get("model", "")
return data_dict return data_dict
def get_image_text(self, page_num: int) -> str: def get_image_text(self, page_num: int) -> str:
@ -1515,13 +1528,19 @@ class DataExtraction:
instructions = self.instructions_config.get("get_image_text", "\n") instructions = self.instructions_config.get("get_image_text", "\n")
logger.info(f"Get text from image of page {page_num}") logger.info(f"Get text from image of page {page_num}")
result, with_error = chat( result, with_error = chat(
prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64 prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64
) )
response = result.get("response", "") response = result.get("response", "")
text = "" text = ""
if with_error: if with_error:
logger.error(f"Can't get text from current image") logger.error(f"Can't get text from current image")
try: try:
if response.startswith("```json"):
response = response.replace("```json", "").replace("```", "").strip()
if response.startswith("```JSON"):
response = response.replace("```JSON", "").replace("```", "").strip()
if response.startswith("```"):
response = response.replace("```", "").strip()
data = json.loads(response) data = json.loads(response)
except: except:
try: try:
@ -1599,11 +1618,11 @@ class DataExtraction:
ter_search = re.search(ter_regex, page_text) ter_search = re.search(ter_regex, page_text)
if ter_search is not None: if ter_search is not None:
include_key_words = True include_key_words = True
if not include_key_words: # if not include_key_words:
is_share_name = self.check_fund_name_as_share(raw_fund_name) # is_share_name = self.check_fund_name_as_share(raw_fund_name)
if not is_share_name: # if not is_share_name:
remove_list.append(data) # remove_list.append(data)
break # break
data["share name"] = raw_fund_name data["share name"] = raw_fund_name
if data.get(key, "") == "": if data.get(key, "") == "":
data.pop(key) data.pop(key)
@ -1723,73 +1742,12 @@ class DataExtraction:
new_data[key] = value new_data[key] = value
new_data_list.append(new_data) new_data_list.append(new_data)
extract_data_info["data"] = new_data_list extract_data_info["data"] = new_data_list
if page_text is not None and len(page_text) > 0: # if page_text is not None and len(page_text) > 0:
try: # try:
self.set_datapoint_feature_properties(new_data_list, page_text, page_num) # self.set_datapoint_feature_properties(new_data_list, page_text, page_num)
except Exception as e: # except Exception as e:
logger.error(f"Error in setting datapoint feature properties: {e}") # logger.error(f"Error in setting datapoint feature properties: {e}")
return extract_data_info return extract_data_info
def set_datapoint_feature_properties(self, data_list: list, page_text: str, page_num: int) -> None:
for feature, properties in self.special_datapoint_feature_config.items():
if self.special_datapoint_feature.get(feature, {}).get("page_index", None) is not None:
continue
provider_ids = properties.get("provider_ids", [])
if len(provider_ids) > 0:
is_current_provider = False
doc_provider_list = self.document_mapping_info_df["ProviderId"].unique().tolist()
if len(doc_provider_list) > 0:
for provider in provider_ids:
if provider in doc_provider_list:
is_current_provider = True
break
if not is_current_provider:
continue
detail_list = properties.get("details", [])
if len(detail_list) == 0:
continue
set_feature_property = False
for detail in detail_list:
regex_text_list = detail.get("regex_text", [])
if len(regex_text_list) == 0:
continue
effective_datapoints = detail.get("effective_datapoints", [])
if len(effective_datapoints) == 0:
continue
exclude_datapoints = detail.get("exclude_datapoints", [])
exist_effective_datapoints = False
exist_exclude_datapoints = False
for data_item in data_list:
datapoints = [datapoint for datapoint in list(data_item.keys())
if datapoint in effective_datapoints]
if len(datapoints) > 0:
exist_effective_datapoints = True
datapoints = [datapoint for datapoint in list(data_item.keys())
if datapoint in exclude_datapoints]
if len(datapoints) > 0:
exist_exclude_datapoints = True
if exist_effective_datapoints and exist_exclude_datapoints:
break
if not exist_effective_datapoints:
continue
if exist_exclude_datapoints:
continue
found_regex_text = False
for regex_text in regex_text_list:
regex_search = re.search(regex_text, page_text, re.IGNORECASE)
if regex_search is not None:
found_regex_text = True
break
if found_regex_text:
if self.special_datapoint_feature[feature].get("page_index", None) is None:
self.special_datapoint_feature[feature]["page_index"] = []
self.special_datapoint_feature[feature]["datapoint"] = effective_datapoints[0]
self.special_datapoint_feature[feature]["page_index"].append(page_num)
set_feature_property = True
if set_feature_property:
break
def split_multi_share_name(self, raw_share_name: str) -> list: def split_multi_share_name(self, raw_share_name: str) -> list:
""" """
@ -1836,25 +1794,25 @@ class DataExtraction:
fund_name = f"{last_fund} {fund_feature}" fund_name = f"{last_fund} {fund_feature}"
return fund_name return fund_name
def check_fund_name_as_share(self, fund_name: str) -> bool: # def check_fund_name_as_share(self, fund_name: str) -> bool:
""" # """
Check if the fund name is the same as share name # Check if the fund name is the same as share name
""" # """
if len(fund_name) == 0 == 0: # if len(fund_name) == 0 == 0:
return False # return False
share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist() # share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
if len(share_name_list) == 0: # if len(share_name_list) == 0:
return False # return False
max_similarity_name, max_similarity = get_most_similar_name( # max_similarity_name, max_similarity = get_most_similar_name(
text=fund_name, # text=fund_name,
name_list=share_name_list, # name_list=share_name_list,
share_name=None, # share_name=None,
fund_name=None, # fund_name=None,
matching_type="share", # matching_type="share",
process_cache=None) # process_cache=None)
if max_similarity >= 0.8: # if max_similarity >= 0.8:
return True # return True
return False # return False
def get_datapoints_by_page_num(self, page_num: int) -> list: def get_datapoints_by_page_num(self, page_num: int) -> list:
datapoints = [] datapoints = []
@ -2165,13 +2123,6 @@ class DataExtraction:
for datapoint in datapoints: for datapoint in datapoints:
investment_level = self.datapoint_level_config.get(datapoint, "") investment_level = self.datapoint_level_config.get(datapoint, "")
if investment_level == "fund_level": if investment_level == "fund_level":
# fund_level_example_list = output_requirement.get("fund_level", [])
# for example in fund_level_example_list:
# try:
# sub_example_list = json.loads(example)
# except:
# sub_example_list = json_repair.loads(example)
# example_list.extend(sub_example_list)
fund_datapoint_value_example[datapoint] = fund_level_config.get( fund_datapoint_value_example[datapoint] = fund_level_config.get(
f"{datapoint}_value", [] f"{datapoint}_value", []
) )
@ -2228,131 +2179,4 @@ class DataExtraction:
instructions.append("Answer:\n") instructions.append("Answer:\n")
instructions_text = "".join(instructions) instructions_text = "".join(instructions)
return instructions_text return instructions_text
# def chat_by_split_context(self,
# page_text: str,
# page_datapoints: list,
# need_exclude: bool,
# exclude_data: list) -> list:
# """
# If occur error, split the context to two parts and try to get data from the two parts
# Relevant document: 503194284, page index 147
# """
# try:
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
# split_context = re.split(r"\n", page_text)
# split_context = [text.strip() for text in split_context
# if len(text.strip()) > 0]
# if len(split_context) < 10:
# return {"data": []}
# split_context_len = len(split_context)
# top_10_context = split_context[:10]
# rest_context = split_context[10:]
# header = "\n".join(top_10_context)
# half_len = split_context_len // 2
# # the member of half_len should not start with number
# # reverse iterate the list by half_len
# half_len_list = [i for i in range(half_len)]
# fund_name_line = ""
# half_line = rest_context[half_len].strip()
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# half_line, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity < 0.2:
# # get the fund name line text from the first half
# for index in reversed(half_len_list):
# line_text = rest_context[index].strip()
# if len(line_text) == 0:
# continue
# line_text_split = line_text.split()
# if len(line_text_split) < 3:
# continue
# first_word = line_text_split[0]
# if first_word.lower() == "class":
# continue
# max_similarity_fund_name, max_similarity = get_most_similar_name(
# line_text, self.provider_fund_name_list, matching_type="fund"
# )
# if max_similarity >= 0.2:
# fund_name_line = line_text
# break
# else:
# fund_name_line = half_line
# half_len += 1
# if fund_name_line == "":
# return {"data": []}
# logger.info(f"Split first part from 0 to {half_len}")
# split_first_part = "\n".join(split_context[:half_len])
# first_part = '\n'.join(split_first_part)
# first_instructions = self.get_instructions_by_datapoints(
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# first_instructions, response_format={"type": "json_object"}
# )
# first_part_data = {"data": []}
# if not with_error:
# try:
# first_part_data = json.loads(response)
# except:
# first_part_data = json_repair.loads(response)
# logger.info(f"Split second part from {half_len} to {split_context_len}")
# split_second_part = "\n".join(split_context[half_len:])
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
# second_instructions = self.get_instructions_by_datapoints(
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
# )
# response, with_error = chat(
# second_instructions, response_format={"type": "json_object"}
# )
# second_part_data = {"data": []}
# if not with_error:
# try:
# second_part_data = json.loads(response)
# except:
# second_part_data = json_repair.loads(response)
# first_part_data_list = first_part_data.get("data", [])
# logger.info(f"First part data count: {len(first_part_data_list)}")
# second_part_data_list = second_part_data.get("data", [])
# logger.info(f"Second part data count: {len(second_part_data_list)}")
# for first_data in first_part_data_list:
# if first_data in second_part_data_list:
# second_part_data_list.remove(first_data)
# else:
# # if the first part data is with same fund name and share name,
# # remove the second part data
# first_data_dp = [key for key in list(first_data.keys())
# if key not in ["fund name", "share name"]]
# # order the data points
# first_data_dp.sort()
# first_fund_name = first_data.get("fund name", "")
# first_share_name = first_data.get("share name", "")
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
# remove_second_list = []
# for second_data in second_part_data_list:
# second_fund_name = second_data.get("fund name", "")
# second_share_name = second_data.get("share name", "")
# if first_fund_name == second_fund_name and \
# first_share_name == second_share_name:
# second_data_dp = [key for key in list(second_data.keys())
# if key not in ["fund name", "share name"]]
# second_data_dp.sort()
# if first_data_dp == second_data_dp:
# remove_second_list.append(second_data)
# for remove_second in remove_second_list:
# if remove_second in second_part_data_list:
# second_part_data_list.remove(remove_second)
# data_list = first_part_data_list + second_part_data_list
# extract_data = {"data": data_list}
# return extract_data
# except Exception as e:
# logger.error(f"Error in split context: {e}")
# return {"data": []}

View File

@ -15,7 +15,6 @@ class FilterPages:
self, self,
doc_id: str, doc_id: str,
pdf_file: str, pdf_file: str,
document_mapping_info_df: pd.DataFrame,
doc_source: str = "emea_ar", doc_source: str = "emea_ar",
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/", output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
) -> None: ) -> None:
@ -32,10 +31,7 @@ class FilterPages:
else: else:
self.apply_pdf2html = False self.apply_pdf2html = False
self.page_text_dict = self.get_pdf_page_text_dict() self.page_text_dict = self.get_pdf_page_text_dict()
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
else:
self.document_mapping_info_df = document_mapping_info_df
self.get_configuration_from_file() self.get_configuration_from_file()
self.doc_info = self.get_doc_info() self.doc_info = self.get_doc_info()
self.datapoint_config, self.datapoint_exclude_config = ( self.datapoint_config, self.datapoint_exclude_config = (
@ -138,7 +134,7 @@ class FilterPages:
self.datapoint_type_config = json.load(file) self.datapoint_type_config = json.load(file)
def get_doc_info(self) -> dict: def get_doc_info(self) -> dict:
if len(self.document_mapping_info_df) == 0: if self.doc_source == "emea_ar":
return { return {
"effective_date": None, "effective_date": None,
"document_type": "ar", "document_type": "ar",
@ -146,22 +142,16 @@ class FilterPages:
"language": "english", "language": "english",
"domicile": "LUX", "domicile": "LUX",
} }
effective_date = self.document_mapping_info_df["EffectiveDate"].iloc[0] elif self.doc_source == "aus_prospectus":
document_type = self.document_mapping_info_df["DocumentType"].iloc[0] return {
if document_type in [4, 5] or self.doc_source == "emea_ar": "effective_date": None,
document_type = "ar" "document_type": "prospectus",
elif document_type == 1 or self.doc_source == "aus_prospectus": "language_id": "0L00000122",
document_type = "prospectus" "language": "english",
language_id = self.document_mapping_info_df["Language"].iloc[0] "domicile": "AUS",
language = self.language_config.get(language_id, None) }
domicile = self.document_mapping_info_df["Domicile"].iloc[0] else:
return { raise ValueError(f"Invalid doc_source: {self.doc_source}")
"effective_date": effective_date,
"document_type": document_type,
"language_id": language_id,
"language": language,
"domicile": domicile,
}
def get_datapoint_config(self) -> dict: def get_datapoint_config(self) -> dict:
domicile = self.doc_info.get("domicile", None) domicile = self.doc_info.get("domicile", None)

View File

@ -29,19 +29,17 @@ class EMEA_AR_Parsing:
pdf_folder: str = r"/data/emea_ar/pdf/", pdf_folder: str = r"/data/emea_ar/pdf/",
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/", output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/", output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/",
output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/",
extract_way: str = "text", extract_way: str = "text",
drilldown_folder: str = r"/data/emea_ar/output/drilldown/", drilldown_folder: str = r"/data/emea_ar/output/drilldown/",
compare_with_provider: bool = True text_model: str = "qwen-plus",
image_model: str = "qwen-vl-plus",
) -> None: ) -> None:
self.doc_id = doc_id self.doc_id = doc_id
self.doc_source = doc_source self.doc_source = doc_source
self.pdf_folder = pdf_folder self.pdf_folder = pdf_folder
os.makedirs(self.pdf_folder, exist_ok=True) os.makedirs(self.pdf_folder, exist_ok=True)
self.compare_with_provider = compare_with_provider
self.pdf_file = self.download_pdf() self.pdf_file = self.download_pdf()
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
if extract_way is None or len(extract_way) == 0: if extract_way is None or len(extract_way) == 0:
extract_way = "text" extract_way = "text"
@ -64,21 +62,9 @@ class EMEA_AR_Parsing:
self.output_extract_data_folder = output_extract_data_folder self.output_extract_data_folder = output_extract_data_folder
os.makedirs(self.output_extract_data_folder, exist_ok=True) os.makedirs(self.output_extract_data_folder, exist_ok=True)
if output_mapping_data_folder is None or len(output_mapping_data_folder) == 0:
output_mapping_data_folder = r"/data/emea_ar/output/mapping_data/docs/"
if not output_mapping_data_folder.endswith("/"):
output_mapping_data_folder = f"{output_mapping_data_folder}/"
if extract_way is not None and len(extract_way) > 0:
output_mapping_data_folder = (
f"{output_mapping_data_folder}by_{extract_way}/"
)
self.output_mapping_data_folder = output_mapping_data_folder
os.makedirs(self.output_mapping_data_folder, exist_ok=True)
self.filter_pages = FilterPages( self.filter_pages = FilterPages(
self.doc_id, self.doc_id,
self.pdf_file, self.pdf_file,
self.document_mapping_info_df,
self.doc_source, self.doc_source,
output_pdf_text_folder, output_pdf_text_folder,
) )
@ -100,6 +86,8 @@ class EMEA_AR_Parsing:
self.apply_drilldown = misc_config.get("apply_drilldown", False) self.apply_drilldown = misc_config.get("apply_drilldown", False)
else: else:
self.apply_drilldown = False self.apply_drilldown = False
self.text_model = text_model
self.image_model = image_model
def download_pdf(self) -> str: def download_pdf(self) -> str:
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id) pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
@ -144,9 +132,10 @@ class EMEA_AR_Parsing:
self.page_text_dict, self.page_text_dict,
self.datapoint_page_info, self.datapoint_page_info,
self.datapoints, self.datapoints,
self.document_mapping_info_df,
extract_way=self.extract_way, extract_way=self.extract_way,
output_image_folder=self.output_extract_image_folder, output_image_folder=self.output_extract_image_folder,
text_model=self.text_model,
image_model=self.image_model,
) )
data_from_gpt = data_extraction.extract_data() data_from_gpt = data_extraction.extract_data()
except Exception as e: except Exception as e:
@ -266,70 +255,6 @@ class EMEA_AR_Parsing:
logger.error(f"Error: {e}") logger.error(f"Error: {e}")
return annotation_list return annotation_list
def mapping_data(self, data_from_gpt: list, re_run: bool = False) -> list:
if not re_run:
output_data_json_folder = os.path.join(
self.output_mapping_data_folder, "json/"
)
os.makedirs(output_data_json_folder, exist_ok=True)
json_file = os.path.join(output_data_json_folder, f"{self.doc_id}.json")
if os.path.exists(json_file):
logger.info(
f"The fund/ share of this document: {self.doc_id} has been mapped, loading data from {json_file}"
)
with open(json_file, "r", encoding="utf-8") as f:
doc_mapping_data = json.load(f)
if self.doc_source == "aus_prospectus":
output_data_folder_splits = output_data_json_folder.split("output")
if len(output_data_folder_splits) == 2:
merged_data_folder = f'{output_data_folder_splits[0]}output/merged_data/docs/'
os.makedirs(merged_data_folder, exist_ok=True)
merged_data_json_folder = os.path.join(merged_data_folder, "json/")
os.makedirs(merged_data_json_folder, exist_ok=True)
merged_data_excel_folder = os.path.join(merged_data_folder, "excel/")
os.makedirs(merged_data_excel_folder, exist_ok=True)
merged_data_file = os.path.join(merged_data_json_folder, f"merged_{self.doc_id}.json")
if os.path.exists(merged_data_file):
with open(merged_data_file, "r", encoding="utf-8") as f:
merged_data_list = json.load(f)
return merged_data_list
else:
data_mapping = DataMapping(
self.doc_id,
self.datapoints,
data_from_gpt,
self.document_mapping_info_df,
self.output_mapping_data_folder,
self.doc_source,
compare_with_provider=self.compare_with_provider
)
merged_data_list = data_mapping.merge_output_data_aus_prospectus(doc_mapping_data,
merged_data_json_folder,
merged_data_excel_folder)
return merged_data_list
else:
return doc_mapping_data
"""
doc_id,
datapoints: list,
raw_document_data_list: list,
document_mapping_info_df: pd.DataFrame,
output_data_folder: str,
"""
data_mapping = DataMapping(
self.doc_id,
self.datapoints,
data_from_gpt,
self.document_mapping_info_df,
self.output_mapping_data_folder,
self.doc_source,
compare_with_provider=self.compare_with_provider
)
return data_mapping.mapping_raw_data_entrance()
def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None: def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None:
logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}") logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}")
@ -347,6 +272,8 @@ def extract_data(
output_data_folder: str, output_data_folder: str,
extract_way: str = "text", extract_way: str = "text",
re_run: bool = False, re_run: bool = False,
text_model: str = "qwen-plus",
image_model: str = "qwen-vl-plus",
) -> None: ) -> None:
logger.info(f"Extract EMEA AR data for doc_id: {doc_id}") logger.info(f"Extract EMEA AR data for doc_id: {doc_id}")
emea_ar_parsing = EMEA_AR_Parsing( emea_ar_parsing = EMEA_AR_Parsing(
@ -355,6 +282,8 @@ def extract_data(
pdf_folder=pdf_folder, pdf_folder=pdf_folder,
output_extract_data_folder=output_data_folder, output_extract_data_folder=output_data_folder,
extract_way=extract_way, extract_way=extract_way,
text_model=text_model,
image_model=image_model,
) )
data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run) data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run)
return data_from_gpt, annotation_list return data_from_gpt, annotation_list
@ -368,6 +297,8 @@ def batch_extract_data(
extract_way: str = "text", extract_way: str = "text",
special_doc_id_list: list = None, special_doc_id_list: list = None,
re_run: bool = False, re_run: bool = False,
text_model: str = "qwen-plus",
image_model: str = "qwen-vl-plus",
) -> None: ) -> None:
pdf_files = glob(pdf_folder + "*.pdf") pdf_files = glob(pdf_folder + "*.pdf")
doc_list = [] doc_list = []
@ -391,6 +322,8 @@ def batch_extract_data(
output_data_folder=output_child_folder, output_data_folder=output_child_folder,
extract_way=extract_way, extract_way=extract_way,
re_run=re_run, re_run=re_run,
text_model=text_model,
image_model=image_model,
) )
result_list.extend(data_from_gpt) result_list.extend(data_from_gpt)
@ -421,31 +354,35 @@ def test_translate_pdf():
if __name__ == "__main__": if __name__ == "__main__":
os.environ["SSL_CERT_FILE"] = certifi.where() os.environ["SSL_CERT_FILE"] = certifi.where()
doc_source = "aus_prospectus" # doc_source = "aus_prospectus"
doc_source = "emea_ar"
re_run = True re_run = True
extract_way = "text" extract_way = "text"
if doc_source == "aus_prospectus": if doc_source == "aus_prospectus":
special_doc_id_list = ["539266874"] special_doc_id_list = ["412778803", "539266874"]
pdf_folder: str = r"/data/aus_prospectus/pdf/" pdf_folder: str = r"./data/aus_prospectus/pdf/"
output_pdf_text_folder: str = r"/data/aus_prospectus/output/pdf_text/" output_pdf_text_folder: str = r"./data/aus_prospectus/output/pdf_text/"
output_child_folder: str = ( output_child_folder: str = (
r"/data/aus_prospectus/output/extract_data/docs/" r"./data/aus_prospectus/output/extract_data/docs/"
) )
output_total_folder: str = ( output_total_folder: str = (
r"/data/aus_prospectus/output/extract_data/total/" r"./data/aus_prospectus/output/extract_data/total/"
) )
elif doc_source == "emea_ar": elif doc_source == "emea_ar":
special_doc_id_list = ["514636993"] special_doc_id_list = ["514636993"]
pdf_folder: str = r"/data/emea_ar/pdf/" pdf_folder: str = r"./data/emea_ar/pdf/"
output_child_folder: str = ( output_child_folder: str = (
r"/data/emea_ar/output/extract_data/docs/" r"./data/emea_ar/output/extract_data/docs/"
) )
output_total_folder: str = ( output_total_folder: str = (
r"/data/emea_ar/output/extract_data/total/" r"./data/emea_ar/output/extract_data/total/"
) )
else: else:
raise ValueError(f"Invalid doc_source: {doc_source}") raise ValueError(f"Invalid doc_source: {doc_source}")
# text_model = "qwen-plus"
text_model = "qwen-max"
image_model = "qwen-vl-plus"
batch_extract_data( batch_extract_data(
pdf_folder=pdf_folder, pdf_folder=pdf_folder,
doc_source=doc_source, doc_source=doc_source,
@ -454,6 +391,8 @@ if __name__ == "__main__":
extract_way=extract_way, extract_way=extract_way,
special_doc_id_list=special_doc_id_list, special_doc_id_list=special_doc_id_list,
re_run=re_run, re_run=re_run,
text_model=text_model,
image_model=image_model,
) )

148
utils/qwen_utils.py Normal file
View File

@ -0,0 +1,148 @@
import requests
import json
import os
from bs4 import BeautifulSoup
import time
from time import sleep
from datetime import datetime
import pytz
import pandas as pd
import dashscope
import dotenv
import base64
dotenv.load_dotenv()
ali_api_key = os.getenv("ALI_API_KEY_QWEN")
def chat(
prompt: str,
text_model: str = "qwen-plus",
image_model: str = "qwen-vl-plus",
image_file: str = None,
image_base64: str = None,
enable_search: bool = False,
):
try:
token = 0
if (
image_base64 is None
and image_file is not None
and len(image_file) > 0
and os.path.exists(image_file)
):
image_base64 = encode_image(image_file)
use_image_model = False
if image_base64 is not None and len(image_base64) > 0:
use_image_model = True
messages = [
{
"role": "user",
"content": [
{"text": prompt},
{
"image": f"data:image/png;base64,{image_base64}",
},
],
}
]
count = 0
while count < 3:
try:
print(f"调用阿里云Qwen模型, 次数: {count + 1}")
response = dashscope.MultiModalConversation.call(
api_key=ali_api_key,
model=image_model,
messages=messages,
)
if response.status_code == 200:
break
else:
print(f"调用阿里云Qwen模型失败: {response.code} {response.message}")
count += 1
sleep(2)
except Exception as e:
print(f"调用阿里云Qwen模型失败: {e}")
count += 1
sleep(2)
if response.status_code == 200:
image_text = (
response.get("output", {})
.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
temp_image_text = ""
if isinstance(image_text, list):
for item in image_text:
if isinstance(item, dict):
temp_image_text += item.get("text", "") + "\n\n"
elif isinstance(item, str):
temp_image_text += item + "\n\n"
else:
pass
response_contents = temp_image_text.strip()
token = response.get("usage", {}).get("total_tokens", 0)
else:
response_contents = f"{response.code} {response.message} 无法分析图片"
token = 0
else:
messages = [{"role": "user", "content": prompt}]
count = 0
while count < 3:
try:
print(f"调用阿里云Qwen模型, 次数: {count + 1}")
response = dashscope.Generation.call(
api_key=ali_api_key,
model=text_model,
messages=messages,
enable_search=enable_search,
search_options={"forced_search": enable_search}, # 强制联网搜索
result_format="message",
)
if response.status_code == 200:
break
else:
print(f"调用阿里云Qwen模型失败: {response.code} {response.message}")
count += 1
sleep(2)
except Exception as e:
print(f"调用阿里云Qwen模型失败: {e}")
count += 1
sleep(2)
# 获取response的token
if response.status_code == 200:
response_contents = (
response.get("output", {})
.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
token = response.get("usage", {}).get("total_tokens", 0)
else:
response_contents = f"{response.code} {response.message}"
token = 0
result = {}
if use_image_model:
result["model"] = image_model
else:
result["model"] = text_model
result["response"] = response_contents
result["prompt_token"] = response.get("usage", {}).get("input_tokens", 0)
result["completion_token"] = response.get("usage", {}).get("output_tokens", 0)
result["total_token"] = token
sleep(2)
return result, False
except Exception as e:
print(f"调用阿里云Qwen模型失败: {e}")
return {}, True
def encode_image(image_path: str):
if image_path is None or len(image_path) == 0 or not os.path.exists(image_path):
return None
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")