959 lines
43 KiB
Python
959 lines
43 KiB
Python
import os
|
|
import pandas as pd
|
|
import time
|
|
import json
|
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
|
from utils.biz_utils import get_unique_words_text, get_beginning_common_words, remove_special_characters, simple_most_similarity_name
|
|
from utils.sql_query_util import query_document_fund_mapping
|
|
from utils.logger import logger
|
|
|
|
|
|
class Metrics:
|
|
def __init__(
|
|
self,
|
|
data_type: str,
|
|
prediction_file: str,
|
|
prediction_sheet_name: str = "Sheet1",
|
|
ground_truth_file: str = None,
|
|
ground_truth_sheet_name: str = "Sheet1",
|
|
output_folder: str = None,
|
|
) -> None:
|
|
self.data_type = data_type
|
|
self.prediction_file = prediction_file
|
|
self.prediction_sheet_name = prediction_sheet_name
|
|
self.ground_truth_file = ground_truth_file
|
|
self.ground_truth_sheet_name = ground_truth_sheet_name
|
|
|
|
if output_folder is None or len(output_folder) == 0:
|
|
output_folder = r"/data/emea_ar/output/metrics/"
|
|
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
self.output_file = os.path.join(
|
|
output_folder,
|
|
f"metrics_{data_type}_{time_stamp}.xlsx",
|
|
)
|
|
|
|
def get_metrics(self, strict_model: bool = False):
|
|
if (
|
|
self.prediction_file is None
|
|
or len(self.prediction_file) == 0
|
|
or not os.path.exists(self.prediction_file)
|
|
):
|
|
logger.error(f"Invalid prediction file: {self.prediction_file}")
|
|
return []
|
|
if (
|
|
self.ground_truth_file is None
|
|
or len(self.ground_truth_file) == 0
|
|
or not os.path.exists(self.ground_truth_file)
|
|
):
|
|
logger.error(f"Invalid ground truth file: {self.ground_truth_file}")
|
|
return []
|
|
|
|
metrics_list = [
|
|
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
|
]
|
|
|
|
missing_error_list, metrics_list = self.calculate_metrics(strict_model=strict_model)
|
|
|
|
missing_error_df = pd.DataFrame(missing_error_list)
|
|
missing_error_df.reset_index(drop=True, inplace=True)
|
|
|
|
metrics_df = pd.DataFrame(metrics_list)
|
|
metrics_df.reset_index(drop=True, inplace=True)
|
|
|
|
with pd.ExcelWriter(self.output_file) as writer:
|
|
missing_error_df.to_excel(writer, sheet_name="Missing_Error", index=False)
|
|
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
|
return missing_error_list, metrics_list, self.output_file
|
|
|
|
def calculate_metrics(self, strict_model: bool = False):
|
|
prediction_df = pd.read_excel(
|
|
self.prediction_file, sheet_name=self.prediction_sheet_name
|
|
)
|
|
ground_truth_df = pd.read_excel(
|
|
self.ground_truth_file, sheet_name=self.ground_truth_sheet_name
|
|
)
|
|
if self.data_type == "page_filter":
|
|
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
|
|
elif self.data_type == "data_extraction":
|
|
ground_truth_df = ground_truth_df[ground_truth_df["rawname_checked"] == 1]
|
|
elif self.data_type in ["investment_mapping", "document_mapping_in_db"]:
|
|
ground_truth_df = ground_truth_df[ground_truth_df["mapping_checked"] == 1]
|
|
else:
|
|
logger.error(f"Invalid data type: {self.data_type}")
|
|
return [], []
|
|
|
|
tor_true = []
|
|
tor_pred = []
|
|
|
|
ter_true = []
|
|
ter_pred = []
|
|
|
|
ogc_true = []
|
|
ogc_pred = []
|
|
|
|
performance_fee_true = []
|
|
performance_fee_pred = []
|
|
|
|
investment_mapping_true = []
|
|
investment_mapping_pred = []
|
|
|
|
missing_error_list = []
|
|
data_point_list = ["tor", "ter", "ogc", "performance_fee"]
|
|
|
|
if self.data_type == "page_filter":
|
|
for index, row in ground_truth_df.iterrows():
|
|
doc_id = row["doc_id"]
|
|
# get first row with the same doc_id
|
|
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[
|
|
0
|
|
]
|
|
for data_point in data_point_list:
|
|
true_data, pred_data, missing_error_data = (
|
|
self.get_page_filter_true_pred_data(
|
|
doc_id, row, prediction_data, data_point
|
|
)
|
|
)
|
|
if data_point == "tor":
|
|
tor_true.extend(true_data)
|
|
tor_pred.extend(pred_data)
|
|
elif data_point == "ter":
|
|
ter_true.extend(true_data)
|
|
ter_pred.extend(pred_data)
|
|
elif data_point == "ogc":
|
|
ogc_true.extend(true_data)
|
|
ogc_pred.extend(pred_data)
|
|
elif data_point == "performance_fee":
|
|
performance_fee_true.extend(true_data)
|
|
performance_fee_pred.extend(pred_data)
|
|
missing_error_list.append(missing_error_data)
|
|
elif self.data_type == "data_extraction":
|
|
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
|
|
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
|
|
# get intersection of doc_id_list
|
|
doc_id_list = list(
|
|
set(prediction_doc_id_list) & set(ground_truth_doc_id_list)
|
|
)
|
|
# order by doc_id
|
|
doc_id_list.sort()
|
|
|
|
for doc_id in doc_id_list:
|
|
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id]
|
|
ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id]
|
|
for data_point in data_point_list:
|
|
true_data, pred_data, missing_error_data = (
|
|
self.get_data_extraction_true_pred_data(
|
|
doc_id, ground_truth_data, prediction_data, data_point
|
|
)
|
|
)
|
|
if data_point == "tor":
|
|
tor_true.extend(true_data)
|
|
tor_pred.extend(pred_data)
|
|
elif data_point == "ter":
|
|
ter_true.extend(true_data)
|
|
ter_pred.extend(pred_data)
|
|
elif data_point == "ogc":
|
|
ogc_true.extend(true_data)
|
|
ogc_pred.extend(pred_data)
|
|
elif data_point == "performance_fee":
|
|
performance_fee_true.extend(true_data)
|
|
performance_fee_pred.extend(pred_data)
|
|
missing_error_list.extend(missing_error_data)
|
|
elif self.data_type == "investment_mapping":
|
|
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
|
|
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
|
|
# get intersection of doc_id_list
|
|
doc_id_list = list(
|
|
set(prediction_doc_id_list) & set(ground_truth_doc_id_list)
|
|
)
|
|
# order by doc_id
|
|
doc_id_list.sort()
|
|
|
|
for doc_id in doc_id_list:
|
|
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id]
|
|
ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id]
|
|
for data_point in data_point_list:
|
|
true_data, pred_data, missing_error_data = self.get_investment_mapping_true_pred_data(
|
|
doc_id, ground_truth_data, prediction_data, data_point
|
|
)
|
|
investment_mapping_true.extend(true_data)
|
|
investment_mapping_pred.extend(pred_data)
|
|
missing_error_list.extend(missing_error_data)
|
|
elif self.data_type == "document_mapping_in_db":
|
|
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
|
|
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
|
|
# get intersection of doc_id_list
|
|
doc_id_list = list(
|
|
set(prediction_doc_id_list) & set(ground_truth_doc_id_list)
|
|
)
|
|
# order by doc_id
|
|
doc_id_list.sort()
|
|
|
|
for doc_id in doc_id_list:
|
|
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id]
|
|
ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id]
|
|
true_data, pred_data, missing_error_data = self.get_document_mapping_in_db_true_pred_data(
|
|
doc_id, ground_truth_data, prediction_data, strict_mode=strict_model
|
|
)
|
|
investment_mapping_true.extend(true_data)
|
|
investment_mapping_pred.extend(pred_data)
|
|
missing_error_list.extend(missing_error_data)
|
|
|
|
metrics_list = []
|
|
if self.data_type in ["investment_mapping", "document_mapping_in_db"]:
|
|
if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0:
|
|
investment_mapping_true.append(1)
|
|
investment_mapping_pred.append(1)
|
|
precision, recall, f1 = self.get_specific_metrics(investment_mapping_true, investment_mapping_pred)
|
|
investment_mapping_support = self.get_support_number(investment_mapping_true)
|
|
metrics_list.append(
|
|
{
|
|
"Data_Point": "Investment Mapping",
|
|
"Precision": precision,
|
|
"Recall": recall,
|
|
"F1": f1,
|
|
"Support": investment_mapping_support,
|
|
}
|
|
)
|
|
logger.info(
|
|
f"Investment mapping Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {investment_mapping_support}"
|
|
)
|
|
else:
|
|
for data_point in data_point_list:
|
|
if data_point == "tor":
|
|
precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred)
|
|
tor_support = self.get_support_number(tor_true)
|
|
metrics_list.append(
|
|
{
|
|
"Data_Point": data_point,
|
|
"Precision": precision,
|
|
"Recall": recall,
|
|
"F1": f1,
|
|
"Support": tor_support,
|
|
}
|
|
)
|
|
logger.info(
|
|
f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {tor_support}"
|
|
)
|
|
elif data_point == "ter":
|
|
precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred)
|
|
ter_support = self.get_support_number(ter_true)
|
|
metrics_list.append(
|
|
{
|
|
"Data_Point": data_point,
|
|
"Precision": precision,
|
|
"Recall": recall,
|
|
"F1": f1,
|
|
"Support": ter_support,
|
|
}
|
|
)
|
|
logger.info(
|
|
f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ter_support}"
|
|
)
|
|
elif data_point == "ogc":
|
|
precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred)
|
|
ogc_support = self.get_support_number(ogc_true)
|
|
metrics_list.append(
|
|
{
|
|
"Data_Point": data_point,
|
|
"Precision": precision,
|
|
"Recall": recall,
|
|
"F1": f1,
|
|
"Support": ogc_support,
|
|
}
|
|
)
|
|
logger.info(
|
|
f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ogc_support}"
|
|
)
|
|
elif data_point == "performance_fee":
|
|
precision, recall, f1 = self.get_specific_metrics(
|
|
performance_fee_true, performance_fee_pred
|
|
)
|
|
performance_fee_support = self.get_support_number(performance_fee_true)
|
|
metrics_list.append(
|
|
{
|
|
"Data_Point": data_point,
|
|
"Precision": precision,
|
|
"Recall": recall,
|
|
"F1": f1,
|
|
"Support": performance_fee_support,
|
|
}
|
|
)
|
|
logger.info(
|
|
f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {performance_fee_support}"
|
|
)
|
|
|
|
# get average metrics
|
|
precision_list = [metric["Precision"] for metric in metrics_list]
|
|
recall_list = [metric["Recall"] for metric in metrics_list]
|
|
f1_list = [metric["F1"] for metric in metrics_list]
|
|
metrics_list.append(
|
|
{
|
|
"Data_Point": "Average",
|
|
"Precision": sum(precision_list) / len(precision_list),
|
|
"Recall": sum(recall_list) / len(recall_list),
|
|
"F1": sum(f1_list) / len(f1_list),
|
|
"Support": sum([metric["Support"] for metric in metrics_list]),
|
|
}
|
|
)
|
|
return missing_error_list, metrics_list
|
|
|
|
def get_support_number(self, true_data: list):
|
|
# get the count which true_data is 1
|
|
return sum(true_data)
|
|
|
|
def get_page_filter_true_pred_data(
|
|
self,
|
|
doc_id,
|
|
ground_truth_data: pd.Series,
|
|
prediction_data: pd.Series,
|
|
data_point: str,
|
|
):
|
|
ground_truth_list = ground_truth_data[data_point]
|
|
if isinstance(ground_truth_list, str):
|
|
ground_truth_list = json.loads(ground_truth_list)
|
|
prediction_list = prediction_data[data_point]
|
|
if isinstance(prediction_list, str):
|
|
prediction_list = json.loads(prediction_list)
|
|
|
|
true_data = []
|
|
pred_data = []
|
|
|
|
missing_error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"missing": [],
|
|
"error": [],
|
|
}
|
|
|
|
missing_data = []
|
|
error_data = []
|
|
|
|
if len(ground_truth_list) == 0 and len(prediction_list) == 0:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
return true_data, pred_data, missing_error_data
|
|
|
|
for prediction in prediction_list:
|
|
if prediction in ground_truth_list:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
else:
|
|
true_data.append(0)
|
|
pred_data.append(1)
|
|
error_data.append(prediction)
|
|
|
|
for ground_truth in ground_truth_list:
|
|
if ground_truth not in prediction_list:
|
|
true_data.append(1)
|
|
pred_data.append(0)
|
|
missing_data.append(ground_truth)
|
|
missing_error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"missing": missing_data,
|
|
"error": error_data,
|
|
}
|
|
|
|
return true_data, pred_data, missing_error_data
|
|
|
|
def get_data_extraction_true_pred_data(
|
|
self,
|
|
doc_id,
|
|
ground_truth_data: pd.DataFrame,
|
|
prediction_data: pd.DataFrame,
|
|
data_point: str,
|
|
):
|
|
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
|
|
dp_prediction = self.modify_data(dp_prediction)
|
|
pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
|
|
pred_simple_name_unique_words_list = (
|
|
dp_prediction["simple_name_unique_words"].unique().tolist()
|
|
)
|
|
|
|
dp_ground_truth = ground_truth_data[
|
|
ground_truth_data["datapoint"] == data_point
|
|
]
|
|
dp_ground_truth = self.modify_data(dp_ground_truth)
|
|
gt_simple_raw_names = dp_ground_truth["simple_raw_name"].unique().tolist()
|
|
gt_simple_name_unique_words_list = (
|
|
dp_ground_truth["simple_name_unique_words"].unique().tolist()
|
|
)
|
|
|
|
true_data = []
|
|
pred_data = []
|
|
|
|
missing_error_data = []
|
|
|
|
if len(dp_ground_truth) == 0 and len(dp_prediction) == 0:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
return true_data, pred_data, missing_error_data
|
|
|
|
for index, prediction in dp_prediction.iterrows():
|
|
pred_page_index = prediction["page_index"]
|
|
pred_raw_name = prediction["raw_name"]
|
|
pred_simple_raw_name = prediction["simple_raw_name"]
|
|
pred_simple_name_unique_words = prediction["simple_name_unique_words"]
|
|
pred_data_point_value = prediction["value"]
|
|
pred_investment_type = prediction["investment_type"]
|
|
|
|
find_raw_name_in_gt = [
|
|
gt_raw_name
|
|
for gt_raw_name in gt_simple_raw_names
|
|
if (
|
|
gt_raw_name in pred_simple_raw_name
|
|
or pred_simple_raw_name in gt_raw_name
|
|
)
|
|
and gt_raw_name.endswith(pred_simple_raw_name.split()[-1])
|
|
]
|
|
|
|
if (
|
|
pred_simple_name_unique_words in gt_simple_name_unique_words_list
|
|
or len(find_raw_name_in_gt) > 0
|
|
):
|
|
# get the ground truth data with the same unique words
|
|
if pred_simple_name_unique_words in gt_simple_name_unique_words_list:
|
|
gt_data_df = dp_ground_truth[
|
|
dp_ground_truth["simple_name_unique_words"]
|
|
== pred_simple_name_unique_words
|
|
]
|
|
if len(gt_data_df) > 1:
|
|
if (
|
|
len(gt_data_df[gt_data_df["page_index"] == pred_page_index])
|
|
== 0
|
|
):
|
|
gt_data = gt_data_df.iloc[0]
|
|
else:
|
|
gt_data = gt_data_df[
|
|
gt_data_df["page_index"] == pred_page_index
|
|
].iloc[0]
|
|
elif len(gt_data_df) == 1:
|
|
gt_data = gt_data_df.iloc[0]
|
|
else:
|
|
gt_data = None
|
|
else:
|
|
gt_data_df = dp_ground_truth[
|
|
dp_ground_truth["simple_raw_name"] == find_raw_name_in_gt[0]
|
|
]
|
|
if len(gt_data_df) > 1:
|
|
if (
|
|
len(gt_data_df[gt_data_df["page_index"] == pred_page_index])
|
|
== 0
|
|
):
|
|
gt_data = gt_data_df.iloc[0]
|
|
else:
|
|
gt_data = gt_data_df[
|
|
gt_data_df["page_index"] == pred_page_index
|
|
].iloc[0]
|
|
elif len(gt_data_df) == 1:
|
|
gt_data = gt_data_df.iloc[0]
|
|
else:
|
|
gt_data = None
|
|
if gt_data is None:
|
|
gt_data_point_value = None
|
|
else:
|
|
gt_data_point_value = gt_data["value"]
|
|
if (
|
|
gt_data_point_value is not None
|
|
and pred_data_point_value == gt_data_point_value
|
|
):
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
else:
|
|
true_data.append(0)
|
|
pred_data.append(1)
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"page_index": pred_page_index,
|
|
"pred_raw_name": pred_raw_name,
|
|
"investment_type": pred_investment_type,
|
|
"error_type": "data value incorrect",
|
|
"error_value": pred_data_point_value,
|
|
"correct_value": gt_data_point_value,
|
|
}
|
|
missing_error_data.append(error_data)
|
|
else:
|
|
# If data point is performance fees, and value is 0,
|
|
# then it's correct
|
|
pred_value_num = None
|
|
try:
|
|
pred_value_num = float(pred_data_point_value)
|
|
except:
|
|
pass
|
|
if data_point == "performance_fee" and pred_value_num == 0:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
else:
|
|
true_data.append(0)
|
|
pred_data.append(1)
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"page_index": pred_page_index,
|
|
"pred_raw_name": pred_raw_name,
|
|
"investment_type": pred_investment_type,
|
|
"error_type": "raw name incorrect",
|
|
"error_value": pred_raw_name,
|
|
"correct_value": "",
|
|
}
|
|
missing_error_data.append(error_data)
|
|
|
|
for index, ground_truth in dp_ground_truth.iterrows():
|
|
gt_page_index = ground_truth["page_index"]
|
|
gt_raw_name = ground_truth["raw_name"]
|
|
gt_simple_raw_name = ground_truth["simple_raw_name"]
|
|
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
|
|
gt_data_point_value = ground_truth["value"]
|
|
gt_investment_type = ground_truth["investment_type"]
|
|
|
|
find_raw_name_in_pred = [
|
|
pred_raw_name
|
|
for pred_raw_name in pred_simple_raw_names
|
|
if (
|
|
gt_simple_raw_name in pred_raw_name
|
|
or pred_raw_name in gt_simple_raw_name
|
|
)
|
|
and pred_raw_name.endswith(gt_simple_raw_name.split()[-1])
|
|
]
|
|
|
|
if (
|
|
gt_simple_name_unique_words not in pred_simple_name_unique_words_list
|
|
and len(find_raw_name_in_pred) == 0
|
|
):
|
|
gt_value_num = None
|
|
try:
|
|
gt_value_num = float(gt_data_point_value)
|
|
except:
|
|
pass
|
|
# If data point is performance fees, and value is 0,
|
|
# then it's correct
|
|
if data_point == "performance_fee" and gt_value_num == 0:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
else:
|
|
true_data.append(1)
|
|
pred_data.append(0)
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"page_index": gt_page_index,
|
|
"pred_raw_name": "",
|
|
"investment_type": gt_investment_type,
|
|
"error_type": "raw name missing",
|
|
"error_value": "",
|
|
"correct_value": gt_raw_name,
|
|
}
|
|
missing_error_data.append(error_data)
|
|
|
|
return true_data, pred_data, missing_error_data
|
|
|
|
def get_investment_mapping_true_pred_data(
|
|
self,
|
|
doc_id,
|
|
ground_truth_data: pd.DataFrame,
|
|
prediction_data: pd.DataFrame,
|
|
data_point: str,
|
|
):
|
|
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
|
|
dp_prediction = self.modify_data(dp_prediction)
|
|
pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
|
|
pred_simple_name_unique_words_list = (
|
|
dp_prediction["simple_name_unique_words"].unique().tolist()
|
|
)
|
|
|
|
dp_ground_truth = ground_truth_data[
|
|
ground_truth_data["datapoint"] == data_point
|
|
]
|
|
dp_ground_truth = self.modify_data(dp_ground_truth)
|
|
gt_simple_raw_names = dp_ground_truth["simple_raw_name"].unique().tolist()
|
|
gt_simple_name_unique_words_list = (
|
|
dp_ground_truth["simple_name_unique_words"].unique().tolist()
|
|
)
|
|
|
|
compare_data_list = []
|
|
for index, ground_truth in dp_ground_truth.iterrows():
|
|
gt_page_index = ground_truth["page_index"]
|
|
gt_raw_name = ground_truth["raw_name"]
|
|
gt_simple_raw_name = ground_truth["simple_raw_name"]
|
|
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
|
|
gt_investment_type = ground_truth["investment_type"]
|
|
|
|
find_raw_name_in_pred = [
|
|
pred_raw_name
|
|
for pred_raw_name in pred_simple_raw_names
|
|
if (
|
|
gt_simple_raw_name in pred_raw_name
|
|
or pred_raw_name in gt_simple_raw_name
|
|
)
|
|
and pred_raw_name.endswith(gt_simple_raw_name.split()[-1])
|
|
]
|
|
|
|
if (
|
|
gt_simple_name_unique_words in pred_simple_name_unique_words_list
|
|
or len(find_raw_name_in_pred) > 0
|
|
):
|
|
# get the ground truth data with the same unique words
|
|
if gt_simple_name_unique_words in pred_simple_name_unique_words_list:
|
|
pred_data_df = dp_prediction[
|
|
dp_prediction["simple_name_unique_words"]
|
|
== gt_simple_name_unique_words
|
|
]
|
|
if len(pred_data_df) > 1:
|
|
if (
|
|
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
|
|
== 0
|
|
):
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = pred_data_df[
|
|
pred_data_df["page_index"] == gt_page_index
|
|
].iloc[0]
|
|
elif len(pred_data_df) == 1:
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = None
|
|
else:
|
|
pred_data_df = dp_prediction[
|
|
dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0]
|
|
]
|
|
if len(pred_data_df) > 1:
|
|
if (
|
|
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
|
|
== 0
|
|
):
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = pred_data_df[
|
|
pred_data_df["page_index"] == gt_page_index
|
|
].iloc[0]
|
|
elif len(pred_data_df) == 1:
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = None
|
|
if pred_data is not None:
|
|
compare_data = {"raw_name": gt_raw_name,
|
|
"investment_type": gt_investment_type,
|
|
"gt_investment_id": ground_truth["investment_id"],
|
|
"gt_investment_name": ground_truth["investment_name"],
|
|
"pred_investment_id": pred_data["investment_id"],
|
|
"pred_investment_name": pred_data["investment_name"]}
|
|
compare_data_list.append(compare_data)
|
|
|
|
true_data = []
|
|
pred_data = []
|
|
missing_error_data = []
|
|
|
|
for compare_data in compare_data_list:
|
|
if compare_data["gt_investment_id"] == compare_data["pred_investment_id"]:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
else:
|
|
true_data.append(1)
|
|
pred_data.append(0)
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"raw_name": compare_data["raw_name"],
|
|
"investment_type": compare_data["investment_type"],
|
|
"error_type": "mapping missing",
|
|
"error_id": compare_data["pred_investment_id"],
|
|
"error_name": compare_data["pred_investment_name"],
|
|
"correct_id": compare_data["gt_investment_id"],
|
|
"correct_name": compare_data["gt_investment_name"]
|
|
}
|
|
missing_error_data.append(error_data)
|
|
|
|
for index, prediction in dp_prediction.iterrows():
|
|
pred_raw_name = prediction["raw_name"]
|
|
pred_investment_id = prediction["investment_id"]
|
|
pred_investment_name = prediction["investment_name"]
|
|
pred_investment_type = prediction["investment_type"]
|
|
gt_data_df = dp_ground_truth[dp_ground_truth["investment_id"] == pred_investment_id]
|
|
if len(gt_data_df) == 0:
|
|
true_data.append(0)
|
|
pred_data.append(1)
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"data_point": data_point,
|
|
"raw_name": pred_raw_name,
|
|
"investment_type": pred_investment_type,
|
|
"error_type": "mapping incorrect",
|
|
"error_id": pred_investment_id,
|
|
"error_name": pred_investment_name,
|
|
"correct_id": "",
|
|
"correct_name": ""
|
|
}
|
|
missing_error_data.append(error_data)
|
|
return true_data, pred_data, missing_error_data
|
|
|
|
def get_document_mapping_in_db_true_pred_data(
|
|
self,
|
|
doc_id,
|
|
ground_truth_data: pd.DataFrame,
|
|
prediction_data: pd.DataFrame,
|
|
strict_mode: bool = False,
|
|
):
|
|
"""
|
|
EMEA AR Mapping Metrics based on document mapping in DB
|
|
1. Make ground truth manually
|
|
According to fund name/ share name in document mapping,
|
|
Find relevant data in document data extraction, input mapping id
|
|
2. Metrics calculation
|
|
Recall:
|
|
Based on each document:
|
|
a. Ground truth data
|
|
According to the mapping id in document mapping,
|
|
filter relevant document data extraction records.
|
|
b. Prediction data
|
|
Get document mapping by fund/ share raw name from PDF document.
|
|
mapping correct: true 1 pred 1
|
|
mapping error
|
|
mapping is empty:
|
|
true 1 pred 0 --- hurt recall
|
|
mapping is incorrect: other fund/ share id:
|
|
true 1 pred 0 --- hurt recall
|
|
if incorrect mapping in document mapping:
|
|
true 0 pred 1 --- hurt precision
|
|
"""
|
|
document_mapping_data = query_document_fund_mapping(doc_id)
|
|
if len(document_mapping_data) == 0:
|
|
return [1], [1], []
|
|
fund_id_list = document_mapping_data["FundId"].unique().tolist()
|
|
share_id_list = document_mapping_data["SecId"].unique().tolist()
|
|
id_list = fund_id_list + share_id_list
|
|
|
|
# get dp_ground_truth which investment_id in id_list
|
|
dp_ground_truth = ground_truth_data[
|
|
ground_truth_data["investment_id"].isin(id_list)
|
|
]
|
|
|
|
dp_ground_truth = self.modify_data(dp_ground_truth)
|
|
# only get the columns: doc_id, raw_name, simple_raw_name, simple_name_unique_words,
|
|
# investment_type, investment_id, investment_name
|
|
# from dp_ground_truth
|
|
dp_ground_truth = dp_ground_truth[["doc_id", "page_index", "raw_name", "simple_raw_name",
|
|
"simple_name_unique_words", "investment_type",
|
|
"investment_id", "investment_name"]]
|
|
dp_ground_truth.drop_duplicates(inplace=True)
|
|
dp_ground_truth.reset_index(drop=True, inplace=True)
|
|
|
|
# fillnan for dp_prediction investment_id to be "" if it is nan
|
|
prediction_data["investment_id"].fillna("", inplace=True)
|
|
prediction_data["investment_name"].fillna("", inplace=True)
|
|
dp_prediction = self.modify_data(prediction_data)
|
|
dp_prediction = dp_prediction[["doc_id", "page_index", "raw_name", "simple_raw_name",
|
|
"simple_name_unique_words", "investment_type",
|
|
"investment_id", "investment_name"]]
|
|
dp_prediction.drop_duplicates(inplace=True)
|
|
dp_prediction.reset_index(drop=True, inplace=True)
|
|
# pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
|
|
# pred_simple_name_unique_words_list = (
|
|
# dp_prediction["simple_name_unique_words"].unique().tolist()
|
|
# )
|
|
|
|
compare_data_list = []
|
|
gt_investment_id_list = []
|
|
for index, ground_truth in dp_ground_truth.iterrows():
|
|
gt_page_index = ground_truth["page_index"]
|
|
gt_raw_name = ground_truth["raw_name"]
|
|
gt_simple_raw_name = ground_truth["simple_raw_name"]
|
|
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
|
|
gt_investment_type = ground_truth["investment_type"]
|
|
gt_investment_id = ground_truth["investment_id"]
|
|
gt_investment_name = ground_truth["investment_name"]
|
|
|
|
# get pred_simple_raw_names by gt_page_index
|
|
pred_page_data = dp_prediction[dp_prediction["page_index"] == gt_page_index]
|
|
if len(pred_page_data) > 0:
|
|
pred_simple_raw_names = pred_page_data["simple_raw_name"].unique().tolist()
|
|
pred_simple_name_unique_words_list = (
|
|
pred_page_data["simple_name_unique_words"].unique().tolist()
|
|
)
|
|
else:
|
|
pred_simple_raw_names = []
|
|
pred_simple_name_unique_words_list = []
|
|
|
|
if gt_investment_id in gt_investment_id_list:
|
|
continue
|
|
find_raw_name_in_pred = [
|
|
pred_raw_name
|
|
for pred_raw_name in pred_simple_raw_names
|
|
if (
|
|
gt_simple_raw_name in pred_raw_name
|
|
or pred_raw_name in gt_simple_raw_name
|
|
)
|
|
and pred_raw_name.endswith(gt_simple_raw_name.split()[-1])
|
|
]
|
|
|
|
if (
|
|
gt_simple_name_unique_words in pred_simple_name_unique_words_list
|
|
or len(find_raw_name_in_pred) > 0
|
|
):
|
|
# get the ground truth data with the same unique words
|
|
if gt_simple_name_unique_words in pred_simple_name_unique_words_list:
|
|
pred_data_df = dp_prediction[
|
|
dp_prediction["simple_name_unique_words"]
|
|
== gt_simple_name_unique_words
|
|
]
|
|
if len(pred_data_df) > 1:
|
|
if (
|
|
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
|
|
== 0
|
|
):
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = pred_data_df[
|
|
pred_data_df["page_index"] == gt_page_index
|
|
].iloc[0]
|
|
elif len(pred_data_df) == 1:
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = None
|
|
else:
|
|
if len(find_raw_name_in_pred) > 1:
|
|
max_similarity_name, max_similarity = simple_most_similarity_name(gt_raw_name, find_raw_name_in_pred)
|
|
if max_similarity_name is not None and len(max_similarity_name) > 0:
|
|
pred_data_df = dp_prediction[
|
|
dp_prediction["simple_raw_name"] == max_similarity_name
|
|
]
|
|
else:
|
|
pred_data_df = dp_prediction[
|
|
dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0]
|
|
]
|
|
else:
|
|
pred_data_df = dp_prediction[
|
|
dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0]
|
|
]
|
|
if len(pred_data_df) > 1:
|
|
if (
|
|
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
|
|
== 0
|
|
):
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = pred_data_df[
|
|
pred_data_df["page_index"] == gt_page_index
|
|
].iloc[0]
|
|
elif len(pred_data_df) == 1:
|
|
pred_data = pred_data_df.iloc[0]
|
|
else:
|
|
pred_data = None
|
|
if pred_data is not None:
|
|
compare_data = {"raw_name": gt_raw_name,
|
|
"investment_type": gt_investment_type,
|
|
"gt_investment_id": gt_investment_id,
|
|
"gt_investment_name": gt_investment_name,
|
|
"pred_investment_id": pred_data["investment_id"],
|
|
"pred_investment_name": pred_data["investment_name"]}
|
|
gt_investment_id_list.append(gt_investment_id)
|
|
compare_data_list.append(compare_data)
|
|
else:
|
|
if strict_mode:
|
|
compare_data = {"raw_name": gt_raw_name,
|
|
"investment_type": gt_investment_type,
|
|
"gt_investment_id": gt_investment_id,
|
|
"gt_investment_name": gt_investment_name,
|
|
"pred_investment_id": "",
|
|
"pred_investment_name": ""}
|
|
compare_data_list.append(compare_data)
|
|
|
|
true_data = []
|
|
pred_data = []
|
|
missing_error_data = []
|
|
|
|
for compare_data in compare_data_list:
|
|
gt_investment_id = compare_data["gt_investment_id"]
|
|
pred_investment_id = compare_data["pred_investment_id"]
|
|
if gt_investment_id == pred_investment_id:
|
|
true_data.append(1)
|
|
pred_data.append(1)
|
|
else:
|
|
true_data.append(1)
|
|
pred_data.append(0)
|
|
if pred_investment_id is not None and len(pred_investment_id) > 0:
|
|
if pred_investment_id in id_list:
|
|
true_data.append(0)
|
|
pred_data.append(1)
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"raw_name": compare_data["raw_name"],
|
|
"investment_type": compare_data["investment_type"],
|
|
"error_type": "mapping incorrect",
|
|
"error_id": pred_investment_id,
|
|
"error_name": compare_data["pred_investment_name"],
|
|
"correct_id": compare_data["gt_investment_id"],
|
|
"correct_name": compare_data["gt_investment_name"]
|
|
}
|
|
else:
|
|
error_data = {
|
|
"doc_id": doc_id,
|
|
"raw_name": compare_data["raw_name"],
|
|
"investment_type": compare_data["investment_type"],
|
|
"error_type": "mapping missing",
|
|
"error_id": "",
|
|
"error_name": "",
|
|
"correct_id": compare_data["gt_investment_id"],
|
|
"correct_name": compare_data["gt_investment_name"]
|
|
}
|
|
missing_error_data.append(error_data)
|
|
|
|
return true_data, pred_data, missing_error_data
|
|
|
|
def modify_data(self, data: pd.DataFrame):
|
|
data["simple_raw_name"] = ""
|
|
data["simple_name_unique_words"] = ""
|
|
page_index_list = data["page_index"].unique().tolist()
|
|
for pagex_index in page_index_list:
|
|
page_data = data[data["page_index"] == pagex_index]
|
|
raw_name_list = page_data["raw_name"].unique().tolist()
|
|
beginning_common_words = get_beginning_common_words(raw_name_list)
|
|
for raw_name in raw_name_list:
|
|
if (
|
|
beginning_common_words is not None
|
|
and len(beginning_common_words) > 0
|
|
):
|
|
simple_raw_name = raw_name.replace(
|
|
beginning_common_words, ""
|
|
).strip()
|
|
if len(simple_raw_name) == 0:
|
|
simple_raw_name = raw_name
|
|
else:
|
|
simple_raw_name = raw_name
|
|
simple_raw_name = remove_special_characters(simple_raw_name)
|
|
temp_splits = [word for word in simple_raw_name.split()
|
|
if word.lower() not in ["class"]]
|
|
if len(temp_splits) > 0:
|
|
simple_raw_name = " ".join(
|
|
word
|
|
for word in simple_raw_name.split()
|
|
if word.lower() not in ["class"]
|
|
)
|
|
simple_raw_name_splits = simple_raw_name.split()
|
|
if len(simple_raw_name_splits) > 2 and \
|
|
simple_raw_name_splits[-1] == "USD":
|
|
simple_raw_name = " ".join(simple_raw_name_splits[:-1])
|
|
# set simple_raw_name which with the same page and same raw_name
|
|
data.loc[
|
|
(data["page_index"] == pagex_index)
|
|
& (data["raw_name"] == raw_name),
|
|
"simple_raw_name",
|
|
] = simple_raw_name
|
|
data.loc[
|
|
(data["page_index"] == pagex_index)
|
|
& (data["raw_name"] == raw_name),
|
|
"simple_name_unique_words",
|
|
] = get_unique_words_text(simple_raw_name)
|
|
return data
|
|
|
|
def get_specific_metrics(self, true_data: list, pred_data: list):
|
|
precision = precision_score(true_data, pred_data)
|
|
recall = recall_score(true_data, pred_data)
|
|
f1 = f1_score(true_data, pred_data)
|
|
return precision, recall, f1
|
|
|
|
def get_datapoint_metrics(self):
|
|
pass
|