only calculate metrics for intersection document list

This commit is contained in:
Blade He 2024-09-19 11:54:51 -05:00
parent 27b3540c63
commit 67371e534e
2 changed files with 17 additions and 3 deletions

View File

@ -119,7 +119,8 @@ class Metrics:
else:
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
doc_id_list = list(set(prediction_doc_id_list + ground_truth_doc_id_list))
# get intersection of doc_id_list
doc_id_list = list(set(prediction_doc_id_list) & set(ground_truth_doc_id_list))
# order by doc_id
doc_id_list.sort()

17
main.py
View File

@ -285,6 +285,19 @@ def batch_start_job(
writer, index=False, sheet_name="extract_data"
)
prediction_sheet_name = "mapping_data"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/"
missing_error_list, metrics_list, metrics_file = get_metrics(
"data_extraction",
output_file,
prediction_sheet_name,
ground_truth_file,
ground_truth_sheet_name,
metrics_output_folder
)
def batch_filter_pdf_files(
pdf_folder: str,
@ -476,9 +489,9 @@ def test_auto_generate_instructions():
def test_data_extraction_metrics():
data_type = "data_extraction"
prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx"
prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240917121708.xlsx"
prediction_sheet_name = "mapping_data"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/"
missing_error_list, metrics_list, metrics_file = get_metrics(