optimize metrics calculation algorithm

This commit is contained in:
Blade He 2024-09-19 11:44:17 -05:00
parent 98e86a6cfd
commit 27b3540c63
3 changed files with 55 additions and 24 deletions

View File

@ -253,7 +253,7 @@ class DataExtraction:
exclude_data: list) -> list: exclude_data: list) -> list:
""" """
If occur error, split the context to two parts and try to get data from the two parts If occur error, split the context to two parts and try to get data from the two parts
Relevant document: 503194284 Relevant document: 503194284, page index 147
""" """
try: try:
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens") logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")

View File

@ -53,7 +53,7 @@ class Metrics:
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0} {"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
] ]
missing_error_list, metrics_list = self.get_metrics() missing_error_list, metrics_list = self.calculate_metrics()
missing_error_df = pd.DataFrame(missing_error_list) missing_error_df = pd.DataFrame(missing_error_list)
missing_error_df.reset_index(drop=True, inplace=True) missing_error_df.reset_index(drop=True, inplace=True)
@ -66,7 +66,7 @@ class Metrics:
metrics_df.to_excel(writer, sheet_name="Metrics", index=False) metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
return missing_error_list, metrics_list, self.output_file return missing_error_list, metrics_list, self.output_file
def get_metrics(self): def calculate_metrics(self):
prediction_df = pd.read_excel( prediction_df = pd.read_excel(
self.prediction_file, sheet_name=self.prediction_sheet_name self.prediction_file, sheet_name=self.prediction_sheet_name
) )
@ -302,11 +302,13 @@ class Metrics:
get_unique_words_text get_unique_words_text
) )
ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist() ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist()
ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist()
# add new column to store unique words for dp_prediction # add new column to store unique words for dp_prediction
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply( dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
get_unique_words_text get_unique_words_text
) )
pred_unique_words = dp_prediction["unique_words"].unique().tolist() pred_unique_words = dp_prediction["unique_words"].unique().tolist()
pred_raw_names = dp_prediction["raw_name"].unique().tolist()
true_data = [] true_data = []
pred_data = [] pred_data = []
@ -325,11 +327,18 @@ class Metrics:
pred_data_point_value = prediction["value"] pred_data_point_value = prediction["value"]
pred_investment_type = prediction["investment_type"] pred_investment_type = prediction["investment_type"]
if pred_unique_words in ground_truth_unique_words: find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
if pred_unique_words in ground_truth_unique_words or len(find_raw_name_in_gt) > 0:
# get the ground truth data with the same unique words # get the ground truth data with the same unique words
gt_data = dp_ground_truth[ if pred_unique_words in ground_truth_unique_words:
dp_ground_truth["unique_words"] == pred_unique_words gt_data = dp_ground_truth[
].iloc[0] dp_ground_truth["unique_words"] == pred_unique_words
].iloc[0]
else:
gt_data = dp_ground_truth[
dp_ground_truth["raw_name"] == find_raw_name_in_gt[0]
].iloc[0]
gt_data_point_value = gt_data["value"] gt_data_point_value = gt_data["value"]
if pred_data_point_value == gt_data_point_value: if pred_data_point_value == gt_data_point_value:
true_data.append(1) true_data.append(1)
@ -370,7 +379,11 @@ class Metrics:
gt_data_point_value = ground_truth["value"] gt_data_point_value = ground_truth["value"]
gt_investment_type = ground_truth["investment_type"] gt_investment_type = ground_truth["investment_type"]
if gt_unique_words not in pred_unique_words: find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
if gt_unique_words not in pred_unique_words and \
len(find_raw_name_in_pred) == 0:
true_data.append(1) true_data.append(1)
pred_data.append(0) pred_data.append(0)
error_data = { error_data = {

50
main.py
View File

@ -355,6 +355,7 @@ def get_metrics(
prediction_file: str, prediction_file: str,
prediction_sheet_name: str, prediction_sheet_name: str,
ground_truth_file: str, ground_truth_file: str,
ground_truth_sheet_name: str = None,
output_folder: str = None, output_folder: str = None,
) -> None: ) -> None:
metrics = Metrics( metrics = Metrics(
@ -362,6 +363,7 @@ def get_metrics(
prediction_file=prediction_file, prediction_file=prediction_file,
prediction_sheet_name=prediction_sheet_name, prediction_sheet_name=prediction_sheet_name,
ground_truth_file=ground_truth_file, ground_truth_file=ground_truth_file,
ground_truth_sheet_name=ground_truth_sheet_name,
output_folder=output_folder, output_folder=output_folder,
) )
missing_error_list, metrics_list, metrics_file = metrics.get_metrics() missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
@ -472,6 +474,22 @@ def test_auto_generate_instructions():
f.write(ogc_ter_performance_fee_instructions_text) f.write(ogc_ter_performance_fee_instructions_text)
def test_data_extraction_metrics():
data_type = "data_extraction"
prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx"
prediction_sheet_name = "mapping_data"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx"
ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/"
missing_error_list, metrics_list, metrics_file = get_metrics(
data_type,
prediction_file,
prediction_sheet_name,
ground_truth_file,
ground_truth_sheet_name,
metrics_output_folder
)
if __name__ == "__main__": if __name__ == "__main__":
pdf_folder = r"/data/emea_ar/small_pdf/" pdf_folder = r"/data/emea_ar/small_pdf/"
page_filter_ground_truth_file = ( page_filter_ground_truth_file = (
@ -506,23 +524,23 @@ if __name__ == "__main__":
# doc_id = "476492237" # doc_id = "476492237"
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run) # extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
special_doc_id_list = [ special_doc_id_list = []
"525574973",
]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_mapping_data = True re_run_mapping_data = True
force_save_total_data = False force_save_total_data = True
batch_start_job( # batch_start_job(
pdf_folder, # pdf_folder,
page_filter_ground_truth_file, # page_filter_ground_truth_file,
output_extract_data_child_folder, # output_extract_data_child_folder,
output_mapping_child_folder, # output_mapping_child_folder,
output_extract_data_total_folder, # output_extract_data_total_folder,
output_mapping_total_folder, # output_mapping_total_folder,
special_doc_id_list, # special_doc_id_list,
re_run_extract_data, # re_run_extract_data,
re_run_mapping_data, # re_run_mapping_data,
force_save_total_data=force_save_total_data, # force_save_total_data=force_save_total_data,
) # )
test_data_extraction_metrics()