110 lines
4.5 KiB
Python
110 lines
4.5 KiB
Python
import pandas as pd
|
|
import os
|
|
import tqdm
|
|
import json_repair
|
|
import json
|
|
from glob import glob
|
|
import fitz
|
|
import re
|
|
import time
|
|
import traceback
|
|
|
|
from utils.logger import logger
|
|
from utils.pdf_download import download_pdf_from_documents_warehouse
|
|
from utils.pdf_util import PDFUtil
|
|
from utils.gpt_utils import chat
|
|
|
|
|
|
class PDFTableExtraction:
|
|
"""
|
|
Iterate PDF pages
|
|
Extract tables from PDF pages
|
|
Save these tables as markdown files
|
|
"""
|
|
def __init__(self,
|
|
pdf_file: str,
|
|
output_folder: str) -> None:
|
|
self.pdf_file = pdf_file
|
|
self.pdf_file_name = os.path.basename(pdf_file)
|
|
self.table_extraction_prompts = self.get_table_extraction_prompts()
|
|
|
|
self.output_folder = output_folder
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
self.prompts_output_folder = os.path.join(output_folder, 'pdf_table_prompts/')
|
|
os.makedirs(self.prompts_output_folder, exist_ok=True)
|
|
|
|
self.json_output_folder = os.path.join(output_folder, 'pdf_table_json/')
|
|
os.makedirs(self.json_output_folder, exist_ok=True)
|
|
|
|
self.table_md_output_folder = os.path.join(output_folder, 'pdf_table_markdown/')
|
|
os.makedirs(self.table_md_output_folder, exist_ok=True)
|
|
|
|
def get_table_extraction_prompts(self):
|
|
instructions_file = r'./instructions/table_extraction_prompts.txt'
|
|
with open(instructions_file, 'r', encoding='utf-8') as file:
|
|
return file.read()
|
|
|
|
def extract_tables(self):
|
|
try:
|
|
if self.pdf_file is None or len(self.pdf_file) == 0 or not os.path.exists(self.pdf_file):
|
|
logger.error(f"Invalid pdf_file: {self.pdf_file}")
|
|
return
|
|
logger.info(f"Start processing {self.pdf_file}")
|
|
pdf_util = PDFUtil(self.pdf_file)
|
|
success, text, page_text_dict = pdf_util.extract_text(output_folder=self.output_folder)
|
|
if success:
|
|
logger.info(f"Successfully extracted text from {self.pdf_file}")
|
|
|
|
for page_num, page_text in page_text_dict.items():
|
|
try:
|
|
self.extract_tables_from_page(page_text, page_num)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
logger.error(f"Error in extracting tables from page {page_num}: {str(e)}")
|
|
except Exception as e:
|
|
logger.error(f"Error in extracting PDF tables: {str(e)}")
|
|
|
|
|
|
def extract_tables_from_page(self, page_text: str, page_num: int):
|
|
pure_pdf_name = self.pdf_file_name.replace('.pdf', '')
|
|
table_extraction_prompts = self.table_extraction_prompts.replace(r'{page_text}', page_text)
|
|
prompts_response_file = os.path.join(self.prompts_output_folder, f'{pure_pdf_name}_{page_num}.txt')
|
|
if os.path.exists(prompts_response_file):
|
|
logger.info(f"Prompts response file already exists: {prompts_response_file}")
|
|
return
|
|
|
|
response, with_error = chat(table_extraction_prompts)
|
|
if with_error:
|
|
logger.error(f"Error in extracting tables from page")
|
|
return
|
|
|
|
json_response = re.search(r'\`\`\`json([\s\S]*)\`\`\`', response)
|
|
if json_response is None:
|
|
logger.info(f"Can't extract tables from page")
|
|
return
|
|
|
|
table_json_text = json_response.group(1)
|
|
table_data = {"tables": []}
|
|
try:
|
|
table_data = json.loads(table_json_text)
|
|
except:
|
|
table_data = json_repair.loads(table_json_text)
|
|
self.save_table_data(table_data, page_num)
|
|
|
|
prompts_response = f'{table_extraction_prompts}\n\n{response}'
|
|
with open(prompts_response_file, 'w', encoding='utf-8') as file:
|
|
file.write(prompts_response)
|
|
|
|
def save_table_data(self, table_data: dict, page_num: int):
|
|
pdf_pure_name = self.pdf_file_name.replace('.pdf', '')
|
|
json_output_file = os.path.join(self.json_output_folder, f'{pdf_pure_name}_{page_num}.json')
|
|
with open(json_output_file, 'w', encoding='utf-8') as file:
|
|
file.write(json.dumps(table_data, indent=4))
|
|
|
|
table_list = table_data.get('tables', [])
|
|
for table_num, table in enumerate(table_list):
|
|
table_md_file = os.path.join(self.table_md_output_folder, f'{pdf_pure_name}_{page_num}_{table_num}.md')
|
|
table = re.sub(r'(\n)+', '\n', table)
|
|
with open(table_md_file, 'w', encoding='utf-8') as file:
|
|
file.write(table) |