support chat with image by ChatGPT4o
This commit is contained in:
parent
6519dc23d4
commit
843f588015
|
|
@ -0,0 +1,18 @@
|
||||||
|
Instructions:
|
||||||
|
Please read the image carefully.
|
||||||
|
Answer below questions:
|
||||||
|
1. Please find the table or tables in the image.
|
||||||
|
2. Output the table contents as markdown format, it's like:
|
||||||
|
|name|age|hobby|
|
||||||
|
|Annie|18|music|
|
||||||
|
The contents should be exactly precise as the image contents.
|
||||||
|
3. Please output the results as JSON format, the result member is with legal markdown table format, the example is:
|
||||||
|
{
|
||||||
|
"tables": ["
|
||||||
|
|name|age|hobby|
|
||||||
|
|Annie|18|music|
|
||||||
|
"]
|
||||||
|
}
|
||||||
|
4. Only output JSON with tables
|
||||||
|
|
||||||
|
Answer:
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import json_repair
|
||||||
|
from utils.pdf_util import PDFUtil
|
||||||
|
from utils.logger import logger
|
||||||
|
from utils.gpt_utils import chat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_base64_pdf_image_list(pdf_file: str,
|
||||||
|
pdf_page_index_list: list,
|
||||||
|
output_folder: str=None) -> dict:
|
||||||
|
if pdf_file is None or pdf_file == "" or not os.path.exists(pdf_file):
|
||||||
|
logger.error("pdf_file is not provided")
|
||||||
|
return None
|
||||||
|
pdf_util = PDFUtil(pdf_file)
|
||||||
|
if pdf_page_index_list is None or len(pdf_page_index_list) == 0:
|
||||||
|
pdf_page_index_list = list(range(pdf_util.get_page_count()))
|
||||||
|
if output_folder is not None and len(output_folder) > 0:
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
pdf_image_info = pdf_util.extract_images(pdf_page_index_list=pdf_page_index_list,
|
||||||
|
output_folder=output_folder)
|
||||||
|
return pdf_image_info
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image_path: str):
|
||||||
|
if image_path is None or len(image_path) == 0 or not os.path.exists(image_path):
|
||||||
|
return None
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def chat_with_image(pdf_file: str,
|
||||||
|
pdf_page_index_list: list,
|
||||||
|
image_folder: str,
|
||||||
|
gpt_folder: str):
|
||||||
|
if pdf_file is None or pdf_file == "" or not os.path.exists(pdf_file):
|
||||||
|
logger.error("pdf_file is not provided")
|
||||||
|
return None
|
||||||
|
pdf_image_info = get_base64_pdf_image_list(pdf_file, pdf_page_index_list, image_folder)
|
||||||
|
image_instructions_file = r'./instructions/table_extraction_image_prompts.txt'
|
||||||
|
with open(image_instructions_file, "r", encoding="utf-8") as file:
|
||||||
|
image_instructions = file.read()
|
||||||
|
os.makedirs(gpt_folder, exist_ok=True)
|
||||||
|
pdf_base_name = os.path.basename(pdf_file).replace(".pdf", "")
|
||||||
|
response_list = {}
|
||||||
|
for page_index, data in pdf_image_info.items():
|
||||||
|
logger.info(f"Processing image in page {page_index}")
|
||||||
|
image_file = data.get("img_file", None)
|
||||||
|
image_base64 = data.get("img_base64", None)
|
||||||
|
response, error = chat(prompt=image_instructions, image_base64=image_base64)
|
||||||
|
if error:
|
||||||
|
logger.error(f"Error in processing image in page {page_index}")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
response_json = json.loads(response)
|
||||||
|
except:
|
||||||
|
response_json = json_repair.loads(response)
|
||||||
|
response_json_file = os.path.join(gpt_folder, f"{pdf_base_name}_{page_index}.json")
|
||||||
|
with open(response_json_file, "w", encoding="utf-8") as file:
|
||||||
|
json.dump(response_json, file, indent=4)
|
||||||
|
logger.info(f"Response for image in page {page_index}: {response}")
|
||||||
|
logger.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pdf_file = r"/data/emea_ar/small_pdf/382366116.pdf"
|
||||||
|
pdf_page_index_list = [29, 35, 71, 77, 83, 89, 97, 103, 112, 121, 130, 140, 195, 250, 305]
|
||||||
|
image_output_folder = r"/data/emea_ar/small_pdf_image/"
|
||||||
|
gpt_output_folder = r"/data/emea_ar/output/gpt_image_response/"
|
||||||
|
chat_with_image(pdf_file, pdf_page_index_list, image_output_folder, gpt_output_folder)
|
||||||
|
|
@ -4,20 +4,26 @@ from openai import AzureOpenAI
|
||||||
import openai
|
import openai
|
||||||
import os
|
import os
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
import base64
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
# loads .env file with your OPENAI_API_KEY
|
# loads .env file with your OPENAI_API_KEY
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
def get_embedding(text, engine=os.getenv("EMBEDDING_ENGINE")):
|
def get_embedding(text, engine=os.getenv("EMBEDDING_ENGINE")):
|
||||||
count = 0
|
count = 0
|
||||||
error = ''
|
error = ""
|
||||||
while count < 5:
|
while count < 5:
|
||||||
try:
|
try:
|
||||||
if count > 0:
|
if count > 0:
|
||||||
print(f'retrying the {count} time for getting text embedding...')
|
print(f"retrying the {count} time for getting text embedding...")
|
||||||
return openai.Embedding.create(input=text, engine=engine)['data'][0]['embedding']
|
return openai.Embedding.create(input=text, engine=engine)["data"][0][
|
||||||
|
"embedding"
|
||||||
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
print(error)
|
print(error)
|
||||||
|
|
@ -35,7 +41,9 @@ def num_tokens_from_messages(messages, model="gpt-35-turbo-16k"):
|
||||||
"""Returns the number of tokens used by a list of messages."""
|
"""Returns the number of tokens used by a list of messages."""
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
if model == "gpt-35-turbo-16k":
|
if model == "gpt-35-turbo-16k":
|
||||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
tokens_per_message = (
|
||||||
|
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
)
|
||||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||||
elif model == "gpt-4-32k":
|
elif model == "gpt-4-32k":
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
|
|
@ -54,26 +62,52 @@ def num_tokens_from_messages(messages, model="gpt-35-turbo-16k"):
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
def chat(prompt: str,
|
def chat(
|
||||||
engine = os.getenv("Engine_GPT4o"),
|
prompt: str,
|
||||||
|
engine=os.getenv("Engine_GPT4o"),
|
||||||
azure_endpoint=os.getenv("OPENAI_API_BASE_GPT4o"),
|
azure_endpoint=os.getenv("OPENAI_API_BASE_GPT4o"),
|
||||||
api_key=os.getenv("OPENAI_API_KEY_GPT4o"),
|
api_key=os.getenv("OPENAI_API_KEY_GPT4o"),
|
||||||
api_version=os.getenv("OPENAI_API_VERSION_GPT4o"),
|
api_version=os.getenv("OPENAI_API_VERSION_GPT4o"),
|
||||||
temperature: float = 0.0):
|
temperature: float = 0.0,
|
||||||
|
image_file: str = None,
|
||||||
|
image_base64: str = None,
|
||||||
|
):
|
||||||
client = AzureOpenAI(
|
client = AzureOpenAI(
|
||||||
azure_endpoint=azure_endpoint,
|
azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version
|
||||||
api_key=api_key,
|
|
||||||
api_version=api_version
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
image_base64 is None
|
||||||
|
and image_file is not None
|
||||||
|
and len(image_file) > 0
|
||||||
|
and os.path.exists(image_file)
|
||||||
|
):
|
||||||
|
image_base64 = encode_image(image_file)
|
||||||
|
|
||||||
|
if image_base64 is not None and len(image_base64) > 0:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
error = ''
|
error = ""
|
||||||
max_tokens = 4000
|
max_tokens = 4000
|
||||||
request_timeout = 120
|
request_timeout = 120
|
||||||
while count < 8:
|
while count < 8:
|
||||||
try:
|
try:
|
||||||
if count > 0:
|
if count > 0:
|
||||||
print(f'retrying the {count} time...')
|
print(f"retrying the {count} time...")
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=engine,
|
model=engine,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
@ -83,16 +117,22 @@ def chat(prompt: str,
|
||||||
presence_penalty=0,
|
presence_penalty=0,
|
||||||
timeout=request_timeout,
|
timeout=request_timeout,
|
||||||
stop=None,
|
stop=None,
|
||||||
messages=[
|
messages=messages,
|
||||||
{"role": "user", "content": prompt}
|
response_format={"type": "json_object"},
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content, False
|
return response.choices[0].message.content, False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
print(f"error message: {error}")
|
print(f"error message: {error}")
|
||||||
if 'maximum context length' in error:
|
if "maximum context length" in error:
|
||||||
return error, True
|
return error, True
|
||||||
count += 1
|
count += 1
|
||||||
sleep(3)
|
sleep(3)
|
||||||
return error, True
|
return error, True
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image_path: str):
|
||||||
|
if image_path is None or len(image_path) == 0 or not os.path.exists(image_path):
|
||||||
|
return None
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import fitz
|
||||||
import json
|
import json
|
||||||
from traceback import print_exc
|
from traceback import print_exc
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import base64
|
||||||
from utils.similarity import Similarity
|
from utils.similarity import Similarity
|
||||||
|
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
|
|
@ -111,6 +112,41 @@ class PDFUtil:
|
||||||
print_exc()
|
print_exc()
|
||||||
return False, str(e), {}
|
return False, str(e), {}
|
||||||
|
|
||||||
|
def extract_images(self,
|
||||||
|
zoom:float = 2.0,
|
||||||
|
pdf_page_index_list: list = None,
|
||||||
|
output_folder: str = None):
|
||||||
|
try:
|
||||||
|
pdf_doc = fitz.open(self.pdf_file)
|
||||||
|
try:
|
||||||
|
pdf_encrypted = pdf_doc.isEncrypted
|
||||||
|
except:
|
||||||
|
pdf_encrypted = pdf_doc.is_encrypted
|
||||||
|
if pdf_encrypted:
|
||||||
|
pdf_doc.authenticate("")
|
||||||
|
if pdf_page_index_list is None or len(pdf_page_index_list) == 0:
|
||||||
|
pdf_page_index_list = range(pdf_doc.page_count)
|
||||||
|
pdf_base_name = os.path.basename(self.pdf_file).replace(".pdf", "")
|
||||||
|
mat = fitz.Matrix(zoom, zoom)
|
||||||
|
output_data = {}
|
||||||
|
for page_num in tqdm(pdf_page_index_list, disable=False):
|
||||||
|
page = pdf_doc[page_num]
|
||||||
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
img_buffer = pix.tobytes(output='png')
|
||||||
|
output_data[page_num] = {}
|
||||||
|
img_base64 = base64.b64encode(img_buffer).decode('utf-8')
|
||||||
|
if output_folder and len(output_folder) > 0:
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
image_file = os.path.join(output_folder, f"{pdf_base_name}_{page_num}.png")
|
||||||
|
pix.save(image_file)
|
||||||
|
output_data[page_num]["img_file"] = image_file
|
||||||
|
output_data[page_num]["img_base64"] = img_base64
|
||||||
|
return output_data
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error extracting images: {e}")
|
||||||
|
print_exc()
|
||||||
|
return {}
|
||||||
|
|
||||||
def parse_blocks_page(self, page: fitz.Page):
|
def parse_blocks_page(self, page: fitz.Page):
|
||||||
blocks = page.get_text("blocks")
|
blocks = page.get_text("blocks")
|
||||||
list_of_blocks = []
|
list_of_blocks = []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue