optimize metrics calculation algorithm
This commit is contained in:
parent
98e86a6cfd
commit
27b3540c63
|
|
@ -253,7 +253,7 @@ class DataExtraction:
|
|||
exclude_data: list) -> list:
|
||||
"""
|
||||
If occur error, split the context to two parts and try to get data from the two parts
|
||||
Relevant document: 503194284
|
||||
Relevant document: 503194284, page index 147
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class Metrics:
|
|||
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
||||
]
|
||||
|
||||
missing_error_list, metrics_list = self.get_metrics()
|
||||
missing_error_list, metrics_list = self.calculate_metrics()
|
||||
|
||||
missing_error_df = pd.DataFrame(missing_error_list)
|
||||
missing_error_df.reset_index(drop=True, inplace=True)
|
||||
|
|
@ -66,7 +66,7 @@ class Metrics:
|
|||
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
||||
return missing_error_list, metrics_list, self.output_file
|
||||
|
||||
def get_metrics(self):
|
||||
def calculate_metrics(self):
|
||||
prediction_df = pd.read_excel(
|
||||
self.prediction_file, sheet_name=self.prediction_sheet_name
|
||||
)
|
||||
|
|
@ -302,11 +302,13 @@ class Metrics:
|
|||
get_unique_words_text
|
||||
)
|
||||
ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist()
|
||||
ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist()
|
||||
# add new column to store unique words for dp_prediction
|
||||
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
|
||||
get_unique_words_text
|
||||
)
|
||||
pred_unique_words = dp_prediction["unique_words"].unique().tolist()
|
||||
pred_raw_names = dp_prediction["raw_name"].unique().tolist()
|
||||
|
||||
true_data = []
|
||||
pred_data = []
|
||||
|
|
@ -325,11 +327,18 @@ class Metrics:
|
|||
pred_data_point_value = prediction["value"]
|
||||
pred_investment_type = prediction["investment_type"]
|
||||
|
||||
if pred_unique_words in ground_truth_unique_words:
|
||||
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names
|
||||
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
|
||||
if pred_unique_words in ground_truth_unique_words or len(find_raw_name_in_gt) > 0:
|
||||
# get the ground truth data with the same unique words
|
||||
if pred_unique_words in ground_truth_unique_words:
|
||||
gt_data = dp_ground_truth[
|
||||
dp_ground_truth["unique_words"] == pred_unique_words
|
||||
].iloc[0]
|
||||
else:
|
||||
gt_data = dp_ground_truth[
|
||||
dp_ground_truth["raw_name"] == find_raw_name_in_gt[0]
|
||||
].iloc[0]
|
||||
gt_data_point_value = gt_data["value"]
|
||||
if pred_data_point_value == gt_data_point_value:
|
||||
true_data.append(1)
|
||||
|
|
@ -370,7 +379,11 @@ class Metrics:
|
|||
gt_data_point_value = ground_truth["value"]
|
||||
gt_investment_type = ground_truth["investment_type"]
|
||||
|
||||
if gt_unique_words not in pred_unique_words:
|
||||
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names
|
||||
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
|
||||
|
||||
if gt_unique_words not in pred_unique_words and \
|
||||
len(find_raw_name_in_pred) == 0:
|
||||
true_data.append(1)
|
||||
pred_data.append(0)
|
||||
error_data = {
|
||||
|
|
|
|||
50
main.py
50
main.py
|
|
@ -355,6 +355,7 @@ def get_metrics(
|
|||
prediction_file: str,
|
||||
prediction_sheet_name: str,
|
||||
ground_truth_file: str,
|
||||
ground_truth_sheet_name: str = None,
|
||||
output_folder: str = None,
|
||||
) -> None:
|
||||
metrics = Metrics(
|
||||
|
|
@ -362,6 +363,7 @@ def get_metrics(
|
|||
prediction_file=prediction_file,
|
||||
prediction_sheet_name=prediction_sheet_name,
|
||||
ground_truth_file=ground_truth_file,
|
||||
ground_truth_sheet_name=ground_truth_sheet_name,
|
||||
output_folder=output_folder,
|
||||
)
|
||||
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
|
||||
|
|
@ -472,6 +474,22 @@ def test_auto_generate_instructions():
|
|||
f.write(ogc_ter_performance_fee_instructions_text)
|
||||
|
||||
|
||||
def test_data_extraction_metrics():
|
||||
data_type = "data_extraction"
|
||||
prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx"
|
||||
prediction_sheet_name = "mapping_data"
|
||||
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx"
|
||||
ground_truth_sheet_name = "mapping_data"
|
||||
metrics_output_folder = r"/data/emea_ar/output/metrics/"
|
||||
missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||
data_type,
|
||||
prediction_file,
|
||||
prediction_sheet_name,
|
||||
ground_truth_file,
|
||||
ground_truth_sheet_name,
|
||||
metrics_output_folder
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pdf_folder = r"/data/emea_ar/small_pdf/"
|
||||
page_filter_ground_truth_file = (
|
||||
|
|
@ -506,23 +524,23 @@ if __name__ == "__main__":
|
|||
|
||||
# doc_id = "476492237"
|
||||
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
|
||||
special_doc_id_list = [
|
||||
"525574973",
|
||||
]
|
||||
special_doc_id_list = []
|
||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||
re_run_mapping_data = True
|
||||
|
||||
force_save_total_data = False
|
||||
batch_start_job(
|
||||
pdf_folder,
|
||||
page_filter_ground_truth_file,
|
||||
output_extract_data_child_folder,
|
||||
output_mapping_child_folder,
|
||||
output_extract_data_total_folder,
|
||||
output_mapping_total_folder,
|
||||
special_doc_id_list,
|
||||
re_run_extract_data,
|
||||
re_run_mapping_data,
|
||||
force_save_total_data=force_save_total_data,
|
||||
)
|
||||
force_save_total_data = True
|
||||
# batch_start_job(
|
||||
# pdf_folder,
|
||||
# page_filter_ground_truth_file,
|
||||
# output_extract_data_child_folder,
|
||||
# output_mapping_child_folder,
|
||||
# output_extract_data_total_folder,
|
||||
# output_mapping_total_folder,
|
||||
# special_doc_id_list,
|
||||
# re_run_extract_data,
|
||||
# re_run_mapping_data,
|
||||
# force_save_total_data=force_save_total_data,
|
||||
# )
|
||||
|
||||
test_data_extraction_metrics()
|
||||
|
|
|
|||
Loading…
Reference in New Issue