support extract the continuous page(s) for not missing next page data which without table header.
This commit is contained in:
parent
1caf552065
commit
878383a72c
|
|
@ -13,26 +13,26 @@ from utils.biz_utils import add_slash_to_text_as_regex, clean_text
|
|||
|
||||
class DataExtraction:
|
||||
def __init__(
|
||||
self,
|
||||
doc_id: str,
|
||||
self,
|
||||
doc_id: str,
|
||||
pdf_file: str,
|
||||
output_data_folder: str,
|
||||
page_text_dict: dict,
|
||||
datapoint_page_info: dict,
|
||||
document_mapping_info_df: pd.DataFrame
|
||||
output_data_folder: str,
|
||||
page_text_dict: dict,
|
||||
datapoint_page_info: dict,
|
||||
document_mapping_info_df: pd.DataFrame,
|
||||
) -> None:
|
||||
self.doc_id = doc_id
|
||||
self.pdf_file = pdf_file
|
||||
if output_data_folder is None or len(output_data_folder) == 0:
|
||||
output_data_folder = r"/data/emea_ar/output/extract_data/docs/"
|
||||
os.makedirs(output_data_folder, exist_ok=True)
|
||||
|
||||
|
||||
self.output_data_json_folder = os.path.join(output_data_folder, "json/")
|
||||
os.makedirs(self.output_data_json_folder, exist_ok=True)
|
||||
|
||||
|
||||
self.output_data_excel_folder = os.path.join(output_data_folder, "excel/")
|
||||
os.makedirs(self.output_data_excel_folder, exist_ok=True)
|
||||
|
||||
|
||||
if page_text_dict is None or len(page_text_dict.keys()) == 0:
|
||||
self.page_text_dict = self.get_pdf_page_text_dict()
|
||||
else:
|
||||
|
|
@ -42,17 +42,19 @@ class DataExtraction:
|
|||
else:
|
||||
self.document_mapping_info_df = document_mapping_info_df
|
||||
self.datapoint_page_info = datapoint_page_info
|
||||
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
|
||||
|
||||
self.datapoints = self.get_datapoints_from_datapoint_page_info()
|
||||
self.instructions_config = self.get_instructions_config()
|
||||
self.datapoint_level_config = self.get_datapoint_level()
|
||||
self.datapoint_name_config = self.get_datapoint_name()
|
||||
|
||||
|
||||
def get_instructions_config(self) -> dict:
|
||||
instructions_config_file = r"./instructions/data_extraction_prompts_config.json"
|
||||
with open(instructions_config_file, "r", encoding="utf-8") as f:
|
||||
instructions_config = json.load(f)
|
||||
return instructions_config
|
||||
|
||||
|
||||
def get_datapoint_level(self) -> dict:
|
||||
datapoint_level_file = r"./configuration/datapoint_level.json"
|
||||
with open(datapoint_level_file, "r", encoding="utf-8") as f:
|
||||
|
|
@ -64,68 +66,181 @@ class DataExtraction:
|
|||
with open(datapoint_name_file, "r", encoding="utf-8") as f:
|
||||
datapoint_name = json.load(f)
|
||||
return datapoint_name
|
||||
|
||||
|
||||
def get_pdf_page_text_dict(self) -> dict:
|
||||
pdf_util = PDFUtil(self.pdf_file)
|
||||
success, text, page_text_dict = pdf_util.extract_text()
|
||||
return page_text_dict
|
||||
|
||||
|
||||
def get_datapoints_from_datapoint_page_info(self) -> list:
|
||||
datapoints = list(self.datapoint_page_info.keys())
|
||||
if "doc_id" in datapoints:
|
||||
datapoints.remove("doc_id")
|
||||
return datapoints
|
||||
|
||||
|
||||
def get_page_nums_from_datapoint_page_info(self) -> list:
|
||||
page_nums_with_datapoints = []
|
||||
for datapoint, page_nums in self.datapoint_page_info.items():
|
||||
if datapoint == "doc_id":
|
||||
continue
|
||||
page_nums_with_datapoints.extend(page_nums)
|
||||
page_nums_with_datapoints = list(set(page_nums_with_datapoints))
|
||||
# sort the page numbers
|
||||
page_nums_with_datapoints.sort()
|
||||
return page_nums_with_datapoints
|
||||
|
||||
def extract_data(self) -> dict:
|
||||
"""
|
||||
keys are
|
||||
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
|
||||
"""
|
||||
data_list = []
|
||||
pdf_page_count = len(self.page_text_dict.keys())
|
||||
handled_page_num_list = []
|
||||
for page_num, page_text in self.page_text_dict.items():
|
||||
if page_num in handled_page_num_list:
|
||||
continue
|
||||
page_datapoints = self.get_datapoints_by_page_num(page_num)
|
||||
if len(page_datapoints) == 0:
|
||||
continue
|
||||
instructions = self.get_instructions_by_datapoints(page_text, page_datapoints)
|
||||
response, with_error = chat(instructions)
|
||||
if with_error:
|
||||
logger.error(f"Error in extracting tables from page")
|
||||
return ""
|
||||
try:
|
||||
data = json.loads(response)
|
||||
except:
|
||||
try:
|
||||
data = json_repair.loads(response)
|
||||
except:
|
||||
data = {}
|
||||
data_dict = {"doc_id": self.doc_id}
|
||||
data_dict["page_index"] = page_num
|
||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||
data_dict["page_text"] = page_text
|
||||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["data"] = data
|
||||
data_list.append(data_dict)
|
||||
json_data_file = os.path.join(self.output_data_json_folder, f"{self.doc_id}.json")
|
||||
extract_data = self.extract_data_by_page(
|
||||
page_num,
|
||||
page_text,
|
||||
page_datapoints,
|
||||
need_exclude=False,
|
||||
exclude_data=None,
|
||||
)
|
||||
data_list.append(extract_data)
|
||||
|
||||
page_data_list = extract_data.get("extract_data", {}).get("data", [])
|
||||
|
||||
current_page_data_count = len(page_data_list)
|
||||
if current_page_data_count > 0:
|
||||
count = 1
|
||||
# some pdf documents have multiple pages for the same data
|
||||
# and the next page may without table header with data point keywords.
|
||||
# the purpose is try to get data from the next page
|
||||
current_text = page_text
|
||||
|
||||
while count < 3:
|
||||
try:
|
||||
next_page_num = page_num + count
|
||||
if next_page_num >= pdf_page_count:
|
||||
break
|
||||
next_datapoints = page_datapoints
|
||||
if next_page_num in self.page_nums_with_datapoints:
|
||||
should_continue = False
|
||||
next_datapoints = self.get_datapoints_by_page_num(next_page_num)
|
||||
if len(next_datapoints) == 0:
|
||||
should_continue = True
|
||||
else:
|
||||
for next_datapoint in next_datapoints:
|
||||
if next_datapoint not in page_datapoints:
|
||||
should_continue = True
|
||||
break
|
||||
next_datapoints.extend(page_datapoints)
|
||||
# remove duplicate datapoints
|
||||
next_datapoints = list(set(next_datapoints))
|
||||
if not should_continue:
|
||||
break
|
||||
next_page_text = self.page_text_dict.get(next_page_num, "")
|
||||
target_text = current_text + next_page_text
|
||||
# try to get data by current page_datapoints
|
||||
next_page_extract_data = self.extract_data_by_page(
|
||||
next_page_num,
|
||||
target_text,
|
||||
next_datapoints,
|
||||
need_exclude=True,
|
||||
exclude_data=page_data_list,
|
||||
)
|
||||
next_page_data_list = next_page_extract_data.get(
|
||||
"extract_data", {}
|
||||
).get("data", [])
|
||||
|
||||
if next_page_data_list is not None and len(next_page_data_list) > 0:
|
||||
for current_page_data in page_data_list:
|
||||
if current_page_data in next_page_data_list:
|
||||
next_page_data_list.remove(current_page_data)
|
||||
next_page_extract_data["extract_data"][
|
||||
"data"
|
||||
] = next_page_data_list
|
||||
data_list.append(next_page_extract_data)
|
||||
handled_page_num_list.append(next_page_num)
|
||||
else:
|
||||
break
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extracting data from next page: {e}")
|
||||
break
|
||||
|
||||
json_data_file = os.path.join(
|
||||
self.output_data_json_folder, f"{self.doc_id}.json"
|
||||
)
|
||||
with open(json_data_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data_list, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
data_df = pd.DataFrame(data_list)
|
||||
data_df.reset_index(drop=True, inplace=True)
|
||||
excel_data_file = os.path.join(self.output_data_excel_folder, f"{self.doc_id}.xlsx")
|
||||
excel_data_file = os.path.join(
|
||||
self.output_data_excel_folder, f"{self.doc_id}.xlsx"
|
||||
)
|
||||
with pd.ExcelWriter(excel_data_file) as writer:
|
||||
data_df.to_excel(writer, sheet_name="extract_data", index=False)
|
||||
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def extract_data_by_page(
|
||||
self,
|
||||
page_num: int,
|
||||
page_text: str,
|
||||
page_datapoints: list,
|
||||
need_exclude: bool = False,
|
||||
exclude_data: list = None,
|
||||
) -> dict:
|
||||
"""
|
||||
keys are
|
||||
doc_id, page_index, datapoint, value, raw_fund_name, fund_id, fund_name, raw_share_name, share_id, share_name
|
||||
"""
|
||||
logger.info(f"Extracting data from page {page_num}")
|
||||
instructions = self.get_instructions_by_datapoints(
|
||||
page_text, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
response, with_error = chat(
|
||||
instructions, response_format={"type": "json_object"}
|
||||
)
|
||||
if with_error:
|
||||
logger.error(f"Error in extracting tables from page")
|
||||
return ""
|
||||
try:
|
||||
data = json.loads(response)
|
||||
except:
|
||||
try:
|
||||
data = json_repair.loads(response)
|
||||
except:
|
||||
data = {"data": []}
|
||||
data_dict = {"doc_id": self.doc_id}
|
||||
data_dict["page_index"] = page_num
|
||||
data_dict["datapoints"] = ", ".join(page_datapoints)
|
||||
data_dict["page_text"] = page_text
|
||||
data_dict["instructions"] = instructions
|
||||
data_dict["raw_answer"] = response
|
||||
data_dict["extract_data"] = data
|
||||
return data_dict
|
||||
|
||||
def get_datapoints_by_page_num(self, page_num: int) -> list:
|
||||
datapoints = []
|
||||
for datapoint in self.datapoints:
|
||||
if page_num in self.datapoint_page_info[datapoint]:
|
||||
datapoints.append(datapoint)
|
||||
return datapoints
|
||||
|
||||
def get_instructions_by_datapoints(self, page_text: str, datapoints: list) -> str:
|
||||
|
||||
def get_instructions_by_datapoints(
|
||||
self,
|
||||
page_text: str,
|
||||
datapoints: list,
|
||||
need_exclude: bool = False,
|
||||
exclude_data: list = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get instructions to extract data from the page by the datapoints
|
||||
Below is the instructions sections:
|
||||
|
|
@ -159,11 +274,11 @@ class DataExtraction:
|
|||
for datapoint in datapoints:
|
||||
datapoint_name = self.datapoint_name_config.get(datapoint, "")
|
||||
datapoint_name_list.append(datapoint_name)
|
||||
|
||||
|
||||
summary = self.instructions_config.get("summary", "\n")
|
||||
instructions.append(summary.format(', '.join(datapoint_name_list)))
|
||||
instructions.append(summary.format(", ".join(datapoint_name_list)))
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
instructions.append("Datapoints Reported name:\n")
|
||||
reported_name_info = self.instructions_config.get("reported_name", {})
|
||||
for datapoint in datapoints:
|
||||
|
|
@ -171,13 +286,15 @@ class DataExtraction:
|
|||
instructions.append(reported_name)
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
instructions.append("Data business features:\n")
|
||||
data_business_features = self.instructions_config.get("data_business_features", {})
|
||||
common = '\n'.join(data_business_features.get("common", []))
|
||||
data_business_features = self.instructions_config.get(
|
||||
"data_business_features", {}
|
||||
)
|
||||
common = "\n".join(data_business_features.get("common", []))
|
||||
instructions.append(common)
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
instructions.append("Datapoints investment level:\n")
|
||||
investment_level_info = data_business_features.get("investment_level", {})
|
||||
for datapoint in datapoints:
|
||||
|
|
@ -185,7 +302,7 @@ class DataExtraction:
|
|||
instructions.append(investment_level)
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
instructions.append("Datapoints value range:\n")
|
||||
data_value_range_info = data_business_features.get("data_value_range", {})
|
||||
for datapoint in datapoints:
|
||||
|
|
@ -193,7 +310,7 @@ class DataExtraction:
|
|||
instructions.append(data_value_range)
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
special_rule_info = data_business_features.get("special_rule", {})
|
||||
with_special_rule_title = False
|
||||
for datapoint in datapoints:
|
||||
|
|
@ -202,11 +319,11 @@ class DataExtraction:
|
|||
if not with_special_rule_title:
|
||||
instructions.append("Special rule:\n")
|
||||
with_special_rule_title = True
|
||||
special_rule = '\n'.join(special_rule_list)
|
||||
special_rule = "\n".join(special_rule_list)
|
||||
instructions.append(special_rule)
|
||||
instructions.append("\n\n")
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
instructions.append("Special cases:\n")
|
||||
special_cases = self.instructions_config.get("special_cases", {})
|
||||
special_cases_common_list = special_cases.get("common", [])
|
||||
|
|
@ -215,10 +332,10 @@ class DataExtraction:
|
|||
instructions.append(title)
|
||||
instructions.append("\n")
|
||||
contents_list = special_cases_common.get("contents", [])
|
||||
contents = '\n'.join(contents_list)
|
||||
contents = "\n".join(contents_list)
|
||||
instructions.append(contents)
|
||||
instructions.append("\n\n")
|
||||
|
||||
|
||||
for datapoint in datapoints:
|
||||
special_case_list = special_cases.get(datapoint, [])
|
||||
for special_case in special_case_list:
|
||||
|
|
@ -226,51 +343,69 @@ class DataExtraction:
|
|||
instructions.append(title)
|
||||
instructions.append("\n")
|
||||
contents_list = special_case.get("contents", [])
|
||||
contents = '\n'.join(contents_list)
|
||||
contents = "\n".join(contents_list)
|
||||
instructions.append(contents)
|
||||
instructions.append("\n\n")
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
|
||||
instructions.append("Output requirement:\n")
|
||||
output_requirement = self.instructions_config.get("output_requirement", {})
|
||||
output_requirement_common_list = output_requirement.get("common", [])
|
||||
instructions.append("\n".join(output_requirement_common_list))
|
||||
instructions.append("\n")
|
||||
|
||||
|
||||
share_datapoint_value_example = {}
|
||||
share_level_config = output_requirement.get("share_level", {})
|
||||
|
||||
example_list = []
|
||||
for datapoint in datapoints:
|
||||
investment_level = self.datapoint_level_config.get(datapoint, "")
|
||||
if investment_level == "fund_level":
|
||||
fund_level_example_list = output_requirement.get("fund_level", [])
|
||||
for example in fund_level_example_list:
|
||||
instructions.append(example)
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
try:
|
||||
sub_example_list = json.loads(example)
|
||||
except:
|
||||
sub_example_list = json_repair.loads(example)
|
||||
example_list.extend(sub_example_list)
|
||||
elif investment_level == "share_level":
|
||||
share_datapoint_value_example[datapoint] = share_level_config.get(f"{datapoint}_value", [])
|
||||
|
||||
share_datapoint_value_example[datapoint] = share_level_config.get(
|
||||
f"{datapoint}_value", []
|
||||
)
|
||||
|
||||
share_datapoint_list = list(share_datapoint_value_example.keys())
|
||||
instructions.append(f"Example:\n")
|
||||
if len(share_datapoint_list) > 0:
|
||||
fund_name_example_list = share_level_config.get("fund_name", [])
|
||||
share_name_example_list = share_level_config.get("share_name", [])
|
||||
|
||||
for index in range(len(fund_name_example_list)):
|
||||
example_dict = {"fund name": fund_name_example_list[index],
|
||||
"share name": share_name_example_list[index]}
|
||||
example_dict = {
|
||||
"fund name": fund_name_example_list[index],
|
||||
"share name": share_name_example_list[index],
|
||||
}
|
||||
for share_datapoint in share_datapoint_list:
|
||||
share_datapoint_values = share_datapoint_value_example[share_datapoint]
|
||||
share_datapoint_values = share_datapoint_value_example[
|
||||
share_datapoint
|
||||
]
|
||||
if index < len(share_datapoint_values):
|
||||
example_dict[share_datapoint] = share_datapoint_values[index]
|
||||
instructions.append(f"Example {index + 1}:\n")
|
||||
instructions.append(json.dumps(example_dict, ensure_ascii=False))
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
|
||||
end_list = self.instructions_config.get("end", [])
|
||||
instructions.append('\n'.join(end_list))
|
||||
example_list.append(example_dict)
|
||||
example_data = {"data": example_list}
|
||||
instructions.append(json.dumps(example_data, ensure_ascii=False, indent=4))
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
|
||||
end_list = self.instructions_config.get("end", [])
|
||||
instructions.append("\n".join(end_list))
|
||||
instructions.append("\n")
|
||||
|
||||
if need_exclude and exclude_data is not None and isinstance(exclude_data, list):
|
||||
instructions.append("Please exclude below data from output:\n")
|
||||
instructions.append(json.dumps(exclude_data, ensure_ascii=False, indent=4))
|
||||
instructions.append("\n")
|
||||
instructions.append("\n")
|
||||
instructions.append("Answer:\n")
|
||||
|
||||
instructions_text = ''.join(instructions)
|
||||
return instructions_text
|
||||
|
||||
instructions_text = "".join(instructions)
|
||||
return instructions_text
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@
|
|||
},
|
||||
"end": [
|
||||
"Only output JSON data.",
|
||||
"Don't output the value which not exist in context, especiall for fund level datapoint: TOR.",
|
||||
"If can't find share class name in context, please output empty JSON data: []"
|
||||
"Don't output the value which not exist in context, especially for fund level datapoint: TOR.",
|
||||
"If can't find share class name in context, please output empty JSON data: {\"data\": []}"
|
||||
]
|
||||
}
|
||||
7
main.py
7
main.py
|
|
@ -335,15 +335,16 @@ if __name__ == "__main__":
|
|||
# )
|
||||
|
||||
# test_auto_generate_instructions()
|
||||
# doc_id = "294132333"
|
||||
# extract_data(doc_id, pdf_folder)
|
||||
|
||||
output_child_folder = r"/data/emea_ar/output/extract_data/docs/"
|
||||
output_total_folder = r"/data/emea_ar/output/extract_data/total/"
|
||||
re_run = False
|
||||
re_run = True
|
||||
batch_extract_data(pdf_folder,
|
||||
page_filter_ground_truth_file,
|
||||
output_child_folder,
|
||||
output_total_folder,
|
||||
special_doc_id_list,
|
||||
re_run)
|
||||
|
||||
# doc_id = "476492237"
|
||||
# extract_data(doc_id, pdf_folder, output_child_folder, re_run)
|
||||
|
|
|
|||
|
|
@ -726,13 +726,31 @@ def pickup_document_from_top_100_providers():
|
|||
top_100_provider_document_file, sheet_name="all_data"
|
||||
)
|
||||
|
||||
top_100_provider_document_fund_count = pd.read_excel(
|
||||
top_100_provider_document_file, sheet_name="doc_fund_count"
|
||||
)
|
||||
top_100_provider_document_fund_count.reset_index(drop=True, inplace=True)
|
||||
|
||||
top_100_provider_document_share_count = pd.read_excel(
|
||||
top_100_provider_document_file, sheet_name="doc_share_class_count"
|
||||
)
|
||||
top_100_provider_document_share_count = \
|
||||
top_100_provider_document_share_count[top_100_provider_document_share_count["with_ar_data"] == True]
|
||||
top_100_provider_document_share_count.reset_index(drop=True, inplace=True)
|
||||
|
||||
|
||||
top_100_provider_document_share_count = pd.merge(
|
||||
top_100_provider_document_share_count,
|
||||
top_100_provider_document_fund_count,
|
||||
on=["DocumentId"],
|
||||
how="left",
|
||||
)
|
||||
top_100_provider_document_share_count = top_100_provider_document_share_count[
|
||||
["DocumentId", "CompanyId_x", "CompanyName_x", "fund_count", "share_class_count"]
|
||||
]
|
||||
top_100_provider_document_share_count.rename(
|
||||
columns={"CompanyId_x": "CompanyId"}, inplace=True
|
||||
)
|
||||
|
||||
# add a new column with name share_count_rank to top_100_provider_document_share_count by merge with provider_share_count
|
||||
top_100_provider_document_share_count = pd.merge(
|
||||
top_100_provider_document_share_count,
|
||||
|
|
@ -742,12 +760,11 @@ def pickup_document_from_top_100_providers():
|
|||
)
|
||||
# Keep columns: DocumentId, CompanyId, CompanyName, share_class_count_x, share_count_rank
|
||||
top_100_provider_document_share_count = top_100_provider_document_share_count[
|
||||
["DocumentId", "CompanyId", "CompanyName_x", "share_class_count_x", "share_count_rank"]
|
||||
["DocumentId", "CompanyId", "CompanyName", "fund_count", "share_class_count_x", "share_count_rank"]
|
||||
]
|
||||
# rename column share_class_count_x to share_class_count
|
||||
top_100_provider_document_share_count.rename(
|
||||
columns={"share_class_count_x": "share_class_count",
|
||||
"CompanyName_x": "Company_Name",
|
||||
"share_count_rank": "provider_share_count_rank"}, inplace=True
|
||||
)
|
||||
top_100_provider_document_share_count = top_100_provider_document_share_count.sort_values(
|
||||
|
|
@ -833,8 +850,8 @@ if __name__ == "__main__":
|
|||
random_small_document_data_file = (
|
||||
r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx"
|
||||
)
|
||||
download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder)
|
||||
output_pdf_page_text(pdf_folder, output_folder)
|
||||
# download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder)
|
||||
# output_pdf_page_text(pdf_folder, output_folder)
|
||||
|
||||
# extract_pdf_table(pdf_folder, output_folder)
|
||||
# analyze_json_error()
|
||||
|
|
@ -846,4 +863,4 @@ if __name__ == "__main__":
|
|||
# output_folder=basic_info_folder,
|
||||
# )
|
||||
# statistics_document_fund_share_count(doc_mapping_from_top_100_provider_file)
|
||||
# pickup_document_from_top_100_providers()
|
||||
pickup_document_from_top_100_providers()
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ def chat(
|
|||
api_key=os.getenv("OPENAI_API_KEY_GPT4o"),
|
||||
api_version=os.getenv("OPENAI_API_VERSION_GPT4o"),
|
||||
temperature: float = 0.0,
|
||||
response_format: dict = None,
|
||||
image_file: str = None,
|
||||
image_base64: str = None,
|
||||
):
|
||||
|
|
@ -108,18 +109,32 @@ def chat(
|
|||
try:
|
||||
if count > 0:
|
||||
print(f"retrying the {count} time...")
|
||||
response = client.chat.completions.create(
|
||||
model=engine,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
timeout=request_timeout,
|
||||
stop=None,
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
if response_format is None:
|
||||
response = client.chat.completions.create(
|
||||
model=engine,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
timeout=request_timeout,
|
||||
stop=None,
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
# response_format={"type": "json_object"}
|
||||
response = client.chat.completions.create(
|
||||
model=engine,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
timeout=request_timeout,
|
||||
stop=None,
|
||||
messages=messages,
|
||||
response_format=response_format,
|
||||
)
|
||||
return response.choices[0].message.content, False
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
|
|
|
|||
Loading…
Reference in New Issue