optimize metrics calculation algorithm
This commit is contained in:
parent
98e86a6cfd
commit
27b3540c63
|
|
@ -253,7 +253,7 @@ class DataExtraction:
|
||||||
exclude_data: list) -> list:
|
exclude_data: list) -> list:
|
||||||
"""
|
"""
|
||||||
If occur error, split the context to two parts and try to get data from the two parts
|
If occur error, split the context to two parts and try to get data from the two parts
|
||||||
Relevant document: 503194284
|
Relevant document: 503194284, page index 147
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
logger.info(f"Split context to get data to fix issue which output length is over 4K tokens")
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ class Metrics:
|
||||||
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
||||||
]
|
]
|
||||||
|
|
||||||
missing_error_list, metrics_list = self.get_metrics()
|
missing_error_list, metrics_list = self.calculate_metrics()
|
||||||
|
|
||||||
missing_error_df = pd.DataFrame(missing_error_list)
|
missing_error_df = pd.DataFrame(missing_error_list)
|
||||||
missing_error_df.reset_index(drop=True, inplace=True)
|
missing_error_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
@ -66,7 +66,7 @@ class Metrics:
|
||||||
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
||||||
return missing_error_list, metrics_list, self.output_file
|
return missing_error_list, metrics_list, self.output_file
|
||||||
|
|
||||||
def get_metrics(self):
|
def calculate_metrics(self):
|
||||||
prediction_df = pd.read_excel(
|
prediction_df = pd.read_excel(
|
||||||
self.prediction_file, sheet_name=self.prediction_sheet_name
|
self.prediction_file, sheet_name=self.prediction_sheet_name
|
||||||
)
|
)
|
||||||
|
|
@ -302,11 +302,13 @@ class Metrics:
|
||||||
get_unique_words_text
|
get_unique_words_text
|
||||||
)
|
)
|
||||||
ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist()
|
ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist()
|
||||||
|
ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist()
|
||||||
# add new column to store unique words for dp_prediction
|
# add new column to store unique words for dp_prediction
|
||||||
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
|
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
|
||||||
get_unique_words_text
|
get_unique_words_text
|
||||||
)
|
)
|
||||||
pred_unique_words = dp_prediction["unique_words"].unique().tolist()
|
pred_unique_words = dp_prediction["unique_words"].unique().tolist()
|
||||||
|
pred_raw_names = dp_prediction["raw_name"].unique().tolist()
|
||||||
|
|
||||||
true_data = []
|
true_data = []
|
||||||
pred_data = []
|
pred_data = []
|
||||||
|
|
@ -325,11 +327,18 @@ class Metrics:
|
||||||
pred_data_point_value = prediction["value"]
|
pred_data_point_value = prediction["value"]
|
||||||
pred_investment_type = prediction["investment_type"]
|
pred_investment_type = prediction["investment_type"]
|
||||||
|
|
||||||
if pred_unique_words in ground_truth_unique_words:
|
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names
|
||||||
|
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
|
||||||
|
if pred_unique_words in ground_truth_unique_words or len(find_raw_name_in_gt) > 0:
|
||||||
# get the ground truth data with the same unique words
|
# get the ground truth data with the same unique words
|
||||||
gt_data = dp_ground_truth[
|
if pred_unique_words in ground_truth_unique_words:
|
||||||
dp_ground_truth["unique_words"] == pred_unique_words
|
gt_data = dp_ground_truth[
|
||||||
].iloc[0]
|
dp_ground_truth["unique_words"] == pred_unique_words
|
||||||
|
].iloc[0]
|
||||||
|
else:
|
||||||
|
gt_data = dp_ground_truth[
|
||||||
|
dp_ground_truth["raw_name"] == find_raw_name_in_gt[0]
|
||||||
|
].iloc[0]
|
||||||
gt_data_point_value = gt_data["value"]
|
gt_data_point_value = gt_data["value"]
|
||||||
if pred_data_point_value == gt_data_point_value:
|
if pred_data_point_value == gt_data_point_value:
|
||||||
true_data.append(1)
|
true_data.append(1)
|
||||||
|
|
@ -370,7 +379,11 @@ class Metrics:
|
||||||
gt_data_point_value = ground_truth["value"]
|
gt_data_point_value = ground_truth["value"]
|
||||||
gt_investment_type = ground_truth["investment_type"]
|
gt_investment_type = ground_truth["investment_type"]
|
||||||
|
|
||||||
if gt_unique_words not in pred_unique_words:
|
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names
|
||||||
|
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
|
||||||
|
|
||||||
|
if gt_unique_words not in pred_unique_words and \
|
||||||
|
len(find_raw_name_in_pred) == 0:
|
||||||
true_data.append(1)
|
true_data.append(1)
|
||||||
pred_data.append(0)
|
pred_data.append(0)
|
||||||
error_data = {
|
error_data = {
|
||||||
|
|
|
||||||
50
main.py
50
main.py
|
|
@ -355,6 +355,7 @@ def get_metrics(
|
||||||
prediction_file: str,
|
prediction_file: str,
|
||||||
prediction_sheet_name: str,
|
prediction_sheet_name: str,
|
||||||
ground_truth_file: str,
|
ground_truth_file: str,
|
||||||
|
ground_truth_sheet_name: str = None,
|
||||||
output_folder: str = None,
|
output_folder: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
metrics = Metrics(
|
metrics = Metrics(
|
||||||
|
|
@ -362,6 +363,7 @@ def get_metrics(
|
||||||
prediction_file=prediction_file,
|
prediction_file=prediction_file,
|
||||||
prediction_sheet_name=prediction_sheet_name,
|
prediction_sheet_name=prediction_sheet_name,
|
||||||
ground_truth_file=ground_truth_file,
|
ground_truth_file=ground_truth_file,
|
||||||
|
ground_truth_sheet_name=ground_truth_sheet_name,
|
||||||
output_folder=output_folder,
|
output_folder=output_folder,
|
||||||
)
|
)
|
||||||
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
|
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
|
||||||
|
|
@ -472,6 +474,22 @@ def test_auto_generate_instructions():
|
||||||
f.write(ogc_ter_performance_fee_instructions_text)
|
f.write(ogc_ter_performance_fee_instructions_text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_extraction_metrics():
|
||||||
|
data_type = "data_extraction"
|
||||||
|
prediction_file = r"/data/emea_ar/output/mapping_data/docs/excel/292989214.xlsx"
|
||||||
|
prediction_sheet_name = "mapping_data"
|
||||||
|
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_292989214.xlsx"
|
||||||
|
ground_truth_sheet_name = "mapping_data"
|
||||||
|
metrics_output_folder = r"/data/emea_ar/output/metrics/"
|
||||||
|
missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||||
|
data_type,
|
||||||
|
prediction_file,
|
||||||
|
prediction_sheet_name,
|
||||||
|
ground_truth_file,
|
||||||
|
ground_truth_sheet_name,
|
||||||
|
metrics_output_folder
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pdf_folder = r"/data/emea_ar/small_pdf/"
|
pdf_folder = r"/data/emea_ar/small_pdf/"
|
||||||
page_filter_ground_truth_file = (
|
page_filter_ground_truth_file = (
|
||||||
|
|
@ -506,23 +524,23 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# doc_id = "476492237"
|
# doc_id = "476492237"
|
||||||
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
|
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
|
||||||
special_doc_id_list = [
|
special_doc_id_list = []
|
||||||
"525574973",
|
|
||||||
]
|
|
||||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||||
re_run_mapping_data = True
|
re_run_mapping_data = True
|
||||||
|
|
||||||
force_save_total_data = False
|
force_save_total_data = True
|
||||||
batch_start_job(
|
# batch_start_job(
|
||||||
pdf_folder,
|
# pdf_folder,
|
||||||
page_filter_ground_truth_file,
|
# page_filter_ground_truth_file,
|
||||||
output_extract_data_child_folder,
|
# output_extract_data_child_folder,
|
||||||
output_mapping_child_folder,
|
# output_mapping_child_folder,
|
||||||
output_extract_data_total_folder,
|
# output_extract_data_total_folder,
|
||||||
output_mapping_total_folder,
|
# output_mapping_total_folder,
|
||||||
special_doc_id_list,
|
# special_doc_id_list,
|
||||||
re_run_extract_data,
|
# re_run_extract_data,
|
||||||
re_run_mapping_data,
|
# re_run_mapping_data,
|
||||||
force_save_total_data=force_save_total_data,
|
# force_save_total_data=force_save_total_data,
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
test_data_extraction_metrics()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue