add mini_main.py

This commit is contained in:
blade 2025-11-10 16:55:55 +08:00
parent 37cf06a394
commit 255752c848
4 changed files with 585 additions and 48 deletions

1
.gitignore vendored
View File

@ -16,3 +16,4 @@
/performance.ipynb /performance.ipynb
/sample_documents/special_cases.txt /sample_documents/special_cases.txt
/aus-prospectus/ /aus-prospectus/
/output/log/*.log

459
mini_main.py Normal file
View File

@ -0,0 +1,459 @@
import os
import json
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
import time
import fitz
import re
from io import BytesIO
from traceback import print_exc
from utils.logger import logger
from utils.pdf_download import download_pdf_from_documents_warehouse
from utils.sql_query_util import query_document_fund_mapping
from utils.pdf_util import PDFUtil
from utils.biz_utils import add_slash_to_text_as_regex
from core.page_filter import FilterPages
from core.data_extraction import DataExtraction
from core.data_mapping import DataMapping
from core.auz_nz.hybrid_solution_script import api_for_fund_matching_call
from core.metrics import Metrics
import certifi
class EMEA_AR_Parsing:
def __init__(
self,
doc_id: str,
doc_source: str = "emea_ar",
pdf_folder: str = r"/data/emea_ar/pdf/",
output_pdf_text_folder: str = r"/data/emea_ar/output/pdf_text/",
output_extract_data_folder: str = r"/data/emea_ar/output/extract_data/docs/",
output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/",
extract_way: str = "text",
drilldown_folder: str = r"/data/emea_ar/output/drilldown/",
compare_with_provider: bool = True
) -> None:
self.doc_id = doc_id
self.doc_source = doc_source
self.pdf_folder = pdf_folder
os.makedirs(self.pdf_folder, exist_ok=True)
self.compare_with_provider = compare_with_provider
self.pdf_file = self.download_pdf()
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
if extract_way is None or len(extract_way) == 0:
extract_way = "text"
self.extract_way = extract_way
self.output_extract_image_folder = None
if self.extract_way == "image":
self.output_extract_image_folder = (
r"/data/emea_ar/output/extract_data/images/"
)
os.makedirs(self.output_extract_image_folder, exist_ok=True)
if output_extract_data_folder is None or len(output_extract_data_folder) == 0:
output_extract_data_folder = r"/data/emea_ar/output/extract_data/docs/"
if not output_extract_data_folder.endswith("/"):
output_extract_data_folder = f"{output_extract_data_folder}/"
if extract_way is not None and len(extract_way) > 0:
output_extract_data_folder = (
f"{output_extract_data_folder}by_{extract_way}/"
)
self.output_extract_data_folder = output_extract_data_folder
os.makedirs(self.output_extract_data_folder, exist_ok=True)
if output_mapping_data_folder is None or len(output_mapping_data_folder) == 0:
output_mapping_data_folder = r"/data/emea_ar/output/mapping_data/docs/"
if not output_mapping_data_folder.endswith("/"):
output_mapping_data_folder = f"{output_mapping_data_folder}/"
if extract_way is not None and len(extract_way) > 0:
output_mapping_data_folder = (
f"{output_mapping_data_folder}by_{extract_way}/"
)
self.output_mapping_data_folder = output_mapping_data_folder
os.makedirs(self.output_mapping_data_folder, exist_ok=True)
self.filter_pages = FilterPages(
self.doc_id,
self.pdf_file,
self.document_mapping_info_df,
self.doc_source,
output_pdf_text_folder,
)
self.page_text_dict = self.filter_pages.page_text_dict
self.datapoint_page_info, self.result_details = self.get_datapoint_page_info()
self.datapoints = self.get_datapoints_from_datapoint_page_info()
if drilldown_folder is None or len(drilldown_folder) == 0:
drilldown_folder = r"/data/emea_ar/output/drilldown/"
os.makedirs(drilldown_folder, exist_ok=True)
self.drilldown_folder = drilldown_folder
misc_config_file = os.path.join(
f"./configuration/{doc_source}/", "misc_config.json"
)
if os.path.exists(misc_config_file):
with open(misc_config_file, "r", encoding="utf-8") as f:
misc_config = json.load(f)
self.apply_drilldown = misc_config.get("apply_drilldown", False)
else:
self.apply_drilldown = False
def download_pdf(self) -> str:
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
return pdf_file
def get_datapoint_page_info(self) -> tuple:
datapoint_page_info, result_details = self.filter_pages.start_job()
return datapoint_page_info, result_details
def get_datapoints_from_datapoint_page_info(self) -> list:
datapoints = list(self.datapoint_page_info.keys())
if "doc_id" in datapoints:
datapoints.remove("doc_id")
return datapoints
def extract_data(
self,
re_run: bool = False,
) -> list:
found_data = False
if not re_run:
output_data_json_folder = os.path.join(
self.output_extract_data_folder, "json/"
)
os.makedirs(output_data_json_folder, exist_ok=True)
json_file = os.path.join(output_data_json_folder, f"{self.doc_id}.json")
if os.path.exists(json_file):
logger.info(
f"The document: {self.doc_id} has been parsed, loading data from {json_file}"
)
with open(json_file, "r", encoding="utf-8") as f:
data_from_gpt = json.load(f)
found_data = True
if not found_data:
try:
data_extraction = DataExtraction(
self.doc_source,
self.doc_id,
self.pdf_file,
self.output_extract_data_folder,
self.page_text_dict,
self.datapoint_page_info,
self.datapoints,
self.document_mapping_info_df,
extract_way=self.extract_way,
output_image_folder=self.output_extract_image_folder,
)
data_from_gpt = data_extraction.extract_data()
except Exception as e:
logger.error(f"Error: {e}")
print_exc()
data_from_gpt = {"data": []}
# Drilldown data to relevant PDF document
annotation_list = []
if self.apply_drilldown:
try:
annotation_list = self.drilldown_pdf_document(data_from_gpt)
except Exception as e:
logger.error(f"Error: {e}")
return data_from_gpt, annotation_list
def drilldown_pdf_document(self, data_from_gpt: list) -> list:
logger.info(f"Drilldown PDF document for doc_id: {self.doc_id}")
pdf_util = PDFUtil(self.pdf_file)
drilldown_data_list = []
for data in data_from_gpt:
doc_id = str(data.get("doc_id", ""))
page_index = data.get("page_index", -1)
if page_index == -1:
continue
extract_data_list = data.get("extract_data", {}).get("data", [])
dp_reported_name_dict = data.get("extract_data", {}).get(
"dp_reported_name", {}
)
highlighted_value_list = []
for extract_data in extract_data_list:
for data_point, value in extract_data.items():
if value in highlighted_value_list:
continue
if data_point in ["ter", "ogc", "performance_fee"]:
continue
drilldown_data = {
"doc_id": doc_id,
"page_index": page_index,
"data_point": data_point,
"parent_text_block": None,
"value": value,
"annotation_attribute": {},
}
drilldown_data_list.append(drilldown_data)
highlighted_value_list.append(value)
for data_point, reported_name in dp_reported_name_dict.items():
if reported_name in highlighted_value_list:
continue
data_point = f"{data_point}_reported_name"
drilldown_data = {
"doc_id": doc_id,
"page_index": page_index,
"data_point": data_point,
"parent_text_block": None,
"value": reported_name,
"annotation_attribute": {},
}
drilldown_data_list.append(drilldown_data)
highlighted_value_list.append(reported_name)
drilldown_result = pdf_util.batch_drilldown(
drilldown_data_list=drilldown_data_list,
output_pdf_folder=self.drilldown_folder,
)
annotation_list = []
if len(drilldown_result) > 0:
logger.info(f"Drilldown PDF document for doc_id: {doc_id} successfully")
annotation_list = drilldown_result.get("annotation_list", [])
for annotation in annotation_list:
annotation["doc_id"] = doc_id
if self.drilldown_folder is not None and len(self.drilldown_folder) > 0:
drilldown_data_folder = os.path.join(self.drilldown_folder, "data/")
os.makedirs(drilldown_data_folder, exist_ok=True)
drilldown_file = os.path.join(
drilldown_data_folder, f"{doc_id}_drilldown.xlsx"
)
drilldown_source_df = pd.DataFrame(drilldown_data_list)
annotation_list_df = pd.DataFrame(annotation_list)
# set drilldown_result_df column order as doc_id, pdf_file, page_index,
# data_point, value, matching_val_area, normalized_bbox
try:
annotation_list_df = annotation_list_df[
[
"doc_id",
"pdf_file",
"page_index",
"data_point",
"value",
"matching_val_area",
"normalized_bbox",
]
]
except Exception as e:
logger.error(f"Error: {e}")
logger.info(f"Writing drilldown data to {drilldown_file}")
try:
with pd.ExcelWriter(drilldown_file) as writer:
drilldown_source_df.to_excel(
writer, index=False, sheet_name="source_data"
)
annotation_list_df.to_excel(
writer, index=False, sheet_name="drilldown_data"
)
except Exception as e:
logger.error(f"Error: {e}")
annotation_list = annotation_list_df.to_dict(orient="records")
try:
drilldown_json_file = os.path.join(
drilldown_data_folder, f"{doc_id}_drilldown.json"
)
with open(drilldown_json_file, "w", encoding="utf-8") as f:
json.dump(annotation_list, f, ensure_ascii=False, indent=4)
except Exception as e:
logger.error(f"Error: {e}")
return annotation_list
def mapping_data(self, data_from_gpt: list, re_run: bool = False) -> list:
if not re_run:
output_data_json_folder = os.path.join(
self.output_mapping_data_folder, "json/"
)
os.makedirs(output_data_json_folder, exist_ok=True)
json_file = os.path.join(output_data_json_folder, f"{self.doc_id}.json")
if os.path.exists(json_file):
logger.info(
f"The fund/ share of this document: {self.doc_id} has been mapped, loading data from {json_file}"
)
with open(json_file, "r", encoding="utf-8") as f:
doc_mapping_data = json.load(f)
if self.doc_source == "aus_prospectus":
output_data_folder_splits = output_data_json_folder.split("output")
if len(output_data_folder_splits) == 2:
merged_data_folder = f'{output_data_folder_splits[0]}output/merged_data/docs/'
os.makedirs(merged_data_folder, exist_ok=True)
merged_data_json_folder = os.path.join(merged_data_folder, "json/")
os.makedirs(merged_data_json_folder, exist_ok=True)
merged_data_excel_folder = os.path.join(merged_data_folder, "excel/")
os.makedirs(merged_data_excel_folder, exist_ok=True)
merged_data_file = os.path.join(merged_data_json_folder, f"merged_{self.doc_id}.json")
if os.path.exists(merged_data_file):
with open(merged_data_file, "r", encoding="utf-8") as f:
merged_data_list = json.load(f)
return merged_data_list
else:
data_mapping = DataMapping(
self.doc_id,
self.datapoints,
data_from_gpt,
self.document_mapping_info_df,
self.output_mapping_data_folder,
self.doc_source,
compare_with_provider=self.compare_with_provider
)
merged_data_list = data_mapping.merge_output_data_aus_prospectus(doc_mapping_data,
merged_data_json_folder,
merged_data_excel_folder)
return merged_data_list
else:
return doc_mapping_data
"""
doc_id,
datapoints: list,
raw_document_data_list: list,
document_mapping_info_df: pd.DataFrame,
output_data_folder: str,
"""
data_mapping = DataMapping(
self.doc_id,
self.datapoints,
data_from_gpt,
self.document_mapping_info_df,
self.output_mapping_data_folder,
self.doc_source,
compare_with_provider=self.compare_with_provider
)
return data_mapping.mapping_raw_data_entrance()
def filter_pages(doc_id: str, pdf_folder: str, doc_source: str) -> None:
logger.info(f"Filter EMEA AR PDF pages for doc_id: {doc_id}")
emea_ar_parsing = EMEA_AR_Parsing(
doc_id, doc_source=doc_source, pdf_folder=pdf_folder
)
datapoint_page_info, result_details = emea_ar_parsing.get_datapoint_page_info()
return datapoint_page_info, result_details
def extract_data(
doc_id: str,
doc_source: str,
pdf_folder: str,
output_data_folder: str,
extract_way: str = "text",
re_run: bool = False,
) -> None:
logger.info(f"Extract EMEA AR data for doc_id: {doc_id}")
emea_ar_parsing = EMEA_AR_Parsing(
doc_id,
doc_source=doc_source,
pdf_folder=pdf_folder,
output_extract_data_folder=output_data_folder,
extract_way=extract_way,
)
data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run)
return data_from_gpt, annotation_list
def batch_extract_data(
pdf_folder: str,
doc_source: str = "emea_ar",
output_child_folder: str = r"/data/emea_ar/output/extract_data/docs/",
output_total_folder: str = r"/data/emea_ar/output/extract_data/total/",
extract_way: str = "text",
special_doc_id_list: list = None,
re_run: bool = False,
) -> None:
pdf_files = glob(pdf_folder + "*.pdf")
doc_list = []
if special_doc_id_list is not None and len(special_doc_id_list) > 0:
doc_list = special_doc_id_list
if len(doc_list) == 0:
logger.info(f"No special doc_id list provided, extracting all documents in {pdf_folder}")
return
result_list = []
for pdf_file in tqdm(pdf_files):
pdf_base_name = os.path.basename(pdf_file)
doc_id = pdf_base_name.split(".")[0]
if doc_list is not None and doc_id not in doc_list:
continue
data_from_gpt = extract_data(
doc_id=doc_id,
doc_source=doc_source,
pdf_folder=pdf_folder,
output_data_folder=output_child_folder,
extract_way=extract_way,
re_run=re_run,
)
result_list.extend(data_from_gpt)
if special_doc_id_list is None or len(special_doc_id_list) == 0:
result_df = pd.DataFrame(result_list)
result_df.reset_index(drop=True, inplace=True)
logger.info(f"Saving the result to {output_total_folder}")
os.makedirs(output_total_folder, exist_ok=True)
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
output_file = os.path.join(
output_total_folder,
f"extract_data_info_{len(pdf_files)}_documents_{time_stamp}.xlsx",
)
with pd.ExcelWriter(output_file) as writer:
result_df.to_excel(writer, index=False, sheet_name="extract_data_info")
def test_translate_pdf():
from core.data_translate import Translate_PDF
pdf_file = r"/data/emea_ar/pdf/451063582.pdf"
output_folder = r"/data/translate/output/"
translate_pdf = Translate_PDF(pdf_file, output_folder)
translate_pdf.start_job()
if __name__ == "__main__":
os.environ["SSL_CERT_FILE"] = certifi.where()
doc_source = "aus_prospectus"
re_run = True
extract_way = "text"
if doc_source == "aus_prospectus":
special_doc_id_list = ["539266874"]
pdf_folder: str = r"/data/aus_prospectus/pdf/"
output_pdf_text_folder: str = r"/data/aus_prospectus/output/pdf_text/"
output_child_folder: str = (
r"/data/aus_prospectus/output/extract_data/docs/"
)
output_total_folder: str = (
r"/data/aus_prospectus/output/extract_data/total/"
)
elif doc_source == "emea_ar":
special_doc_id_list = ["514636993"]
pdf_folder: str = r"/data/emea_ar/pdf/"
output_child_folder: str = (
r"/data/emea_ar/output/extract_data/docs/"
)
output_total_folder: str = (
r"/data/emea_ar/output/extract_data/total/"
)
else:
raise ValueError(f"Invalid doc_source: {doc_source}")
batch_extract_data(
pdf_folder=pdf_folder,
doc_source=doc_source,
output_child_folder=output_child_folder,
output_total_folder=output_total_folder,
extract_way=extract_way,
special_doc_id_list=special_doc_id_list,
re_run=re_run,
)

File diff suppressed because one or more lines are too long

77
test_k_shape.py Normal file
View File

@ -0,0 +1,77 @@
import pandas as pd
import numpy as np
import sys
import os
# 添加项目路径
sys.path.append('crypto_quant')
from crypto_quant.core.biz.metrics_calculation import MetricsCalculation
def test_k_shape():
# 创建测试数据
test_data = pd.DataFrame({
'open': [9.3030000000],
'high': [9.3030000000],
'low': [9.3020000000],
'close': [9.3020000000]
})
print("测试数据:")
print(test_data)
print()
# 计算基本特征
test_data['high_low_diff'] = test_data['high'] - test_data['low']
test_data['open_close_diff'] = abs(test_data['close'] - test_data['open'])
test_data['open_close_fill'] = test_data['open_close_diff'] / test_data['high_low_diff']
test_data['price_range_ratio'] = test_data['high_low_diff'] / test_data['close'] * 100
print("计算的特征:")
print(f"high_low_diff: {test_data['high_low_diff'].iloc[0]}")
print(f"open_close_diff: {test_data['open_close_diff'].iloc[0]}")
print(f"open_close_fill: {test_data['open_close_fill'].iloc[0]}")
print(f"price_range_ratio: {test_data['price_range_ratio'].iloc[0]}%")
print()
# 检查"一字"条件
price_range_ratio = test_data['price_range_ratio'].iloc[0]
open_close_fill = test_data['open_close_fill'].iloc[0]
print("条件检查:")
print(f"price_range_ratio < 0.01: {price_range_ratio < 0.01}")
print(f"open_close_fill > 0.9: {open_close_fill > 0.9}")
print()
# 使用MetricsCalculation类
mc = MetricsCalculation()
# 为了测试我们需要创建一个有足够数据的DataFrame
# 复制测试数据多次以创建滚动窗口
extended_data = pd.concat([test_data] * 25, ignore_index=True)
# 运行set_k_shape函数
result = mc.set_k_shape(extended_data.copy())
print("分类结果:")
print(f"k_shape: {result['k_shape'].iloc[0]}")
print()
# 详细分析为什么没有被分类为"一字"
print("详细分析:")
print(f"价格范围比例: {price_range_ratio:.6f}%")
print(f"实体占比: {open_close_fill:.6f}")
print()
if price_range_ratio < 0.01:
print("✓ 满足价格范围比例 < 0.01% 的条件")
else:
print(f"✗ 不满足价格范围比例 < 0.01% 的条件 (实际: {price_range_ratio:.6f}%)")
if open_close_fill > 0.9:
print("✓ 满足实体占比 > 0.9 的条件")
else:
print(f"✗ 不满足实体占比 > 0.9 的条件 (实际: {open_close_fill:.6f})")
if __name__ == "__main__":
test_k_shape()