support calculate page filter metrics.

This commit is contained in:
Blade He 2024-09-03 17:07:53 -05:00
parent f81e2862f3
commit 7198450e53
6 changed files with 428 additions and 23 deletions

View File

@ -0,0 +1,17 @@
{
"ISIN": {
"english": []
},
"ter": {
"english": []
},
"tor": {
"english": []
},
"ogc": {
"english": ["operating expenses paid"]
},
"performance_fee": {
"english": []
}
}

View File

@ -299,7 +299,6 @@
"Ongoing Fund Charge", "Ongoing Fund Charge",
"Operating Charge", "Operating Charge",
"Operating Charges", "Operating Charges",
"Operating expenses",
"Operating, Administrative and Servicing Expenses" "Operating, Administrative and Servicing Expenses"
], ],
"spanish": [ "spanish": [
@ -340,8 +339,12 @@
"Performance Fees", "Performance Fees",
"performance-based fee", "performance-based fee",
"performance-related fee", "performance-related fee",
"with performance)", "Performance- related Fee",
"with performance fee)" "perform- mance fees",
"per- formance fees",
"with performance",
"with performance fee",
"de Performance"
], ],
"spanish": [ "spanish": [
"Comisión de Gestión sobre Resultados", "Comisión de Gestión sobre Resultados",

237
core/metrics.py Normal file
View File

@ -0,0 +1,237 @@
import os
import pandas as pd
import time
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from utils.logger import logger
class Metrics:
def __init__(
self,
data_type: str,
prediction_file: str,
prediction_sheet_name: str = "Sheet1",
ground_truth_file: str = None,
output_folder: str = None,
) -> None:
self.data_type = data_type
self.prediction_file = prediction_file
self.prediction_sheet_name = prediction_sheet_name
self.ground_truth_file = ground_truth_file
if output_folder is None or len(output_folder) == 0:
output_folder = r"/data/emea_ar/output/metrics/"
os.makedirs(output_folder, exist_ok=True)
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
self.output_file = os.path.join(
output_folder,
f"metrics_{data_type}_{time_stamp}.xlsx",
)
def get_metrics(self):
if (
self.prediction_file is None
or len(self.prediction_file) == 0
or not os.path.exists(self.prediction_file)
):
logger.error(f"Invalid prediction file: {self.prediction_file}")
return []
if (
self.ground_truth_file is None
or len(self.ground_truth_file) == 0
or not os.path.exists(self.ground_truth_file)
):
logger.error(f"Invalid ground truth file: {self.ground_truth_file}")
return []
metrics_list = [
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
]
if self.data_type == "page_filter":
missing_error_list, metrics_list = self.get_page_filter_metrics()
elif self.data_type == "datapoint":
missing_error_list, metrics_list = self.get_datapoint_metrics()
else:
logger.error(f"Invalid data type: {self.data_type}")
missing_error_df = pd.DataFrame(missing_error_list)
missing_error_df.reset_index(drop=True, inplace=True)
metrics_df = pd.DataFrame(metrics_list)
metrics_df.reset_index(drop=True, inplace=True)
with pd.ExcelWriter(self.output_file) as writer:
missing_error_df.to_excel(writer, sheet_name="Missing_Error", index=False)
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
return missing_error_list, metrics_list, self.output_file
def get_page_filter_metrics(self):
prediction_df = pd.read_excel(self.prediction_file, sheet_name=self.prediction_sheet_name)
ground_truth_df = pd.read_excel(self.ground_truth_file, sheet_name="Sheet1")
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
tor_true = []
tor_pred = []
ter_true = []
ter_pred = []
ogc_true = []
ogc_pred = []
performance_fee_true = []
performance_fee_pred = []
missing_error_list = []
data_point_list = ["tor", "ter", "ogc", "performance_fee"]
for index, row in ground_truth_df.iterrows():
doc_id = row["doc_id"]
# get first row with the same doc_id
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[0]
for data_point in data_point_list:
true_data, pred_data, missing_error_data = self.get_true_pred_data(
doc_id, row, prediction_data, data_point
)
if data_point == "tor":
tor_true.extend(true_data)
tor_pred.extend(pred_data)
elif data_point == "ter":
ter_true.extend(true_data)
ter_pred.extend(pred_data)
elif data_point == "ogc":
ogc_true.extend(true_data)
ogc_pred.extend(pred_data)
elif data_point == "performance_fee":
performance_fee_true.extend(true_data)
performance_fee_pred.extend(pred_data)
missing_error_list.append(missing_error_data)
metrics_list = []
for data_point in data_point_list:
if data_point == "tor":
precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(tor_true),
}
)
logger.info(
f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(tor_true)}"
)
elif data_point == "ter":
precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(ter_true),
}
)
logger.info(
f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(ter_true)}"
)
elif data_point == "ogc":
precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(ogc_true),
}
)
logger.info(
f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(ogc_true)}"
)
elif data_point == "performance_fee":
precision, recall, f1 = self.get_specific_metrics(
performance_fee_true, performance_fee_pred
)
metrics_list.append(
{
"Data_Point": data_point,
"Precision": precision,
"Recall": recall,
"F1": f1,
"Support": len(performance_fee_true),
}
)
logger.info(
f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(performance_fee_true)}"
)
# get average metrics
precision_list = [metric["Precision"] for metric in metrics_list]
recall_list = [metric["Recall"] for metric in metrics_list]
f1_list = [metric["F1"] for metric in metrics_list]
metrics_list.append(
{
"Data_Point": "Average",
"Precision": sum(precision_list) / len(precision_list),
"Recall": sum(recall_list) / len(recall_list),
"F1": sum(f1_list) / len(f1_list),
"Support": sum([metric["Support"] for metric in metrics_list]),
}
)
return missing_error_list, metrics_list
def get_true_pred_data(
self, doc_id, ground_truth_data: pd.Series, prediction_data: pd.Series, data_point: str
):
ground_truth_list = ground_truth_data[data_point]
if isinstance(ground_truth_list, str):
ground_truth_list = json.loads(ground_truth_list)
prediction_list = prediction_data[data_point]
if isinstance(prediction_list, str):
prediction_list = json.loads(prediction_list)
true_data = []
pred_data = []
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": [], "error": []}
missing_data = []
error_data = []
if len(ground_truth_list) == 0 and len(prediction_list) == 0:
true_data.append(1)
pred_data.append(1)
return true_data, pred_data, missing_error_data
for prediction in prediction_list:
if prediction in ground_truth_list:
true_data.append(1)
pred_data.append(1)
else:
true_data.append(0)
pred_data.append(1)
error_data.append(prediction)
for ground_truth in ground_truth_list:
if ground_truth not in prediction_list:
true_data.append(1)
pred_data.append(0)
missing_data.append(ground_truth)
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": missing_data, "error": error_data}
return true_data, pred_data, missing_error_data
def get_specific_metrics(self, true_data: list, pred_data: list):
precision = precision_score(true_data, pred_data)
recall = recall_score(true_data, pred_data)
f1 = f1_score(true_data, pred_data)
return precision, recall, f1
def get_datapoint_metrics(self):
pass

View File

@ -22,7 +22,7 @@ class FilterPages:
self.document_mapping_info_df = document_mapping_info_df self.document_mapping_info_df = document_mapping_info_df
self.get_configuration_from_file() self.get_configuration_from_file()
self.doc_info = self.get_doc_info() self.doc_info = self.get_doc_info()
self.datapoint_config = self.get_datapoint_config() self.datapoint_config, self.datapoint_exclude_config = self.get_datapoint_config()
def get_pdf_page_text_dict(self) -> dict: def get_pdf_page_text_dict(self) -> dict:
pdf_util = PDFUtil(self.pdf_file) pdf_util = PDFUtil(self.pdf_file)
@ -33,12 +33,15 @@ class FilterPages:
language_config_file = r"./configuration/language.json" language_config_file = r"./configuration/language.json"
domicile_datapoint_config_file = r"./configuration/domicile_datapoints.json" domicile_datapoint_config_file = r"./configuration/domicile_datapoints.json"
datapoint_keywords_config_file = r"./configuration/datapoint_keyword.json" datapoint_keywords_config_file = r"./configuration/datapoint_keyword.json"
datapoint_exclude_keywords_config_file = r"./configuration/datapoint_exclude_keyword.json"
with open(language_config_file, "r", encoding="utf-8") as file: with open(language_config_file, "r", encoding="utf-8") as file:
self.language_config = json.load(file) self.language_config = json.load(file)
with open(domicile_datapoint_config_file, "r", encoding="utf-8") as file: with open(domicile_datapoint_config_file, "r", encoding="utf-8") as file:
self.domicile_datapoint_config = json.load(file) self.domicile_datapoint_config = json.load(file)
with open(datapoint_keywords_config_file, "r", encoding="utf-8") as file: with open(datapoint_keywords_config_file, "r", encoding="utf-8") as file:
self.datapoint_keywords_config = json.load(file) self.datapoint_keywords_config = json.load(file)
with open(datapoint_exclude_keywords_config_file, "r", encoding="utf-8") as file:
self.datapoint_exclude_keywords_config = json.load(file)
def get_doc_info(self) -> dict: def get_doc_info(self) -> dict:
if len(self.document_mapping_info_df) == 0: if len(self.document_mapping_info_df) == 0:
@ -77,18 +80,30 @@ class FilterPages:
if self.domicile_datapoint_config[domicile].get(document_type, None) is None: if self.domicile_datapoint_config[domicile].get(document_type, None) is None:
document_type = "ar" document_type = "ar"
datapoint_list = self.domicile_datapoint_config[domicile][document_type] datapoint_list = self.domicile_datapoint_config[domicile][document_type]
datapoint_keywords = self.get_keywords("include", datapoint_list, language)
datapoint_exclude_keywords = self.get_keywords("exclude", datapoint_list, language)
return datapoint_keywords, datapoint_exclude_keywords
def get_keywords(self, keywords_type: str, datapoint_list: list, language: str) -> dict:
if keywords_type == "include":
config = self.datapoint_keywords_config
elif keywords_type == "exclude":
config = self.datapoint_exclude_keywords_config
else:
config = self.datapoint_keywords_config
datapoint_keywords = {} datapoint_keywords = {}
for datapoint in datapoint_list: for datapoint in datapoint_list:
keywords = self.datapoint_keywords_config.get(datapoint, {}).get(language, []) keywords = config.get(datapoint, {}).get(language, [])
if len(keywords) > 0: if len(keywords) > 0:
keywords = self.optimize_keywords_regex(keywords) keywords = self.optimize_keywords_regex(keywords)
datapoint_keywords[datapoint] = keywords datapoint_keywords[datapoint] = keywords
if language != "english": if language != "english":
english_keywords = self.datapoint_keywords_config.get(datapoint, {}).get("english", []) english_keywords = config.get(datapoint, {}).get("english", [])
if len(english_keywords) > 0: if len(english_keywords) > 0:
english_keywords = self.optimize_keywords_regex(english_keywords) english_keywords = self.optimize_keywords_regex(english_keywords)
datapoint_keywords[datapoint] += english_keywords datapoint_keywords[datapoint] += english_keywords
return datapoint_keywords return datapoint_keywords
def optimize_keywords_regex(self, keywords: list) -> list: def optimize_keywords_regex(self, keywords: list) -> list:
new_keywords = [] new_keywords = []
@ -110,15 +125,78 @@ class FilterPages:
} }
""" """
result = {"doc_id": self.doc_id} result = {"doc_id": self.doc_id}
result_details = []
for datapoint in self.datapoint_config.keys(): for datapoint in self.datapoint_config.keys():
result[datapoint] = [] result[datapoint] = []
for page_num, page_text in self.page_text_dict.items(): for page_num, page_text in self.page_text_dict.items():
text = clean_text(page_text) text = "\n" + clean_text(page_text) + "\n"
for datapoint, keywords in self.datapoint_config.items(): for datapoint, keywords in self.datapoint_config.items():
# idx = idx & np.array([re.findall(r'\b' + word + r'\d*\b', page) != [] for page in self.pages_clean]) # idx = idx & np.array([re.findall(r'\b' + word + r'\d*\b', page) != [] for page in self.pages_clean])
find_datapoint = False
for keyword in keywords: for keyword in keywords:
search_regex = r"\b{0}\d*\b\s*".format(keyword) search_iter = self.search_keyword(text, keyword)
if re.search(search_regex, text, re.IGNORECASE): for search in search_iter:
search_text = search.group().strip()
exclude_search_list = self.search_exclude_keywords(text, datapoint)
if exclude_search_list is not None:
need_exclude = False
for exclude_search_text in exclude_search_list:
if search_text in exclude_search_text:
need_exclude = True
break
if need_exclude:
continue
is_valid = self.search_in_sentence_is_valid(search_text, text)
if not is_valid:
continue
result[datapoint].append(page_num) result[datapoint].append(page_num)
detail = {
"doc_id": self.doc_id,
"datapoint": datapoint,
"page_num": page_num,
"keyword": keyword,
"text": search_text
}
result_details.append(detail)
find_datapoint = True
break break
return result if find_datapoint:
break
return result, result_details
def search_in_sentence_is_valid(self,
search_text: str,
text: str):
search_text_regex = add_slash_to_text_as_regex(search_text)
search_regex = r"\n.*{0}.*\n".format(search_text_regex)
search_iter = re.finditer(search_regex, text, re.IGNORECASE)
is_valid = False
lower_word_count_threshold = 7
for search in search_iter:
lower_word_count = 0
if search is not None:
# get the word count in search text which start with lower case
search_text = search.group().strip()
search_text_split = search_text.split()
for split in search_text_split:
if split[0].islower():
lower_word_count += 1
if lower_word_count < lower_word_count_threshold:
is_valid = True
break
return is_valid
def search_keyword(self, text: str, keyword: str):
search_regex = r"\b{0}\d*\W*\s*\b".format(keyword)
return re.finditer(search_regex, text, re.IGNORECASE)
def search_exclude_keywords(self, text: str, datapoint: str):
exclude_keywords = self.datapoint_exclude_config.get(datapoint, [])
search_list = []
for keyword in exclude_keywords:
search_iter = self.search_keyword(text, keyword)
for search in search_iter:
search_list.append(search.group())
return search_list

90
main.py
View File

@ -8,6 +8,7 @@ from utils.logger import logger
from utils.pdf_download import download_pdf_from_documents_warehouse from utils.pdf_download import download_pdf_from_documents_warehouse
from utils.sql_query_util import query_document_fund_mapping from utils.sql_query_util import query_document_fund_mapping
from core.page_filter import FilterPages from core.page_filter import FilterPages
from core.metrics import Metrics
class EMEA_AR_Parsing: class EMEA_AR_Parsing:
@ -17,7 +18,7 @@ class EMEA_AR_Parsing:
os.makedirs(self.pdf_folder, exist_ok=True) os.makedirs(self.pdf_folder, exist_ok=True)
self.pdf_file = self.download_pdf() self.pdf_file = self.download_pdf()
self.document_mapping_info_df = query_document_fund_mapping(doc_id) self.document_mapping_info_df = query_document_fund_mapping(doc_id)
self.datapoint_page_info = self.get_datapoint_page_info() self.datapoint_page_info, self.result_details = self.get_datapoint_page_info()
def download_pdf(self) -> str: def download_pdf(self) -> str:
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id) pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
@ -27,26 +28,51 @@ class EMEA_AR_Parsing:
filter_pages = FilterPages( filter_pages = FilterPages(
self.doc_id, self.pdf_file, self.document_mapping_info_df self.doc_id, self.pdf_file, self.document_mapping_info_df
) )
datapoint_page_info = filter_pages.start_job() datapoint_page_info, result_details = filter_pages.start_job()
return datapoint_page_info return datapoint_page_info, result_details
def filter_pages(doc_id: str, pdf_folder: str) -> None: def filter_pages(doc_id: str, pdf_folder: str) -> None:
logger.info(f"Parsing EMEA AR for doc_id: {doc_id}") logger.info(f"Parsing EMEA AR for doc_id: {doc_id}")
emea_ar_parsing = EMEA_AR_Parsing(doc_id, pdf_folder) emea_ar_parsing = EMEA_AR_Parsing(doc_id, pdf_folder)
return emea_ar_parsing.datapoint_page_info return emea_ar_parsing.datapoint_page_info, emea_ar_parsing.result_details
def batch_filter_pdf_files(pdf_folder: str, output_folder: str) -> None: def batch_filter_pdf_files(
pdf_folder: str,
doc_data_excel_file: str = None,
output_folder: str = r"/data/emea_ar/output/filter_pages/",
special_doc_id_list: list = None,
) -> None:
pdf_files = glob(pdf_folder + "*.pdf") pdf_files = glob(pdf_folder + "*.pdf")
doc_list = []
if special_doc_id_list is not None and len(special_doc_id_list) > 0:
doc_list = special_doc_id_list
if (
len(doc_list) == 0
and doc_data_excel_file is not None
and len(doc_data_excel_file) > 0
and os.path.exists(doc_data_excel_file)
):
doc_data_df = pd.read_excel(doc_data_excel_file)
doc_data_df = doc_data_df[doc_data_df["Checked"] == 1]
doc_list = [str(doc_id) for doc_id in doc_data_df["doc_id"].tolist()]
result_list = [] result_list = []
result_details = []
for pdf_file in tqdm(pdf_files): for pdf_file in tqdm(pdf_files):
pdf_base_name = os.path.basename(pdf_file) pdf_base_name = os.path.basename(pdf_file)
doc_id = pdf_base_name.split(".")[0] doc_id = pdf_base_name.split(".")[0]
datapoint_page_info = filter_pages(doc_id=doc_id, pdf_folder=pdf_folder) if doc_list is not None and doc_id not in doc_list:
result_list.append(datapoint_page_info) continue
doc_datapoint_page_info, doc_result_details = filter_pages(doc_id=doc_id, pdf_folder=pdf_folder)
result_list.append(doc_datapoint_page_info)
result_details.extend(doc_result_details)
result_df = pd.DataFrame(result_list) result_df = pd.DataFrame(result_list)
result_df.reset_index(drop=True, inplace=True) result_df.reset_index(drop=True, inplace=True)
result_details_df = pd.DataFrame(result_details)
result_details_df.reset_index(drop=True, inplace=True)
logger.info(f"Saving the result to {output_folder}") logger.info(f"Saving the result to {output_folder}")
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
@ -56,10 +82,54 @@ def batch_filter_pdf_files(pdf_folder: str, output_folder: str) -> None:
f"datapoint_page_info_{len(result_df)}_documents_{time_stamp}.xlsx", f"datapoint_page_info_{len(result_df)}_documents_{time_stamp}.xlsx",
) )
with pd.ExcelWriter(output_file) as writer: with pd.ExcelWriter(output_file) as writer:
result_df.to_excel(writer, index=False) result_df.to_excel(writer, index=False, sheet_name="dp_page_info")
result_details_df.to_excel(writer, index=False, sheet_name="dp_page_info_details")
if len(special_doc_id_list) == 0:
logger.info(f"Calculating metrics for {output_file}")
metrics_output_folder = r"/data/emea_ar/output/metrics/"
missing_error_list, metrics_list, metrics_file = get_metrics(
data_type="page_filter",
prediction_file=output_file,
prediction_sheet_name="dp_page_info",
ground_truth_file=doc_data_excel_file,
output_folder=metrics_output_folder,
)
return missing_error_list, metrics_list, metrics_file
def get_metrics(
data_type: str,
prediction_file: str,
prediction_sheet_name: str,
ground_truth_file: str,
output_folder: str = None
) -> None:
metrics = Metrics(
data_type=data_type,
prediction_file=prediction_file,
prediction_sheet_name=prediction_sheet_name,
ground_truth_file=ground_truth_file,
output_folder=output_folder
)
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
return missing_error_list, metrics_list, metrics_file
if __name__ == "__main__": if __name__ == "__main__":
pdf_folder = r"/data/emea_ar/small_pdf/" pdf_folder = r"/data/emea_ar/small_pdf/"
output_folder = r"/data/emea_ar/output/filter_pages/" page_filter_ground_truth_file = (
batch_filter_pdf_files(pdf_folder, output_folder) r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx"
)
prediction_output_folder = r"/data/emea_ar/output/filter_pages/"
metrics_output_folder = r"/data/emea_ar/output/metrics/"
special_doc_id_list = []
batch_filter_pdf_files(
pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list
)
# data_type = "page_filter"
# prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx"
# missing_error_list, metrics_list, metrics_file = get_metrics(
# data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder
# )

View File

@ -15,8 +15,8 @@ def add_slash_to_text_as_regex(text: str):
def clean_text(text: str) -> str: def clean_text(text: str) -> str:
text = text.lower() # text = text.lower()
# update the specical character which begin with \u, e.g \u2004 or \u00a0 to be space # update the specical character which begin with \u, e.g \u2004 or \u00a0 to be space
text = re.sub(r"\\u[0-9a-z]{4}", ' ', text) text = re.sub(r"\\u[A-Z0-9a-z]{4}", ' ', text)
text = re.sub(r"( ){2,}", ' ', text.strip()) text = re.sub(r"( ){2,}", ' ', text.strip())
return text return text