From 843f588015d33e14283d3424d67244ad59060757 Mon Sep 17 00:00:00 2001 From: Blade He Date: Mon, 26 Aug 2024 11:19:07 -0500 Subject: [PATCH] support chat with image by ChatGPT4o --- .../table_extraction_image_prompts.txt | 18 ++++ playground.py | 72 +++++++++++++ utils/gpt_utils.py | 100 ++++++++++++------ utils/pdf_util.py | 38 ++++++- 4 files changed, 197 insertions(+), 31 deletions(-) create mode 100644 instructions/table_extraction_image_prompts.txt create mode 100644 playground.py diff --git a/instructions/table_extraction_image_prompts.txt b/instructions/table_extraction_image_prompts.txt new file mode 100644 index 0000000..c823582 --- /dev/null +++ b/instructions/table_extraction_image_prompts.txt @@ -0,0 +1,18 @@ +Instructions: +Please read the image carefully. +Answer below questions: +1. Please find the table or tables in the image. +2. Output the table contents as markdown format, it's like: +|name|age|hobby| +|Annie|18|music| +The contents should be exactly precise as the image contents. +3. Please output the results as JSON format, the result member is with legal markdown table format, the example is: +{ +"tables": [" +|name|age|hobby| +|Annie|18|music| +"] +} +4. Only output JSON with tables + +Answer: \ No newline at end of file diff --git a/playground.py b/playground.py new file mode 100644 index 0000000..cfb97ad --- /dev/null +++ b/playground.py @@ -0,0 +1,72 @@ +import os +import json +import base64 +import json_repair +from utils.pdf_util import PDFUtil +from utils.logger import logger +from utils.gpt_utils import chat + + + +def get_base64_pdf_image_list(pdf_file: str, + pdf_page_index_list: list, + output_folder: str=None) -> dict: + if pdf_file is None or pdf_file == "" or not os.path.exists(pdf_file): + logger.error("pdf_file is not provided") + return None + pdf_util = PDFUtil(pdf_file) + if pdf_page_index_list is None or len(pdf_page_index_list) == 0: + pdf_page_index_list = list(range(pdf_util.get_page_count())) + if output_folder is not None and len(output_folder) > 0: + os.makedirs(output_folder, exist_ok=True) + pdf_image_info = pdf_util.extract_images(pdf_page_index_list=pdf_page_index_list, + output_folder=output_folder) + return pdf_image_info + + +def encode_image(image_path: str): + if image_path is None or len(image_path) == 0 or not os.path.exists(image_path): + return None + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def chat_with_image(pdf_file: str, + pdf_page_index_list: list, + image_folder: str, + gpt_folder: str): + if pdf_file is None or pdf_file == "" or not os.path.exists(pdf_file): + logger.error("pdf_file is not provided") + return None + pdf_image_info = get_base64_pdf_image_list(pdf_file, pdf_page_index_list, image_folder) + image_instructions_file = r'./instructions/table_extraction_image_prompts.txt' + with open(image_instructions_file, "r", encoding="utf-8") as file: + image_instructions = file.read() + os.makedirs(gpt_folder, exist_ok=True) + pdf_base_name = os.path.basename(pdf_file).replace(".pdf", "") + response_list = {} + for page_index, data in pdf_image_info.items(): + logger.info(f"Processing image in page {page_index}") + image_file = data.get("img_file", None) + image_base64 = data.get("img_base64", None) + response, error = chat(prompt=image_instructions, image_base64=image_base64) + if error: + logger.error(f"Error in processing image in page {page_index}") + continue + try: + response_json = json.loads(response) + except: + response_json = json_repair.loads(response) + response_json_file = os.path.join(gpt_folder, f"{pdf_base_name}_{page_index}.json") + with open(response_json_file, "w", encoding="utf-8") as file: + json.dump(response_json, file, indent=4) + logger.info(f"Response for image in page {page_index}: {response}") + logger.info("Done") + + +if __name__ == "__main__": + pdf_file = r"/data/emea_ar/small_pdf/382366116.pdf" + pdf_page_index_list = [29, 35, 71, 77, 83, 89, 97, 103, 112, 121, 130, 140, 195, 250, 305] + image_output_folder = r"/data/emea_ar/small_pdf_image/" + gpt_output_folder = r"/data/emea_ar/output/gpt_image_response/" + chat_with_image(pdf_file, pdf_page_index_list, image_output_folder, gpt_output_folder) \ No newline at end of file diff --git a/utils/gpt_utils.py b/utils/gpt_utils.py index e248e92..628cd50 100644 --- a/utils/gpt_utils.py +++ b/utils/gpt_utils.py @@ -4,20 +4,26 @@ from openai import AzureOpenAI import openai import os from time import sleep +import base64 import dotenv + # loads .env file with your OPENAI_API_KEY dotenv.load_dotenv() # tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = tiktoken.get_encoding("cl100k_base") + + def get_embedding(text, engine=os.getenv("EMBEDDING_ENGINE")): count = 0 - error = '' + error = "" while count < 5: try: if count > 0: - print(f'retrying the {count} time for getting text embedding...') - return openai.Embedding.create(input=text, engine=engine)['data'][0]['embedding'] + print(f"retrying the {count} time for getting text embedding...") + return openai.Embedding.create(input=text, engine=engine)["data"][0][ + "embedding" + ] except Exception as e: error = str(e) print(error) @@ -35,7 +41,9 @@ def num_tokens_from_messages(messages, model="gpt-35-turbo-16k"): """Returns the number of tokens used by a list of messages.""" encoding = tiktoken.get_encoding("cl100k_base") if model == "gpt-35-turbo-16k": - tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_message = ( + 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + ) tokens_per_name = -1 # if there's a name, the role is omitted elif model == "gpt-4-32k": tokens_per_message = 3 @@ -54,45 +62,77 @@ def num_tokens_from_messages(messages, model="gpt-35-turbo-16k"): return num_tokens -def chat(prompt: str, - engine = os.getenv("Engine_GPT4o"), - azure_endpoint=os.getenv("OPENAI_API_BASE_GPT4o"), - api_key=os.getenv("OPENAI_API_KEY_GPT4o"), - api_version=os.getenv("OPENAI_API_VERSION_GPT4o"), - temperature: float = 0.0): +def chat( + prompt: str, + engine=os.getenv("Engine_GPT4o"), + azure_endpoint=os.getenv("OPENAI_API_BASE_GPT4o"), + api_key=os.getenv("OPENAI_API_KEY_GPT4o"), + api_version=os.getenv("OPENAI_API_VERSION_GPT4o"), + temperature: float = 0.0, + image_file: str = None, + image_base64: str = None, +): client = AzureOpenAI( - azure_endpoint=azure_endpoint, - api_key=api_key, - api_version=api_version + azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version ) + if ( + image_base64 is None + and image_file is not None + and len(image_file) > 0 + and os.path.exists(image_file) + ): + image_base64 = encode_image(image_file) + + if image_base64 is not None and len(image_base64) > 0: + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + ], + } + ] + else: + messages = [{"role": "user", "content": prompt}] + count = 0 - error = '' + error = "" max_tokens = 4000 request_timeout = 120 while count < 8: try: if count > 0: - print(f'retrying the {count} time...') - response = client.chat.completions.create( - model=engine, - temperature=temperature, - max_tokens=max_tokens, - top_p=0.95, - frequency_penalty=0, - presence_penalty=0, - timeout=request_timeout, - stop=None, - messages=[ - {"role": "user", "content": prompt} - ] - ) + print(f"retrying the {count} time...") + response = client.chat.completions.create( + model=engine, + temperature=temperature, + max_tokens=max_tokens, + top_p=0.95, + frequency_penalty=0, + presence_penalty=0, + timeout=request_timeout, + stop=None, + messages=messages, + response_format={"type": "json_object"}, + ) return response.choices[0].message.content, False except Exception as e: error = str(e) print(f"error message: {error}") - if 'maximum context length' in error: + if "maximum context length" in error: return error, True count += 1 sleep(3) - return error, True \ No newline at end of file + return error, True + + +def encode_image(image_path: str): + if image_path is None or len(image_path) == 0 or not os.path.exists(image_path): + return None + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") diff --git a/utils/pdf_util.py b/utils/pdf_util.py index 78b8ecf..4d0af21 100644 --- a/utils/pdf_util.py +++ b/utils/pdf_util.py @@ -8,6 +8,7 @@ import fitz import json from traceback import print_exc from tqdm import tqdm +import base64 from utils.similarity import Similarity from utils.logger import logger @@ -110,7 +111,42 @@ class PDFUtil: logger.error(f"Error extracting text: {e}") print_exc() return False, str(e), {} - + + def extract_images(self, + zoom:float = 2.0, + pdf_page_index_list: list = None, + output_folder: str = None): + try: + pdf_doc = fitz.open(self.pdf_file) + try: + pdf_encrypted = pdf_doc.isEncrypted + except: + pdf_encrypted = pdf_doc.is_encrypted + if pdf_encrypted: + pdf_doc.authenticate("") + if pdf_page_index_list is None or len(pdf_page_index_list) == 0: + pdf_page_index_list = range(pdf_doc.page_count) + pdf_base_name = os.path.basename(self.pdf_file).replace(".pdf", "") + mat = fitz.Matrix(zoom, zoom) + output_data = {} + for page_num in tqdm(pdf_page_index_list, disable=False): + page = pdf_doc[page_num] + pix = page.get_pixmap(matrix=mat) + img_buffer = pix.tobytes(output='png') + output_data[page_num] = {} + img_base64 = base64.b64encode(img_buffer).decode('utf-8') + if output_folder and len(output_folder) > 0: + os.makedirs(output_folder, exist_ok=True) + image_file = os.path.join(output_folder, f"{pdf_base_name}_{page_num}.png") + pix.save(image_file) + output_data[page_num]["img_file"] = image_file + output_data[page_num]["img_base64"] = img_base64 + return output_data + except Exception as e: + logger.error(f"Error extracting images: {e}") + print_exc() + return {} + def parse_blocks_page(self, page: fitz.Page): blocks = page.get_text("blocks") list_of_blocks = []