update for apply ALI QWEN as Demo
This commit is contained in:
parent
255752c848
commit
ea81197bcd
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"apply_pdf2html": true,
|
"apply_pdf2html": false,
|
||||||
"apply_drilldown": false
|
"apply_drilldown": false
|
||||||
}
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"apply_pdf2html": false,
|
"apply_pdf2html": false,
|
||||||
"apply_drilldown": true
|
"apply_drilldown": false
|
||||||
}
|
}
|
||||||
|
|
@ -5,9 +5,8 @@ import re
|
||||||
import fitz
|
import fitz
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from traceback import print_exc
|
from traceback import print_exc
|
||||||
from utils.gpt_utils import chat
|
from utils.qwen_utils import chat
|
||||||
from utils.pdf_util import PDFUtil
|
from utils.pdf_util import PDFUtil
|
||||||
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
|
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, \
|
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, \
|
||||||
get_most_similar_name, remove_abundant_data, replace_special_table_header
|
get_most_similar_name, remove_abundant_data, replace_special_table_header
|
||||||
|
|
@ -23,11 +22,20 @@ class DataExtraction:
|
||||||
page_text_dict: dict,
|
page_text_dict: dict,
|
||||||
datapoint_page_info: dict,
|
datapoint_page_info: dict,
|
||||||
datapoints: list,
|
datapoints: list,
|
||||||
document_mapping_info_df: pd.DataFrame,
|
|
||||||
extract_way: str = "text",
|
extract_way: str = "text",
|
||||||
output_image_folder: str = None,
|
output_image_folder: str = None,
|
||||||
|
text_model: str = "qwen-plus",
|
||||||
|
image_model: str = "qwen-vl-plus",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.doc_source = doc_source
|
self.doc_source = doc_source
|
||||||
|
if self.doc_source == "aus_prospectus":
|
||||||
|
self.document_type = 1
|
||||||
|
elif self.doc_source == "emea_ar":
|
||||||
|
self.document_type = 2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid document source: {self.doc_source}")
|
||||||
|
self.text_model = text_model
|
||||||
|
self.image_model = image_model
|
||||||
self.doc_id = doc_id
|
self.doc_id = doc_id
|
||||||
self.pdf_file = pdf_file
|
self.pdf_file = pdf_file
|
||||||
self.configuration_folder = f"./configuration/{doc_source}/"
|
self.configuration_folder = f"./configuration/{doc_source}/"
|
||||||
|
|
@ -46,26 +54,7 @@ class DataExtraction:
|
||||||
self.page_text_dict = self.get_pdf_page_text_dict()
|
self.page_text_dict = self.get_pdf_page_text_dict()
|
||||||
else:
|
else:
|
||||||
self.page_text_dict = page_text_dict
|
self.page_text_dict = page_text_dict
|
||||||
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
|
|
||||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
|
|
||||||
else:
|
|
||||||
self.document_mapping_info_df = document_mapping_info_df
|
|
||||||
|
|
||||||
self.fund_name_list = self.document_mapping_info_df["FundName"].unique().tolist()
|
|
||||||
|
|
||||||
# get document type by DocumentType in self.document_mapping_info_df
|
|
||||||
self.document_type = int(self.document_mapping_info_df["DocumentType"].iloc[0])
|
|
||||||
self.investment_objective_pages = []
|
|
||||||
if self.document_type == 1:
|
|
||||||
self.investment_objective_pages = self.get_investment_objective_pages()
|
|
||||||
|
|
||||||
self.provider_mapping_df = self.get_provider_mapping()
|
|
||||||
if len(self.provider_mapping_df) == 0:
|
|
||||||
self.provider_fund_name_list = []
|
|
||||||
else:
|
|
||||||
self.provider_fund_name_list = (
|
|
||||||
self.provider_mapping_df["FundName"].unique().tolist()
|
|
||||||
)
|
|
||||||
self.document_category, self.document_production = self.get_document_category_production()
|
self.document_category, self.document_production = self.get_document_category_production()
|
||||||
self.datapoint_page_info = self.get_datapoint_page_info(datapoint_page_info)
|
self.datapoint_page_info = self.get_datapoint_page_info(datapoint_page_info)
|
||||||
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
|
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
|
||||||
|
|
@ -77,11 +66,14 @@ class DataExtraction:
|
||||||
self.replace_table_header_config = self.get_replace_table_header_config()
|
self.replace_table_header_config = self.get_replace_table_header_config()
|
||||||
self.special_datapoint_feature_config = self.get_special_datapoint_feature_config()
|
self.special_datapoint_feature_config = self.get_special_datapoint_feature_config()
|
||||||
self.special_datapoint_feature = self.init_special_datapoint_feature()
|
self.special_datapoint_feature = self.init_special_datapoint_feature()
|
||||||
|
self.investment_objective_pages = self.get_investment_objective_pages()
|
||||||
|
|
||||||
self.datapoint_reported_name_config, self.non_english_reported_name_config = \
|
self.datapoint_reported_name_config, self.non_english_reported_name_config = \
|
||||||
self.get_datapoint_reported_name()
|
self.get_datapoint_reported_name()
|
||||||
self.extract_way = extract_way
|
self.extract_way = extract_way
|
||||||
self.output_image_folder = output_image_folder
|
self.output_image_folder = output_image_folder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_special_datapoint_feature_config(self) -> dict:
|
def get_special_datapoint_feature_config(self) -> dict:
|
||||||
special_datapoint_feature_config_file = os.path.join(self.configuration_folder, "special_datapoint_feature.json")
|
special_datapoint_feature_config_file = os.path.join(self.configuration_folder, "special_datapoint_feature.json")
|
||||||
|
|
@ -118,17 +110,27 @@ class DataExtraction:
|
||||||
if len(document_category_prompt) > 0:
|
if len(document_category_prompt) > 0:
|
||||||
prompts = f"Context: \n{first_4_page_text}\n\Instructions: \n{document_category_prompt}"
|
prompts = f"Context: \n{first_4_page_text}\n\Instructions: \n{document_category_prompt}"
|
||||||
result, with_error = chat(
|
result, with_error = chat(
|
||||||
prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000
|
prompt=prompts, text_model=self.text_model, image_model=self.image_model
|
||||||
)
|
)
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
if not with_error:
|
if not with_error:
|
||||||
try:
|
try:
|
||||||
|
if response.startswith("```json"):
|
||||||
|
response = response.replace("```json", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```JSON"):
|
||||||
|
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.replace("```", "").strip()
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
document_category = data.get("document_category", None)
|
document_category = data.get("document_category", None)
|
||||||
document_production = data.get("document_production", None)
|
document_production = data.get("document_production", None)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
if document_category is None or len(document_category) == 0:
|
||||||
|
print(f"Document category is None or empty, use default value: Super")
|
||||||
|
document_category = "Super"
|
||||||
|
if document_production is None or len(document_production) == 0:
|
||||||
|
document_production = "AUS"
|
||||||
return document_category, document_production
|
return document_category, document_production
|
||||||
|
|
||||||
def get_objective_fund_name(self, page_text: str) -> str:
|
def get_objective_fund_name(self, page_text: str) -> str:
|
||||||
|
|
@ -142,11 +144,17 @@ class DataExtraction:
|
||||||
if len(objective_fund_name_prompt) > 0:
|
if len(objective_fund_name_prompt) > 0:
|
||||||
prompts = f"Context: \n{page_text}\n\Instructions: \n{objective_fund_name_prompt}"
|
prompts = f"Context: \n{page_text}\n\Instructions: \n{objective_fund_name_prompt}"
|
||||||
result, with_error = chat(
|
result, with_error = chat(
|
||||||
prompt=prompts, response_format={"type": "json_object"}, max_tokens=1000
|
prompt=prompts, text_model=self.text_model, image_model=self.image_model
|
||||||
)
|
)
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
if not with_error:
|
if not with_error:
|
||||||
try:
|
try:
|
||||||
|
if response.startswith("```json"):
|
||||||
|
response = response.replace("```json", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```JSON"):
|
||||||
|
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.replace("```", "").strip()
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
fund_name = data.get("fund_name", "")
|
fund_name = data.get("fund_name", "")
|
||||||
except:
|
except:
|
||||||
|
|
@ -187,8 +195,8 @@ class DataExtraction:
|
||||||
with open(language_config_file, "r", encoding="utf-8") as file:
|
with open(language_config_file, "r", encoding="utf-8") as file:
|
||||||
self.language_config = json.load(file)
|
self.language_config = json.load(file)
|
||||||
|
|
||||||
self.language_id = self.document_mapping_info_df["Language"].iloc[0]
|
self.language_id = "0L00000122"
|
||||||
self.language = self.language_config.get(self.language_id, None)
|
self.language = "english"
|
||||||
|
|
||||||
datapoint_reported_name_config_file = os.path.join(self.configuration_folder, "datapoint_reported_name.json")
|
datapoint_reported_name_config_file = os.path.join(self.configuration_folder, "datapoint_reported_name.json")
|
||||||
all_datapoint_reported_name = {}
|
all_datapoint_reported_name = {}
|
||||||
|
|
@ -210,20 +218,6 @@ class DataExtraction:
|
||||||
reported_name_list.sort()
|
reported_name_list.sort()
|
||||||
datapoint_reported_name_config[datapoint] = reported_name_list
|
datapoint_reported_name_config[datapoint] = reported_name_list
|
||||||
return datapoint_reported_name_config, non_english_reported_name_config
|
return datapoint_reported_name_config, non_english_reported_name_config
|
||||||
|
|
||||||
def get_provider_mapping(self):
|
|
||||||
if len(self.document_mapping_info_df) == 0:
|
|
||||||
return pd.DataFrame()
|
|
||||||
provider_id_list = (
|
|
||||||
self.document_mapping_info_df["ProviderId"].unique().tolist()
|
|
||||||
)
|
|
||||||
provider_mapping_list = []
|
|
||||||
for provider_id in provider_id_list:
|
|
||||||
provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False))
|
|
||||||
provider_mapping_df = pd.concat(provider_mapping_list)
|
|
||||||
provider_mapping_df = provider_mapping_df.drop_duplicates()
|
|
||||||
provider_mapping_df.reset_index(drop=True, inplace=True)
|
|
||||||
return provider_mapping_df
|
|
||||||
|
|
||||||
def get_pdf_image_base64(self, page_index: int) -> dict:
|
def get_pdf_image_base64(self, page_index: int) -> dict:
|
||||||
pdf_util = PDFUtil(self.pdf_file)
|
pdf_util = PDFUtil(self.pdf_file)
|
||||||
|
|
@ -557,8 +551,6 @@ class DataExtraction:
|
||||||
"""
|
"""
|
||||||
If some datapoint with production name, then each fund/ share class in the same document for the datapoint should be with same value.
|
If some datapoint with production name, then each fund/ share class in the same document for the datapoint should be with same value.
|
||||||
"""
|
"""
|
||||||
if len(self.fund_name_list) < 3:
|
|
||||||
return data_list, []
|
|
||||||
raw_name_dict = self.get_raw_name_dict(data_list)
|
raw_name_dict = self.get_raw_name_dict(data_list)
|
||||||
raw_name_list = list(raw_name_dict.keys())
|
raw_name_list = list(raw_name_dict.keys())
|
||||||
if len(raw_name_list) < 3:
|
if len(raw_name_list) < 3:
|
||||||
|
|
@ -1125,11 +1117,17 @@ class DataExtraction:
|
||||||
if len(compare_table_structure_prompts) > 0:
|
if len(compare_table_structure_prompts) > 0:
|
||||||
prompts = f"Context: \ncurrent page contents:\n{current_page_text}\nnext page contents:\n{next_page_text}\nInstructions:\n{compare_table_structure_prompts}\n"
|
prompts = f"Context: \ncurrent page contents:\n{current_page_text}\nnext page contents:\n{next_page_text}\nInstructions:\n{compare_table_structure_prompts}\n"
|
||||||
result, with_error = chat(
|
result, with_error = chat(
|
||||||
prompt=prompts, response_format={"type": "json_object"}, max_tokens=100
|
prompt=prompts, text_model="qwen-plus", image_model="qwen-vl-plus"
|
||||||
)
|
)
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
if not with_error:
|
if not with_error:
|
||||||
try:
|
try:
|
||||||
|
if response.startswith("```json"):
|
||||||
|
response = response.replace("```json", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```JSON"):
|
||||||
|
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.replace("```", "").strip()
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
answer = data.get("answer", "No")
|
answer = data.get("answer", "No")
|
||||||
if answer.lower() == "yes":
|
if answer.lower() == "yes":
|
||||||
|
|
@ -1300,9 +1298,6 @@ class DataExtraction:
|
||||||
"""
|
"""
|
||||||
logger.info(f"Extracting data from page {page_num}")
|
logger.info(f"Extracting data from page {page_num}")
|
||||||
if self.document_type == 1:
|
if self.document_type == 1:
|
||||||
# pre_context = f"The document type is prospectus. \nThe fund names in this document are {', '.join(self.fund_name_list)}."
|
|
||||||
# if pre_context in page_text:
|
|
||||||
# page_text = page_text.replace(pre_context, "\n").strip()
|
|
||||||
pre_context = ""
|
pre_context = ""
|
||||||
if len(self.investment_objective_pages) > 0:
|
if len(self.investment_objective_pages) > 0:
|
||||||
# Get the page number of the most recent investment objective at the top of the current page.
|
# Get the page number of the most recent investment objective at the top of the current page.
|
||||||
|
|
@ -1330,8 +1325,9 @@ class DataExtraction:
|
||||||
extract_way="text"
|
extract_way="text"
|
||||||
)
|
)
|
||||||
result, with_error = chat(
|
result, with_error = chat(
|
||||||
prompt=instructions, response_format={"type": "json_object"}
|
prompt=instructions, text_model=self.text_model, image_model=self.image_model
|
||||||
)
|
)
|
||||||
|
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
if with_error:
|
if with_error:
|
||||||
logger.error(f"Error in extracting tables from page")
|
logger.error(f"Error in extracting tables from page")
|
||||||
|
|
@ -1346,8 +1342,15 @@ class DataExtraction:
|
||||||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||||
data_dict["total_token"] = result.get("total_token", 0)
|
data_dict["total_token"] = result.get("total_token", 0)
|
||||||
|
data_dict["model"] = result.get("model", "")
|
||||||
return data_dict
|
return data_dict
|
||||||
try:
|
try:
|
||||||
|
if response.startswith("```json"):
|
||||||
|
response = response.replace("```json", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```JSON"):
|
||||||
|
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.replace("```", "").strip()
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1388,6 +1391,7 @@ class DataExtraction:
|
||||||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||||
data_dict["total_token"] = result.get("total_token", 0)
|
data_dict["total_token"] = result.get("total_token", 0)
|
||||||
|
data_dict["model"] = result.get("model", "")
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def extract_data_by_page_image(
|
def extract_data_by_page_image(
|
||||||
|
|
@ -1418,6 +1422,7 @@ class DataExtraction:
|
||||||
data_dict["prompt_token"] = 0
|
data_dict["prompt_token"] = 0
|
||||||
data_dict["completion_token"] = 0
|
data_dict["completion_token"] = 0
|
||||||
data_dict["total_token"] = 0
|
data_dict["total_token"] = 0
|
||||||
|
data_dict["model"] = self.image_model
|
||||||
return data_dict
|
return data_dict
|
||||||
else:
|
else:
|
||||||
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
|
if previous_page_last_fund is not None and len(previous_page_last_fund) > 0:
|
||||||
|
|
@ -1463,7 +1468,7 @@ class DataExtraction:
|
||||||
extract_way="image"
|
extract_way="image"
|
||||||
)
|
)
|
||||||
result, with_error = chat(
|
result, with_error = chat(
|
||||||
prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64
|
||||||
)
|
)
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
if with_error:
|
if with_error:
|
||||||
|
|
@ -1479,8 +1484,15 @@ class DataExtraction:
|
||||||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||||
data_dict["total_token"] = result.get("total_token", 0)
|
data_dict["total_token"] = result.get("total_token", 0)
|
||||||
|
data_dict["model"] = result.get("model", "")
|
||||||
return data_dict
|
return data_dict
|
||||||
try:
|
try:
|
||||||
|
if response.startswith("```json"):
|
||||||
|
response = response.replace("```json", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```JSON"):
|
||||||
|
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.replace("```", "").strip()
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1508,6 +1520,7 @@ class DataExtraction:
|
||||||
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
data_dict["prompt_token"] = result.get("prompt_token", 0)
|
||||||
data_dict["completion_token"] = result.get("completion_token", 0)
|
data_dict["completion_token"] = result.get("completion_token", 0)
|
||||||
data_dict["total_token"] = result.get("total_token", 0)
|
data_dict["total_token"] = result.get("total_token", 0)
|
||||||
|
data_dict["model"] = result.get("model", "")
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def get_image_text(self, page_num: int) -> str:
|
def get_image_text(self, page_num: int) -> str:
|
||||||
|
|
@ -1515,13 +1528,19 @@ class DataExtraction:
|
||||||
instructions = self.instructions_config.get("get_image_text", "\n")
|
instructions = self.instructions_config.get("get_image_text", "\n")
|
||||||
logger.info(f"Get text from image of page {page_num}")
|
logger.info(f"Get text from image of page {page_num}")
|
||||||
result, with_error = chat(
|
result, with_error = chat(
|
||||||
prompt=instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
prompt=instructions, text_model=self.text_model, image_model=self.image_model, image_base64=image_base64
|
||||||
)
|
)
|
||||||
response = result.get("response", "")
|
response = result.get("response", "")
|
||||||
text = ""
|
text = ""
|
||||||
if with_error:
|
if with_error:
|
||||||
logger.error(f"Can't get text from current image")
|
logger.error(f"Can't get text from current image")
|
||||||
try:
|
try:
|
||||||
|
if response.startswith("```json"):
|
||||||
|
response = response.replace("```json", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```JSON"):
|
||||||
|
response = response.replace("```JSON", "").replace("```", "").strip()
|
||||||
|
if response.startswith("```"):
|
||||||
|
response = response.replace("```", "").strip()
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
|
|
@ -1599,11 +1618,11 @@ class DataExtraction:
|
||||||
ter_search = re.search(ter_regex, page_text)
|
ter_search = re.search(ter_regex, page_text)
|
||||||
if ter_search is not None:
|
if ter_search is not None:
|
||||||
include_key_words = True
|
include_key_words = True
|
||||||
if not include_key_words:
|
# if not include_key_words:
|
||||||
is_share_name = self.check_fund_name_as_share(raw_fund_name)
|
# is_share_name = self.check_fund_name_as_share(raw_fund_name)
|
||||||
if not is_share_name:
|
# if not is_share_name:
|
||||||
remove_list.append(data)
|
# remove_list.append(data)
|
||||||
break
|
# break
|
||||||
data["share name"] = raw_fund_name
|
data["share name"] = raw_fund_name
|
||||||
if data.get(key, "") == "":
|
if data.get(key, "") == "":
|
||||||
data.pop(key)
|
data.pop(key)
|
||||||
|
|
@ -1723,73 +1742,12 @@ class DataExtraction:
|
||||||
new_data[key] = value
|
new_data[key] = value
|
||||||
new_data_list.append(new_data)
|
new_data_list.append(new_data)
|
||||||
extract_data_info["data"] = new_data_list
|
extract_data_info["data"] = new_data_list
|
||||||
if page_text is not None and len(page_text) > 0:
|
# if page_text is not None and len(page_text) > 0:
|
||||||
try:
|
# try:
|
||||||
self.set_datapoint_feature_properties(new_data_list, page_text, page_num)
|
# self.set_datapoint_feature_properties(new_data_list, page_text, page_num)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"Error in setting datapoint feature properties: {e}")
|
# logger.error(f"Error in setting datapoint feature properties: {e}")
|
||||||
return extract_data_info
|
return extract_data_info
|
||||||
|
|
||||||
def set_datapoint_feature_properties(self, data_list: list, page_text: str, page_num: int) -> None:
|
|
||||||
for feature, properties in self.special_datapoint_feature_config.items():
|
|
||||||
if self.special_datapoint_feature.get(feature, {}).get("page_index", None) is not None:
|
|
||||||
continue
|
|
||||||
provider_ids = properties.get("provider_ids", [])
|
|
||||||
if len(provider_ids) > 0:
|
|
||||||
is_current_provider = False
|
|
||||||
doc_provider_list = self.document_mapping_info_df["ProviderId"].unique().tolist()
|
|
||||||
if len(doc_provider_list) > 0:
|
|
||||||
for provider in provider_ids:
|
|
||||||
if provider in doc_provider_list:
|
|
||||||
is_current_provider = True
|
|
||||||
break
|
|
||||||
if not is_current_provider:
|
|
||||||
continue
|
|
||||||
detail_list = properties.get("details", [])
|
|
||||||
if len(detail_list) == 0:
|
|
||||||
continue
|
|
||||||
set_feature_property = False
|
|
||||||
for detail in detail_list:
|
|
||||||
regex_text_list = detail.get("regex_text", [])
|
|
||||||
if len(regex_text_list) == 0:
|
|
||||||
continue
|
|
||||||
effective_datapoints = detail.get("effective_datapoints", [])
|
|
||||||
if len(effective_datapoints) == 0:
|
|
||||||
continue
|
|
||||||
exclude_datapoints = detail.get("exclude_datapoints", [])
|
|
||||||
|
|
||||||
exist_effective_datapoints = False
|
|
||||||
exist_exclude_datapoints = False
|
|
||||||
for data_item in data_list:
|
|
||||||
datapoints = [datapoint for datapoint in list(data_item.keys())
|
|
||||||
if datapoint in effective_datapoints]
|
|
||||||
if len(datapoints) > 0:
|
|
||||||
exist_effective_datapoints = True
|
|
||||||
datapoints = [datapoint for datapoint in list(data_item.keys())
|
|
||||||
if datapoint in exclude_datapoints]
|
|
||||||
if len(datapoints) > 0:
|
|
||||||
exist_exclude_datapoints = True
|
|
||||||
if exist_effective_datapoints and exist_exclude_datapoints:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not exist_effective_datapoints:
|
|
||||||
continue
|
|
||||||
if exist_exclude_datapoints:
|
|
||||||
continue
|
|
||||||
found_regex_text = False
|
|
||||||
for regex_text in regex_text_list:
|
|
||||||
regex_search = re.search(regex_text, page_text, re.IGNORECASE)
|
|
||||||
if regex_search is not None:
|
|
||||||
found_regex_text = True
|
|
||||||
break
|
|
||||||
if found_regex_text:
|
|
||||||
if self.special_datapoint_feature[feature].get("page_index", None) is None:
|
|
||||||
self.special_datapoint_feature[feature]["page_index"] = []
|
|
||||||
self.special_datapoint_feature[feature]["datapoint"] = effective_datapoints[0]
|
|
||||||
self.special_datapoint_feature[feature]["page_index"].append(page_num)
|
|
||||||
set_feature_property = True
|
|
||||||
if set_feature_property:
|
|
||||||
break
|
|
||||||
|
|
||||||
def split_multi_share_name(self, raw_share_name: str) -> list:
|
def split_multi_share_name(self, raw_share_name: str) -> list:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1836,25 +1794,25 @@ class DataExtraction:
|
||||||
fund_name = f"{last_fund} {fund_feature}"
|
fund_name = f"{last_fund} {fund_feature}"
|
||||||
return fund_name
|
return fund_name
|
||||||
|
|
||||||
def check_fund_name_as_share(self, fund_name: str) -> bool:
|
# def check_fund_name_as_share(self, fund_name: str) -> bool:
|
||||||
"""
|
# """
|
||||||
Check if the fund name is the same as share name
|
# Check if the fund name is the same as share name
|
||||||
"""
|
# """
|
||||||
if len(fund_name) == 0 == 0:
|
# if len(fund_name) == 0 == 0:
|
||||||
return False
|
# return False
|
||||||
share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
|
# share_name_list = self.document_mapping_info_df["ShareClassName"].unique().tolist()
|
||||||
if len(share_name_list) == 0:
|
# if len(share_name_list) == 0:
|
||||||
return False
|
# return False
|
||||||
max_similarity_name, max_similarity = get_most_similar_name(
|
# max_similarity_name, max_similarity = get_most_similar_name(
|
||||||
text=fund_name,
|
# text=fund_name,
|
||||||
name_list=share_name_list,
|
# name_list=share_name_list,
|
||||||
share_name=None,
|
# share_name=None,
|
||||||
fund_name=None,
|
# fund_name=None,
|
||||||
matching_type="share",
|
# matching_type="share",
|
||||||
process_cache=None)
|
# process_cache=None)
|
||||||
if max_similarity >= 0.8:
|
# if max_similarity >= 0.8:
|
||||||
return True
|
# return True
|
||||||
return False
|
# return False
|
||||||
|
|
||||||
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
||||||
datapoints = []
|
datapoints = []
|
||||||
|
|
@ -2165,13 +2123,6 @@ class DataExtraction:
|
||||||
for datapoint in datapoints:
|
for datapoint in datapoints:
|
||||||
investment_level = self.datapoint_level_config.get(datapoint, "")
|
investment_level = self.datapoint_level_config.get(datapoint, "")
|
||||||
if investment_level == "fund_level":
|
if investment_level == "fund_level":
|
||||||
# fund_level_example_list = output_requirement.get("fund_level", [])
|
|
||||||
# for example in fund_level_example_list:
|
|
||||||
# try:
|
|
||||||
# sub_example_list = json.loads(example)
|
|
||||||
# except:
|
|
||||||
# sub_example_list = json_repair.loads(example)
|
|
||||||
# example_list.extend(sub_example_list)
|
|
||||||
fund_datapoint_value_example[datapoint] = fund_level_config.get(
|
fund_datapoint_value_example[datapoint] = fund_level_config.get(
|
||||||
f"{datapoint}_value", []
|
f"{datapoint}_value", []
|
||||||
)
|
)
|
||||||
|
|
@ -2228,131 +2179,4 @@ class DataExtraction:
|
||||||
instructions.append("Answer:\n")
|
instructions.append("Answer:\n")
|
||||||
|
|
||||||
instructions_text = "".join(instructions)
|
instructions_text = "".join(instructions)
|
||||||
return instructions_text
|
return instructions_text
|
||||||
|
|
||||||
# def chat_by_split_context(self,
|
|
||||||
# page_text: str,
|
|
||||||
# page_datapoints: list,
|
|
||||||
# need_exclude: bool,
|
|
||||||
# exclude_data: list) -> list:
|
|
||||||
# """
|
|
||||||
# If occur error, split the context to two parts and try to get data from the two parts
|
|
||||||
# Relevant document: 503194284, page index 147
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
|
||||||
# split_context = re.split(r"\n", page_text)
|
|
||||||
# split_context = [text.strip() for text in split_context
|
|
||||||
# if len(text.strip()) > 0]
|
|
||||||
# if len(split_context) < 10:
|
|
||||||
# return {"data": []}
|
|
||||||
|
|
||||||
# split_context_len = len(split_context)
|
|
||||||
# top_10_context = split_context[:10]
|
|
||||||
# rest_context = split_context[10:]
|
|
||||||
# header = "\n".join(top_10_context)
|
|
||||||
# half_len = split_context_len // 2
|
|
||||||
# # the member of half_len should not start with number
|
|
||||||
# # reverse iterate the list by half_len
|
|
||||||
# half_len_list = [i for i in range(half_len)]
|
|
||||||
|
|
||||||
# fund_name_line = ""
|
|
||||||
# half_line = rest_context[half_len].strip()
|
|
||||||
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
|
||||||
# half_line, self.provider_fund_name_list, matching_type="fund"
|
|
||||||
# )
|
|
||||||
# if max_similarity < 0.2:
|
|
||||||
# # get the fund name line text from the first half
|
|
||||||
# for index in reversed(half_len_list):
|
|
||||||
# line_text = rest_context[index].strip()
|
|
||||||
# if len(line_text) == 0:
|
|
||||||
# continue
|
|
||||||
# line_text_split = line_text.split()
|
|
||||||
# if len(line_text_split) < 3:
|
|
||||||
# continue
|
|
||||||
# first_word = line_text_split[0]
|
|
||||||
# if first_word.lower() == "class":
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
|
||||||
# line_text, self.provider_fund_name_list, matching_type="fund"
|
|
||||||
# )
|
|
||||||
# if max_similarity >= 0.2:
|
|
||||||
# fund_name_line = line_text
|
|
||||||
# break
|
|
||||||
# else:
|
|
||||||
# fund_name_line = half_line
|
|
||||||
# half_len += 1
|
|
||||||
# if fund_name_line == "":
|
|
||||||
# return {"data": []}
|
|
||||||
|
|
||||||
# logger.info(f"Split first part from 0 to {half_len}")
|
|
||||||
# split_first_part = "\n".join(split_context[:half_len])
|
|
||||||
# first_part = '\n'.join(split_first_part)
|
|
||||||
# first_instructions = self.get_instructions_by_datapoints(
|
|
||||||
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
|
||||||
# )
|
|
||||||
# response, with_error = chat(
|
|
||||||
# first_instructions, response_format={"type": "json_object"}
|
|
||||||
# )
|
|
||||||
# first_part_data = {"data": []}
|
|
||||||
# if not with_error:
|
|
||||||
# try:
|
|
||||||
# first_part_data = json.loads(response)
|
|
||||||
# except:
|
|
||||||
# first_part_data = json_repair.loads(response)
|
|
||||||
|
|
||||||
# logger.info(f"Split second part from {half_len} to {split_context_len}")
|
|
||||||
# split_second_part = "\n".join(split_context[half_len:])
|
|
||||||
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
|
||||||
# second_instructions = self.get_instructions_by_datapoints(
|
|
||||||
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
|
||||||
# )
|
|
||||||
# response, with_error = chat(
|
|
||||||
# second_instructions, response_format={"type": "json_object"}
|
|
||||||
# )
|
|
||||||
# second_part_data = {"data": []}
|
|
||||||
# if not with_error:
|
|
||||||
# try:
|
|
||||||
# second_part_data = json.loads(response)
|
|
||||||
# except:
|
|
||||||
# second_part_data = json_repair.loads(response)
|
|
||||||
|
|
||||||
# first_part_data_list = first_part_data.get("data", [])
|
|
||||||
# logger.info(f"First part data count: {len(first_part_data_list)}")
|
|
||||||
# second_part_data_list = second_part_data.get("data", [])
|
|
||||||
# logger.info(f"Second part data count: {len(second_part_data_list)}")
|
|
||||||
# for first_data in first_part_data_list:
|
|
||||||
# if first_data in second_part_data_list:
|
|
||||||
# second_part_data_list.remove(first_data)
|
|
||||||
# else:
|
|
||||||
# # if the first part data is with same fund name and share name,
|
|
||||||
# # remove the second part data
|
|
||||||
# first_data_dp = [key for key in list(first_data.keys())
|
|
||||||
# if key not in ["fund name", "share name"]]
|
|
||||||
# # order the data points
|
|
||||||
# first_data_dp.sort()
|
|
||||||
# first_fund_name = first_data.get("fund name", "")
|
|
||||||
# first_share_name = first_data.get("share name", "")
|
|
||||||
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
|
||||||
# remove_second_list = []
|
|
||||||
# for second_data in second_part_data_list:
|
|
||||||
# second_fund_name = second_data.get("fund name", "")
|
|
||||||
# second_share_name = second_data.get("share name", "")
|
|
||||||
# if first_fund_name == second_fund_name and \
|
|
||||||
# first_share_name == second_share_name:
|
|
||||||
# second_data_dp = [key for key in list(second_data.keys())
|
|
||||||
# if key not in ["fund name", "share name"]]
|
|
||||||
# second_data_dp.sort()
|
|
||||||
# if first_data_dp == second_data_dp:
|
|
||||||
# remove_second_list.append(second_data)
|
|
||||||
# for remove_second in remove_second_list:
|
|
||||||
# if remove_second in second_part_data_list:
|
|
||||||
# second_part_data_list.remove(remove_second)
|
|
||||||
|
|
||||||
# data_list = first_part_data_list + second_part_data_list
|
|
||||||
# extract_data = {"data": data_list}
|
|
||||||
# return extract_data
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Error in split context: {e}")
|
|
||||||
# return {"data": []}
|
|
||||||
|
|
@ -15,7 +15,6 @@ class FilterPages:
|
||||||
self,
|
self,
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
pdf_file: str,
|
pdf_file: str,
|
||||||
document_mapping_info_df: pd.DataFrame,
|
|
||||||
doc_source: str = "emea_ar",
|
doc_source: str = "emea_ar",
|
||||||
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
|
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -32,10 +31,7 @@ class FilterPages:
|
||||||
else:
|
else:
|
||||||
self.apply_pdf2html = False
|
self.apply_pdf2html = False
|
||||||
self.page_text_dict = self.get_pdf_page_text_dict()
|
self.page_text_dict = self.get_pdf_page_text_dict()
|
||||||
if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
|
|
||||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
|
|
||||||
else:
|
|
||||||
self.document_mapping_info_df = document_mapping_info_df
|
|
||||||
self.get_configuration_from_file()
|
self.get_configuration_from_file()
|
||||||
self.doc_info = self.get_doc_info()
|
self.doc_info = self.get_doc_info()
|
||||||
self.datapoint_config, self.datapoint_exclude_config = (
|
self.datapoint_config, self.datapoint_exclude_config = (
|
||||||
|
|
@ -138,7 +134,7 @@ class FilterPages:
|
||||||
self.datapoint_type_config = json.load(file)
|
self.datapoint_type_config = json.load(file)
|
||||||
|
|
||||||
def get_doc_info(self) -> dict:
|
def get_doc_info(self) -> dict:
|
||||||
if len(self.document_mapping_info_df) == 0:
|
if self.doc_source == "emea_ar":
|
||||||
return {
|
return {
|
||||||
"effective_date": None,
|
"effective_date": None,
|
||||||
"document_type": "ar",
|
"document_type": "ar",
|
||||||
|
|
@ -146,22 +142,16 @@ class FilterPages:
|
||||||
"language": "english",
|
"language": "english",
|
||||||
"domicile": "LUX",
|
"domicile": "LUX",
|
||||||
}
|
}
|
||||||
effective_date = self.document_mapping_info_df["EffectiveDate"].iloc[0]
|
elif self.doc_source == "aus_prospectus":
|
||||||
document_type = self.document_mapping_info_df["DocumentType"].iloc[0]
|
return {
|
||||||
if document_type in [4, 5] or self.doc_source == "emea_ar":
|
"effective_date": None,
|
||||||
document_type = "ar"
|
"document_type": "prospectus",
|
||||||
elif document_type == 1 or self.doc_source == "aus_prospectus":
|
"language_id": "0L00000122",
|
||||||
document_type = "prospectus"
|
"language": "english",
|
||||||
language_id = self.document_mapping_info_df["Language"].iloc[0]
|
"domicile": "AUS",
|
||||||
language = self.language_config.get(language_id, None)
|
}
|
||||||
domicile = self.document_mapping_info_df["Domicile"].iloc[0]
|
else:
|
||||||
return {
|
raise ValueError(f"Invalid doc_source: {self.doc_source}")
|
||||||
"effective_date": effective_date,
|
|
||||||
"document_type": document_type,
|
|
||||||
"language_id": language_id,
|
|
||||||
"language": language,
|
|
||||||
"domicile": domicile,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_datapoint_config(self) -> dict:
|
def get_datapoint_config(self) -> dict:
|
||||||
domicile = self.doc_info.get("domicile", None)
|
domicile = self.doc_info.get("domicile", None)
|
||||||
|
|
|
||||||
121
mini_main.py
121
mini_main.py
|
|
@ -29,19 +29,17 @@ class EMEA_AR_Parsing:
|
||||||
pdf_folder: str = r"/data/emea_ar/pdf/",
|
pdf_folder: str = r"/data/emea_ar/pdf/",
|
||||||
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
|
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
|
||||||
output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/",
|
output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/",
|
||||||
output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/",
|
|
||||||
extract_way: str = "text",
|
extract_way: str = "text",
|
||||||
drilldown_folder: str = r"/data/emea_ar/output/drilldown/",
|
drilldown_folder: str = r"/data/emea_ar/output/drilldown/",
|
||||||
compare_with_provider: bool = True
|
text_model: str = "qwen-plus",
|
||||||
|
image_model: str = "qwen-vl-plus",
|
||||||
) -> None:
|
) -> None:
|
||||||
self.doc_id = doc_id
|
self.doc_id = doc_id
|
||||||
self.doc_source = doc_source
|
self.doc_source = doc_source
|
||||||
self.pdf_folder = pdf_folder
|
self.pdf_folder = pdf_folder
|
||||||
os.makedirs(self.pdf_folder, exist_ok=True)
|
os.makedirs(self.pdf_folder, exist_ok=True)
|
||||||
self.compare_with_provider = compare_with_provider
|
|
||||||
|
|
||||||
self.pdf_file = self.download_pdf()
|
self.pdf_file = self.download_pdf()
|
||||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
|
|
||||||
|
|
||||||
if extract_way is None or len(extract_way) == 0:
|
if extract_way is None or len(extract_way) == 0:
|
||||||
extract_way = "text"
|
extract_way = "text"
|
||||||
|
|
@ -64,21 +62,9 @@ class EMEA_AR_Parsing:
|
||||||
self.output_extract_data_folder = output_extract_data_folder
|
self.output_extract_data_folder = output_extract_data_folder
|
||||||
os.makedirs(self.output_extract_data_folder, exist_ok=True)
|
os.makedirs(self.output_extract_data_folder, exist_ok=True)
|
||||||
|
|
||||||
if output_mapping_data_folder is None or len(output_mapping_data_folder) == 0:
|
|
||||||
output_mapping_data_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
|
||||||
if not output_mapping_data_folder.endswith("/"):
|
|
||||||
output_mapping_data_folder = f"{output_mapping_data_folder}/"
|
|
||||||
if extract_way is not None and len(extract_way) > 0:
|
|
||||||
output_mapping_data_folder = (
|
|
||||||
f"{output_mapping_data_folder}by_{extract_way}/"
|
|
||||||
)
|
|
||||||
self.output_mapping_data_folder = output_mapping_data_folder
|
|
||||||
os.makedirs(self.output_mapping_data_folder, exist_ok=True)
|
|
||||||
|
|
||||||
self.filter_pages = FilterPages(
|
self.filter_pages = FilterPages(
|
||||||
self.doc_id,
|
self.doc_id,
|
||||||
self.pdf_file,
|
self.pdf_file,
|
||||||
self.document_mapping_info_df,
|
|
||||||
self.doc_source,
|
self.doc_source,
|
||||||
output_pdf_text_folder,
|
output_pdf_text_folder,
|
||||||
)
|
)
|
||||||
|
|
@ -100,6 +86,8 @@ class EMEA_AR_Parsing:
|
||||||
self.apply_drilldown = misc_config.get("apply_drilldown", False)
|
self.apply_drilldown = misc_config.get("apply_drilldown", False)
|
||||||
else:
|
else:
|
||||||
self.apply_drilldown = False
|
self.apply_drilldown = False
|
||||||
|
self.text_model = text_model
|
||||||
|
self.image_model = image_model
|
||||||
|
|
||||||
def download_pdf(self) -> str:
|
def download_pdf(self) -> str:
|
||||||
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
|
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
|
||||||
|
|
@ -144,9 +132,10 @@ class EMEA_AR_Parsing:
|
||||||
self.page_text_dict,
|
self.page_text_dict,
|
||||||
self.datapoint_page_info,
|
self.datapoint_page_info,
|
||||||
self.datapoints,
|
self.datapoints,
|
||||||
self.document_mapping_info_df,
|
|
||||||
extract_way=self.extract_way,
|
extract_way=self.extract_way,
|
||||||
output_image_folder=self.output_extract_image_folder,
|
output_image_folder=self.output_extract_image_folder,
|
||||||
|
text_model=self.text_model,
|
||||||
|
image_model=self.image_model,
|
||||||
)
|
)
|
||||||
data_from_gpt = data_extraction.extract_data()
|
data_from_gpt = data_extraction.extract_data()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -266,70 +255,6 @@ class EMEA_AR_Parsing:
|
||||||
logger.error(f"Error: {e}")
|
logger.error(f"Error: {e}")
|
||||||
return annotation_list
|
return annotation_list
|
||||||
|
|
||||||
def mapping_data(self, data_from_gpt: list, re_run: bool = False) -> list:
|
|
||||||
if not re_run:
|
|
||||||
output_data_json_folder = os.path.join(
|
|
||||||
self.output_mapping_data_folder, "json/"
|
|
||||||
)
|
|
||||||
os.makedirs(output_data_json_folder, exist_ok=True)
|
|
||||||
json_file = os.path.join(output_data_json_folder, f"{self.doc_id}.json")
|
|
||||||
if os.path.exists(json_file):
|
|
||||||
logger.info(
|
|
||||||
f"The fund/ share of this document: {self.doc_id} has been mapped, loading data from {json_file}"
|
|
||||||
)
|
|
||||||
with open(json_file, "r", encoding="utf-8") as f:
|
|
||||||
doc_mapping_data = json.load(f)
|
|
||||||
if self.doc_source == "aus_prospectus":
|
|
||||||
output_data_folder_splits = output_data_json_folder.split("output")
|
|
||||||
if len(output_data_folder_splits) == 2:
|
|
||||||
merged_data_folder = f'{output_data_folder_splits[0]}output/merged_data/docs/'
|
|
||||||
os.makedirs(merged_data_folder, exist_ok=True)
|
|
||||||
|
|
||||||
merged_data_json_folder = os.path.join(merged_data_folder, "json/")
|
|
||||||
os.makedirs(merged_data_json_folder, exist_ok=True)
|
|
||||||
|
|
||||||
merged_data_excel_folder = os.path.join(merged_data_folder, "excel/")
|
|
||||||
os.makedirs(merged_data_excel_folder, exist_ok=True)
|
|
||||||
|
|
||||||
merged_data_file = os.path.join(merged_data_json_folder, f"merged_{self.doc_id}.json")
|
|
||||||
if os.path.exists(merged_data_file):
|
|
||||||
with open(merged_data_file, "r", encoding="utf-8") as f:
|
|
||||||
merged_data_list = json.load(f)
|
|
||||||
return merged_data_list
|
|
||||||
else:
|
|
||||||
data_mapping = DataMapping(
|
|
||||||
self.doc_id,
|
|
||||||
self.datapoints,
|
|
||||||
data_from_gpt,
|
|
||||||
self.document_mapping_info_df,
|
|
||||||
self.output_mapping_data_folder,
|
|
||||||
self.doc_source,
|
|
||||||
compare_with_provider=self.compare_with_provider
|
|
||||||
)
|
|
||||||
merged_data_list = data_mapping.merge_output_data_aus_prospectus(doc_mapping_data,
|
|
||||||
merged_data_json_folder,
|
|
||||||
merged_data_excel_folder)
|
|
||||||
return merged_data_list
|
|
||||||
else:
|
|
||||||
return doc_mapping_data
|
|
||||||
"""
|
|
||||||
doc_id,
|
|
||||||
datapoints: list,
|
|
||||||
raw_document_data_list: list,
|
|
||||||
document_mapping_info_df: pd.DataFrame,
|
|
||||||
output_data_folder: str,
|
|
||||||
"""
|
|
||||||
data_mapping = DataMapping(
|
|
||||||
self.doc_id,
|
|
||||||
self.datapoints,
|
|
||||||
data_from_gpt,
|
|
||||||
self.document_mapping_info_df,
|
|
||||||
self.output_mapping_data_folder,
|
|
||||||
self.doc_source,
|
|
||||||
compare_with_provider=self.compare_with_provider
|
|
||||||
)
|
|
||||||
return data_mapping.mapping_raw_data_entrance()
|
|
||||||
|
|
||||||
|
|
||||||
def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None:
|
def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None:
|
||||||
logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}")
|
logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}")
|
||||||
|
|
@ -347,6 +272,8 @@ def extract_data(
|
||||||
output_data_folder: str,
|
output_data_folder: str,
|
||||||
extract_way: str = "text",
|
extract_way: str = "text",
|
||||||
re_run: bool = False,
|
re_run: bool = False,
|
||||||
|
text_model: str = "qwen-plus",
|
||||||
|
image_model: str = "qwen-vl-plus",
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(f"Extract EMEA AR data for doc_id: {doc_id}")
|
logger.info(f"Extract EMEA AR data for doc_id: {doc_id}")
|
||||||
emea_ar_parsing = EMEA_AR_Parsing(
|
emea_ar_parsing = EMEA_AR_Parsing(
|
||||||
|
|
@ -355,6 +282,8 @@ def extract_data(
|
||||||
pdf_folder=pdf_folder,
|
pdf_folder=pdf_folder,
|
||||||
output_extract_data_folder=output_data_folder,
|
output_extract_data_folder=output_data_folder,
|
||||||
extract_way=extract_way,
|
extract_way=extract_way,
|
||||||
|
text_model=text_model,
|
||||||
|
image_model=image_model,
|
||||||
)
|
)
|
||||||
data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run)
|
data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run)
|
||||||
return data_from_gpt, annotation_list
|
return data_from_gpt, annotation_list
|
||||||
|
|
@ -368,6 +297,8 @@ def batch_extract_data(
|
||||||
extract_way: str = "text",
|
extract_way: str = "text",
|
||||||
special_doc_id_list: list = None,
|
special_doc_id_list: list = None,
|
||||||
re_run: bool = False,
|
re_run: bool = False,
|
||||||
|
text_model: str = "qwen-plus",
|
||||||
|
image_model: str = "qwen-vl-plus",
|
||||||
) -> None:
|
) -> None:
|
||||||
pdf_files = glob(pdf_folder + "*.pdf")
|
pdf_files = glob(pdf_folder + "*.pdf")
|
||||||
doc_list = []
|
doc_list = []
|
||||||
|
|
@ -391,6 +322,8 @@ def batch_extract_data(
|
||||||
output_data_folder=output_child_folder,
|
output_data_folder=output_child_folder,
|
||||||
extract_way=extract_way,
|
extract_way=extract_way,
|
||||||
re_run=re_run,
|
re_run=re_run,
|
||||||
|
text_model=text_model,
|
||||||
|
image_model=image_model,
|
||||||
)
|
)
|
||||||
result_list.extend(data_from_gpt)
|
result_list.extend(data_from_gpt)
|
||||||
|
|
||||||
|
|
@ -421,31 +354,35 @@ def test_translate_pdf():
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||||
|
|
||||||
doc_source = "aus_prospectus"
|
# doc_source = "aus_prospectus"
|
||||||
|
doc_source = "emea_ar"
|
||||||
re_run = True
|
re_run = True
|
||||||
extract_way = "text"
|
extract_way = "text"
|
||||||
if doc_source == "aus_prospectus":
|
if doc_source == "aus_prospectus":
|
||||||
special_doc_id_list = ["539266874"]
|
special_doc_id_list = ["412778803", "539266874"]
|
||||||
pdf_folder: str = r"/data/aus_prospectus/pdf/"
|
pdf_folder: str = r"./data/aus_prospectus/pdf/"
|
||||||
output_pdf_text_folder: str = r"/data/aus_prospectus/output/pdf_text/"
|
output_pdf_text_folder: str = r"./data/aus_prospectus/output/pdf_text/"
|
||||||
output_child_folder: str = (
|
output_child_folder: str = (
|
||||||
r"/data/aus_prospectus/output/extract_data/docs/"
|
r"./data/aus_prospectus/output/extract_data/docs/"
|
||||||
)
|
)
|
||||||
output_total_folder: str = (
|
output_total_folder: str = (
|
||||||
r"/data/aus_prospectus/output/extract_data/total/"
|
r"./data/aus_prospectus/output/extract_data/total/"
|
||||||
)
|
)
|
||||||
elif doc_source == "emea_ar":
|
elif doc_source == "emea_ar":
|
||||||
special_doc_id_list = ["514636993"]
|
special_doc_id_list = ["514636993"]
|
||||||
pdf_folder: str = r"/data/emea_ar/pdf/"
|
pdf_folder: str = r"./data/emea_ar/pdf/"
|
||||||
output_child_folder: str = (
|
output_child_folder: str = (
|
||||||
r"/data/emea_ar/output/extract_data/docs/"
|
r"./data/emea_ar/output/extract_data/docs/"
|
||||||
)
|
)
|
||||||
output_total_folder: str = (
|
output_total_folder: str = (
|
||||||
r"/data/emea_ar/output/extract_data/total/"
|
r"./data/emea_ar/output/extract_data/total/"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid doc_source: {doc_source}")
|
raise ValueError(f"Invalid doc_source: {doc_source}")
|
||||||
|
|
||||||
|
# text_model = "qwen-plus"
|
||||||
|
text_model = "qwen-max"
|
||||||
|
image_model = "qwen-vl-plus"
|
||||||
batch_extract_data(
|
batch_extract_data(
|
||||||
pdf_folder=pdf_folder,
|
pdf_folder=pdf_folder,
|
||||||
doc_source=doc_source,
|
doc_source=doc_source,
|
||||||
|
|
@ -454,6 +391,8 @@ if __name__ == "__main__":
|
||||||
extract_way=extract_way,
|
extract_way=extract_way,
|
||||||
special_doc_id_list=special_doc_id_list,
|
special_doc_id_list=special_doc_id_list,
|
||||||
re_run=re_run,
|
re_run=re_run,
|
||||||
|
text_model=text_model,
|
||||||
|
image_model=image_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import time
|
||||||
|
from time import sleep
|
||||||
|
from datetime import datetime
|
||||||
|
import pytz
|
||||||
|
import pandas as pd
|
||||||
|
import dashscope
|
||||||
|
import dotenv
|
||||||
|
import base64
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
ali_api_key = os.getenv("ALI_API_KEY_QWEN")
|
||||||
|
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
prompt: str,
|
||||||
|
text_model: str = "qwen-plus",
|
||||||
|
image_model: str = "qwen-vl-plus",
|
||||||
|
image_file: str = None,
|
||||||
|
image_base64: str = None,
|
||||||
|
enable_search: bool = False,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
token = 0
|
||||||
|
if (
|
||||||
|
image_base64 is None
|
||||||
|
and image_file is not None
|
||||||
|
and len(image_file) > 0
|
||||||
|
and os.path.exists(image_file)
|
||||||
|
):
|
||||||
|
image_base64 = encode_image(image_file)
|
||||||
|
|
||||||
|
use_image_model = False
|
||||||
|
if image_base64 is not None and len(image_base64) > 0:
|
||||||
|
use_image_model = True
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"text": prompt},
|
||||||
|
{
|
||||||
|
"image": f"data:image/png;base64,{image_base64}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
count = 0
|
||||||
|
while count < 3:
|
||||||
|
try:
|
||||||
|
print(f"调用阿里云Qwen模型, 次数: {count + 1}")
|
||||||
|
response = dashscope.MultiModalConversation.call(
|
||||||
|
api_key=ali_api_key,
|
||||||
|
model=image_model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"调用阿里云Qwen模型失败: {response.code} {response.message}")
|
||||||
|
count += 1
|
||||||
|
sleep(2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"调用阿里云Qwen模型失败: {e}")
|
||||||
|
count += 1
|
||||||
|
sleep(2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
image_text = (
|
||||||
|
response.get("output", {})
|
||||||
|
.get("choices", [])[0]
|
||||||
|
.get("message", {})
|
||||||
|
.get("content", "")
|
||||||
|
)
|
||||||
|
temp_image_text = ""
|
||||||
|
if isinstance(image_text, list):
|
||||||
|
for item in image_text:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
temp_image_text += item.get("text", "") + "\n\n"
|
||||||
|
elif isinstance(item, str):
|
||||||
|
temp_image_text += item + "\n\n"
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
response_contents = temp_image_text.strip()
|
||||||
|
token = response.get("usage", {}).get("total_tokens", 0)
|
||||||
|
else:
|
||||||
|
response_contents = f"{response.code} {response.message} 无法分析图片"
|
||||||
|
token = 0
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
count = 0
|
||||||
|
while count < 3:
|
||||||
|
try:
|
||||||
|
print(f"调用阿里云Qwen模型, 次数: {count + 1}")
|
||||||
|
response = dashscope.Generation.call(
|
||||||
|
api_key=ali_api_key,
|
||||||
|
model=text_model,
|
||||||
|
messages=messages,
|
||||||
|
enable_search=enable_search,
|
||||||
|
search_options={"forced_search": enable_search}, # 强制联网搜索
|
||||||
|
result_format="message",
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print(f"调用阿里云Qwen模型失败: {response.code} {response.message}")
|
||||||
|
count += 1
|
||||||
|
sleep(2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"调用阿里云Qwen模型失败: {e}")
|
||||||
|
count += 1
|
||||||
|
sleep(2)
|
||||||
|
|
||||||
|
# 获取response的token
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_contents = (
|
||||||
|
response.get("output", {})
|
||||||
|
.get("choices", [])[0]
|
||||||
|
.get("message", {})
|
||||||
|
.get("content", "")
|
||||||
|
)
|
||||||
|
token = response.get("usage", {}).get("total_tokens", 0)
|
||||||
|
else:
|
||||||
|
response_contents = f"{response.code} {response.message}"
|
||||||
|
token = 0
|
||||||
|
result = {}
|
||||||
|
if use_image_model:
|
||||||
|
result["model"] = image_model
|
||||||
|
else:
|
||||||
|
result["model"] = text_model
|
||||||
|
result["response"] = response_contents
|
||||||
|
result["prompt_token"] = response.get("usage", {}).get("input_tokens", 0)
|
||||||
|
result["completion_token"] = response.get("usage", {}).get("output_tokens", 0)
|
||||||
|
result["total_token"] = token
|
||||||
|
sleep(2)
|
||||||
|
return result, False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"调用阿里云Qwen模型失败: {e}")
|
||||||
|
return {}, True
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image_path: str):
|
||||||
|
if image_path is None or len(image_path) == 0 or not os.path.exists(image_path):
|
||||||
|
return None
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
Loading…
Reference in New Issue