2024-12-04 22:57:52 +00:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
from glob import glob
|
|
|
|
|
import json
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import os
|
|
|
|
|
from traceback import print_exc
|
|
|
|
|
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
|
|
|
|
|
|
|
|
|
from utils.logger import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_complex_document_metrics(verify_file_path: str, document_list: list = []):
|
2024-12-10 22:17:47 +00:00
|
|
|
data_df_1 = pd.read_excel(verify_file_path, sheet_name="data_in_doc_mapping")
|
2024-12-04 22:57:52 +00:00
|
|
|
# convert doc_id column to string
|
2024-12-10 22:17:47 +00:00
|
|
|
data_df_1["doc_id"] = data_df_1["doc_id"].astype(str)
|
|
|
|
|
data_df_1 = data_df_1[data_df_1["raw_check"].isin([0, 1])]
|
|
|
|
|
|
|
|
|
|
exclude_documents = ["532422548"]
|
|
|
|
|
# remove data by doc_id not in exclude_documents
|
|
|
|
|
data_df_1 = data_df_1[~data_df_1["doc_id"].isin(exclude_documents)]
|
2024-12-04 22:57:52 +00:00
|
|
|
|
|
|
|
|
if document_list is not None and len(document_list) > 0:
|
2024-12-10 22:17:47 +00:00
|
|
|
data_df_1 = data_df_1[data_df_1["doc_id"].isin(document_list)]
|
|
|
|
|
|
|
|
|
|
data_df_2 = pd.read_excel(verify_file_path, sheet_name="total_mapping_data")
|
|
|
|
|
data_df_2["doc_id"] = data_df_2["doc_id"].astype(str)
|
|
|
|
|
data_df_2 = data_df_2[data_df_2["raw_check"].isin([0, 1])]
|
|
|
|
|
|
|
|
|
|
data_df = pd.concat([data_df_1, data_df_2], ignore_index=True)
|
2024-12-04 22:57:52 +00:00
|
|
|
|
|
|
|
|
data_df.fillna("", inplace=True)
|
|
|
|
|
data_df.reset_index(drop=True, inplace=True)
|
|
|
|
|
|
2024-12-10 22:17:47 +00:00
|
|
|
metrics_df_list = []
|
|
|
|
|
doc_id_list = data_df["doc_id"].unique().tolist()
|
|
|
|
|
for doc_id in tqdm(doc_id_list):
|
|
|
|
|
try:
|
|
|
|
|
document_data_df = data_df[data_df["doc_id"] == doc_id]
|
|
|
|
|
document_metrics_df = calc_metrics(document_data_df, doc_id)
|
|
|
|
|
metrics_df_list.append(document_metrics_df)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error when calculating metrics for document {doc_id}")
|
|
|
|
|
print_exc()
|
|
|
|
|
|
|
|
|
|
total_metrics_df = calc_metrics(data_df, doc_id=None)
|
|
|
|
|
metrics_df_list.append(total_metrics_df)
|
|
|
|
|
|
|
|
|
|
all_metrics_df = pd.concat(metrics_df_list, ignore_index=True)
|
|
|
|
|
all_metrics_df.reset_index(drop=True, inplace=True)
|
|
|
|
|
|
|
|
|
|
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
|
|
|
|
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
|
|
|
|
|
output_metrics_file = os.path.join(output_folder,
|
|
|
|
|
f"complex_{verify_file_name}_metrics_all.xlsx")
|
|
|
|
|
with pd.ExcelWriter(output_metrics_file) as writer:
|
|
|
|
|
all_metrics_df.to_excel(writer, index=False, sheet_name="metrics")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calc_metrics(data_df: pd.DataFrame, doc_id: str = None):
|
2024-12-04 22:57:52 +00:00
|
|
|
# tor data
|
|
|
|
|
tor_data_df = data_df[data_df["datapoint"] == "tor"]
|
2024-12-10 22:17:47 +00:00
|
|
|
if len(tor_data_df) > 0:
|
|
|
|
|
tor_metrics = get_sub_metrics(tor_data_df, "tor", doc_id)
|
|
|
|
|
logger.info(f"TOR metrics: {tor_metrics}")
|
|
|
|
|
else:
|
|
|
|
|
tor_metrics = None
|
2024-12-04 22:57:52 +00:00
|
|
|
|
|
|
|
|
# ter data
|
|
|
|
|
ter_data_df = data_df[data_df["datapoint"] == "ter"]
|
2024-12-10 22:17:47 +00:00
|
|
|
if len(ter_data_df) > 0:
|
|
|
|
|
ter_metrics = get_sub_metrics(ter_data_df, "ter", doc_id)
|
|
|
|
|
logger.info(f"TER metrics: {ter_metrics}")
|
|
|
|
|
else:
|
|
|
|
|
ter_metrics = None
|
2024-12-04 22:57:52 +00:00
|
|
|
|
|
|
|
|
# ogc data
|
|
|
|
|
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
|
2024-12-10 22:17:47 +00:00
|
|
|
if len(ogc_data_df) > 0:
|
|
|
|
|
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc", doc_id)
|
|
|
|
|
logger.info(f"OGC metrics: {ogc_metrics}")
|
|
|
|
|
else:
|
|
|
|
|
ogc_metrics = None
|
2024-12-04 22:57:52 +00:00
|
|
|
|
|
|
|
|
# performance_fee data
|
|
|
|
|
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
|
2024-12-10 22:17:47 +00:00
|
|
|
if len(performance_fee_data_df) > 0:
|
|
|
|
|
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee", doc_id)
|
|
|
|
|
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
|
|
|
|
|
else:
|
|
|
|
|
performance_fee_metrics = None
|
2024-12-04 22:57:52 +00:00
|
|
|
|
2024-12-10 22:17:47 +00:00
|
|
|
metrics_candidates = [tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics]
|
|
|
|
|
metrics_list = [metrics for metrics in metrics_candidates if metrics is not None]
|
|
|
|
|
metrics_df = pd.DataFrame(metrics_list)
|
2024-12-04 22:57:52 +00:00
|
|
|
# add average metrics
|
2024-12-10 22:17:47 +00:00
|
|
|
if doc_id is not None and len(doc_id) > 0:
|
|
|
|
|
avg_metrics = {
|
|
|
|
|
"DocumentId": doc_id,
|
|
|
|
|
"DataPoint": "average",
|
|
|
|
|
"F1": metrics_df["F1"].mean(),
|
|
|
|
|
"Precision": metrics_df["Precision"].mean(),
|
|
|
|
|
"Recall": metrics_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": metrics_df["Accuracy"].mean(),
|
|
|
|
|
"Support": metrics_df["Support"].sum()
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
avg_metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": "average",
|
|
|
|
|
"F1": metrics_df["F1"].mean(),
|
|
|
|
|
"Precision": metrics_df["Precision"].mean(),
|
|
|
|
|
"Recall": metrics_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": metrics_df["Accuracy"].mean(),
|
|
|
|
|
"Support": metrics_df["Support"].sum()
|
|
|
|
|
}
|
2024-12-06 20:50:34 +00:00
|
|
|
|
2024-12-10 22:17:47 +00:00
|
|
|
metrics_list.append(avg_metrics)
|
|
|
|
|
metrics_df = pd.DataFrame(metrics_list)
|
|
|
|
|
metrics_df.reset_index(drop=True, inplace=True)
|
|
|
|
|
return metrics_df
|
2024-12-04 22:57:52 +00:00
|
|
|
|
|
|
|
|
|
2024-12-10 22:17:47 +00:00
|
|
|
def get_sub_metrics(data_df: pd.DataFrame, data_point: str, doc_id: str = None) -> dict:
|
2024-12-04 22:57:52 +00:00
|
|
|
data_df_raw_check_1 = data_df[data_df["raw_check"] == 1]
|
|
|
|
|
gt_list = [1] * len(data_df_raw_check_1)
|
|
|
|
|
pre_list = [1] * len(data_df_raw_check_1)
|
|
|
|
|
|
|
|
|
|
data_df_raw_check_0 = data_df[data_df["raw_check"] == 0]
|
|
|
|
|
for index, row in data_df_raw_check_0.iterrows():
|
|
|
|
|
if row["raw_check_comment"] == "modify":
|
|
|
|
|
gt_list.append(0)
|
|
|
|
|
pre_list.append(1)
|
|
|
|
|
|
|
|
|
|
gt_list.append(1)
|
|
|
|
|
pre_list.append(0)
|
|
|
|
|
elif row["raw_check_comment"] == "incorrect":
|
|
|
|
|
gt_list.append(0)
|
|
|
|
|
pre_list.append(1)
|
|
|
|
|
elif row["raw_check_comment"] == "supplement":
|
|
|
|
|
gt_list.append(1)
|
|
|
|
|
pre_list.append(0)
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# calculate metrics
|
|
|
|
|
accuracy = accuracy_score(gt_list, pre_list)
|
|
|
|
|
precision = precision_score(gt_list, pre_list)
|
|
|
|
|
recall = recall_score(gt_list, pre_list)
|
|
|
|
|
f1 = f1_score(gt_list, pre_list)
|
|
|
|
|
support = sum(gt_list)
|
2024-12-10 22:17:47 +00:00
|
|
|
if doc_id is not None and len(doc_id) > 0:
|
|
|
|
|
metrics = {
|
|
|
|
|
"DocumentId": doc_id,
|
|
|
|
|
"DataPoint": data_point,
|
|
|
|
|
"F1": f1,
|
|
|
|
|
"Precision": precision,
|
|
|
|
|
"Recall": recall,
|
|
|
|
|
"Accuracy": accuracy,
|
|
|
|
|
"Support": support
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": data_point,
|
|
|
|
|
"F1": f1,
|
|
|
|
|
"Precision": precision,
|
|
|
|
|
"Recall": recall,
|
|
|
|
|
"Accuracy": accuracy,
|
|
|
|
|
"Support": support
|
|
|
|
|
}
|
2024-12-04 22:57:52 +00:00
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
|
2024-12-10 22:17:47 +00:00
|
|
|
def get_metrics_based_documents(metrics_file: str, document_list: list):
|
|
|
|
|
metrics_df = pd.read_excel(metrics_file, sheet_name="metrics")
|
|
|
|
|
metrics_df_list = []
|
|
|
|
|
for doc_id in tqdm(document_list):
|
|
|
|
|
try:
|
|
|
|
|
document_metrics_df = metrics_df[metrics_df["DocumentId"] == doc_id]
|
|
|
|
|
metrics_df_list.append(document_metrics_df)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Error when calculating metrics for document {doc_id}")
|
|
|
|
|
print_exc()
|
|
|
|
|
metrics_document_df = pd.concat(metrics_df_list, ignore_index=True)
|
|
|
|
|
|
|
|
|
|
stats_metrics_list = []
|
|
|
|
|
tor_df = metrics_document_df[metrics_document_df["DataPoint"] == "tor"]
|
|
|
|
|
if len(tor_df) > 0:
|
|
|
|
|
tor_metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": "tor",
|
|
|
|
|
"F1": tor_df["F1"].mean(),
|
|
|
|
|
"Precision": tor_df["Precision"].mean(),
|
|
|
|
|
"Recall": tor_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": tor_df["Accuracy"].mean(),
|
|
|
|
|
"Support": tor_df["Support"].sum()
|
|
|
|
|
}
|
|
|
|
|
stats_metrics_list.append(tor_metrics)
|
|
|
|
|
ter_df = metrics_document_df[metrics_document_df["DataPoint"] == "ter"]
|
|
|
|
|
if len(ter_df) > 0:
|
|
|
|
|
ter_metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": "ter",
|
|
|
|
|
"F1": ter_df["F1"].mean(),
|
|
|
|
|
"Precision": ter_df["Precision"].mean(),
|
|
|
|
|
"Recall": ter_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": ter_df["Accuracy"].mean(),
|
|
|
|
|
"Support": ter_df["Support"].sum()
|
|
|
|
|
}
|
|
|
|
|
stats_metrics_list.append(ter_metrics)
|
|
|
|
|
ogc_df = metrics_document_df[metrics_document_df["DataPoint"] == "ogc"]
|
|
|
|
|
if len(ogc_df) > 0:
|
|
|
|
|
ogc_metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": "ogc",
|
|
|
|
|
"F1": ogc_df["F1"].mean(),
|
|
|
|
|
"Precision": ogc_df["Precision"].mean(),
|
|
|
|
|
"Recall": ogc_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": ogc_df["Accuracy"].mean(),
|
|
|
|
|
"Support": ogc_df["Support"].sum()
|
|
|
|
|
}
|
|
|
|
|
stats_metrics_list.append(ogc_metrics)
|
|
|
|
|
performance_fee_df = metrics_document_df[metrics_document_df["DataPoint"] == "performance_fee"]
|
|
|
|
|
if len(performance_fee_df) > 0:
|
|
|
|
|
performance_fee_metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": "performance_fee",
|
|
|
|
|
"F1": performance_fee_df["F1"].mean(),
|
|
|
|
|
"Precision": performance_fee_df["Precision"].mean(),
|
|
|
|
|
"Recall": performance_fee_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": performance_fee_df["Accuracy"].mean(),
|
|
|
|
|
"Support": performance_fee_df["Support"].sum()
|
|
|
|
|
}
|
|
|
|
|
stats_metrics_list.append(performance_fee_metrics)
|
|
|
|
|
average_df = metrics_document_df[metrics_document_df["DataPoint"] == "average"]
|
|
|
|
|
if len(average_df) > 0:
|
|
|
|
|
avg_metrics = {
|
|
|
|
|
"DocumentId": "All",
|
|
|
|
|
"DataPoint": "average",
|
|
|
|
|
"F1": average_df["F1"].mean(),
|
|
|
|
|
"Precision": average_df["Precision"].mean(),
|
|
|
|
|
"Recall": average_df["Recall"].mean(),
|
|
|
|
|
"Accuracy": average_df["Accuracy"].mean(),
|
|
|
|
|
"Support": average_df["Support"].sum()
|
|
|
|
|
}
|
|
|
|
|
stats_metrics_list.append(avg_metrics)
|
|
|
|
|
|
|
|
|
|
stats_metrics_df = pd.DataFrame(stats_metrics_list)
|
|
|
|
|
metrics_df_list.append(stats_metrics_df)
|
|
|
|
|
all_metrics_df = pd.concat(metrics_df_list, ignore_index=True)
|
|
|
|
|
all_metrics_df.reset_index(drop=True, inplace=True)
|
|
|
|
|
|
|
|
|
|
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
|
|
|
|
verify_file_name = "complex_mapping_data_info_31_documents_by_text_second_round_metrics_remain_7.xlsx"
|
|
|
|
|
output_metrics_file = os.path.join(output_folder, verify_file_name)
|
|
|
|
|
with pd.ExcelWriter(output_metrics_file) as writer:
|
|
|
|
|
all_metrics_df.to_excel(writer, index=False, sheet_name="metrics")
|
|
|
|
|
|
|
|
|
|
return all_metrics_df
|
|
|
|
|
|
|
|
|
|
|
2024-12-04 22:57:52 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
|
file_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/complex/"
|
2024-12-06 20:50:34 +00:00
|
|
|
verify_file = "mapping_data_info_31_documents_by_text_second_round.xlsx"
|
2024-12-04 22:57:52 +00:00
|
|
|
verify_file_path = os.path.join(file_folder, verify_file)
|
|
|
|
|
calculate_complex_document_metrics(verify_file_path=verify_file_path,
|
2024-12-10 22:17:47 +00:00
|
|
|
document_list=None)
|
|
|
|
|
document_list = ["492029971",
|
|
|
|
|
"510300817",
|
|
|
|
|
"512745032",
|
|
|
|
|
"514213638",
|
|
|
|
|
"527525440",
|
|
|
|
|
"534535767"]
|
|
|
|
|
metrics_file = "complex_mapping_data_info_31_documents_by_text_second_round_metrics_all.xlsx"
|
|
|
|
|
metrics_file_path = os.path.join(file_folder, metrics_file)
|
|
|
|
|
# get_metrics_based_documents(metrics_file=metrics_file_path,
|
|
|
|
|
# document_list=document_list)
|