dc-ml-emea-ar/test_specific_biz_logic.py

69 lines
2.8 KiB
Python

import os
import json
import pandas as pd
from glob import glob
from tqdm import tqdm
from utils.logger import logger
from utils.sql_query_util import query_document_fund_mapping
from core.page_filter import FilterPages
from core.data_extraction import DataExtraction
def test_validate_extraction_data():
document_id = "481482392"
pdf_file = f"/data/emea_ar/pdf/481482392.pdf"
output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/"
output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/"
document_mapping_info_df = query_document_fund_mapping(document_id, rerun=False)
filter_pages = FilterPages(
document_id, pdf_file, document_mapping_info_df
)
page_text_dict = filter_pages.page_text_dict
datapoint_page_info, result_details = get_datapoint_page_info(filter_pages)
datapoints = get_datapoints_from_datapoint_page_info(datapoint_page_info)
data_extraction = DataExtraction(
doc_id=document_id,
pdf_file=pdf_file,
output_data_folder=output_extract_data_child_folder,
page_text_dict=page_text_dict,
datapoint_page_info=datapoint_page_info,
datapoints=datapoints,
document_mapping_info_df=document_mapping_info_df,
extract_way="text",
output_image_folder=None
)
output_data_json_folder = os.path.join(
r"/data/emea_ar/output/extract_data/docs/by_text/", "json/"
)
os.makedirs(output_data_json_folder, exist_ok=True)
json_file = os.path.join(output_data_json_folder, f"{document_id}.json")
data_from_gpt = None
if os.path.exists(json_file):
logger.info(
f"The document: {document_id} has been parsed, loading data from {json_file}"
)
with open(json_file, "r", encoding="utf-8") as f:
data_from_gpt = json.load(f)
for extract_data in data_from_gpt:
page_index = extract_data["page_index"]
if page_index == 451:
logger.info(f"Page index: {page_index}")
raw_answer = extract_data["raw_answer"]
raw_answer_json = json.loads(raw_answer)
extract_data_info = data_extraction.validate_data(raw_answer_json)
print(extract_data_info)
def get_datapoint_page_info(filter_pages) -> tuple:
datapoint_page_info, result_details = filter_pages.start_job()
return datapoint_page_info, result_details
def get_datapoints_from_datapoint_page_info(datapoint_page_info) -> list:
datapoints = list(datapoint_page_info.keys())
if "doc_id" in datapoints:
datapoints.remove("doc_id")
return datapoints
if __name__ == "__main__":
test_validate_extraction_data()