support calculate page filter metrics.
This commit is contained in:
parent
f81e2862f3
commit
7198450e53
|
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"ISIN": {
|
||||||
|
"english": []
|
||||||
|
},
|
||||||
|
"ter": {
|
||||||
|
"english": []
|
||||||
|
},
|
||||||
|
"tor": {
|
||||||
|
"english": []
|
||||||
|
},
|
||||||
|
"ogc": {
|
||||||
|
"english": ["operating expenses paid"]
|
||||||
|
},
|
||||||
|
"performance_fee": {
|
||||||
|
"english": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -299,7 +299,6 @@
|
||||||
"Ongoing Fund Charge",
|
"Ongoing Fund Charge",
|
||||||
"Operating Charge",
|
"Operating Charge",
|
||||||
"Operating Charges",
|
"Operating Charges",
|
||||||
"Operating expenses",
|
|
||||||
"Operating, Administrative and Servicing Expenses"
|
"Operating, Administrative and Servicing Expenses"
|
||||||
],
|
],
|
||||||
"spanish": [
|
"spanish": [
|
||||||
|
|
@ -340,8 +339,12 @@
|
||||||
"Performance Fees",
|
"Performance Fees",
|
||||||
"performance-based fee",
|
"performance-based fee",
|
||||||
"performance-related fee",
|
"performance-related fee",
|
||||||
"with performance)",
|
"Performance- related Fee",
|
||||||
"with performance fee)"
|
"perform- mance fees",
|
||||||
|
"per- formance fees",
|
||||||
|
"with performance",
|
||||||
|
"with performance fee",
|
||||||
|
"de Performance"
|
||||||
],
|
],
|
||||||
"spanish": [
|
"spanish": [
|
||||||
"Comisión de Gestión sobre Resultados",
|
"Comisión de Gestión sobre Resultados",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,237 @@
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||||
|
from utils.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class Metrics:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_type: str,
|
||||||
|
prediction_file: str,
|
||||||
|
prediction_sheet_name: str = "Sheet1",
|
||||||
|
ground_truth_file: str = None,
|
||||||
|
output_folder: str = None,
|
||||||
|
) -> None:
|
||||||
|
self.data_type = data_type
|
||||||
|
self.prediction_file = prediction_file
|
||||||
|
self.prediction_sheet_name = prediction_sheet_name
|
||||||
|
self.ground_truth_file = ground_truth_file
|
||||||
|
|
||||||
|
if output_folder is None or len(output_folder) == 0:
|
||||||
|
output_folder = r"/data/emea_ar/output/metrics/"
|
||||||
|
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
self.output_file = os.path.join(
|
||||||
|
output_folder,
|
||||||
|
f"metrics_{data_type}_{time_stamp}.xlsx",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_metrics(self):
|
||||||
|
if (
|
||||||
|
self.prediction_file is None
|
||||||
|
or len(self.prediction_file) == 0
|
||||||
|
or not os.path.exists(self.prediction_file)
|
||||||
|
):
|
||||||
|
logger.error(f"Invalid prediction file: {self.prediction_file}")
|
||||||
|
return []
|
||||||
|
if (
|
||||||
|
self.ground_truth_file is None
|
||||||
|
or len(self.ground_truth_file) == 0
|
||||||
|
or not os.path.exists(self.ground_truth_file)
|
||||||
|
):
|
||||||
|
logger.error(f"Invalid ground truth file: {self.ground_truth_file}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
metrics_list = [
|
||||||
|
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
|
||||||
|
]
|
||||||
|
if self.data_type == "page_filter":
|
||||||
|
missing_error_list, metrics_list = self.get_page_filter_metrics()
|
||||||
|
elif self.data_type == "datapoint":
|
||||||
|
missing_error_list, metrics_list = self.get_datapoint_metrics()
|
||||||
|
else:
|
||||||
|
logger.error(f"Invalid data type: {self.data_type}")
|
||||||
|
|
||||||
|
missing_error_df = pd.DataFrame(missing_error_list)
|
||||||
|
missing_error_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
metrics_df = pd.DataFrame(metrics_list)
|
||||||
|
metrics_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
with pd.ExcelWriter(self.output_file) as writer:
|
||||||
|
missing_error_df.to_excel(writer, sheet_name="Missing_Error", index=False)
|
||||||
|
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
|
||||||
|
return missing_error_list, metrics_list, self.output_file
|
||||||
|
|
||||||
|
def get_page_filter_metrics(self):
|
||||||
|
prediction_df = pd.read_excel(self.prediction_file, sheet_name=self.prediction_sheet_name)
|
||||||
|
ground_truth_df = pd.read_excel(self.ground_truth_file, sheet_name="Sheet1")
|
||||||
|
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
|
||||||
|
|
||||||
|
tor_true = []
|
||||||
|
tor_pred = []
|
||||||
|
|
||||||
|
ter_true = []
|
||||||
|
ter_pred = []
|
||||||
|
|
||||||
|
ogc_true = []
|
||||||
|
ogc_pred = []
|
||||||
|
|
||||||
|
performance_fee_true = []
|
||||||
|
performance_fee_pred = []
|
||||||
|
|
||||||
|
missing_error_list = []
|
||||||
|
data_point_list = ["tor", "ter", "ogc", "performance_fee"]
|
||||||
|
|
||||||
|
for index, row in ground_truth_df.iterrows():
|
||||||
|
doc_id = row["doc_id"]
|
||||||
|
# get first row with the same doc_id
|
||||||
|
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[0]
|
||||||
|
for data_point in data_point_list:
|
||||||
|
true_data, pred_data, missing_error_data = self.get_true_pred_data(
|
||||||
|
doc_id, row, prediction_data, data_point
|
||||||
|
)
|
||||||
|
if data_point == "tor":
|
||||||
|
tor_true.extend(true_data)
|
||||||
|
tor_pred.extend(pred_data)
|
||||||
|
elif data_point == "ter":
|
||||||
|
ter_true.extend(true_data)
|
||||||
|
ter_pred.extend(pred_data)
|
||||||
|
elif data_point == "ogc":
|
||||||
|
ogc_true.extend(true_data)
|
||||||
|
ogc_pred.extend(pred_data)
|
||||||
|
elif data_point == "performance_fee":
|
||||||
|
performance_fee_true.extend(true_data)
|
||||||
|
performance_fee_pred.extend(pred_data)
|
||||||
|
missing_error_list.append(missing_error_data)
|
||||||
|
|
||||||
|
metrics_list = []
|
||||||
|
for data_point in data_point_list:
|
||||||
|
if data_point == "tor":
|
||||||
|
precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred)
|
||||||
|
metrics_list.append(
|
||||||
|
{
|
||||||
|
"Data_Point": data_point,
|
||||||
|
"Precision": precision,
|
||||||
|
"Recall": recall,
|
||||||
|
"F1": f1,
|
||||||
|
"Support": len(tor_true),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(tor_true)}"
|
||||||
|
)
|
||||||
|
elif data_point == "ter":
|
||||||
|
precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred)
|
||||||
|
metrics_list.append(
|
||||||
|
{
|
||||||
|
"Data_Point": data_point,
|
||||||
|
"Precision": precision,
|
||||||
|
"Recall": recall,
|
||||||
|
"F1": f1,
|
||||||
|
"Support": len(ter_true),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(ter_true)}"
|
||||||
|
)
|
||||||
|
elif data_point == "ogc":
|
||||||
|
precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred)
|
||||||
|
metrics_list.append(
|
||||||
|
{
|
||||||
|
"Data_Point": data_point,
|
||||||
|
"Precision": precision,
|
||||||
|
"Recall": recall,
|
||||||
|
"F1": f1,
|
||||||
|
"Support": len(ogc_true),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(ogc_true)}"
|
||||||
|
)
|
||||||
|
elif data_point == "performance_fee":
|
||||||
|
precision, recall, f1 = self.get_specific_metrics(
|
||||||
|
performance_fee_true, performance_fee_pred
|
||||||
|
)
|
||||||
|
metrics_list.append(
|
||||||
|
{
|
||||||
|
"Data_Point": data_point,
|
||||||
|
"Precision": precision,
|
||||||
|
"Recall": recall,
|
||||||
|
"F1": f1,
|
||||||
|
"Support": len(performance_fee_true),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {len(performance_fee_true)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get average metrics
|
||||||
|
precision_list = [metric["Precision"] for metric in metrics_list]
|
||||||
|
recall_list = [metric["Recall"] for metric in metrics_list]
|
||||||
|
f1_list = [metric["F1"] for metric in metrics_list]
|
||||||
|
metrics_list.append(
|
||||||
|
{
|
||||||
|
"Data_Point": "Average",
|
||||||
|
"Precision": sum(precision_list) / len(precision_list),
|
||||||
|
"Recall": sum(recall_list) / len(recall_list),
|
||||||
|
"F1": sum(f1_list) / len(f1_list),
|
||||||
|
"Support": sum([metric["Support"] for metric in metrics_list]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return missing_error_list, metrics_list
|
||||||
|
|
||||||
|
def get_true_pred_data(
|
||||||
|
self, doc_id, ground_truth_data: pd.Series, prediction_data: pd.Series, data_point: str
|
||||||
|
):
|
||||||
|
ground_truth_list = ground_truth_data[data_point]
|
||||||
|
if isinstance(ground_truth_list, str):
|
||||||
|
ground_truth_list = json.loads(ground_truth_list)
|
||||||
|
prediction_list = prediction_data[data_point]
|
||||||
|
if isinstance(prediction_list, str):
|
||||||
|
prediction_list = json.loads(prediction_list)
|
||||||
|
|
||||||
|
true_data = []
|
||||||
|
pred_data = []
|
||||||
|
|
||||||
|
|
||||||
|
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": [], "error": []}
|
||||||
|
|
||||||
|
missing_data = []
|
||||||
|
error_data = []
|
||||||
|
|
||||||
|
if len(ground_truth_list) == 0 and len(prediction_list) == 0:
|
||||||
|
true_data.append(1)
|
||||||
|
pred_data.append(1)
|
||||||
|
return true_data, pred_data, missing_error_data
|
||||||
|
|
||||||
|
for prediction in prediction_list:
|
||||||
|
if prediction in ground_truth_list:
|
||||||
|
true_data.append(1)
|
||||||
|
pred_data.append(1)
|
||||||
|
else:
|
||||||
|
true_data.append(0)
|
||||||
|
pred_data.append(1)
|
||||||
|
error_data.append(prediction)
|
||||||
|
|
||||||
|
for ground_truth in ground_truth_list:
|
||||||
|
if ground_truth not in prediction_list:
|
||||||
|
true_data.append(1)
|
||||||
|
pred_data.append(0)
|
||||||
|
missing_data.append(ground_truth)
|
||||||
|
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": missing_data, "error": error_data}
|
||||||
|
|
||||||
|
return true_data, pred_data, missing_error_data
|
||||||
|
|
||||||
|
def get_specific_metrics(self, true_data: list, pred_data: list):
|
||||||
|
precision = precision_score(true_data, pred_data)
|
||||||
|
recall = recall_score(true_data, pred_data)
|
||||||
|
f1 = f1_score(true_data, pred_data)
|
||||||
|
return precision, recall, f1
|
||||||
|
|
||||||
|
def get_datapoint_metrics(self):
|
||||||
|
pass
|
||||||
|
|
@ -22,7 +22,7 @@ class FilterPages:
|
||||||
self.document_mapping_info_df = document_mapping_info_df
|
self.document_mapping_info_df = document_mapping_info_df
|
||||||
self.get_configuration_from_file()
|
self.get_configuration_from_file()
|
||||||
self.doc_info = self.get_doc_info()
|
self.doc_info = self.get_doc_info()
|
||||||
self.datapoint_config = self.get_datapoint_config()
|
self.datapoint_config, self.datapoint_exclude_config = self.get_datapoint_config()
|
||||||
|
|
||||||
def get_pdf_page_text_dict(self) -> dict:
|
def get_pdf_page_text_dict(self) -> dict:
|
||||||
pdf_util = PDFUtil(self.pdf_file)
|
pdf_util = PDFUtil(self.pdf_file)
|
||||||
|
|
@ -33,12 +33,15 @@ class FilterPages:
|
||||||
language_config_file = r"./configuration/language.json"
|
language_config_file = r"./configuration/language.json"
|
||||||
domicile_datapoint_config_file = r"./configuration/domicile_datapoints.json"
|
domicile_datapoint_config_file = r"./configuration/domicile_datapoints.json"
|
||||||
datapoint_keywords_config_file = r"./configuration/datapoint_keyword.json"
|
datapoint_keywords_config_file = r"./configuration/datapoint_keyword.json"
|
||||||
|
datapoint_exclude_keywords_config_file = r"./configuration/datapoint_exclude_keyword.json"
|
||||||
with open(language_config_file, "r", encoding="utf-8") as file:
|
with open(language_config_file, "r", encoding="utf-8") as file:
|
||||||
self.language_config = json.load(file)
|
self.language_config = json.load(file)
|
||||||
with open(domicile_datapoint_config_file, "r", encoding="utf-8") as file:
|
with open(domicile_datapoint_config_file, "r", encoding="utf-8") as file:
|
||||||
self.domicile_datapoint_config = json.load(file)
|
self.domicile_datapoint_config = json.load(file)
|
||||||
with open(datapoint_keywords_config_file, "r", encoding="utf-8") as file:
|
with open(datapoint_keywords_config_file, "r", encoding="utf-8") as file:
|
||||||
self.datapoint_keywords_config = json.load(file)
|
self.datapoint_keywords_config = json.load(file)
|
||||||
|
with open(datapoint_exclude_keywords_config_file, "r", encoding="utf-8") as file:
|
||||||
|
self.datapoint_exclude_keywords_config = json.load(file)
|
||||||
|
|
||||||
def get_doc_info(self) -> dict:
|
def get_doc_info(self) -> dict:
|
||||||
if len(self.document_mapping_info_df) == 0:
|
if len(self.document_mapping_info_df) == 0:
|
||||||
|
|
@ -77,14 +80,26 @@ class FilterPages:
|
||||||
if self.domicile_datapoint_config[domicile].get(document_type, None) is None:
|
if self.domicile_datapoint_config[domicile].get(document_type, None) is None:
|
||||||
document_type = "ar"
|
document_type = "ar"
|
||||||
datapoint_list = self.domicile_datapoint_config[domicile][document_type]
|
datapoint_list = self.domicile_datapoint_config[domicile][document_type]
|
||||||
|
datapoint_keywords = self.get_keywords("include", datapoint_list, language)
|
||||||
|
datapoint_exclude_keywords = self.get_keywords("exclude", datapoint_list, language)
|
||||||
|
return datapoint_keywords, datapoint_exclude_keywords
|
||||||
|
|
||||||
|
def get_keywords(self, keywords_type: str, datapoint_list: list, language: str) -> dict:
|
||||||
|
if keywords_type == "include":
|
||||||
|
config = self.datapoint_keywords_config
|
||||||
|
elif keywords_type == "exclude":
|
||||||
|
config = self.datapoint_exclude_keywords_config
|
||||||
|
else:
|
||||||
|
config = self.datapoint_keywords_config
|
||||||
datapoint_keywords = {}
|
datapoint_keywords = {}
|
||||||
|
|
||||||
for datapoint in datapoint_list:
|
for datapoint in datapoint_list:
|
||||||
keywords = self.datapoint_keywords_config.get(datapoint, {}).get(language, [])
|
keywords = config.get(datapoint, {}).get(language, [])
|
||||||
if len(keywords) > 0:
|
if len(keywords) > 0:
|
||||||
keywords = self.optimize_keywords_regex(keywords)
|
keywords = self.optimize_keywords_regex(keywords)
|
||||||
datapoint_keywords[datapoint] = keywords
|
datapoint_keywords[datapoint] = keywords
|
||||||
if language != "english":
|
if language != "english":
|
||||||
english_keywords = self.datapoint_keywords_config.get(datapoint, {}).get("english", [])
|
english_keywords = config.get(datapoint, {}).get("english", [])
|
||||||
if len(english_keywords) > 0:
|
if len(english_keywords) > 0:
|
||||||
english_keywords = self.optimize_keywords_regex(english_keywords)
|
english_keywords = self.optimize_keywords_regex(english_keywords)
|
||||||
datapoint_keywords[datapoint] += english_keywords
|
datapoint_keywords[datapoint] += english_keywords
|
||||||
|
|
@ -110,15 +125,78 @@ class FilterPages:
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
result = {"doc_id": self.doc_id}
|
result = {"doc_id": self.doc_id}
|
||||||
|
result_details = []
|
||||||
for datapoint in self.datapoint_config.keys():
|
for datapoint in self.datapoint_config.keys():
|
||||||
result[datapoint] = []
|
result[datapoint] = []
|
||||||
for page_num, page_text in self.page_text_dict.items():
|
for page_num, page_text in self.page_text_dict.items():
|
||||||
text = clean_text(page_text)
|
text = "\n" + clean_text(page_text) + "\n"
|
||||||
for datapoint, keywords in self.datapoint_config.items():
|
for datapoint, keywords in self.datapoint_config.items():
|
||||||
# idx = idx & np.array([re.findall(r'\b' + word + r'\d*\b', page) != [] for page in self.pages_clean])
|
# idx = idx & np.array([re.findall(r'\b' + word + r'\d*\b', page) != [] for page in self.pages_clean])
|
||||||
|
find_datapoint = False
|
||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
search_regex = r"\b{0}\d*\b\s*".format(keyword)
|
search_iter = self.search_keyword(text, keyword)
|
||||||
if re.search(search_regex, text, re.IGNORECASE):
|
for search in search_iter:
|
||||||
|
search_text = search.group().strip()
|
||||||
|
exclude_search_list = self.search_exclude_keywords(text, datapoint)
|
||||||
|
if exclude_search_list is not None:
|
||||||
|
need_exclude = False
|
||||||
|
for exclude_search_text in exclude_search_list:
|
||||||
|
if search_text in exclude_search_text:
|
||||||
|
need_exclude = True
|
||||||
|
break
|
||||||
|
if need_exclude:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_valid = self.search_in_sentence_is_valid(search_text, text)
|
||||||
|
if not is_valid:
|
||||||
|
continue
|
||||||
result[datapoint].append(page_num)
|
result[datapoint].append(page_num)
|
||||||
|
detail = {
|
||||||
|
"doc_id": self.doc_id,
|
||||||
|
"datapoint": datapoint,
|
||||||
|
"page_num": page_num,
|
||||||
|
"keyword": keyword,
|
||||||
|
"text": search_text
|
||||||
|
}
|
||||||
|
result_details.append(detail)
|
||||||
|
find_datapoint = True
|
||||||
break
|
break
|
||||||
return result
|
if find_datapoint:
|
||||||
|
break
|
||||||
|
return result, result_details
|
||||||
|
|
||||||
|
def search_in_sentence_is_valid(self,
|
||||||
|
search_text: str,
|
||||||
|
text: str):
|
||||||
|
search_text_regex = add_slash_to_text_as_regex(search_text)
|
||||||
|
search_regex = r"\n.*{0}.*\n".format(search_text_regex)
|
||||||
|
search_iter = re.finditer(search_regex, text, re.IGNORECASE)
|
||||||
|
is_valid = False
|
||||||
|
lower_word_count_threshold = 7
|
||||||
|
for search in search_iter:
|
||||||
|
lower_word_count = 0
|
||||||
|
if search is not None:
|
||||||
|
# get the word count in search text which start with lower case
|
||||||
|
search_text = search.group().strip()
|
||||||
|
search_text_split = search_text.split()
|
||||||
|
for split in search_text_split:
|
||||||
|
if split[0].islower():
|
||||||
|
lower_word_count += 1
|
||||||
|
if lower_word_count < lower_word_count_threshold:
|
||||||
|
is_valid = True
|
||||||
|
break
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
def search_keyword(self, text: str, keyword: str):
|
||||||
|
search_regex = r"\b{0}\d*\W*\s*\b".format(keyword)
|
||||||
|
return re.finditer(search_regex, text, re.IGNORECASE)
|
||||||
|
|
||||||
|
def search_exclude_keywords(self, text: str, datapoint: str):
|
||||||
|
exclude_keywords = self.datapoint_exclude_config.get(datapoint, [])
|
||||||
|
search_list = []
|
||||||
|
for keyword in exclude_keywords:
|
||||||
|
search_iter = self.search_keyword(text, keyword)
|
||||||
|
|
||||||
|
for search in search_iter:
|
||||||
|
search_list.append(search.group())
|
||||||
|
return search_list
|
||||||
90
main.py
90
main.py
|
|
@ -8,6 +8,7 @@ from utils.logger import logger
|
||||||
from utils.pdf_download import download_pdf_from_documents_warehouse
|
from utils.pdf_download import download_pdf_from_documents_warehouse
|
||||||
from utils.sql_query_util import query_document_fund_mapping
|
from utils.sql_query_util import query_document_fund_mapping
|
||||||
from core.page_filter import FilterPages
|
from core.page_filter import FilterPages
|
||||||
|
from core.metrics import Metrics
|
||||||
|
|
||||||
|
|
||||||
class EMEA_AR_Parsing:
|
class EMEA_AR_Parsing:
|
||||||
|
|
@ -17,7 +18,7 @@ class EMEA_AR_Parsing:
|
||||||
os.makedirs(self.pdf_folder, exist_ok=True)
|
os.makedirs(self.pdf_folder, exist_ok=True)
|
||||||
self.pdf_file = self.download_pdf()
|
self.pdf_file = self.download_pdf()
|
||||||
self.document_mapping_info_df = query_document_fund_mapping(doc_id)
|
self.document_mapping_info_df = query_document_fund_mapping(doc_id)
|
||||||
self.datapoint_page_info = self.get_datapoint_page_info()
|
self.datapoint_page_info, self.result_details = self.get_datapoint_page_info()
|
||||||
|
|
||||||
def download_pdf(self) -> str:
|
def download_pdf(self) -> str:
|
||||||
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
|
pdf_file = download_pdf_from_documents_warehouse(self.pdf_folder, self.doc_id)
|
||||||
|
|
@ -27,27 +28,52 @@ class EMEA_AR_Parsing:
|
||||||
filter_pages = FilterPages(
|
filter_pages = FilterPages(
|
||||||
self.doc_id, self.pdf_file, self.document_mapping_info_df
|
self.doc_id, self.pdf_file, self.document_mapping_info_df
|
||||||
)
|
)
|
||||||
datapoint_page_info = filter_pages.start_job()
|
datapoint_page_info, result_details = filter_pages.start_job()
|
||||||
return datapoint_page_info
|
return datapoint_page_info, result_details
|
||||||
|
|
||||||
|
|
||||||
def filter_pages(doc_id: str, pdf_folder: str) -> None:
|
def filter_pages(doc_id: str, pdf_folder: str) -> None:
|
||||||
logger.info(f"Parsing EMEA AR for doc_id: {doc_id}")
|
logger.info(f"Parsing EMEA AR for doc_id: {doc_id}")
|
||||||
emea_ar_parsing = EMEA_AR_Parsing(doc_id, pdf_folder)
|
emea_ar_parsing = EMEA_AR_Parsing(doc_id, pdf_folder)
|
||||||
return emea_ar_parsing.datapoint_page_info
|
return emea_ar_parsing.datapoint_page_info, emea_ar_parsing.result_details
|
||||||
|
|
||||||
|
|
||||||
def batch_filter_pdf_files(pdf_folder: str, output_folder: str) -> None:
|
def batch_filter_pdf_files(
|
||||||
|
pdf_folder: str,
|
||||||
|
doc_data_excel_file: str = None,
|
||||||
|
output_folder: str = r"/data/emea_ar/output/filter_pages/",
|
||||||
|
special_doc_id_list: list = None,
|
||||||
|
) -> None:
|
||||||
pdf_files = glob(pdf_folder + "*.pdf")
|
pdf_files = glob(pdf_folder + "*.pdf")
|
||||||
|
doc_list = []
|
||||||
|
if special_doc_id_list is not None and len(special_doc_id_list) > 0:
|
||||||
|
doc_list = special_doc_id_list
|
||||||
|
if (
|
||||||
|
len(doc_list) == 0
|
||||||
|
and doc_data_excel_file is not None
|
||||||
|
and len(doc_data_excel_file) > 0
|
||||||
|
and os.path.exists(doc_data_excel_file)
|
||||||
|
):
|
||||||
|
doc_data_df = pd.read_excel(doc_data_excel_file)
|
||||||
|
doc_data_df = doc_data_df[doc_data_df["Checked"] == 1]
|
||||||
|
doc_list = [str(doc_id) for doc_id in doc_data_df["doc_id"].tolist()]
|
||||||
result_list = []
|
result_list = []
|
||||||
|
result_details = []
|
||||||
for pdf_file in tqdm(pdf_files):
|
for pdf_file in tqdm(pdf_files):
|
||||||
pdf_base_name = os.path.basename(pdf_file)
|
pdf_base_name = os.path.basename(pdf_file)
|
||||||
doc_id = pdf_base_name.split(".")[0]
|
doc_id = pdf_base_name.split(".")[0]
|
||||||
datapoint_page_info = filter_pages(doc_id=doc_id, pdf_folder=pdf_folder)
|
if doc_list is not None and doc_id not in doc_list:
|
||||||
result_list.append(datapoint_page_info)
|
continue
|
||||||
|
doc_datapoint_page_info, doc_result_details = filter_pages(doc_id=doc_id, pdf_folder=pdf_folder)
|
||||||
|
result_list.append(doc_datapoint_page_info)
|
||||||
|
result_details.extend(doc_result_details)
|
||||||
|
|
||||||
result_df = pd.DataFrame(result_list)
|
result_df = pd.DataFrame(result_list)
|
||||||
result_df.reset_index(drop=True, inplace=True)
|
result_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
result_details_df = pd.DataFrame(result_details)
|
||||||
|
result_details_df.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
logger.info(f"Saving the result to {output_folder}")
|
logger.info(f"Saving the result to {output_folder}")
|
||||||
os.makedirs(output_folder, exist_ok=True)
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
|
@ -56,10 +82,54 @@ def batch_filter_pdf_files(pdf_folder: str, output_folder: str) -> None:
|
||||||
f"datapoint_page_info_{len(result_df)}_documents_{time_stamp}.xlsx",
|
f"datapoint_page_info_{len(result_df)}_documents_{time_stamp}.xlsx",
|
||||||
)
|
)
|
||||||
with pd.ExcelWriter(output_file) as writer:
|
with pd.ExcelWriter(output_file) as writer:
|
||||||
result_df.to_excel(writer, index=False)
|
result_df.to_excel(writer, index=False, sheet_name="dp_page_info")
|
||||||
|
result_details_df.to_excel(writer, index=False, sheet_name="dp_page_info_details")
|
||||||
|
|
||||||
|
if len(special_doc_id_list) == 0:
|
||||||
|
logger.info(f"Calculating metrics for {output_file}")
|
||||||
|
metrics_output_folder = r"/data/emea_ar/output/metrics/"
|
||||||
|
missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||||
|
data_type="page_filter",
|
||||||
|
prediction_file=output_file,
|
||||||
|
prediction_sheet_name="dp_page_info",
|
||||||
|
ground_truth_file=doc_data_excel_file,
|
||||||
|
output_folder=metrics_output_folder,
|
||||||
|
)
|
||||||
|
return missing_error_list, metrics_list, metrics_file
|
||||||
|
|
||||||
|
|
||||||
|
def get_metrics(
|
||||||
|
data_type: str,
|
||||||
|
prediction_file: str,
|
||||||
|
prediction_sheet_name: str,
|
||||||
|
ground_truth_file: str,
|
||||||
|
output_folder: str = None
|
||||||
|
) -> None:
|
||||||
|
metrics = Metrics(
|
||||||
|
data_type=data_type,
|
||||||
|
prediction_file=prediction_file,
|
||||||
|
prediction_sheet_name=prediction_sheet_name,
|
||||||
|
ground_truth_file=ground_truth_file,
|
||||||
|
output_folder=output_folder
|
||||||
|
)
|
||||||
|
missing_error_list, metrics_list, metrics_file = metrics.get_metrics()
|
||||||
|
return missing_error_list, metrics_list, metrics_file
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pdf_folder = r"/data/emea_ar/small_pdf/"
|
pdf_folder = r"/data/emea_ar/small_pdf/"
|
||||||
output_folder = r"/data/emea_ar/output/filter_pages/"
|
page_filter_ground_truth_file = (
|
||||||
batch_filter_pdf_files(pdf_folder, output_folder)
|
r"/data/emea_ar/ground_truth/page_filter/datapoint_page_info_88_documents.xlsx"
|
||||||
|
)
|
||||||
|
prediction_output_folder = r"/data/emea_ar/output/filter_pages/"
|
||||||
|
metrics_output_folder = r"/data/emea_ar/output/metrics/"
|
||||||
|
special_doc_id_list = []
|
||||||
|
batch_filter_pdf_files(
|
||||||
|
pdf_folder, page_filter_ground_truth_file, prediction_output_folder, special_doc_id_list
|
||||||
|
)
|
||||||
|
|
||||||
|
# data_type = "page_filter"
|
||||||
|
# prediction_file = r"/data/emea_ar/output/filter_pages/datapoint_page_info_73_documents_20240903145002.xlsx"
|
||||||
|
# missing_error_list, metrics_list, metrics_file = get_metrics(
|
||||||
|
# data_type, prediction_file, page_filter_ground_truth_file, metrics_output_folder
|
||||||
|
# )
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ def add_slash_to_text_as_regex(text: str):
|
||||||
|
|
||||||
|
|
||||||
def clean_text(text: str) -> str:
|
def clean_text(text: str) -> str:
|
||||||
text = text.lower()
|
# text = text.lower()
|
||||||
# update the specical character which begin with \u, e.g \u2004 or \u00a0 to be space
|
# update the specical character which begin with \u, e.g \u2004 or \u00a0 to be space
|
||||||
text = re.sub(r"\\u[0-9a-z]{4}", ' ', text)
|
text = re.sub(r"\\u[A-Z0-9a-z]{4}", ' ', text)
|
||||||
text = re.sub(r"( ){2,}", ' ', text.strip())
|
text = re.sub(r"( ){2,}", ' ', text.strip())
|
||||||
return text
|
return text
|
||||||
Loading…
Reference in New Issue