optimize data extraction algorithm: if can't find cost numeric value from PDF page text, then extract data by Vision ChatGPT
This commit is contained in:
parent
8b651f374c
commit
f166e73362
|
|
@ -136,7 +136,7 @@ class DataExtraction:
|
||||||
page_datapoints = self.get_datapoints_by_page_num(page_num)
|
page_datapoints = self.get_datapoints_by_page_num(page_num)
|
||||||
if len(page_datapoints) == 0:
|
if len(page_datapoints) == 0:
|
||||||
continue
|
continue
|
||||||
extract_data = self.extract_data_by_page_text(
|
extract_data = self.extract_data_by_page(
|
||||||
page_num,
|
page_num,
|
||||||
page_text,
|
page_text,
|
||||||
page_datapoints,
|
page_datapoints,
|
||||||
|
|
@ -179,7 +179,7 @@ class DataExtraction:
|
||||||
next_page_text = self.page_text_dict.get(next_page_num, "")
|
next_page_text = self.page_text_dict.get(next_page_num, "")
|
||||||
target_text = current_text + next_page_text
|
target_text = current_text + next_page_text
|
||||||
# try to get data by current page_datapoints
|
# try to get data by current page_datapoints
|
||||||
next_page_extract_data = self.extract_data_by_page_text(
|
next_page_extract_data = self.extract_data_by_page(
|
||||||
next_page_num,
|
next_page_num,
|
||||||
target_text,
|
target_text,
|
||||||
next_datapoints,
|
next_datapoints,
|
||||||
|
|
@ -313,6 +313,25 @@ class DataExtraction:
|
||||||
)
|
)
|
||||||
with pd.ExcelWriter(excel_data_file) as writer:
|
with pd.ExcelWriter(excel_data_file) as writer:
|
||||||
data_df.to_excel(writer, sheet_name="extract_data", index=False)
|
data_df.to_excel(writer, sheet_name="extract_data", index=False)
|
||||||
|
|
||||||
|
def extract_data_by_page(
|
||||||
|
self,
|
||||||
|
page_num: int,
|
||||||
|
page_text: str,
|
||||||
|
page_datapoints: list,
|
||||||
|
need_exclude: bool = False,
|
||||||
|
exclude_data: list = None,) -> dict:
|
||||||
|
# If can't find numberic value, e.g. 1.25 or 3,88
|
||||||
|
# apply Vision ChatGPT to extract data
|
||||||
|
numeric_regex = r"\d+(\.|\,)\d+"
|
||||||
|
if not re.search(numeric_regex, page_text):
|
||||||
|
logger.info(f"Can't find numberic value in page {page_num}, apply Vision ChatGPT to extract data")
|
||||||
|
return self.extract_data_by_page_image(
|
||||||
|
page_num, page_datapoints, need_exclude, exclude_data)
|
||||||
|
else:
|
||||||
|
return self.extract_data_by_page_text(
|
||||||
|
page_num, page_text, page_datapoints, need_exclude, exclude_data
|
||||||
|
)
|
||||||
|
|
||||||
def extract_data_by_page_text(
|
def extract_data_by_page_text(
|
||||||
self,
|
self,
|
||||||
|
|
@ -328,7 +347,11 @@ class DataExtraction:
|
||||||
"""
|
"""
|
||||||
logger.info(f"Extracting data from page {page_num}")
|
logger.info(f"Extracting data from page {page_num}")
|
||||||
instructions = self.get_instructions_by_datapoints(
|
instructions = self.get_instructions_by_datapoints(
|
||||||
page_text, page_datapoints, need_exclude, exclude_data
|
page_text,
|
||||||
|
page_datapoints,
|
||||||
|
need_exclude,
|
||||||
|
exclude_data,
|
||||||
|
extract_way="text"
|
||||||
)
|
)
|
||||||
response, with_error = chat(
|
response, with_error = chat(
|
||||||
instructions, response_format={"type": "json_object"}
|
instructions, response_format={"type": "json_object"}
|
||||||
|
|
@ -342,6 +365,7 @@ class DataExtraction:
|
||||||
data_dict["instructions"] = instructions
|
data_dict["instructions"] = instructions
|
||||||
data_dict["raw_answer"] = response
|
data_dict["raw_answer"] = response
|
||||||
data_dict["extract_data"] = {"data": []}
|
data_dict["extract_data"] = {"data": []}
|
||||||
|
data_dict["extract_way"] = "text"
|
||||||
return data_dict
|
return data_dict
|
||||||
try:
|
try:
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
|
|
@ -367,12 +391,15 @@ class DataExtraction:
|
||||||
data_dict["instructions"] = instructions
|
data_dict["instructions"] = instructions
|
||||||
data_dict["raw_answer"] = response
|
data_dict["raw_answer"] = response
|
||||||
data_dict["extract_data"] = data
|
data_dict["extract_data"] = data
|
||||||
|
data_dict["extract_way"] = "text"
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def extract_data_by_page_image(
|
def extract_data_by_page_image(
|
||||||
self,
|
self,
|
||||||
page_num: int,
|
page_num: int,
|
||||||
page_datapoints: list
|
page_datapoints: list,
|
||||||
|
need_exclude: bool = False,
|
||||||
|
exclude_data: list = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
keys are
|
keys are
|
||||||
|
|
@ -381,7 +408,11 @@ class DataExtraction:
|
||||||
logger.info(f"Extracting data from page {page_num}")
|
logger.info(f"Extracting data from page {page_num}")
|
||||||
image_base64 = self.get_pdf_image_base64(page_num)
|
image_base64 = self.get_pdf_image_base64(page_num)
|
||||||
instructions = self.get_instructions_by_datapoints(
|
instructions = self.get_instructions_by_datapoints(
|
||||||
"", page_datapoints, need_exclude=False, exclude_data=None
|
"",
|
||||||
|
page_datapoints,
|
||||||
|
need_exclude=need_exclude,
|
||||||
|
exclude_data=exclude_data,
|
||||||
|
extract_way="image"
|
||||||
)
|
)
|
||||||
response, with_error = chat(
|
response, with_error = chat(
|
||||||
instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
instructions, response_format={"type": "json_object"}, image_base64=image_base64
|
||||||
|
|
@ -391,9 +422,11 @@ class DataExtraction:
|
||||||
data_dict = {"doc_id": self.doc_id}
|
data_dict = {"doc_id": self.doc_id}
|
||||||
data_dict["page_index"] = page_num
|
data_dict["page_index"] = page_num
|
||||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||||
|
data_dict["page_text"] = ""
|
||||||
data_dict["instructions"] = instructions
|
data_dict["instructions"] = instructions
|
||||||
data_dict["raw_answer"] = response
|
data_dict["raw_answer"] = response
|
||||||
data_dict["extract_data"] = {"data": []}
|
data_dict["extract_data"] = {"data": []}
|
||||||
|
data_dict["extract_way"] = "image"
|
||||||
return data_dict
|
return data_dict
|
||||||
try:
|
try:
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
|
|
@ -407,137 +440,12 @@ class DataExtraction:
|
||||||
data_dict = {"doc_id": self.doc_id}
|
data_dict = {"doc_id": self.doc_id}
|
||||||
data_dict["page_index"] = page_num
|
data_dict["page_index"] = page_num
|
||||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||||
|
data_dict["page_text"] = ""
|
||||||
data_dict["instructions"] = instructions
|
data_dict["instructions"] = instructions
|
||||||
data_dict["raw_answer"] = response
|
data_dict["raw_answer"] = response
|
||||||
data_dict["extract_data"] = data
|
data_dict["extract_data"] = data
|
||||||
|
data_dict["extract_way"] = "image"
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
def chat_by_split_context(self,
|
|
||||||
page_text: str,
|
|
||||||
page_datapoints: list,
|
|
||||||
need_exclude: bool,
|
|
||||||
exclude_data: list) -> list:
|
|
||||||
"""
|
|
||||||
If occur error, split the context to two parts and try to get data from the two parts
|
|
||||||
Relevant document: 503194284, page index 147
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
|
||||||
split_context = re.split(r"\n", page_text)
|
|
||||||
split_context = [text.strip() for text in split_context
|
|
||||||
if len(text.strip()) > 0]
|
|
||||||
if len(split_context) < 10:
|
|
||||||
return {"data": []}
|
|
||||||
|
|
||||||
split_context_len = len(split_context)
|
|
||||||
top_10_context = split_context[:10]
|
|
||||||
rest_context = split_context[10:]
|
|
||||||
header = "\n".join(top_10_context)
|
|
||||||
half_len = split_context_len // 2
|
|
||||||
# the member of half_len should not start with number
|
|
||||||
# reverse iterate the list by half_len
|
|
||||||
half_len_list = [i for i in range(half_len)]
|
|
||||||
|
|
||||||
fund_name_line = ""
|
|
||||||
half_line = rest_context[half_len].strip()
|
|
||||||
max_similarity_fund_name, max_similarity = get_most_similar_name(
|
|
||||||
half_line, self.provider_fund_name_list, matching_type="fund"
|
|
||||||
)
|
|
||||||
if max_similarity < 0.2:
|
|
||||||
# get the fund name line text from the first half
|
|
||||||
for index in reversed(half_len_list):
|
|
||||||
line_text = rest_context[index].strip()
|
|
||||||
if len(line_text) == 0:
|
|
||||||
continue
|
|
||||||
line_text_split = line_text.split()
|
|
||||||
if len(line_text_split) < 3:
|
|
||||||
continue
|
|
||||||
first_word = line_text_split[0]
|
|
||||||
if first_word.lower() == "class":
|
|
||||||
continue
|
|
||||||
|
|
||||||
max_similarity_fund_name, max_similarity = get_most_similar_name(
|
|
||||||
line_text, self.provider_fund_name_list, matching_type="fund"
|
|
||||||
)
|
|
||||||
if max_similarity >= 0.2:
|
|
||||||
fund_name_line = line_text
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
fund_name_line = half_line
|
|
||||||
half_len += 1
|
|
||||||
if fund_name_line == "":
|
|
||||||
return {"data": []}
|
|
||||||
|
|
||||||
logger.info(f"Split first part from 0 to {half_len}")
|
|
||||||
split_first_part = "\n".join(split_context[:half_len])
|
|
||||||
first_part = '\n'.join(split_first_part)
|
|
||||||
first_instructions = self.get_instructions_by_datapoints(
|
|
||||||
first_part, page_datapoints, need_exclude, exclude_data
|
|
||||||
)
|
|
||||||
response, with_error = chat(
|
|
||||||
first_instructions, response_format={"type": "json_object"}
|
|
||||||
)
|
|
||||||
first_part_data = {"data": []}
|
|
||||||
if not with_error:
|
|
||||||
try:
|
|
||||||
first_part_data = json.loads(response)
|
|
||||||
except:
|
|
||||||
first_part_data = json_repair.loads(response)
|
|
||||||
|
|
||||||
logger.info(f"Split second part from {half_len} to {split_context_len}")
|
|
||||||
split_second_part = "\n".join(split_context[half_len:])
|
|
||||||
second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
|
||||||
second_instructions = self.get_instructions_by_datapoints(
|
|
||||||
second_part, page_datapoints, need_exclude, exclude_data
|
|
||||||
)
|
|
||||||
response, with_error = chat(
|
|
||||||
second_instructions, response_format={"type": "json_object"}
|
|
||||||
)
|
|
||||||
second_part_data = {"data": []}
|
|
||||||
if not with_error:
|
|
||||||
try:
|
|
||||||
second_part_data = json.loads(response)
|
|
||||||
except:
|
|
||||||
second_part_data = json_repair.loads(response)
|
|
||||||
|
|
||||||
first_part_data_list = first_part_data.get("data", [])
|
|
||||||
logger.info(f"First part data count: {len(first_part_data_list)}")
|
|
||||||
second_part_data_list = second_part_data.get("data", [])
|
|
||||||
logger.info(f"Second part data count: {len(second_part_data_list)}")
|
|
||||||
for first_data in first_part_data_list:
|
|
||||||
if first_data in second_part_data_list:
|
|
||||||
second_part_data_list.remove(first_data)
|
|
||||||
else:
|
|
||||||
# if the first part data is with same fund name and share name,
|
|
||||||
# remove the second part data
|
|
||||||
first_data_dp = [key for key in list(first_data.keys())
|
|
||||||
if key not in ["fund name", "share name"]]
|
|
||||||
# order the data points
|
|
||||||
first_data_dp.sort()
|
|
||||||
first_fund_name = first_data.get("fund name", "")
|
|
||||||
first_share_name = first_data.get("share name", "")
|
|
||||||
if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
|
||||||
remove_second_list = []
|
|
||||||
for second_data in second_part_data_list:
|
|
||||||
second_fund_name = second_data.get("fund name", "")
|
|
||||||
second_share_name = second_data.get("share name", "")
|
|
||||||
if first_fund_name == second_fund_name and \
|
|
||||||
first_share_name == second_share_name:
|
|
||||||
second_data_dp = [key for key in list(second_data.keys())
|
|
||||||
if key not in ["fund name", "share name"]]
|
|
||||||
second_data_dp.sort()
|
|
||||||
if first_data_dp == second_data_dp:
|
|
||||||
remove_second_list.append(second_data)
|
|
||||||
for remove_second in remove_second_list:
|
|
||||||
if remove_second in second_part_data_list:
|
|
||||||
second_part_data_list.remove(remove_second)
|
|
||||||
|
|
||||||
data_list = first_part_data_list + second_part_data_list
|
|
||||||
extract_data = {"data": data_list}
|
|
||||||
return extract_data
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in split context: {e}")
|
|
||||||
return {"data": []}
|
|
||||||
|
|
||||||
def validate_data(self, extract_data_info: dict) -> dict:
|
def validate_data(self, extract_data_info: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
@ -634,7 +542,6 @@ class DataExtraction:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
||||||
datapoints = []
|
datapoints = []
|
||||||
for datapoint in self.datapoints:
|
for datapoint in self.datapoints:
|
||||||
|
|
@ -648,6 +555,7 @@ class DataExtraction:
|
||||||
datapoints: list,
|
datapoints: list,
|
||||||
need_exclude: bool = False,
|
need_exclude: bool = False,
|
||||||
exclude_data: list = None,
|
exclude_data: list = None,
|
||||||
|
extract_way: str = "text",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get instructions to extract data from the page by the datapoints
|
Get instructions to extract data from the page by the datapoints
|
||||||
|
|
@ -678,7 +586,7 @@ class DataExtraction:
|
||||||
end
|
end
|
||||||
"""
|
"""
|
||||||
instructions = []
|
instructions = []
|
||||||
if self.extract_way == "text":
|
if extract_way == "text":
|
||||||
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
|
instructions = [f"Context:\n{page_text}\n\nInstructions:\n"]
|
||||||
|
|
||||||
datapoint_name_list = []
|
datapoint_name_list = []
|
||||||
|
|
@ -686,9 +594,9 @@ class DataExtraction:
|
||||||
datapoint_name = self.datapoint_name_config.get(datapoint, "")
|
datapoint_name = self.datapoint_name_config.get(datapoint, "")
|
||||||
datapoint_name_list.append(datapoint_name)
|
datapoint_name_list.append(datapoint_name)
|
||||||
|
|
||||||
if self.extract_way == "text":
|
if extract_way == "text":
|
||||||
summary = self.instructions_config.get("summary", "\n")
|
summary = self.instructions_config.get("summary", "\n")
|
||||||
elif self.extract_way == "image":
|
elif extract_way == "image":
|
||||||
summary = self.instructions_config.get("summary_image", "\n")
|
summary = self.instructions_config.get("summary_image", "\n")
|
||||||
else:
|
else:
|
||||||
summary = self.instructions_config.get("summary", "\n")
|
summary = self.instructions_config.get("summary", "\n")
|
||||||
|
|
@ -696,7 +604,7 @@ class DataExtraction:
|
||||||
instructions.append(summary.format(", ".join(datapoint_name_list)))
|
instructions.append(summary.format(", ".join(datapoint_name_list)))
|
||||||
instructions.append("\n")
|
instructions.append("\n")
|
||||||
|
|
||||||
if self.extract_way == "image":
|
if extract_way == "image":
|
||||||
image_features = self.instructions_config.get("image_features", [])
|
image_features = self.instructions_config.get("image_features", [])
|
||||||
instructions.extend(image_features)
|
instructions.extend(image_features)
|
||||||
instructions.append("\n")
|
instructions.append("\n")
|
||||||
|
|
@ -831,3 +739,130 @@ class DataExtraction:
|
||||||
|
|
||||||
instructions_text = "".join(instructions)
|
instructions_text = "".join(instructions)
|
||||||
return instructions_text
|
return instructions_text
|
||||||
|
|
||||||
|
# def chat_by_split_context(self,
|
||||||
|
# page_text: str,
|
||||||
|
# page_datapoints: list,
|
||||||
|
# need_exclude: bool,
|
||||||
|
# exclude_data: list) -> list:
|
||||||
|
# """
|
||||||
|
# If occur error, split the context to two parts and try to get data from the two parts
|
||||||
|
# Relevant document: 503194284, page index 147
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||||
|
# split_context = re.split(r"\n", page_text)
|
||||||
|
# split_context = [text.strip() for text in split_context
|
||||||
|
# if len(text.strip()) > 0]
|
||||||
|
# if len(split_context) < 10:
|
||||||
|
# return {"data": []}
|
||||||
|
|
||||||
|
# split_context_len = len(split_context)
|
||||||
|
# top_10_context = split_context[:10]
|
||||||
|
# rest_context = split_context[10:]
|
||||||
|
# header = "\n".join(top_10_context)
|
||||||
|
# half_len = split_context_len // 2
|
||||||
|
# # the member of half_len should not start with number
|
||||||
|
# # reverse iterate the list by half_len
|
||||||
|
# half_len_list = [i for i in range(half_len)]
|
||||||
|
|
||||||
|
# fund_name_line = ""
|
||||||
|
# half_line = rest_context[half_len].strip()
|
||||||
|
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||||
|
# half_line, self.provider_fund_name_list, matching_type="fund"
|
||||||
|
# )
|
||||||
|
# if max_similarity < 0.2:
|
||||||
|
# # get the fund name line text from the first half
|
||||||
|
# for index in reversed(half_len_list):
|
||||||
|
# line_text = rest_context[index].strip()
|
||||||
|
# if len(line_text) == 0:
|
||||||
|
# continue
|
||||||
|
# line_text_split = line_text.split()
|
||||||
|
# if len(line_text_split) < 3:
|
||||||
|
# continue
|
||||||
|
# first_word = line_text_split[0]
|
||||||
|
# if first_word.lower() == "class":
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||||
|
# line_text, self.provider_fund_name_list, matching_type="fund"
|
||||||
|
# )
|
||||||
|
# if max_similarity >= 0.2:
|
||||||
|
# fund_name_line = line_text
|
||||||
|
# break
|
||||||
|
# else:
|
||||||
|
# fund_name_line = half_line
|
||||||
|
# half_len += 1
|
||||||
|
# if fund_name_line == "":
|
||||||
|
# return {"data": []}
|
||||||
|
|
||||||
|
# logger.info(f"Split first part from 0 to {half_len}")
|
||||||
|
# split_first_part = "\n".join(split_context[:half_len])
|
||||||
|
# first_part = '\n'.join(split_first_part)
|
||||||
|
# first_instructions = self.get_instructions_by_datapoints(
|
||||||
|
# first_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
||||||
|
# )
|
||||||
|
# response, with_error = chat(
|
||||||
|
# first_instructions, response_format={"type": "json_object"}
|
||||||
|
# )
|
||||||
|
# first_part_data = {"data": []}
|
||||||
|
# if not with_error:
|
||||||
|
# try:
|
||||||
|
# first_part_data = json.loads(response)
|
||||||
|
# except:
|
||||||
|
# first_part_data = json_repair.loads(response)
|
||||||
|
|
||||||
|
# logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||||
|
# split_second_part = "\n".join(split_context[half_len:])
|
||||||
|
# second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
||||||
|
# second_instructions = self.get_instructions_by_datapoints(
|
||||||
|
# second_part, page_datapoints, need_exclude, exclude_data, extract_way="text"
|
||||||
|
# )
|
||||||
|
# response, with_error = chat(
|
||||||
|
# second_instructions, response_format={"type": "json_object"}
|
||||||
|
# )
|
||||||
|
# second_part_data = {"data": []}
|
||||||
|
# if not with_error:
|
||||||
|
# try:
|
||||||
|
# second_part_data = json.loads(response)
|
||||||
|
# except:
|
||||||
|
# second_part_data = json_repair.loads(response)
|
||||||
|
|
||||||
|
# first_part_data_list = first_part_data.get("data", [])
|
||||||
|
# logger.info(f"First part data count: {len(first_part_data_list)}")
|
||||||
|
# second_part_data_list = second_part_data.get("data", [])
|
||||||
|
# logger.info(f"Second part data count: {len(second_part_data_list)}")
|
||||||
|
# for first_data in first_part_data_list:
|
||||||
|
# if first_data in second_part_data_list:
|
||||||
|
# second_part_data_list.remove(first_data)
|
||||||
|
# else:
|
||||||
|
# # if the first part data is with same fund name and share name,
|
||||||
|
# # remove the second part data
|
||||||
|
# first_data_dp = [key for key in list(first_data.keys())
|
||||||
|
# if key not in ["fund name", "share name"]]
|
||||||
|
# # order the data points
|
||||||
|
# first_data_dp.sort()
|
||||||
|
# first_fund_name = first_data.get("fund name", "")
|
||||||
|
# first_share_name = first_data.get("share name", "")
|
||||||
|
# if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
||||||
|
# remove_second_list = []
|
||||||
|
# for second_data in second_part_data_list:
|
||||||
|
# second_fund_name = second_data.get("fund name", "")
|
||||||
|
# second_share_name = second_data.get("share name", "")
|
||||||
|
# if first_fund_name == second_fund_name and \
|
||||||
|
# first_share_name == second_share_name:
|
||||||
|
# second_data_dp = [key for key in list(second_data.keys())
|
||||||
|
# if key not in ["fund name", "share name"]]
|
||||||
|
# second_data_dp.sort()
|
||||||
|
# if first_data_dp == second_data_dp:
|
||||||
|
# remove_second_list.append(second_data)
|
||||||
|
# for remove_second in remove_second_list:
|
||||||
|
# if remove_second in second_part_data_list:
|
||||||
|
# second_part_data_list.remove(remove_second)
|
||||||
|
|
||||||
|
# data_list = first_part_data_list + second_part_data_list
|
||||||
|
# extract_data = {"data": data_list}
|
||||||
|
# return extract_data
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error in split context: {e}")
|
||||||
|
# return {"data": []}
|
||||||
|
|
|
||||||
4
main.py
4
main.py
|
|
@ -809,10 +809,10 @@ if __name__ == "__main__":
|
||||||
]
|
]
|
||||||
# special_doc_id_list = check_mapping_doc_id_list
|
# special_doc_id_list = check_mapping_doc_id_list
|
||||||
special_doc_id_list = check_db_mapping_doc_id_list
|
special_doc_id_list = check_db_mapping_doc_id_list
|
||||||
special_doc_id_list = ["423395975"]
|
special_doc_id_list = ["514213638"]
|
||||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||||
re_run_extract_data = True
|
re_run_extract_data = False
|
||||||
re_run_mapping_data = True
|
re_run_mapping_data = True
|
||||||
force_save_total_data = False
|
force_save_total_data = False
|
||||||
calculate_metrics = False
|
calculate_metrics = False
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ def chat(
|
||||||
image_file: str = None,
|
image_file: str = None,
|
||||||
image_base64: str = None,
|
image_base64: str = None,
|
||||||
):
|
):
|
||||||
if engine != "gpt-4o-2024-08-06-research":
|
if not engine.startswith("gpt-4o"):
|
||||||
max_tokens = 4096
|
max_tokens = 4096
|
||||||
|
|
||||||
client = AzureOpenAI(
|
client = AzureOpenAI(
|
||||||
|
|
@ -138,6 +138,7 @@ def chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
sleep(1)
|
||||||
return response.choices[0].message.content, False
|
return response.choices[0].message.content, False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error = str(e)
|
error = str(e)
|
||||||
|
|
@ -145,7 +146,7 @@ def chat(
|
||||||
if "maximum context length" in error:
|
if "maximum context length" in error:
|
||||||
return error, True
|
return error, True
|
||||||
count += 1
|
count += 1
|
||||||
sleep(3)
|
sleep(2)
|
||||||
return error, True
|
return error, True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue