145 lines
4.9 KiB
Python
145 lines
4.9 KiB
Python
from tqdm import tqdm
|
|
from glob import glob
|
|
import json
|
|
import pandas as pd
|
|
import os
|
|
from traceback import print_exc
|
|
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
|
|
|
from utils.logger import logger
|
|
|
|
|
|
def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []):
|
|
data_df = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping")
|
|
# convert doc_id column to string
|
|
data_df["doc_id"] = data_df["doc_id"].astype(str)
|
|
data_df = data_df[data_df["raw_check"].isin([0, 1])]
|
|
|
|
if document_list is not None and len(document_list) > 0:
|
|
data_df = data_df[data_df["doc_id"].isin(document_list)]
|
|
|
|
data_df.fillna("", inplace=True)
|
|
data_df.reset_index(drop=True, inplace=True)
|
|
|
|
# tor data
|
|
tor_data_df = data_df[data_df["datapoint"] == "tor"]
|
|
tor_metrics = get_sub_metrics(tor_data_df, "tor")
|
|
logger.info(f"TOR metrics: {tor_metrics}")
|
|
|
|
# ter data
|
|
ter_data_df = data_df[data_df["datapoint"] == "ter"]
|
|
ter_metrics = get_sub_metrics(ter_data_df, "ter")
|
|
logger.info(f"TER metrics: {ter_metrics}")
|
|
|
|
# ogc data
|
|
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
|
|
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
|
|
logger.info(f"OGC metrics: {ogc_metrics}")
|
|
|
|
# performance_fee data
|
|
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
|
|
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
|
|
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
|
|
|
|
metrics_df = pd.DataFrame([tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics])
|
|
# add average metrics
|
|
avg_metrics = {
|
|
"DataPoint": "average",
|
|
"F1": metrics_df["F1"].mean(),
|
|
"Precision": metrics_df["Precision"].mean(),
|
|
"Recall": metrics_df["Recall"].mean(),
|
|
"Accuracy": metrics_df["Accuracy"].mean(),
|
|
"Support": metrics_df["Support"].sum()
|
|
}
|
|
|
|
metrics_df = pd.DataFrame([tor_metrics, ter_metrics,
|
|
ogc_metrics, performance_fee_metrics,
|
|
avg_metrics])
|
|
metrics_df.reset_index(drop=True, inplace=True)
|
|
|
|
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
|
|
|
document_count = len(document_list) \
|
|
if document_list is not None and len(document_list) > 0 \
|
|
else len(data_df["doc_id"].unique())
|
|
|
|
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
|
|
|
|
output_metrics_file = os.path.join(output_folder,
|
|
f"complex_{verify_file_name}_metrics.xlsx")
|
|
with pd.ExcelWriter(output_metrics_file) as writer:
|
|
metrics_df.to_excel(writer, index=False, sheet_name="metrics")
|
|
|
|
|
|
def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
|
|
data_df_raw_check_1 = data_df[data_df["raw_check"] == 1]
|
|
gt_list = [1] * len(data_df_raw_check_1)
|
|
pre_list = [1] * len(data_df_raw_check_1)
|
|
|
|
data_df_raw_check_0 = data_df[data_df["raw_check"] == 0]
|
|
for index, row in data_df_raw_check_0.iterrows():
|
|
if row["raw_check_comment"] == "modify":
|
|
gt_list.append(0)
|
|
pre_list.append(1)
|
|
|
|
gt_list.append(1)
|
|
pre_list.append(0)
|
|
elif row["raw_check_comment"] == "incorrect":
|
|
gt_list.append(0)
|
|
pre_list.append(1)
|
|
elif row["raw_check_comment"] == "supplement":
|
|
gt_list.append(1)
|
|
pre_list.append(0)
|
|
else:
|
|
pass
|
|
|
|
# calculate metrics
|
|
accuracy = accuracy_score(gt_list, pre_list)
|
|
precision = precision_score(gt_list, pre_list)
|
|
recall = recall_score(gt_list, pre_list)
|
|
f1 = f1_score(gt_list, pre_list)
|
|
support = sum(gt_list)
|
|
|
|
metrics = {
|
|
"DataPoint": data_point,
|
|
"F1": f1,
|
|
"Precision": precision,
|
|
"Recall": recall,
|
|
"Accuracy": accuracy,
|
|
"Support": support
|
|
}
|
|
return metrics
|
|
|
|
|
|
if __name__ == "__main__":
|
|
file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
|
verify_file = "mapping_data_info_31_documents_by_text_second_round.xlsx"
|
|
verify_file_path = os.path.join(file_folder, verify_file)
|
|
document_list = [
|
|
"334584772",
|
|
"337293427",
|
|
"337937633",
|
|
"404712928",
|
|
"406913630",
|
|
"407275419",
|
|
"422686965",
|
|
"422760148",
|
|
"422760156",
|
|
"422761666",
|
|
"423364758",
|
|
"423365707",
|
|
"423395975",
|
|
"423418395",
|
|
"423418540",
|
|
"425595958",
|
|
"451063582",
|
|
"451878128",
|
|
"466580448",
|
|
"481482392",
|
|
"508704368",
|
|
"532998065",
|
|
"536344026",
|
|
"540307575"
|
|
]
|
|
calculate_complex_document_metrics(verify_file_path=verify_file_path,
|
|
document_list=document_list) |