support calculate mapping metrics based on document investment mapping in database

This commit is contained in:
Blade He 2024-09-27 13:20:50 -05:00
parent 0c4c541319
commit 39cd53dc33
3 changed files with 276 additions and 14 deletions

View File

@ -4,6 +4,7 @@ import time
import json import json
from sklearn.metrics import precision_score, recall_score, f1_score from sklearn.metrics import precision_score, recall_score, f1_score
from utils.biz_utils import get_unique_words_text, get_beginning_common_words, remove_special_characters from utils.biz_utils import get_unique_words_text, get_beginning_common_words, remove_special_characters
from utils.sql_query_util import query_document_fund_mapping
from utils.logger import logger from utils.logger import logger
@ -33,7 +34,7 @@ class Metrics:
f"metrics_{data_type}_{time_stamp}.xlsx", f"metrics_{data_type}_{time_stamp}.xlsx",
) )
def get_metrics(self): def get_metrics(self, strict_model: bool = False):
if ( if (
self.prediction_file is None self.prediction_file is None
or len(self.prediction_file) == 0 or len(self.prediction_file) == 0
@ -53,7 +54,7 @@ class Metrics:
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0} {"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
] ]
missing_error_list, metrics_list = self.calculate_metrics() missing_error_list, metrics_list = self.calculate_metrics(strict_model=strict_model)
missing_error_df = pd.DataFrame(missing_error_list) missing_error_df = pd.DataFrame(missing_error_list)
missing_error_df.reset_index(drop=True, inplace=True) missing_error_df.reset_index(drop=True, inplace=True)
@ -66,7 +67,7 @@ class Metrics:
metrics_df.to_excel(writer, sheet_name="Metrics", index=False) metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
return missing_error_list, metrics_list, self.output_file return missing_error_list, metrics_list, self.output_file
def calculate_metrics(self): def calculate_metrics(self, strict_model: bool = False):
prediction_df = pd.read_excel( prediction_df = pd.read_excel(
self.prediction_file, sheet_name=self.prediction_sheet_name self.prediction_file, sheet_name=self.prediction_sheet_name
) )
@ -77,7 +78,7 @@ class Metrics:
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1] ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
elif self.data_type == "data_extraction": elif self.data_type == "data_extraction":
ground_truth_df = ground_truth_df[ground_truth_df["rawname_checked"] == 1] ground_truth_df = ground_truth_df[ground_truth_df["rawname_checked"] == 1]
elif self.data_type == "investment_mapping": elif self.data_type in ["investment_mapping", "document_mapping_in_db"]:
ground_truth_df = ground_truth_df[ground_truth_df["mapping_checked"] == 1] ground_truth_df = ground_truth_df[ground_truth_df["mapping_checked"] == 1]
else: else:
logger.error(f"Invalid data type: {self.data_type}") logger.error(f"Invalid data type: {self.data_type}")
@ -179,9 +180,28 @@ class Metrics:
investment_mapping_true.extend(true_data) investment_mapping_true.extend(true_data)
investment_mapping_pred.extend(pred_data) investment_mapping_pred.extend(pred_data)
missing_error_list.extend(missing_error_data) missing_error_list.extend(missing_error_data)
elif self.data_type == "document_mapping_in_db":
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
# get intersection of doc_id_list
doc_id_list = list(
set(prediction_doc_id_list) & set(ground_truth_doc_id_list)
)
# order by doc_id
doc_id_list.sort()
for doc_id in doc_id_list:
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id]
ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id]
true_data, pred_data, missing_error_data = self.get_document_mapping_in_db_true_pred_data(
doc_id, ground_truth_data, prediction_data, strict_mode=strict_model
)
investment_mapping_true.extend(true_data)
investment_mapping_pred.extend(pred_data)
missing_error_list.extend(missing_error_data)
metrics_list = [] metrics_list = []
if self.data_type == "investment_mapping": if self.data_type in ["investment_mapping", "document_mapping_in_db"]:
if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0: if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0:
investment_mapping_true.append(1) investment_mapping_true.append(1)
investment_mapping_pred.append(1) investment_mapping_pred.append(1)
@ -669,6 +689,208 @@ class Metrics:
missing_error_data.append(error_data) missing_error_data.append(error_data)
return true_data, pred_data, missing_error_data return true_data, pred_data, missing_error_data
def get_document_mapping_in_db_true_pred_data(
self,
doc_id,
ground_truth_data: pd.DataFrame,
prediction_data: pd.DataFrame,
strict_mode: bool = False,
):
"""
EMEA AR Mapping Metrics based on document mapping in DB
1. Make ground truth manually
According to fund name/ share name in document mapping,
Find relevant data in document data extraction, input mapping id
2. Metrics calculation
Recall:
Based on each document:
a. Ground truth data
According to the mapping id in document mapping,
filter relevant document data extraction records.
b. Prediction data
Get document mapping by fund/ share raw name from PDF document.
mapping correct: true 1 pred 1
mapping error
mapping is empty:
true 1 pred 0 --- hurt recall
mapping is incorrect: other fund/ share id:
true 1 pred 0 --- hurt recall
if incorrect mapping in document mapping:
true 0 pred 1 --- hurt precision
"""
document_mapping_data = query_document_fund_mapping(doc_id)
if len(document_mapping_data) == 0:
return [1], [1], []
fund_id_list = document_mapping_data["FundId"].unique().tolist()
share_id_list = document_mapping_data["SecId"].unique().tolist()
id_list = fund_id_list + share_id_list
# get dp_ground_truth which investment_id in id_list
dp_ground_truth = ground_truth_data[
ground_truth_data["investment_id"].isin(id_list)
]
dp_ground_truth = self.modify_data(dp_ground_truth)
# only get the columns: doc_id, raw_name, simple_raw_name, simple_name_unique_words,
# investment_type, investment_id, investment_name
# from dp_ground_truth
dp_ground_truth = dp_ground_truth[["doc_id", "page_index", "raw_name", "simple_raw_name",
"simple_name_unique_words", "investment_type",
"investment_id", "investment_name"]]
dp_ground_truth.drop_duplicates(inplace=True)
dp_ground_truth.reset_index(drop=True, inplace=True)
# fillnan for dp_prediction investment_id to be "" if it is nan
prediction_data["investment_id"].fillna("", inplace=True)
prediction_data["investment_name"].fillna("", inplace=True)
dp_prediction = self.modify_data(prediction_data)
dp_prediction = dp_prediction[["doc_id", "page_index", "raw_name", "simple_raw_name",
"simple_name_unique_words", "investment_type",
"investment_id", "investment_name"]]
dp_prediction.drop_duplicates(inplace=True)
dp_prediction.reset_index(drop=True, inplace=True)
# pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
# pred_simple_name_unique_words_list = (
# dp_prediction["simple_name_unique_words"].unique().tolist()
# )
compare_data_list = []
gt_investment_id_list = []
for index, ground_truth in dp_ground_truth.iterrows():
gt_page_index = ground_truth["page_index"]
gt_raw_name = ground_truth["raw_name"]
gt_simple_raw_name = ground_truth["simple_raw_name"]
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
gt_investment_type = ground_truth["investment_type"]
gt_investment_id = ground_truth["investment_id"]
gt_investment_name = ground_truth["investment_name"]
# get pred_simple_raw_names by gt_page_index
pred_page_data = dp_prediction[dp_prediction["page_index"] == gt_page_index]
if len(pred_page_data) > 0:
pred_simple_raw_names = pred_page_data["simple_raw_name"].unique().tolist()
pred_simple_name_unique_words_list = (
pred_page_data["simple_name_unique_words"].unique().tolist()
)
else:
pred_simple_raw_names = []
pred_simple_name_unique_words_list = []
if gt_investment_id in gt_investment_id_list:
continue
find_raw_name_in_pred = [
pred_raw_name
for pred_raw_name in pred_simple_raw_names
if (
gt_simple_raw_name in pred_raw_name
or pred_raw_name in gt_simple_raw_name
)
and pred_raw_name.endswith(gt_simple_raw_name.split()[-1])
]
if (
gt_simple_name_unique_words in pred_simple_name_unique_words_list
or len(find_raw_name_in_pred) > 0
):
# get the ground truth data with the same unique words
if gt_simple_name_unique_words in pred_simple_name_unique_words_list:
pred_data_df = dp_prediction[
dp_prediction["simple_name_unique_words"]
== gt_simple_name_unique_words
]
if len(pred_data_df) > 1:
if (
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
== 0
):
pred_data = pred_data_df.iloc[0]
else:
pred_data = pred_data_df[
pred_data_df["page_index"] == gt_page_index
].iloc[0]
elif len(pred_data_df) == 1:
pred_data = pred_data_df.iloc[0]
else:
pred_data = None
else:
pred_data_df = dp_prediction[
dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0]
]
if len(pred_data_df) > 1:
if (
len(pred_data_df[pred_data_df["page_index"] == gt_page_index])
== 0
):
pred_data = pred_data_df.iloc[0]
else:
pred_data = pred_data_df[
pred_data_df["page_index"] == gt_page_index
].iloc[0]
elif len(pred_data_df) == 1:
pred_data = pred_data_df.iloc[0]
else:
pred_data = None
if pred_data is not None:
compare_data = {"raw_name": gt_raw_name,
"investment_type": gt_investment_type,
"gt_investment_id": gt_investment_id,
"gt_investment_name": gt_investment_name,
"pred_investment_id": pred_data["investment_id"],
"pred_investment_name": pred_data["investment_name"]}
gt_investment_id_list.append(gt_investment_id)
compare_data_list.append(compare_data)
else:
if strict_mode:
compare_data = {"raw_name": gt_raw_name,
"investment_type": gt_investment_type,
"gt_investment_id": gt_investment_id,
"gt_investment_name": gt_investment_name,
"pred_investment_id": "",
"pred_investment_name": ""}
compare_data_list.append(compare_data)
true_data = []
pred_data = []
missing_error_data = []
for compare_data in compare_data_list:
gt_investment_id = compare_data["gt_investment_id"]
pred_investment_id = compare_data["pred_investment_id"]
if gt_investment_id == pred_investment_id:
true_data.append(1)
pred_data.append(1)
else:
true_data.append(1)
pred_data.append(0)
if pred_investment_id is not None and len(pred_investment_id) > 0:
if pred_investment_id in id_list:
true_data.append(0)
pred_data.append(1)
error_data = {
"doc_id": doc_id,
"raw_name": compare_data["raw_name"],
"investment_type": compare_data["investment_type"],
"error_type": "mapping incorrect",
"error_id": pred_investment_id,
"error_name": compare_data["pred_investment_name"],
"correct_id": compare_data["gt_investment_id"],
"correct_name": compare_data["gt_investment_name"]
}
else:
error_data = {
"doc_id": doc_id,
"raw_name": compare_data["raw_name"],
"investment_type": compare_data["investment_type"],
"error_type": "mapping missing",
"error_id": "",
"error_name": "",
"correct_id": compare_data["gt_investment_id"],
"correct_name": compare_data["gt_investment_name"]
}
missing_error_data.append(error_data)
return true_data, pred_data, missing_error_data
def modify_data(self, data: pd.DataFrame): def modify_data(self, data: pd.DataFrame):
data["simple_raw_name"] = "" data["simple_raw_name"] = ""
data["simple_name_unique_words"] = "" data["simple_name_unique_words"] = ""

49
main.py
View File

@ -345,9 +345,19 @@ def batch_start_job(
metrics_output_folder, metrics_output_folder,
) )
logger.info(f"Calculating metrics for investment mapping") # logger.info(f"Calculating metrics for investment mapping by actual document mapping")
# missing_error_list, metrics_list, metrics_file = get_metrics(
# "investment_mapping",
# output_file,
# prediction_sheet_name,
# ground_truth_file,
# ground_truth_sheet_name,
# metrics_output_folder,
# )
logger.info(f"Calculating metrics for investment mapping by database document mapping")
missing_error_list, metrics_list, metrics_file = get_metrics( missing_error_list, metrics_list, metrics_file = get_metrics(
"investment_mapping", "document_mapping_in_db",
output_file, output_file,
prediction_sheet_name, prediction_sheet_name,
ground_truth_file, ground_truth_file,
@ -436,7 +446,7 @@ def get_metrics(
ground_truth_sheet_name=ground_truth_sheet_name, ground_truth_sheet_name=ground_truth_sheet_name,
output_folder=output_folder, output_folder=output_folder,
) )
missing_error_list, metrics_list, metrics_file = metrics.get_metrics() missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=True)
return missing_error_list, metrics_list, metrics_file return missing_error_list, metrics_list, metrics_file
@ -657,13 +667,38 @@ if __name__ == "__main__":
"479793787", "479793787",
"471641628", "471641628",
] ]
special_doc_id_list = check_mapping_doc_id_list check_db_mapping_doc_id_list = [
special_doc_id_list = ["402113224"] "292989214",
"316237292",
"321733631",
"323390570",
"327956364",
"332223498",
"333207452",
"334718372",
"344636875",
"349679479",
"362246081",
"366179419",
"380945052",
"382366116",
"387202452",
"389171486",
"391456740",
"391736837",
"394778487",
"401684600",
"402113224",
"402181770"
]
# special_doc_id_list = check_mapping_doc_id_list
special_doc_id_list = check_db_mapping_doc_id_list
# special_doc_id_list = ["382366116"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_extract_data = False re_run_extract_data = False
re_run_mapping_data = True re_run_mapping_data = False
force_save_total_data = False force_save_total_data = True
extract_ways = ["text"] extract_ways = ["text"]
for extract_way in extract_ways: for extract_way in extract_ways:

View File

@ -119,10 +119,15 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list
for i in range(len(copy_name_list)): for i in range(len(copy_name_list)):
temp_splits = copy_name_list[i].split() temp_splits = copy_name_list[i].split()
copy_name_list[i] = ' '.join([split for split in temp_splits copy_name_list[i] = ' '.join([split for split in temp_splits
if remove_special_characters(split).lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']]) if remove_special_characters(split).lower()
not in ['fund', "funds", 'portfolio',
'class', 'classes',
'share', 'shares']])
final_splits = [] final_splits = []
for split in new_splits: for split in new_splits:
if split.lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']: if split.lower() not in ['fund', "funds", 'portfolio',
'class', 'classes',
'share', 'shares']:
final_splits.append(split) final_splits.append(split)
text = ' '.join(final_splits) text = ' '.join(final_splits)