support calculate mapping metrics based on document investment mapping in database
This commit is contained in:
parent
0c4c541319
commit
39cd53dc33
232
core/metrics.py
232
core/metrics.py
|
|
@ -4,6 +4,7 @@ import time
|
||||||
import json
|
import json
|
||||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||||
from utils.biz_utils import get_unique_words_text, get_beginning_common_words, remove_special_characters
|
from utils.biz_utils import get_unique_words_text, get_beginning_common_words, remove_special_characters
|
||||||
|
from utils.sql_query_util import query_document_fund_mapping
|
||||||
from utils.logger import logger
|
from utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,7 +34,7 @@ class Metrics:
|
||||||
f"metrics_{data_type}_{time_stamp}.xlsx",
|
f"metrics_{data_type}_{time_stamp}.xlsx",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_metrics(self):
|
def get_metrics(self, strict_model: bool = False):
|
||||||
if (
|
if (
|
||||||
self.prediction_file is None
|
self.prediction_file is None
|
||||||
or len(self.prediction_file) == 0
|
or len(self.prediction_file) == 0
|
||||||
|
|
@ -53,7 +54,7 @@ class Metrics:
|
||||||
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
||||||
]
|
]
|
||||||
|
|
||||||
missing_error_list, metrics_list = self.calculate_metrics()
|
missing_error_list, metrics_list = self.calculate_metrics(strict_model=strict_model)
|
||||||
|
|
||||||
missing_error_df = pd.DataFrame(missing_error_list)
|
missing_error_df = pd.DataFrame(missing_error_list)
|
||||||
missing_error_df.reset_index(drop=True, inplace=True)
|
missing_error_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
@ -66,7 +67,7 @@ class Metrics:
|
||||||
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
||||||
return missing_error_list, metrics_list, self.output_file
|
return missing_error_list, metrics_list, self.output_file
|
||||||
|
|
||||||
def calculate_metrics(self):
|
def calculate_metrics(self, strict_model: bool = False):
|
||||||
prediction_df = pd.read_excel(
|
prediction_df = pd.read_excel(
|
||||||
self.prediction_file, sheet_name=self.prediction_sheet_name
|
self.prediction_file, sheet_name=self.prediction_sheet_name
|
||||||
)
|
)
|
||||||
|
|
@ -77,7 +78,7 @@ class Metrics:
|
||||||
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
|
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
|
||||||
elif self.data_type == "data_extraction":
|
elif self.data_type == "data_extraction":
|
||||||
ground_truth_df = ground_truth_df[ground_truth_df["rawname_checked"] == 1]
|
ground_truth_df = ground_truth_df[ground_truth_df["rawname_checked"] == 1]
|
||||||
elif self.data_type == "investment_mapping":
|
elif self.data_type in ["investment_mapping", "document_mapping_in_db"]:
|
||||||
ground_truth_df = ground_truth_df[ground_truth_df["mapping_checked"] == 1]
|
ground_truth_df = ground_truth_df[ground_truth_df["mapping_checked"] == 1]
|
||||||
else:
|
else:
|
||||||
logger.error(f"Invalid data type: {self.data_type}")
|
logger.error(f"Invalid data type: {self.data_type}")
|
||||||
|
|
@ -179,9 +180,28 @@ class Metrics:
|
||||||
investment_mapping_true.extend(true_data)
|
investment_mapping_true.extend(true_data)
|
||||||
investment_mapping_pred.extend(pred_data)
|
investment_mapping_pred.extend(pred_data)
|
||||||
missing_error_list.extend(missing_error_data)
|
missing_error_list.extend(missing_error_data)
|
||||||
|
elif self.data_type == "document_mapping_in_db":
|
||||||
|
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
|
||||||
|
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
|
||||||
|
# get intersection of doc_id_list
|
||||||
|
doc_id_list = list(
|
||||||
|
set(prediction_doc_id_list) & set(ground_truth_doc_id_list)
|
||||||
|
)
|
||||||
|
# order by doc_id
|
||||||
|
doc_id_list.sort()
|
||||||
|
|
||||||
|
for doc_id in doc_id_list:
|
||||||
|
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id]
|
||||||
|
ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id]
|
||||||
|
true_data, pred_data, missing_error_data = self.get_document_mapping_in_db_true_pred_data(
|
||||||
|
doc_id, ground_truth_data, prediction_data, strict_mode=strict_model
|
||||||
|
)
|
||||||
|
investment_mapping_true.extend(true_data)
|
||||||
|
investment_mapping_pred.extend(pred_data)
|
||||||
|
missing_error_list.extend(missing_error_data)
|
||||||
|
|
||||||
metrics_list = []
|
metrics_list = []
|
||||||
if self.data_type == "investment_mapping":
|
if self.data_type in ["investment_mapping", "document_mapping_in_db"]:
|
||||||
if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0:
|
if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0:
|
||||||
investment_mapping_true.append(1)
|
investment_mapping_true.append(1)
|
||||||
investment_mapping_pred.append(1)
|
investment_mapping_pred.append(1)
|
||||||
|
|
@ -669,6 +689,208 @@ class Metrics:
|
||||||
missing_error_data.append(error_data)
|
missing_error_data.append(error_data)
|
||||||
return true_data, pred_data, missing_error_data
|
return true_data, pred_data, missing_error_data
|
||||||
|
|
||||||
|
def get_document_mapping_in_db_true_pred_data(
|
||||||
|
self,
|
||||||
|
doc_id,
|
||||||
|
ground_truth_data: pd.DataFrame,
|
||||||
|
prediction_data: pd.DataFrame,
|
||||||
|
strict_mode: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
EMEA AR Mapping Metrics based on document mapping in DB
|
||||||
|
1. Make ground truth manually
|
||||||
|
According to fund name/ share name in document mapping,
|
||||||
|
Find relevant data in document data extraction, input mapping id
|
||||||
|
2. Metrics calculation
|
||||||
|
Recall:
|
||||||
|
Based on each document:
|
||||||
|
a. Ground truth data
|
||||||
|
According to the mapping id in document mapping,
|
||||||
|
filter relevant document data extraction records.
|
||||||
|
b. Prediction data
|
||||||
|
Get document mapping by fund/ share raw name from PDF document.
|
||||||
|
mapping correct: true 1 pred 1
|
||||||
|
mapping error
|
||||||
|
mapping is empty:
|
||||||
|
true 1 pred 0 --- hurt recall
|
||||||
|
mapping is incorrect: other fund/ share id:
|
||||||
|
true 1 pred 0 --- hurt recall
|
||||||
|
if incorrect mapping in document mapping:
|
||||||
|
true 0 pred 1 --- hurt precision
|
||||||
|
"""
|
||||||
|
document_mapping_data = query_document_fund_mapping(doc_id)
|
||||||
|
if len(document_mapping_data) == 0:
|
||||||
|
return [1], [1], []
|
||||||
|
fund_id_list = document_mapping_data["FundId"].unique().tolist()
|
||||||
|
share_id_list = document_mapping_data["SecId"].unique().tolist()
|
||||||
|
id_list = fund_id_list + share_id_list
|
||||||
|
|
||||||
|
# get dp_ground_truth which investment_id in id_list
|
||||||
|
dp_ground_truth = ground_truth_data[
|
||||||
|
ground_truth_data["investment_id"].isin(id_list)
|
||||||
|
]
|
||||||
|
|
||||||
|
dp_ground_truth = self.modify_data(dp_ground_truth)
|
||||||
|
# only get the columns: doc_id, raw_name, simple_raw_name, simple_name_unique_words,
|
||||||
|
# investment_type, investment_id, investment_name
|
||||||
|
# from dp_ground_truth
|
||||||
|
dp_ground_truth = dp_ground_truth[["doc_id", "page_index", "raw_name", "simple_raw_name",
|
||||||
|
"simple_name_unique_words", "investment_type",
|
||||||
|
"investment_id", "investment_name"]]
|
||||||
|
dp_ground_truth.drop_duplicates(inplace=True)
|
||||||
|
dp_ground_truth.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
# fillnan for dp_prediction investment_id to be "" if it is nan
|
||||||
|
prediction_data["investment_id"].fillna("", inplace=True)
|
||||||
|
prediction_data["investment_name"].fillna("", inplace=True)
|
||||||
|
dp_prediction = self.modify_data(prediction_data)
|
||||||
|
dp_prediction = dp_prediction[["doc_id", "page_index", "raw_name", "simple_raw_name",
|
||||||
|
"simple_name_unique_words", "investment_type",
|
||||||
|
"investment_id", "investment_name"]]
|
||||||
|
dp_prediction.drop_duplicates(inplace=True)
|
||||||
|
dp_prediction.reset_index(drop=True, inplace=True)
|
||||||
|
# pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
|
||||||
|
# pred_simple_name_unique_words_list = (
|
||||||
|
# dp_prediction["simple_name_unique_words"].unique().tolist()
|
||||||
|
# )
|
||||||
|
|
||||||
|
compare_data_list = []
|
||||||
|
gt_investment_id_list = []
|
||||||
|
for index, ground_truth in dp_ground_truth.iterrows():
|
||||||
|
gt_page_index = ground_truth["page_index"]
|
||||||
|
gt_raw_name = ground_truth["raw_name"]
|
||||||
|
gt_simple_raw_name = ground_truth["simple_raw_name"]
|
||||||
|
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
|
||||||
|
gt_investment_type = ground_truth["investment_type"]
|
||||||
|
gt_investment_id = ground_truth["investment_id"]
|
||||||
|
gt_investment_name = ground_truth["investment_name"]
|
||||||
|
|
||||||
|
# get pred_simple_raw_names by gt_page_index
|
||||||
|
pred_page_data = dp_prediction[dp_prediction["page_index"] == gt_page_index]
|
||||||
|
if len(pred_page_data) > 0:
|
||||||
|
pred_simple_raw_names = pred_page_data["simple_raw_name"].unique().tolist()
|
||||||
|
pred_simple_name_unique_words_list = (
|
||||||
|
pred_page_data["simple_name_unique_words"].unique().tolist()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pred_simple_raw_names = []
|
||||||
|
pred_simple_name_unique_words_list = []
|
||||||
|
|
||||||
|
if gt_investment_id in gt_investment_id_list:
|
||||||
|
continue
|
||||||
|
find_raw_name_in_pred = [
|
||||||
|
pred_raw_name
|
||||||
|
for pred_raw_name in pred_simple_raw_names
|
||||||
|
if (
|
||||||
|
gt_simple_raw_name in pred_raw_name
|
||||||
|
or pred_raw_name in gt_simple_raw_name
|
||||||
|
)
|
||||||
|
and pred_raw_name.endswith(gt_simple_raw_name.split()[-1])
|
||||||
|
]
|
||||||
|
|
||||||
|
if (
|
||||||
|
gt_simple_name_unique_words in pred_simple_name_unique_words_list
|
||||||
|
or len(find_raw_name_in_pred) > 0
|
||||||
|
):
|
||||||
|
# get the ground truth data with the same unique words
|
||||||
|
if gt_simple_name_unique_words in pred_simple_name_unique_words_list:
|
||||||
|
pred_data_df = dp_prediction[
|
||||||
|
dp_prediction["simple_name_unique_words"]
|
||||||
|
== gt_simple_name_unique_words
|
||||||
|
]
|
||||||
|
if len(pred_data_df) > 1:
|
||||||
|
if (
|
||||||
|
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
pred_data = pred_data_df.iloc[0]
|
||||||
|
else:
|
||||||
|
pred_data = pred_data_df[
|
||||||
|
pred_data_df["page_index"] == gt_page_index
|
||||||
|
].iloc[0]
|
||||||
|
elif len(pred_data_df) == 1:
|
||||||
|
pred_data = pred_data_df.iloc[0]
|
||||||
|
else:
|
||||||
|
pred_data = None
|
||||||
|
else:
|
||||||
|
pred_data_df = dp_prediction[
|
||||||
|
dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0]
|
||||||
|
]
|
||||||
|
if len(pred_data_df) > 1:
|
||||||
|
if (
|
||||||
|
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
|
||||||
|
== 0
|
||||||
|
):
|
||||||
|
pred_data = pred_data_df.iloc[0]
|
||||||
|
else:
|
||||||
|
pred_data = pred_data_df[
|
||||||
|
pred_data_df["page_index"] == gt_page_index
|
||||||
|
].iloc[0]
|
||||||
|
elif len(pred_data_df) == 1:
|
||||||
|
pred_data = pred_data_df.iloc[0]
|
||||||
|
else:
|
||||||
|
pred_data = None
|
||||||
|
if pred_data is not None:
|
||||||
|
compare_data = {"raw_name": gt_raw_name,
|
||||||
|
"investment_type": gt_investment_type,
|
||||||
|
"gt_investment_id": gt_investment_id,
|
||||||
|
"gt_investment_name": gt_investment_name,
|
||||||
|
"pred_investment_id": pred_data["investment_id"],
|
||||||
|
"pred_investment_name": pred_data["investment_name"]}
|
||||||
|
gt_investment_id_list.append(gt_investment_id)
|
||||||
|
compare_data_list.append(compare_data)
|
||||||
|
else:
|
||||||
|
if strict_mode:
|
||||||
|
compare_data = {"raw_name": gt_raw_name,
|
||||||
|
"investment_type": gt_investment_type,
|
||||||
|
"gt_investment_id": gt_investment_id,
|
||||||
|
"gt_investment_name": gt_investment_name,
|
||||||
|
"pred_investment_id": "",
|
||||||
|
"pred_investment_name": ""}
|
||||||
|
compare_data_list.append(compare_data)
|
||||||
|
|
||||||
|
true_data = []
|
||||||
|
pred_data = []
|
||||||
|
missing_error_data = []
|
||||||
|
|
||||||
|
for compare_data in compare_data_list:
|
||||||
|
gt_investment_id = compare_data["gt_investment_id"]
|
||||||
|
pred_investment_id = compare_data["pred_investment_id"]
|
||||||
|
if gt_investment_id == pred_investment_id:
|
||||||
|
true_data.append(1)
|
||||||
|
pred_data.append(1)
|
||||||
|
else:
|
||||||
|
true_data.append(1)
|
||||||
|
pred_data.append(0)
|
||||||
|
if pred_investment_id is not None and len(pred_investment_id) > 0:
|
||||||
|
if pred_investment_id in id_list:
|
||||||
|
true_data.append(0)
|
||||||
|
pred_data.append(1)
|
||||||
|
error_data = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"raw_name": compare_data["raw_name"],
|
||||||
|
"investment_type": compare_data["investment_type"],
|
||||||
|
"error_type": "mapping incorrect",
|
||||||
|
"error_id": pred_investment_id,
|
||||||
|
"error_name": compare_data["pred_investment_name"],
|
||||||
|
"correct_id": compare_data["gt_investment_id"],
|
||||||
|
"correct_name": compare_data["gt_investment_name"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_data = {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"raw_name": compare_data["raw_name"],
|
||||||
|
"investment_type": compare_data["investment_type"],
|
||||||
|
"error_type": "mapping missing",
|
||||||
|
"error_id": "",
|
||||||
|
"error_name": "",
|
||||||
|
"correct_id": compare_data["gt_investment_id"],
|
||||||
|
"correct_name": compare_data["gt_investment_name"]
|
||||||
|
}
|
||||||
|
missing_error_data.append(error_data)
|
||||||
|
|
||||||
|
return true_data, pred_data, missing_error_data
|
||||||
|
|
||||||
def modify_data(self, data: pd.DataFrame):
|
def modify_data(self, data: pd.DataFrame):
|
||||||
data["simple_raw_name"] = ""
|
data["simple_raw_name"] = ""
|
||||||
data["simple_name_unique_words"] = ""
|
data["simple_name_unique_words"] = ""
|
||||||
|
|
|
||||||
49
main.py
49
main.py
|
|
@ -345,9 +345,19 @@ def batch_start_job(
|
||||||
metrics_output_folder,
|
metrics_output_folder,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Calculating metrics for investment mapping")
|
# logger.info(f"Calculating metrics for investment mapping by actual document mapping")
|
||||||
|
# missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||||
|
# "investment_mapping",
|
||||||
|
# output_file,
|
||||||
|
# prediction_sheet_name,
|
||||||
|
# ground_truth_file,
|
||||||
|
# ground_truth_sheet_name,
|
||||||
|
# metrics_output_folder,
|
||||||
|
# )
|
||||||
|
|
||||||
|
logger.info(f"Calculating metrics for investment mapping by database document mapping")
|
||||||
missing_error_list, metrics_list, metrics_file = get_metrics(
|
missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||||
"investment_mapping",
|
"document_mapping_in_db",
|
||||||
output_file,
|
output_file,
|
||||||
prediction_sheet_name,
|
prediction_sheet_name,
|
||||||
ground_truth_file,
|
ground_truth_file,
|
||||||
|
|
@ -436,7 +446,7 @@ def get_metrics(
|
||||||
ground_truth_sheet_name=ground_truth_sheet_name,
|
ground_truth_sheet_name=ground_truth_sheet_name,
|
||||||
output_folder=output_folder,
|
output_folder=output_folder,
|
||||||
)
|
)
|
||||||
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
|
missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=True)
|
||||||
return missing_error_list, metrics_list, metrics_file
|
return missing_error_list, metrics_list, metrics_file
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -657,13 +667,38 @@ if __name__ == "__main__":
|
||||||
"479793787",
|
"479793787",
|
||||||
"471641628",
|
"471641628",
|
||||||
]
|
]
|
||||||
special_doc_id_list = check_mapping_doc_id_list
|
check_db_mapping_doc_id_list = [
|
||||||
special_doc_id_list = ["402113224"]
|
"292989214",
|
||||||
|
"316237292",
|
||||||
|
"321733631",
|
||||||
|
"323390570",
|
||||||
|
"327956364",
|
||||||
|
"332223498",
|
||||||
|
"333207452",
|
||||||
|
"334718372",
|
||||||
|
"344636875",
|
||||||
|
"349679479",
|
||||||
|
"362246081",
|
||||||
|
"366179419",
|
||||||
|
"380945052",
|
||||||
|
"382366116",
|
||||||
|
"387202452",
|
||||||
|
"389171486",
|
||||||
|
"391456740",
|
||||||
|
"391736837",
|
||||||
|
"394778487",
|
||||||
|
"401684600",
|
||||||
|
"402113224",
|
||||||
|
"402181770"
|
||||||
|
]
|
||||||
|
# special_doc_id_list = check_mapping_doc_id_list
|
||||||
|
special_doc_id_list = check_db_mapping_doc_id_list
|
||||||
|
# special_doc_id_list = ["382366116"]
|
||||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||||
re_run_extract_data = False
|
re_run_extract_data = False
|
||||||
re_run_mapping_data = True
|
re_run_mapping_data = False
|
||||||
force_save_total_data = False
|
force_save_total_data = True
|
||||||
|
|
||||||
extract_ways = ["text"]
|
extract_ways = ["text"]
|
||||||
for extract_way in extract_ways:
|
for extract_way in extract_ways:
|
||||||
|
|
|
||||||
|
|
@ -119,10 +119,15 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list
|
||||||
for i in range(len(copy_name_list)):
|
for i in range(len(copy_name_list)):
|
||||||
temp_splits = copy_name_list[i].split()
|
temp_splits = copy_name_list[i].split()
|
||||||
copy_name_list[i] = ' '.join([split for split in temp_splits
|
copy_name_list[i] = ' '.join([split for split in temp_splits
|
||||||
if remove_special_characters(split).lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']])
|
if remove_special_characters(split).lower()
|
||||||
|
not in ['fund', "funds", 'portfolio',
|
||||||
|
'class', 'classes',
|
||||||
|
'share', 'shares']])
|
||||||
final_splits = []
|
final_splits = []
|
||||||
for split in new_splits:
|
for split in new_splits:
|
||||||
if split.lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']:
|
if split.lower() not in ['fund', "funds", 'portfolio',
|
||||||
|
'class', 'classes',
|
||||||
|
'share', 'shares']:
|
||||||
final_splits.append(split)
|
final_splits.append(split)
|
||||||
|
|
||||||
text = ' '.join(final_splits)
|
text = ' '.join(final_splits)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue