update for apply ALI QWEN as Demo
This commit is contained in:
parent
255752c848
commit
ea81197bcd
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"apply_pdf2html": true,
|
||||
"apply_pdf2html": false,
|
||||
"apply_drilldown": false
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"apply_pdf2html": false,
|
||||
"apply_drilldown": true
|
||||
"apply_drilldown": false
|
||||
}
|
||||
|
|
@ -5,9 +5,8 @@ import re
|
|||
import fitz
|
||||
import pandas as pd
|
||||
from traceback import print_exc
|
||||
from utils.gpt_utils import chat
|
||||
from utils.qwen_utils import chat
|
||||
from utils.pdf_util import PDFUtil
|
||||
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
|
||||
from utils.logger import logger
|
||||
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, \
|
||||
get_most_similar_name, remove_abundant_data, replace_special_table_header
|
||||
|
|
@ -23,11 +22,20 @@ class DataExtraction:
|
|||
page_text_dict: dict,
|
||||
datapoint_page_info: dict,
|
||||
datapoints: list,
|
||||
document_mapping_info_df: pd.DataFrame,
|
||||
extract_way: str = "text",
|
||||
output_image_folder: str = None,
|
||||
text_model: str = "qwen-plus",
|
||||
image_model: str = "qwen-vl-plus",
|
||||
) -> None:
|
||||
self.doc_source = doc_source
|
||||
if self.doc_source == "aus_prospectus":
|
||||
self.document_type = 1
|
||||
elif self.doc_source == "emea_ar":
|
||||
self.document_type = 2
|
||||
else:
|
||||
raise ValueError(f"Invalid document source: {self.doc_source}")
|
||||
self.text_model = text_model
|
||||
self.image_model = image_model
|
||||
self.doc_id = doc_id
|
||||
self.pdf_file = pdf_file
|
||||
self.configuration_folder = f"./configuration/{doc_source}/"
|
||||
|
|
@ -46,26 +54,7 @@ class DataExtraction:
|
|||
self.page_text_dict = self.get_pdf_page_text_dict()
|
||||
else:
|
||||
self.page_text_dict = page_text_dict
|
||||
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
|
||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
|
||||
else:
|
||||
self.document_mapping_info_df = document_mapping_info_df
|
||||
|
||||
self.fund_name_list = self.document_mapping_info_df["FundName"].unique().tolist()
|
||||
|
||||
# get document type by DocumentType in self.document_mapping_info_df
|
||||
self.document_type = int(self.document_mapping_info_df["DocumentType"].iloc[0])
|
||||
self.investment_objective_pages = []
|
||||
if self.document_type == 1:
|
||||
self.investment_objective_pages = self.get_investment_objective_pages()
|
||||
|
||||
self.provider_mapping_df = self.get_provider_mapping()
|
||||
if len(self.provider_mapping_df) == 0:
|
||||
self.provider_fund_name_list = []
|
||||
else:
|
||||
self.provider_fund_name_list = (
|
||||
self.provider_mapping_df["FundName"].unique().tolist()
|
||||
)
|
||||
self.document_category, self.document_production = self.get_document_category_production()
|
||||
self.datapoint_page_info = self.get_datapoint_page_info(datapoint_page_info)
|
||||
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
|
||||
|
|
@ -77,11 +66,14 @@ class DataExtraction:
|
|||
self.replace_table_header_config = self.get_replace_table_header_config()
|
||||
self.special_datapoint_feature_config = self.get_special_datapoint_feature_config()
|
||||
self.special_datapoint_feature = self.init_special_datapoint_feature()
|
||||
self.investment_objective_pages = self.get_investment_objective_pages()
|
||||
|
||||
self.datapoint_reported_name_config, self.non_english_reported_name_config = \
|
||||
self.get_datapoint_reported_name()
|
||||
self.extract_way = extract_way
|
||||
self.output_image_folder = output_image_folder
|
||||
|
||||
|
||||
|
||||
def get_special_datapoint_feature_config(self) -> dict:
|
||||
special_datapoint_feature_config_file = os.path.join(self.configuration_folder, "special_datapoint_feature.json")
|
||||
|
|
@ -118,17 +110,27 @@ class DataExtraction:
|
|||
if len(document_category_prompt) > 0:
|
||||
prompts = f"Context: \n{first_4_page_text}\n\Instructions: \n{document_category_prompt}"
|
||||
result, with_error = chat(
|
||||
prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000
|
||||
prompt=prompts, text_model=self.text_model, image_model=self.image_model
|
||||
)
|
||||
response = result.get("response", "")
|
||||
if not with_error:
|
||||
try:
|
||||
if response.startswith("```json"):
|
||||
response = response.replace("```json", "").replace("```", "").strip()
|
||||
if response.startswith("```JSON"):
|
||||
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||
if response.startswith("```"):
|
||||
response = response.replace("```", "").strip()
|
||||
data = json.loads(response)
|
||||
document_category = data.get("document_category", None)
|
||||
document_production = data.get("document_production", None)
|
||||
except:
|
||||
pass
|
||||
|
||||
if document_category is None or len(document_category) == 0:
|
||||
print(f"Document category is None or empty, use default value: Super")
|
||||
document_category = "Super"
|
||||
if document_production is None or len(document_production) == 0:
|
||||
document_production = "AUS"
|
||||
return document_category, document_production
|
||||
|
||||
def get_objective_fund_name(self, page_text: str) -> str:
|
||||
|
|
@ -142,11 +144,17 @@ class DataExtraction:
|
|||
if len(objective_fund_name_prompt) > 0:
|
||||
prompts = f"Context: \n{page_text}\n\Instructions: \n{objective_fund_name_prompt}"
|
||||
result, with_error = chat(
|
||||
prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000
|
||||
prompt=prompts, text_model=self.text_model, image_model=self.image_model
|
||||
)
|
||||
response = result.get("response", "")
|
||||
if not with_error:
|
||||
try:
|
||||
if response.startswith("```json"):
|
||||
response = response.replace("```json", "").replace("```", "").strip()
|
||||
if response.startswith("```JSON"):
|
||||
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||
if response.startswith("```"):
|
||||
response = response.replace("```", "").strip()
|
||||
data = json.loads(response)
|
||||
fund_name = data.get("fund_name", "")
|
||||
except:
|
||||
|
|
@ -187,8 +195,8 @@ class DataExtraction:
|
|||
with open(language_config_file, "r", encoding="utf-8") as file:
|
||||
self.language_config = json.load(file)
|
||||
|
||||
self.language_id = self.document_mapping_info_df["Language"].iloc[0]
|
||||
self.language = self.language_config.get(self.language_id, None)
|
||||
self.language_id = "0L00000122"
|
||||
self.language = "english"
|
||||
|
||||
datapoint_reported_name_config_file = os.path.join(self.configuration_folder, "datapoint_reported_name.json")
|
||||
all_datapoint_reported_name = {}
|
||||
|
|
@ -210,20 +218,6 @@ class DataExtraction:
|
|||
reported_name_list.sort()
|
||||
datapoint_reported_name_config[datapoint] = reported_name_list
|
||||
return datapoint_reported_name_config, non_english_reported_name_config
|
||||
|
||||
def get_provider_mapping(self):
|
||||
if len(self.document_mapping_info_df) == 0:
|
||||
return pd.DataFrame()
|
||||
provider_id_list = (
|
||||
self.document_mapping_info_df["ProviderId"].unique().tolist()
|
||||
)
|
||||
provider_mapping_list = []
|
||||
for provider_id in provider_id_list:
|
||||
provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False))
|
||||
provider_mapping_df = pd.concat(provider_mapping_list)
|
||||
provider_mapping_df = provider_mapping_df.drop_duplicates()
|
||||
provider_mapping_df.reset_index(drop=True, inplace=True)
|
||||
return provider_mapping_df
|
||||
|
||||
def get_pdf_image_base64(self, page_index: int) -> dict:
|
||||
pdf_util = PDFUtil(self.pdf_file)
|
||||
|
|
@ -557,8 +551,6 @@ class DataExtraction:
|
|||
"""
|
||||
If some datapoint with production name, then each fund/ share class in the same document for the datapoint should be with same value.
|
||||
"""
|
||||
if len(self.fund_name_list) < 3:
|
||||
return data_list, []
|
||||
raw_name_dict = self.get_raw_name_dict(data_list)
|
||||
raw_name_list = list(raw_name_dict.keys())
|
||||
if len(raw_name_list) < 3:
|
||||
|
|
@ -1125,11 +1117,17 @@ class DataExtraction:
|
|||
if len(compare_table_structure_prompts) > 0:
|
||||
prompts = f"Context: \ncurrent page contents:\n{current_page_text}\nnext page contents:\n{next_page_text}\nInstructions:\n{compare_table_structure_prompts}\n"
|
||||
result, with_error = chat(
|
||||
prompt=prompts, response_format={"type": "json_object"}, max_tokens=100
|
||||
prompt=prompts, text_model="qwen-plus", image_model="qwen-vl-plus"
|
||||
)
|
||||
response = result.get("response", "")
|
||||
if not with_error:
|
||||
try:
|
||||
if response.startswith("```json"):
|
||||
response = response.replace("```json", "").replace("```", "").strip()
|
||||
if response.startswith("```JSON"):
|
||||
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||
if response.startswith("```"):
|
||||
response = response.replace("```", "").strip()
|
||||
data = json.loads(response)
|
||||
answer = data.get("answer", "No")
|
||||
if answer.lower() == "yes":
|
||||
|
|
@ -1300,9 +1298,6 @@ class DataExtraction:
|
|||
"""
|
||||
logger.info(f"Extracting data from page {page_num}")
|
||||
if self.document_type == 1:
|
||||
# pre_context = f"The document type is prospectus. \nThe fund names in this document are {', '.join(self.fund_name_list)}."
|
||||
# if pre_context in page_text:
|
||||
# page_text = page_text.replace(pre_context, "\n").strip()
|
||||
pre_context = ""
|
||||
if len(self.investment_objective_pages) > 0:
|
||||
# Get the page number of the most recent investment objective at the top of the current page.
|
||||
|
|
@ -1330,8 +1325,9 @@ class DataExtraction:
|
|||
extract_way="text"
|
||||
)
|
||||
result, with_error = chat(
|
||||
prompt=instructions, response_format={"type": "json_object"}
|
||||
prompt=instructions, text_model=self.text_model, image_model=self.image_model
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
if with_error:
|
||||
logger.error(f"Error in extracting tables from page")
|
||||
|
|
@ -1346,8 +1342,15 @@ class DataExtraction:
|
|||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||
data_dict["total_token"] = result.get("total_token", 0)
|
||||
data_dict["model"] = result.get("model", "")
|
||||
return data_dict
|
||||
try:
|
||||
if response.startswith("```json"):
|
||||
response = response.replace("```json", "").replace("```", "").strip()
|
||||
if response.startswith("```JSON"):
|
||||
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||
if response.startswith("```"):
|
||||
response = response.replace("```", "").strip()
|
||||
data = json.loads(response)
|
||||
except:
|
||||
try:
|
||||
|
|
@ -1388,6 +1391,7 @@ class DataExtraction:
|
|||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||
data_dict["total_token"] = result.get("total_token", 0)
|
||||
data_dict["model"] = result.get("model", "")
|
||||
return data_dict
|
||||
|
||||
def extract_data_by_page_image(
|
||||
|
|
@ -1418,6 +1422,7 @@ class DataExtraction:
|
|||
data_dict["prompt_token"] = 0
|
||||
data_dict["completion_token"] = 0
|
||||
data_dict["total_token"] = 0
|
||||
data_dict["model"] = self.image_model
|
||||
return data_dict
|
||||
else:
|
||||
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
|
||||
|
|
@ -1463,7 +1468,7 @@ class DataExtraction:
|
|||
extract_way="image"
|
||||
)
|
||||
result, with_error = chat(
|
||||
prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
||||
prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64
|
||||
)
|
||||
response = result.get("response", "")
|
||||
if with_error:
|
||||
|
|
@ -1479,8 +1484,15 @@ class DataExtraction:
|
|||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||
data_dict["total_token"] = result.get("total_token", 0)
|
||||
data_dict["model"] = result.get("model", "")
|
||||
return data_dict
|
||||
try:
|
||||
if response.startswith("```json"):
|
||||
response = response.replace("```json", "").replace("```", "").strip()
|
||||
if response.startswith("```JSON"):
|
||||
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||
if response.startswith("```"):
|
||||
response = response.replace("```", "").strip()
|
||||
data = json.loads(response)
|
||||
except:
|
||||
try:
|
||||
|
|
@ -1508,6 +1520,7 @@ class DataExtraction:
|
|||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||
data_dict["total_token"] = result.get("total_token", 0)
|
||||
data_dict["model"] = result.get("model", "")
|
||||
return data_dict
|
||||
|
||||
def get_image_text(self, page_num: int) -> str:
|
||||
|
|
@ -1515,13 +1528,19 @@ class DataExtraction:
|
|||
instructions = self.instructions_config.get("get_image_text", "\n")
|
||||
logger.info(f"Get text from image of page {page_num}")
|
||||
result, with_error = chat(
|
||||
prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
||||
prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64
|
||||
)
|
||||
response = result.get("response", "")
|
||||
text = ""
|
||||
if with_error:
|
||||
logger.error(f"Can't get text from current image")
|
||||
try:
|
||||
if response.startswith("```json"):
|
||||
response = response.replace("```json", "").replace("```", "").strip()
|
||||
if response.startswith("```JSON"):
|
||||
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||
if response.startswith("```"):
|
||||
response = response.replace("```", "").strip()
|
||||
data = json.loads(response)
|
||||
except:
|
||||
try:
|
||||
|
|
@ -1599,11 +1618,11 @@ class DataExtraction:
|
|||
ter_search = re.search(ter_regex, page_text)
|
||||
if ter_search is not None:
|
||||
include_key_words = True
|
||||
if not include_key_words:
|
||||
is_share_name = self.check_fund_name_as_share(raw_fund_name)
|
||||
if not is_share_name:
|
||||
remove_list.append(data)
|
||||
break
|
||||
# if not include_key_words:
|
||||
# is_share_name = self.check_fund_name_as_share(raw_fund_name)
|
||||
# if not is_share_name:
|
||||
# remove_list.append(data)
|
||||
# break
|
||||
data["share name"] = raw_fund_name
|
||||
if data.get(key, "") == "":
|
||||
data.pop(key)
|
||||
|
|
@ -1723,73 +1742,12 @@ class DataExtraction:
|
|||
new_data[key] = value
|
||||
new_data_list.append(new_data)
|
||||
extract_data_info["data"] = new_data_list
|
||||
if page_text is not None and len(page_text) > 0:
|
||||
try:
|
||||
self.set_datapoint_feature_properties(new_data_list, page_text, page_num)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in setting datapoint feature properties: {e}")
|
||||
# if page_text is not None and len(page_text) > 0:
|
||||
# try:
|
||||
# self.set_datapoint_feature_properties(new_data_list, page_text, page_num)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in setting datapoint feature properties: {e}")
|
||||
return extract_data_info
|
||||
|
||||
def set_datapoint_feature_properties(self, data_list: list, page_text: str, page_num: int) -> None:
|
||||
for feature, properties in self.special_datapoint_feature_config.items():
|
||||
if self.special_datapoint_feature.get(feature, {}).get("page_index", None) is not None:
|
||||
continue
|
||||
provider_ids = properties.get("provider_ids", [])
|
||||
if len(provider_ids) > 0:
|
||||
is_current_provider = False
|
||||
doc_provider_list = self.document_mapping_info_df["ProviderId"].unique().tolist()
|
||||
if len(doc_provider_list) > 0:
|
||||
for provider in provider_ids:
|
||||
if provider in doc_provider_list:
|
||||
is_current_provider = True
|
||||
break
|
||||
if not is_current_provider:
|
||||
continue
|
||||
detail_list = properties.get("details", [])
|
||||
if len(detail_list) == 0:
|
||||
continue
|
||||
set_feature_property = False
|
||||
for detail in detail_list:
|
||||
regex_text_list = detail.get("regex_text", [])
|
||||
if len(regex_text_list) == 0:
|
||||
continue
|
||||
effective_datapoints = detail.get("effective_datapoints", [])
|
||||
if len(effective_datapoints) == 0:
|
||||
continue
|
||||
exclude_datapoints = detail.get("exclude_datapoints", [])
|
||||
|
||||
exist_effective_datapoints = False
|
||||
exist_exclude_datapoints = False
|
||||
for data_item in data_list:
|
||||
datapoints = [datapoint for datapoint in list(data_item.keys())
|
||||
if datapoint in effective_datapoints]
|
||||
if len(datapoints) > 0:
|
||||
exist_effective_datapoints = True
|
||||
datapoints = [datapoint for datapoint in list(data_item.keys())
|
||||
if datapoint in exclude_datapoints]
|
||||
if len(datapoints) > 0:
|
||||
exist_exclude_datapoints = True
|
||||
if exist_effective_datapoints and exist_exclude_datapoints:
|
||||
break
|
||||
|
||||
if not exist_effective_datapoints:
|
||||
continue
|
||||
if exist_exclude_datapoints:
|
||||
continue
|
||||
found_regex_text = False
|
||||
for regex_text in regex_text_list:
|
||||
regex_search = re.search(regex_text, page_text, re.IGNORECASE)
|
||||
if regex_search is not None:
|
||||
found_regex_text = True
|
||||
break
|
||||
if found_regex_text:
|
||||
if self.special_datapoint_feature[feature].get("page_index", None) is None:
|
||||
self.special_datapoint_feature[feature]["page_index"] = []
|
||||
self.special_datapoint_feature[feature]["datapoint"] = effective_datapoints[0]
|
||||
self.special_datapoint_feature[feature]["page_index"].append(page_num)
|
||||
set_feature_property = True
|
||||
if set_feature_property:
|
||||
break
|
||||
|
||||
def split_multi_share_name(self, raw_share_name: str) -> list:
|
||||
"""
|
||||
|
|
@ -1836,25 +1794,25 @@ class DataExtraction:
|
|||
fund_name = f"{last_fund} {fund_feature}"
|
||||
return fund_name
|
||||
|
||||
def check_fund_name_as_share(self, fund_name: str) -> bool:
|
||||
"""
|
||||
Check if the fund name is the same as share name
|
||||
"""
|
||||
if len(fund_name) == 0 == 0:
|
||||
return False
|
||||
share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
|
||||
if len(share_name_list) == 0:
|
||||
return False
|
||||
max_similarity_name, max_similarity = get_most_similar_name(
|
||||
text=fund_name,
|
||||
name_list=share_name_list,
|
||||
share_name=None,
|
||||
fund_name=None,
|
||||
matching_type="share",
|
||||
process_cache=None)
|
||||
if max_similarity >= 0.8:
|
||||
return True
|
||||
return False
|
||||
# def check_fund_name_as_share(self, fund_name: str) -> bool:
|
||||
# """
|
||||
# Check if the fund name is the same as share name
|
||||
# """
|
||||
# if len(fund_name) == 0 == 0:
|
||||
# return False
|
||||
# share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
|
||||
# if len(share_name_list) == 0:
|
||||
# return False
|
||||
# max_similarity_name, max_similarity = get_most_similar_name(
|
||||
# text=fund_name,
|
||||
# name_list=share_name_list,
|
||||
# share_name=None,
|
||||
# fund_name=None,
|
||||
# matching_type="share",
|
||||
# process_cache=None)
|
||||
# if max_similarity >= 0.8:
|
||||
# return True
|
||||
# return False
|
||||
|
||||
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
||||
datapoints = []
|
||||
|
|
@ -2165,13 +2123,6 @@ class DataExtraction:
|
|||
for datapoint in datapoints:
|
||||
investment_level = self.datapoint_level_config.get(datapoint, "")
|
||||
if investment_level == "fund_level":
|
||||
# fund_level_example_list = output_requirement.get("fund_level", [])
|
||||
# for example in fund_level_example_list:
|
||||
# try:
|
||||
# sub_example_list = json.loads(example)
|
||||
# except:
|
||||
# sub_example_list = json_repair.loads(example)
|
||||
# example_list.extend(sub_example_list)
|
||||
fund_datapoint_value_example[datapoint] = fund_level_config.get(
|
||||
f"{datapoint}_value", []
|
||||
)
|
||||
|
|
@ -2228,131 +2179,4 @@ class DataExtraction:
|
|||
instructions.append("Answer:\n")
|
||||
|
||||
instructions_text = "".join(instructions)
|
||||
return instructions_text
|
||||
|
||||
# def chat_by_split_context(self,
|
||||
# page_text: str,
|
||||
# page_datapoints: list,
|
||||
# need_exclude: bool,
|
||||
# exclude_data: list) -> list:
|
||||
# """
|
||||
# If occur error, split the context to two parts and try to get data from the two parts
|
||||
# Relevant document: 503194284, page index 147
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||
# split_context = re.split(r"\n", page_text)
|
||||
# split_context = [text.strip() for text in split_context
|
||||
# if len(text.strip()) > 0]
|
||||
# if len(split_context) < 10:
|
||||
# return {"data": []}
|
||||
|
||||
# split_context_len = len(split_context)
|
||||
# top_10_context = split_context[:10]
|
||||
# rest_context = split_context[10:]
|
||||
# header = "\n".join(top_10_context)
|
||||
# half_len = split_context_len // 2
|
||||
# # the member of half_len should not start with number
|
||||
# # reverse iterate the list by half_len
|
||||
# half_len_list = [i for i in range(half_len)]
|
||||
|
||||
# fund_name_line = ""
|
||||
# half_line = rest_context[half_len].strip()
|
||||
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
# half_line, self.provider_fund_name_list, matching_type="fund"
|
||||
# )
|
||||
# if max_similarity < 0.2:
|
||||
# # get the fund name line text from the first half
|
||||
# for index in reversed(half_len_list):
|
||||
# line_text = rest_context[index].strip()
|
||||
# if len(line_text) == 0:
|
||||
# continue
|
||||
# line_text_split = line_text.split()
|
||||
# if len(line_text_split) < 3:
|
||||
# continue
|
||||
# first_word = line_text_split[0]
|
||||
# if first_word.lower() == "class":
|
||||
# continue
|
||||
|
||||
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
# line_text, self.provider_fund_name_list, matching_type="fund"
|
||||
# )
|
||||
# if max_similarity >= 0.2:
|
||||
# fund_name_line = line_text
|
||||
# break
|
||||
# else:
|
||||
# fund_name_line = half_line
|
||||
# half_len += 1
|
||||
# if fund_name_line == "":
|
||||
# return {"data": []}
|
||||
|
||||
# logger.info(f"Split first part from 0 to {half_len}")
|
||||
# split_first_part = "\n".join(split_context[:half_len])
|
||||
# first_part = '\n'.join(split_first_part)
|
||||
# first_instructions = self.get_instructions_by_datapoints(
|
||||
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
||||
# )
|
||||
# response, with_error = chat(
|
||||
# first_instructions, response_format={"type": "json_object"}
|
||||
# )
|
||||
# first_part_data = {"data": []}
|
||||
# if not with_error:
|
||||
# try:
|
||||
# first_part_data = json.loads(response)
|
||||
# except:
|
||||
# first_part_data = json_repair.loads(response)
|
||||
|
||||
# logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||
# split_second_part = "\n".join(split_context[half_len:])
|
||||
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
||||
# second_instructions = self.get_instructions_by_datapoints(
|
||||
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
||||
# )
|
||||
# response, with_error = chat(
|
||||
# second_instructions, response_format={"type": "json_object"}
|
||||
# )
|
||||
# second_part_data = {"data": []}
|
||||
# if not with_error:
|
||||
# try:
|
||||
# second_part_data = json.loads(response)
|
||||
# except:
|
||||
# second_part_data = json_repair.loads(response)
|
||||
|
||||
# first_part_data_list = first_part_data.get("data", [])
|
||||
# logger.info(f"First part data count: {len(first_part_data_list)}")
|
||||
# second_part_data_list = second_part_data.get("data", [])
|
||||
# logger.info(f"Second part data count: {len(second_part_data_list)}")
|
||||
# for first_data in first_part_data_list:
|
||||
# if first_data in second_part_data_list:
|
||||
# second_part_data_list.remove(first_data)
|
||||
# else:
|
||||
# # if the first part data is with same fund name and share name,
|
||||
# # remove the second part data
|
||||
# first_data_dp = [key for key in list(first_data.keys())
|
||||
# if key not in ["fund name", "share name"]]
|
||||
# # order the data points
|
||||
# first_data_dp.sort()
|
||||
# first_fund_name = first_data.get("fund name", "")
|
||||
# first_share_name = first_data.get("share name", "")
|
||||
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
||||
# remove_second_list = []
|
||||
# for second_data in second_part_data_list:
|
||||
# second_fund_name = second_data.get("fund name", "")
|
||||
# second_share_name = second_data.get("share name", "")
|
||||
# if first_fund_name == second_fund_name and \
|
||||
# first_share_name == second_share_name:
|
||||
# second_data_dp = [key for key in list(second_data.keys())
|
||||
# if key not in ["fund name", "share name"]]
|
||||
# second_data_dp.sort()
|
||||
# if first_data_dp == second_data_dp:
|
||||
# remove_second_list.append(second_data)
|
||||
# for remove_second in remove_second_list:
|
||||
# if remove_second in second_part_data_list:
|
||||
# second_part_data_list.remove(remove_second)
|
||||
|
||||
# data_list = first_part_data_list + second_part_data_list
|
||||
# extract_data = {"data": data_list}
|
||||
# return extract_data
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in split context: {e}")
|
||||
# return {"data": []}
|
||||
return instructions_text
|
||||
|
|
@ -15,7 +15,6 @@ class FilterPages:
|
|||
self,
|
||||
doc_id: str,
|
||||
pdf_file: str,
|
||||
document_mapping_info_df: pd.DataFrame,
|
||||
doc_source: str = "emea_ar",
|
||||
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
|
||||
) -> None:
|
||||
|
|
@ -32,10 +31,7 @@ class FilterPages:
|
|||
else:
|
||||
self.apply_pdf2html = False
|
||||
self.page_text_dict = self.get_pdf_page_text_dict()
|
||||
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
|
||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
|
||||
else:
|
||||
self.document_mapping_info_df = document_mapping_info_df
|
||||
|
||||
self.get_configuration_from_file()
|
||||
self.doc_info = self.get_doc_info()
|
||||
self.datapoint_config, self.datapoint_exclude_config = (
|
||||
|
|
@ -138,7 +134,7 @@ class FilterPages:
|
|||
self.datapoint_type_config = json.load(file)
|
||||
|
||||
def get_doc_info(self) -> dict:
|
||||
if len(self.document_mapping_info_df) == 0:
|
||||
if self.doc_source == "emea_ar":
|
||||
return {
|
||||
"effective_date": None,
|
||||
"document_type": "ar",
|
||||
|
|
@ -146,22 +142,16 @@ class FilterPages:
|
|||
"language": "english",
|
||||
"domicile": "LUX",
|
||||
}
|
||||
effective_date = self.document_mapping_info_df["EffectiveDate"].iloc[0]
|
||||
document_type = self.document_mapping_info_df["DocumentType"].iloc[0]
|
||||
if document_type in [4, 5] or self.doc_source == "emea_ar":
|
||||
document_type = "ar"
|
||||
elif document_type == 1 or self.doc_source == "aus_prospectus":
|
||||
document_type = "prospectus"
|
||||
language_id = self.document_mapping_info_df["Language"].iloc[0]
|
||||
language = self.language_config.get(language_id, None)
|
||||
domicile = self.document_mapping_info_df["Domicile"].iloc[0]
|
||||
return {
|
||||
"effective_date": effective_date,
|
||||
"document_type": document_type,
|
||||
"language_id": language_id,
|
||||
"language": language,
|
||||
"domicile": domicile,
|
||||
}
|
||||
elif self.doc_source == "aus_prospectus":
|
||||
return {
|
||||
"effective_date": None,
|
||||
"document_type": "prospectus",
|
||||
"language_id": "0L00000122",
|
||||
"language": "english",
|
||||
"domicile": "AUS",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid doc_source: {self.doc_source}")
|
||||
|
||||
def get_datapoint_config(self) -> dict:
|
||||
domicile = self.doc_info.get("domicile", None)
|
||||
|
|
|
|||
121
mini_main.py
121
mini_main.py
|
|
@ -29,19 +29,17 @@ class EMEA_AR_Parsing:
|
|||
pdf_folder: str = r"/data/emea_ar/pdf/",
|
||||
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
|
||||
output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/",
|
||||
output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/",
|
||||
extract_way: str = "text",
|
||||
drilldown_folder: str = r"/data/emea_ar/output/drilldown/",
|
||||
compare_with_provider: bool = True
|
||||
text_model: str = "qwen-plus",
|
||||
image_model: str = "qwen-vl-plus",
|
||||
) -> None:
|
||||
self.doc_id = doc_id
|
||||
self.doc_source = doc_source
|
||||
self.pdf_folder = pdf_folder
|
||||
os.makedirs(self.pdf_folder, exist_ok=True)
|
||||
self.compare_with_provider = compare_with_provider
|
||||
|
||||
|
||||
self.pdf_file = self.download_pdf()
|
||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
|
||||
|
||||
if extract_way is None or len(extract_way) == 0:
|
||||
extract_way = "text"
|
||||
|
|
@ -64,21 +62,9 @@ class EMEA_AR_Parsing:
|
|||
self.output_extract_data_folder = output_extract_data_folder
|
||||
os.makedirs(self.output_extract_data_folder, exist_ok=True)
|
||||
|
||||
if output_mapping_data_folder is None or len(output_mapping_data_folder) == 0:
|
||||
output_mapping_data_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||
if not output_mapping_data_folder.endswith("/"):
|
||||
output_mapping_data_folder = f"{output_mapping_data_folder}/"
|
||||
if extract_way is not None and len(extract_way) > 0:
|
||||
output_mapping_data_folder = (
|
||||
f"{output_mapping_data_folder}by_{extract_way}/"
|
||||
)
|
||||
self.output_mapping_data_folder = output_mapping_data_folder
|
||||
os.makedirs(self.output_mapping_data_folder, exist_ok=True)
|
||||
|
||||
self.filter_pages = FilterPages(
|
||||
self.doc_id,
|
||||
self.pdf_file,
|
||||
self.document_mapping_info_df,
|
||||
self.doc_source,
|
||||
output_pdf_text_folder,
|
||||
)
|
||||
|
|
@ -100,6 +86,8 @@ class EMEA_AR_Parsing:
|
|||
self.apply_drilldown = misc_config.get("apply_drilldown", False)
|
||||
else:
|
||||
self.apply_drilldown = False
|
||||
self.text_model = text_model
|
||||
self.image_model = image_model
|
||||
|
||||
def download_pdf(self) -> str:
|
||||
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
|
||||
|
|
@ -144,9 +132,10 @@ class EMEA_AR_Parsing:
|
|||
self.page_text_dict,
|
||||
self.datapoint_page_info,
|
||||
self.datapoints,
|
||||
self.document_mapping_info_df,
|
||||
extract_way=self.extract_way,
|
||||
output_image_folder=self.output_extract_image_folder,
|
||||
text_model=self.text_model,
|
||||
image_model=self.image_model,
|
||||
)
|
||||
data_from_gpt = data_extraction.extract_data()
|
||||
except Exception as e:
|
||||
|
|
@ -266,70 +255,6 @@ class EMEA_AR_Parsing:
|
|||
logger.error(f"Error: {e}")
|
||||
return annotation_list
|
||||
|
||||
def mapping_data(self, data_from_gpt: list, re_run: bool = False) -> list:
|
||||
if not re_run:
|
||||
output_data_json_folder = os.path.join(
|
||||
self.output_mapping_data_folder, "json/"
|
||||
)
|
||||
os.makedirs(output_data_json_folder, exist_ok=True)
|
||||
json_file = os.path.join(output_data_json_folder, f"{self.doc_id}.json")
|
||||
if os.path.exists(json_file):
|
||||
logger.info(
|
||||
f"The fund/ share of this document: {self.doc_id} has been mapped, loading data from {json_file}"
|
||||
)
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
doc_mapping_data = json.load(f)
|
||||
if self.doc_source == "aus_prospectus":
|
||||
output_data_folder_splits = output_data_json_folder.split("output")
|
||||
if len(output_data_folder_splits) == 2:
|
||||
merged_data_folder = f'{output_data_folder_splits[0]}output/merged_data/docs/'
|
||||
os.makedirs(merged_data_folder, exist_ok=True)
|
||||
|
||||
merged_data_json_folder = os.path.join(merged_data_folder, "json/")
|
||||
os.makedirs(merged_data_json_folder, exist_ok=True)
|
||||
|
||||
merged_data_excel_folder = os.path.join(merged_data_folder, "excel/")
|
||||
os.makedirs(merged_data_excel_folder, exist_ok=True)
|
||||
|
||||
merged_data_file = os.path.join(merged_data_json_folder, f"merged_{self.doc_id}.json")
|
||||
if os.path.exists(merged_data_file):
|
||||
with open(merged_data_file, "r", encoding="utf-8") as f:
|
||||
merged_data_list = json.load(f)
|
||||
return merged_data_list
|
||||
else:
|
||||
data_mapping = DataMapping(
|
||||
self.doc_id,
|
||||
self.datapoints,
|
||||
data_from_gpt,
|
||||
self.document_mapping_info_df,
|
||||
self.output_mapping_data_folder,
|
||||
self.doc_source,
|
||||
compare_with_provider=self.compare_with_provider
|
||||
)
|
||||
merged_data_list = data_mapping.merge_output_data_aus_prospectus(doc_mapping_data,
|
||||
merged_data_json_folder,
|
||||
merged_data_excel_folder)
|
||||
return merged_data_list
|
||||
else:
|
||||
return doc_mapping_data
|
||||
"""
|
||||
doc_id,
|
||||
datapoints: list,
|
||||
raw_document_data_list: list,
|
||||
document_mapping_info_df: pd.DataFrame,
|
||||
output_data_folder: str,
|
||||
"""
|
||||
data_mapping = DataMapping(
|
||||
self.doc_id,
|
||||
self.datapoints,
|
||||
data_from_gpt,
|
||||
self.document_mapping_info_df,
|
||||
self.output_mapping_data_folder,
|
||||
self.doc_source,
|
||||
compare_with_provider=self.compare_with_provider
|
||||
)
|
||||
return data_mapping.mapping_raw_data_entrance()
|
||||
|
||||
|
||||
def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None:
|
||||
logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}")
|
||||
|
|
@ -347,6 +272,8 @@ def extract_data(
|
|||
output_data_folder: str,
|
||||
extract_way: str = "text",
|
||||
re_run: bool = False,
|
||||
text_model: str = "qwen-plus",
|
||||
image_model: str = "qwen-vl-plus",
|
||||
) -> None:
|
||||
logger.info(f"Extract EMEA AR data for doc_id: {doc_id}")
|
||||
emea_ar_parsing = EMEA_AR_Parsing(
|
||||
|
|
@ -355,6 +282,8 @@ def extract_data(
|
|||
pdf_folder=pdf_folder,
|
||||
output_extract_data_folder=output_data_folder,
|
||||
extract_way=extract_way,
|
||||
text_model=text_model,
|
||||
image_model=image_model,
|
||||
)
|
||||
data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run)
|
||||
return data_from_gpt, annotation_list
|
||||
|
|
@ -368,6 +297,8 @@ def batch_extract_data(
|
|||
extract_way: str = "text",
|
||||
special_doc_id_list: list = None,
|
||||
re_run: bool = False,
|
||||
text_model: str = "qwen-plus",
|
||||
image_model: str = "qwen-vl-plus",
|
||||
) -> None:
|
||||
pdf_files = glob(pdf_folder + "*.pdf")
|
||||
doc_list = []
|
||||
|
|
@ -391,6 +322,8 @@ def batch_extract_data(
|
|||
output_data_folder=output_child_folder,
|
||||
extract_way=extract_way,
|
||||
re_run=re_run,
|
||||
text_model=text_model,
|
||||
image_model=image_model,
|
||||
)
|
||||
result_list.extend(data_from_gpt)
|
||||
|
||||
|
|
@ -421,31 +354,35 @@ def test_translate_pdf():
|
|||
if __name__ == "__main__":
|
||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
|
||||
doc_source = "aus_prospectus"
|
||||
# doc_source = "aus_prospectus"
|
||||
doc_source = "emea_ar"
|
||||
re_run = True
|
||||
extract_way = "text"
|
||||
if doc_source == "aus_prospectus":
|
||||
special_doc_id_list = ["539266874"]
|
||||
pdf_folder: str = r"/data/aus_prospectus/pdf/"
|
||||
output_pdf_text_folder: str = r"/data/aus_prospectus/output/pdf_text/"
|
||||
special_doc_id_list = ["412778803", "539266874"]
|
||||
pdf_folder: str = r"./data/aus_prospectus/pdf/"
|
||||
output_pdf_text_folder: str = r"./data/aus_prospectus/output/pdf_text/"
|
||||
output_child_folder: str = (
|
||||
r"/data/aus_prospectus/output/extract_data/docs/"
|
||||
r"./data/aus_prospectus/output/extract_data/docs/"
|
||||
)
|
||||
output_total_folder: str = (
|
||||
r"/data/aus_prospectus/output/extract_data/total/"
|
||||
r"./data/aus_prospectus/output/extract_data/total/"
|
||||
)
|
||||
elif doc_source == "emea_ar":
|
||||
special_doc_id_list = ["514636993"]
|
||||
pdf_folder: str = r"/data/emea_ar/pdf/"
|
||||
pdf_folder: str = r"./data/emea_ar/pdf/"
|
||||
output_child_folder: str = (
|
||||
r"/data/emea_ar/output/extract_data/docs/"
|
||||
r"./data/emea_ar/output/extract_data/docs/"
|
||||
)
|
||||
output_total_folder: str = (
|
||||
r"/data/emea_ar/output/extract_data/total/"
|
||||
r"./data/emea_ar/output/extract_data/total/"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid doc_source: {doc_source}")
|
||||
|
||||
# text_model = "qwen-plus"
|
||||
text_model = "qwen-max"
|
||||
image_model = "qwen-vl-plus"
|
||||
batch_extract_data(
|
||||
pdf_folder=pdf_folder,
|
||||
doc_source=doc_source,
|
||||
|
|
@ -454,6 +391,8 @@ if __name__ == "__main__":
|
|||
extract_way=extract_way,
|
||||
special_doc_id_list=special_doc_id_list,
|
||||
re_run=re_run,
|
||||
text_model=text_model,
|
||||
image_model=image_model,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,148 @@
|
|||
import requests
|
||||
import json
|
||||
import os
|
||||
from bs4 import BeautifulSoup
|
||||
import time
|
||||
from time import sleep
|
||||
from datetime import datetime
|
||||
import pytz
|
||||
import pandas as pd
|
||||
import dashscope
|
||||
import dotenv
|
||||
import base64
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
ali_api_key = os.getenv("ALI_API_KEY_QWEN")
|
||||
|
||||
|
||||
def chat(
|
||||
prompt: str,
|
||||
text_model: str = "qwen-plus",
|
||||
image_model: str = "qwen-vl-plus",
|
||||
image_file: str = None,
|
||||
image_base64: str = None,
|
||||
enable_search: bool = False,
|
||||
):
|
||||
try:
|
||||
token = 0
|
||||
if (
|
||||
image_base64 is None
|
||||
and image_file is not None
|
||||
and len(image_file) > 0
|
||||
and os.path.exists(image_file)
|
||||
):
|
||||
image_base64 = encode_image(image_file)
|
||||
|
||||
use_image_model = False
|
||||
if image_base64 is not None and len(image_base64) > 0:
|
||||
use_image_model = True
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": prompt},
|
||||
{
|
||||
"image": f"data:image/png;base64,{image_base64}",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
count = 0
|
||||
while count < 3:
|
||||
try:
|
||||
print(f"调用阿里云Qwen模型, 次数: {count + 1}")
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
api_key=ali_api_key,
|
||||
model=image_model,
|
||||
messages=messages,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
else:
|
||||
print(f"调用阿里云Qwen模型失败: {response.code} {response.message}")
|
||||
count += 1
|
||||
sleep(2)
|
||||
except Exception as e:
|
||||
print(f"调用阿里云Qwen模型失败: {e}")
|
||||
count += 1
|
||||
sleep(2)
|
||||
if response.status_code == 200:
|
||||
image_text = (
|
||||
response.get("output", {})
|
||||
.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
temp_image_text = ""
|
||||
if isinstance(image_text, list):
|
||||
for item in image_text:
|
||||
if isinstance(item, dict):
|
||||
temp_image_text += item.get("text", "") + "\n\n"
|
||||
elif isinstance(item, str):
|
||||
temp_image_text += item + "\n\n"
|
||||
else:
|
||||
pass
|
||||
response_contents = temp_image_text.strip()
|
||||
token = response.get("usage", {}).get("total_tokens", 0)
|
||||
else:
|
||||
response_contents = f"{response.code} {response.message} 无法分析图片"
|
||||
token = 0
|
||||
else:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
count = 0
|
||||
while count < 3:
|
||||
try:
|
||||
print(f"调用阿里云Qwen模型, 次数: {count + 1}")
|
||||
response = dashscope.Generation.call(
|
||||
api_key=ali_api_key,
|
||||
model=text_model,
|
||||
messages=messages,
|
||||
enable_search=enable_search,
|
||||
search_options={"forced_search": enable_search}, # 强制联网搜索
|
||||
result_format="message",
|
||||
)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
else:
|
||||
print(f"调用阿里云Qwen模型失败: {response.code} {response.message}")
|
||||
count += 1
|
||||
sleep(2)
|
||||
except Exception as e:
|
||||
print(f"调用阿里云Qwen模型失败: {e}")
|
||||
count += 1
|
||||
sleep(2)
|
||||
|
||||
# 获取response的token
|
||||
if response.status_code == 200:
|
||||
response_contents = (
|
||||
response.get("output", {})
|
||||
.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
token = response.get("usage", {}).get("total_tokens", 0)
|
||||
else:
|
||||
response_contents = f"{response.code} {response.message}"
|
||||
token = 0
|
||||
result = {}
|
||||
if use_image_model:
|
||||
result["model"] = image_model
|
||||
else:
|
||||
result["model"] = text_model
|
||||
result["response"] = response_contents
|
||||
result["prompt_token"] = response.get("usage", {}).get("input_tokens", 0)
|
||||
result["completion_token"] = response.get("usage", {}).get("output_tokens", 0)
|
||||
result["total_token"] = token
|
||||
sleep(2)
|
||||
return result, False
|
||||
except Exception as e:
|
||||
print(f"调用阿里云Qwen模型失败: {e}")
|
||||
return {}, True
|
||||
|
||||
|
||||
def encode_image(image_path: str):
|
||||
if image_path is None or len(image_path) == 0 or not os.path.exists(image_path):
|
||||
return None
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
Loading…
Reference in New Issue