optimize metrics calculation algorithm

This commit is contained in:
Blade He 2024-09-19 11:44:17 -05:00
parent 98e86a6cfd
commit 27b3540c63
3 changed files with 55 additions and 24 deletions

View File

@ -253,7 +253,7 @@ class DataExtraction:
exclude_data: list) -> list:
"""
If occur error, split the context to two parts and try to get data from the two parts
Relevant document: 503194284
Relevant document: 503194284, page index 147
"""
try:
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")

View File

@ -53,7 +53,7 @@ class Metrics:
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
]
missing_error_list, metrics_list = self.get_metrics()
missing_error_list, metrics_list = self.calculate_metrics()
missing_error_df = pd.DataFrame(missing_error_list)
missing_error_df.reset_index(drop=True, inplace=True)
@ -66,7 +66,7 @@ class Metrics:
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
return missing_error_list, metrics_list, self.output_file
def get_metrics(self):
def calculate_metrics(self):
prediction_df = pd.read_excel(
self.prediction_file, sheet_name=self.prediction_sheet_name
)
@ -302,11 +302,13 @@ class Metrics:
get_unique_words_text
)
ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist()
ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist()
# add new column to store unique words for dp_prediction
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
get_unique_words_text
)
pred_unique_words = dp_prediction["unique_words"].unique().tolist()
pred_raw_names = dp_prediction["raw_name"].unique().tolist()
true_data = []
pred_data = []
@ -325,11 +327,18 @@ class Metrics:
pred_data_point_value = prediction["value"]
pred_investment_type = prediction["investment_type"]
if pred_unique_words in ground_truth_unique_words:
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
if pred_unique_words in ground_truth_unique_words or len(find_raw_name_in_gt) > 0:
# get the ground truth data with the same unique words
gt_data = dp_ground_truth[
dp_ground_truth["unique_words"] == pred_unique_words
].iloc[0]
if pred_unique_words in ground_truth_unique_words:
gt_data = dp_ground_truth[
dp_ground_truth["unique_words"] == pred_unique_words
].iloc[0]
else:
gt_data = dp_ground_truth[
dp_ground_truth["raw_name"] == find_raw_name_in_gt[0]
].iloc[0]
gt_data_point_value = gt_data["value"]
if pred_data_point_value == gt_data_point_value:
true_data.append(1)
@ -370,7 +379,11 @@ class Metrics:
gt_data_point_value = ground_truth["value"]
gt_investment_type = ground_truth["investment_type"]
if gt_unique_words not in pred_unique_words:
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
if gt_unique_words not in pred_unique_words and \
len(find_raw_name_in_pred) == 0:
true_data.append(1)
pred_data.append(0)
error_data = {

50
main.py
View File

@ -355,6 +355,7 @@ def get_metrics(
prediction_file: str,
prediction_sheet_name: str,
ground_truth_file: str,
ground_truth_sheet_name: str = None,
output_folder: str = None,
) -> None:
metrics = Metrics(
@ -362,6 +363,7 @@ def get_metrics(
prediction_file=prediction_file,
prediction_sheet_name=prediction_sheet_name,
ground_truth_file=ground_truth_file,
ground_truth_sheet_name=ground_truth_sheet_name,
output_folder=output_folder,
)
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
@ -472,6 +474,22 @@ def test_auto_generate_instructions():
f.write(ogc_ter_performance_fee_instructions_text)
def test_data_extraction_metrics():
data_type = "data_extraction"
prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx"
prediction_sheet_name = "mapping_data"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx"
ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/"
missing_error_list, metrics_list, metrics_file = get_metrics(
data_type,
prediction_file,
prediction_sheet_name,
ground_truth_file,
ground_truth_sheet_name,
metrics_output_folder
)
if __name__ == "__main__":
pdf_folder = r"/data/emea_ar/small_pdf/"
page_filter_ground_truth_file = (
@ -506,23 +524,23 @@ if __name__ == "__main__":
# doc_id = "476492237"
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
special_doc_id_list = [
"525574973",
]
special_doc_id_list = []
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_mapping_data = True
force_save_total_data = False
batch_start_job(
pdf_folder,
page_filter_ground_truth_file,
output_extract_data_child_folder,
output_mapping_child_folder,
output_extract_data_total_folder,
output_mapping_total_folder,
special_doc_id_list,
re_run_extract_data,
re_run_mapping_data,
force_save_total_data=force_save_total_data,
)
force_save_total_data = True
# batch_start_job(
# pdf_folder,
# page_filter_ground_truth_file,
# output_extract_data_child_folder,
# output_mapping_child_folder,
# output_extract_data_total_folder,
# output_mapping_total_folder,
# special_doc_id_list,
# re_run_extract_data,
# re_run_mapping_data,
# force_save_total_data=force_save_total_data,
# )
test_data_extraction_metrics()