dc-ml-emea-ar/core/metrics.py

238 lines
9.1 KiB
Python
Raw Normal View History

2024-09-03 22:07:53 +00:00
import os
import pandas as pd
import time
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from utils.logger import logger
class Metrics:
def __init__(
self,
data_type: str,
prediction_file: str,
prediction_sheet_name: str = "Sheet1",
ground_truth_file: str = None,
output_folder: str = None,
) -> None:
self.data_type = data_type
self.prediction_file = prediction_file
self.prediction_sheet_name = prediction_sheet_name
self.ground_truth_file = ground_truth_file
if output_folder is None or len(output_folder) == 0:
output_folder = r"/data/emea_ar/output/metrics/"
os.makedirs(output_folder, exist_ok=True)
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
self.output_file = os.path.join(
output_folder,
f"metrics_{data_type}_{time_stamp}.xlsx",
)
def get_metrics(self):
if (
self.prediction_file is None
or len(self.prediction_file) == 0
or not os.path.exists(self.prediction_file)
):
logger.error(f"Invalid prediction file: {self.prediction_file}")
return []
if (
self.ground_truth_file is None
or len(self.ground_truth_file) == 0
or not os.path.exists(self.ground_truth_file)
):
logger.error(f"Invalid ground truth file: {self.ground_truth_file}")
return []
metrics_list = [
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
]
if self.data_type == "page_filter":
missing_error_list, metrics_list = self.get_page_filter_metrics()
elif self.data_type == "datapoint":
missing_error_list, metrics_list = self.get_datapoint_metrics()
else:
logger.error(f"Invalid data type: {self.data_type}")
missing_error_df = pd.DataFrame(missing_error_list)
missing_error_df.reset_index(drop=True, inplace=True)
metrics_df = pd.DataFrame(metrics_list)
metrics_df.reset_index(drop=True, inplace=True)
with pd.ExcelWriter(self.output_file) as writer:
missing_error_df.to_excel(writer, sheet_name="Missing_Error", index=False)
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
return missing_error_list, metrics_list, self.output_file
def get_page_filter_metrics(self):
prediction_df = pd.read_excel(self.prediction_file, sheet_name=self.prediction_sheet_name)
ground_truth_df = pd.read_excel(self.ground_truth_file, sheet_name="Sheet1")
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
tor_true = []
tor_pred = []
ter_true = []
ter_pred = []
ogc_true = []
ogc_pred = []
performance_fee_true = []
performance_fee_pred = []
missing_error_list = []
data_point_list = ["tor", "ter", "ogc", "performance_fee"]
for index, row in ground_truth_df.iterrows():
doc_id = row["doc_id"]
# get first row with the same doc_id
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[0]
for data_point in data_point_list:
true_data, pred_data, missing_error_data = self.get_true_pred_data(
doc_id, row, prediction_data, data_point
)
if data_point == "tor":
tor_true.extend(true_data)
tor_pred.extend(pred_data)
elif data_point == "ter":
ter_true.extend(true_data)
ter_pred.extend(pred_data)
elif data_point == "ogc":
ogc_true.extend(true_data)
ogc_pred.extend(pred_data)
elif data_point == "performance_fee":
performance_fee_true.extend(true_data)
performance_fee_pred.extend(pred_data)
missing_error_list.append(missing_error_data)
metrics_list = []
for data_point in data_point_list:
if data_point == "tor":
precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(tor_true),
}
)
logger.info(
f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(tor_true)}"
)
elif data_point == "ter":
precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(ter_true),
}
)
logger.info(
f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(ter_true)}"
)
elif data_point == "ogc":
precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(ogc_true),
}
)
logger.info(
f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(ogc_true)}"
)
elif data_point == "performance_fee":
precision, recall, f1 = self.get_specific_metrics(
performance_fee_true, performance_fee_pred
)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(performance_fee_true),
}
)
logger.info(
f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(performance_fee_true)}"
)
# get average metrics
precision_list = [metric["Precision"] for metric in metrics_list]
recall_list = [metric["Recall"] for metric in metrics_list]
f1_list = [metric["F1"] for metric in metrics_list]
metrics_list.append(
{
"Data_Point": "Average",
"Precision": sum(precision_list) / len(precision_list),
"Recall": sum(recall_list) / len(recall_list),
"F1": sum(f1_list) / len(f1_list),
"Support": sum([metric["Support"] for metric in metrics_list]),
}
)
return missing_error_list, metrics_list
def get_true_pred_data(
self, doc_id, ground_truth_data: pd.Series, prediction_data: pd.Series, data_point: str
):
ground_truth_list = ground_truth_data[data_point]
if isinstance(ground_truth_list, str):
ground_truth_list = json.loads(ground_truth_list)
prediction_list = prediction_data[data_point]
if isinstance(prediction_list, str):
prediction_list = json.loads(prediction_list)
true_data = []
pred_data = []
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": [], "error": []}
missing_data = []
error_data = []
if len(ground_truth_list) == 0 and len(prediction_list) == 0:
true_data.append(1)
pred_data.append(1)
return true_data, pred_data, missing_error_data
for prediction in prediction_list:
if prediction in ground_truth_list:
true_data.append(1)
pred_data.append(1)
else:
true_data.append(0)
pred_data.append(1)
error_data.append(prediction)
for ground_truth in ground_truth_list:
if ground_truth not in prediction_list:
true_data.append(1)
pred_data.append(0)
missing_data.append(ground_truth)
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": missing_data, "error": error_data}
return true_data, pred_data, missing_error_data
def get_specific_metrics(self, true_data: list, pred_data: list):
precision = precision_score(true_data, pred_data)
recall = recall_score(true_data, pred_data)
f1 = f1_score(true_data, pred_data)
return precision, recall, f1
def get_datapoint_metrics(self):
pass