diff --git a/core/metrics.py b/core/metrics.py index 8bbd539..895f587 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -119,7 +119,8 @@ class Metrics: else: prediction_doc_id_list = prediction_df["doc_id"].unique().tolist() ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist() - doc_id_list = list(set(prediction_doc_id_list + ground_truth_doc_id_list)) + # get intersection of doc_id_list + doc_id_list = list(set(prediction_doc_id_list) & set(ground_truth_doc_id_list)) # order by doc_id doc_id_list.sort() diff --git a/main.py b/main.py index e71e647..125cd57 100644 --- a/main.py +++ b/main.py @@ -284,6 +284,19 @@ def batch_start_job( result_extract_data_df.to_excel( writer, index=False, sheet_name="extract_data" ) + + prediction_sheet_name = "mapping_data" + ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" + ground_truth_sheet_name = "mapping_data" + metrics_output_folder = r"/data/emea_ar/output/metrics/" + missing_error_list, metrics_list, metrics_file = get_metrics( + "data_extraction", + output_file, + prediction_sheet_name, + ground_truth_file, + ground_truth_sheet_name, + metrics_output_folder + ) def batch_filter_pdf_files( @@ -476,9 +489,9 @@ def test_auto_generate_instructions(): def test_data_extraction_metrics(): data_type = "data_extraction" - prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx" + prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240917121708.xlsx" prediction_sheet_name = "mapping_data" - ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx" + ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" ground_truth_sheet_name = "mapping_data" metrics_output_folder = r"/data/emea_ar/output/metrics/" missing_error_list, metrics_list, metrics_file = get_metrics(