realize to calculate data extraction metrics.

This commit is contained in:
Blade He 2024-09-18 17:10:54 -05:00
parent 50e6c3c19d
commit 98e86a6cfd
5 changed files with 305 additions and 66 deletions

View File

@ -3,6 +3,7 @@ import pandas as pd
import time
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from utils.biz_utils import get_unique_words_text
from utils.logger import logger
@ -13,12 +14,14 @@ class Metrics:
prediction_file: str,
prediction_sheet_name: str = "Sheet1",
ground_truth_file: str = None,
ground_truth_sheet_name: str = "Sheet1",
output_folder: str = None,
) -> None:
self.data_type = data_type
self.prediction_file = prediction_file
self.prediction_sheet_name = prediction_sheet_name
self.ground_truth_file = ground_truth_file
self.ground_truth_sheet_name = ground_truth_sheet_name
if output_folder is None or len(output_folder) == 0:
output_folder = r"/data/emea_ar/output/metrics/"
@ -49,12 +52,8 @@ class Metrics:
metrics_list = [
{"Data_Point": "NAN", "Precision": 0, "Recall": 0, "F1": 0, "Support": 0}
]
if self.data_type == "page_filter":
missing_error_list, metrics_list = self.get_page_filter_metrics()
elif self.data_type == "datapoint":
missing_error_list, metrics_list = self.get_datapoint_metrics()
else:
logger.error(f"Invalid data type: {self.data_type}")
missing_error_list, metrics_list = self.get_metrics()
missing_error_df = pd.DataFrame(missing_error_list)
missing_error_df.reset_index(drop=True, inplace=True)
@ -67,9 +66,13 @@ class Metrics:
metrics_df.to_excel(writer, sheet_name="Metrics", index=False)
return missing_error_list, metrics_list, self.output_file
def get_page_filter_metrics(self):
prediction_df = pd.read_excel(self.prediction_file, sheet_name=self.prediction_sheet_name)
ground_truth_df = pd.read_excel(self.ground_truth_file, sheet_name="Sheet1")
def get_metrics(self):
prediction_df = pd.read_excel(
self.prediction_file, sheet_name=self.prediction_sheet_name
)
ground_truth_df = pd.read_excel(
self.ground_truth_file, sheet_name=self.ground_truth_sheet_name
)
ground_truth_df = ground_truth_df[ground_truth_df["Checked"] == 1]
tor_true = []
@ -87,14 +90,19 @@ class Metrics:
missing_error_list = []
data_point_list = ["tor", "ter", "ogc", "performance_fee"]
if self.data_type == "page_filter":
for index, row in ground_truth_df.iterrows():
doc_id = row["doc_id"]
# get first row with the same doc_id
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[0]
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id].iloc[
0
]
for data_point in data_point_list:
true_data, pred_data, missing_error_data = self.get_true_pred_data(
true_data, pred_data, missing_error_data = (
self.get_page_filter_true_pred_data(
doc_id, row, prediction_data, data_point
)
)
if data_point == "tor":
tor_true.extend(true_data)
tor_pred.extend(pred_data)
@ -108,6 +116,35 @@ class Metrics:
performance_fee_true.extend(true_data)
performance_fee_pred.extend(pred_data)
missing_error_list.append(missing_error_data)
else:
prediction_doc_id_list = prediction_df["doc_id"].unique().tolist()
ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist()
doc_id_list = list(set(prediction_doc_id_list + ground_truth_doc_id_list))
# order by doc_id
doc_id_list.sort()
for doc_id in doc_id_list:
prediction_data = prediction_df[prediction_df["doc_id"] == doc_id]
ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id]
for data_point in data_point_list:
true_data, pred_data, missing_error_data = (
self.get_data_extraction_true_pred_data(
doc_id, ground_truth_data, prediction_data, data_point
)
)
if data_point == "tor":
tor_true.extend(true_data)
tor_pred.extend(pred_data)
elif data_point == "ter":
ter_true.extend(true_data)
ter_pred.extend(pred_data)
elif data_point == "ogc":
ogc_true.extend(true_data)
ogc_pred.extend(pred_data)
elif data_point == "performance_fee":
performance_fee_true.extend(true_data)
performance_fee_pred.extend(pred_data)
missing_error_list.extend(missing_error_data)
metrics_list = []
for data_point in data_point_list:
@ -193,9 +230,12 @@ class Metrics:
# get the count which true_data is 1
return sum(true_data)
def get_true_pred_data(
self, doc_id, ground_truth_data: pd.Series, prediction_data: pd.Series, data_point: str
def get_page_filter_true_pred_data(
self,
doc_id,
ground_truth_data: pd.Series,
prediction_data: pd.Series,
data_point: str,
):
ground_truth_list = ground_truth_data[data_point]
if isinstance(ground_truth_list, str):
@ -207,8 +247,12 @@ class Metrics:
true_data = []
pred_data = []
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": [], "error": []}
missing_error_data = {
"doc_id": doc_id,
"data_point": data_point,
"missing": [],
"error": [],
}
missing_data = []
error_data = []
@ -232,7 +276,114 @@ class Metrics:
true_data.append(1)
pred_data.append(0)
missing_data.append(ground_truth)
missing_error_data = {"doc_id": doc_id, "data_point": data_point, "missing": missing_data, "error": error_data}
missing_error_data = {
"doc_id": doc_id,
"data_point": data_point,
"missing": missing_data,
"error": error_data,
}
return true_data, pred_data, missing_error_data
def get_data_extraction_true_pred_data(
self,
doc_id,
ground_truth_data: pd.DataFrame,
prediction_data: pd.DataFrame,
data_point: str,
):
dp_ground_truth = ground_truth_data[
ground_truth_data["datapoint"] == data_point
]
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
# add new column to store unique words for dp_ground_truth
dp_ground_truth["unique_words"] = dp_ground_truth["raw_name"].apply(
get_unique_words_text
)
ground_truth_unique_words = dp_ground_truth["unique_words"].unique().tolist()
# add new column to store unique words for dp_prediction
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
get_unique_words_text
)
pred_unique_words = dp_prediction["unique_words"].unique().tolist()
true_data = []
pred_data = []
missing_error_data = []
if len(dp_ground_truth) == 0 and len(dp_prediction) == 0:
true_data.append(1)
pred_data.append(1)
return true_data, pred_data, missing_error_data
for index, prediction in dp_prediction.iterrows():
pred_page_index = prediction["page_index"]
pred_raw_name = prediction["raw_name"]
pred_unique_words = prediction["unique_words"]
pred_data_point_value = prediction["value"]
pred_investment_type = prediction["investment_type"]
if pred_unique_words in ground_truth_unique_words:
# get the ground truth data with the same unique words
gt_data = dp_ground_truth[
dp_ground_truth["unique_words"] == pred_unique_words
].iloc[0]
gt_data_point_value = gt_data["value"]
if pred_data_point_value == gt_data_point_value:
true_data.append(1)
pred_data.append(1)
else:
true_data.append(0)
pred_data.append(1)
error_data = {
"doc_id": doc_id,
"data_point": data_point,
"page_index": pred_page_index,
"pred_raw_name": pred_raw_name,
"investment_type": pred_investment_type,
"error_type": "data value incorrect",
"error_value": pred_data_point_value,
"correct_value": gt_data_point_value,
}
missing_error_data.append(error_data)
else:
true_data.append(0)
pred_data.append(1)
error_data = {
"doc_id": doc_id,
"data_point": data_point,
"page_index": pred_page_index,
"pred_raw_name": pred_raw_name,
"investment_type": pred_investment_type,
"error_type": "raw name incorrect",
"error_value": pred_raw_name,
"correct_value": "",
}
missing_error_data.append(error_data)
for index, ground_truth in dp_ground_truth.iterrows():
gt_page_index = ground_truth["page_index"]
gt_raw_name = ground_truth["raw_name"]
gt_unique_words = ground_truth["unique_words"]
gt_data_point_value = ground_truth["value"]
gt_investment_type = ground_truth["investment_type"]
if gt_unique_words not in pred_unique_words:
true_data.append(1)
pred_data.append(0)
error_data = {
"doc_id": doc_id,
"data_point": data_point,
"page_index": gt_page_index,
"pred_raw_name": "",
"investment_type": gt_investment_type,
"error_type": "raw name missing",
"error_value": pred_data_point_value,
"correct_value": gt_raw_name,
}
missing_error_data.append(error_data)
return true_data, pred_data, missing_error_data

View File

@ -8,10 +8,20 @@
},
"data_business_features": {
"common": [
"Most of cases, the data is in the table(s) of context.",
"Fund name: a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.\nb. The sub-fund name may be as the first column values in the table.",
"General rules:",
"- Most of cases, the data is in the table(s) of context.",
"- Fund name: ",
"a. The full fund name should be main fund name + sub-fund name, e,g, main fund name is Black Rock European, sub-fund name is Growth, the full fund name is: Black Rock European Growth.",
"b. The sub-fund name may be as the first column or first row values in the table.",
"b.1 fund name example:",
"- context:",
"Summary information\nCapital International Fund Audited Annual Report 2023 | 15\nFootnotes are on page 17.\nCapital Group Multi-Sector \nIncome Fund (LUX) \n(CGMSILU)\nCapital Group US High Yield \nFund (LUX) (CGUSHYLU)\nCapital Group Emerging \nMarkets Debt Fund (LUX) \n(CGEMDLU)",
"fund names: Capital International Group Multi-Sector Income Fund (LUX), Capital International Group US High Yield Fund (LUX), Capital International Group Emerging Markets Debt Fund (LUX)",
"- Only extract the latest data from context:",
"If with multiple data values in same row, please extract the latest.",
"Only output the values which with significant reported names.\nPlease exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!",
"- Reported names:",
"Only output the values which with significant reported names.",
"Please exclude below reported names and relevant values: \"Management Fees\", \"Management\", \"Management Fees p.a.\", \"Taxe d Abonnement in % p.a.\".\nDON'T EXTRACT MANAGEMENT FEES!",
"One fund could be with multiple share classes and relevant share class level data values."
],
"investment_level": {
@ -106,7 +116,7 @@
"Only output the data point which with relevant value.",
"Don't ignore the data point which with negative value, e.g. -0.12, -1.13",
"Don't ignore the data point which with explicit zero value, e.g. 0, 0.00",
"Ignore the data point which with -, N/A, N/A%, N/A %, NONE, etc.",
"Ignore the data point which value with -, *, **, N/A, N/A%, N/A %, NONE, etc.",
"Fund level data: (\"fund name\" and \"TOR\") and share level data: (\"fund name\", \"share name\", \"ter\", \"performance fees\", \"ogc\") should be output separately.",
"The output should be JSON format, the format is like below example(s):"
],

10
main.py
View File

@ -217,6 +217,7 @@ def batch_start_job(
special_doc_id_list: list = None,
re_run_extract_data: bool = False,
re_run_mapping_data: bool = False,
force_save_total_data: bool = False,
):
pdf_files = glob(pdf_folder + "*.pdf")
doc_list = []
@ -250,7 +251,7 @@ def batch_start_job(
result_extract_data_list.extend(doc_data_from_gpt)
result_mapping_data_list.extend(doc_mapping_data_list)
if special_doc_id_list is None or len(special_doc_id_list) == 0:
if force_save_total_data or (special_doc_id_list is None or len(special_doc_id_list) == 0):
result_extract_data_df = pd.DataFrame(result_extract_data_list)
result_extract_data_df.reset_index(drop=True, inplace=True)
@ -505,10 +506,14 @@ if __name__ == "__main__":
# doc_id = "476492237"
# extract_data(doc_id, pdf_folder, output_extract_data_child_folder, re_run)
special_doc_id_list = ["508854243"]
special_doc_id_list = [
"525574973",
]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_mapping_data = True
force_save_total_data = False
batch_start_job(
pdf_folder,
page_filter_ground_truth_file,
@ -519,4 +524,5 @@ if __name__ == "__main__":
special_doc_id_list,
re_run_extract_data,
re_run_mapping_data,
force_save_total_data=force_save_total_data,
)

View File

@ -113,7 +113,11 @@ def analyze_json_error():
def statistics_document(
pdf_folder: str, doc_mapping_file_path: str, output_folder: str
pdf_folder: str,
doc_mapping_file_path: str,
sheet_name: str = "all_data",
output_folder: str = "/data/emea_ar/basic_information/English/",
output_file: str = "doc_mapping_statistics_data.xlsx"
):
if pdf_folder is None or len(pdf_folder) == 0 or not os.path.exists(pdf_folder):
logger.error(f"Invalid pdf_folder: {pdf_folder}")
@ -132,7 +136,7 @@ def statistics_document(
describe_stat_df_list = []
# statistics document mapping information
doc_mapping_data = pd.read_excel(doc_mapping_file_path, sheet_name="all_data")
doc_mapping_data = pd.read_excel(doc_mapping_file_path, sheet_name=sheet_name)
# statistics doc_mapping_data for counting FundId count based on DocumentId
logger.info(
@ -172,15 +176,15 @@ def statistics_document(
)
describe_stat_df_list.append(doc_share_class_count_stat_df)
# statistics doc_mapping_data for counting FundId count based on ProviderCompanyId and CompanyName
# statistics doc_mapping_data for counting FundId count based on CompanyId and CompanyName
logger.info(
"statistics doc_mapping_data for counting FundId count based on ProviderCompanyId and CompanyName"
"statistics doc_mapping_data for counting FundId count based on CompanyId and CompanyName"
)
provider_fund_id_df = doc_mapping_data[
["ProviderCompanyId", "CompanyName", "FundId"]
["CompanyId", "CompanyName", "FundId"]
].drop_duplicates()
provider_fund_count = (
provider_fund_id_df.groupby(["ProviderCompanyId", "CompanyName"])
provider_fund_id_df.groupby(["CompanyId", "CompanyName"])
.size()
.reset_index(name="fund_count")
)
@ -194,15 +198,15 @@ def statistics_document(
)
describe_stat_df_list.append(provider_fund_count_stat_df)
# statistics doc_mapping_data for counting FundClassId count based on ProviderCompanyId
# statistics doc_mapping_data for counting FundClassId count based on CompanyId
logger.info(
"statistics doc_mapping_data for counting FundClassId count based on ProviderCompanyId"
"statistics doc_mapping_data for counting FundClassId count based on CompanyId"
)
provider_share_class_id_df = doc_mapping_data[
["ProviderCompanyId", "CompanyName", "FundClassId"]
["CompanyId", "CompanyName", "FundClassId"]
].drop_duplicates()
provider_share_class_count = (
provider_share_class_id_df.groupby(["ProviderCompanyId", "CompanyName"])
provider_share_class_id_df.groupby(["CompanyId", "CompanyName"])
.size()
.reset_index(name="share_class_count")
)
@ -238,13 +242,18 @@ def statistics_document(
)
describe_stat_df_list.append(fund_share_class_count_stat_df)
stat_file = os.path.join(output_folder, "doc_mapping_statistics_data.xlsx")
stat_file = os.path.join(output_folder, output_file)
doc_id_list = [str(docid) for docid in doc_mapping_data["DocumentId"].unique().tolist()]
# statistics document page number
pdf_files = glob(os.path.join(pdf_folder, "*.pdf"))
logger.info(f"Total {len(pdf_files)} pdf files found in {pdf_folder}")
logger.info("statistics document page number")
doc_page_num_list = []
for pdf_file in tqdm(pdf_files):
pdf_base_name = os.path.basename(pdf_file).replace(".pdf", "")
if pdf_base_name not in doc_id_list:
continue
docid = os.path.basename(pdf_file).split(".")[0]
doc = fitz.open(pdf_file)
page_num = doc.page_count
@ -829,6 +838,46 @@ def pickup_document_from_top_100_providers():
)
def compare_records_count_by_document_id():
data_from_document = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
sheet_name = "mapping_data"
data_from_document_df = pd.read_excel(data_from_document, sheet_name=sheet_name)
data_from_document_df.rename(
columns={"doc_id": "DocumentId"}, inplace=True
)
# get the count of records by DocumentId
document_records_count = data_from_document_df.groupby("DocumentId").size().reset_index(name="records_count")
data_from_database = r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx"
sheet_name = "random_small_document_all_data"
data_from_database_df = pd.read_excel(data_from_database, sheet_name=sheet_name)
database_records_count = data_from_database_df.groupby("DocumentId").size().reset_index(name="records_count")
# merge document_records_count with database_records_count
records_count_compare = pd.merge(
document_records_count,
database_records_count,
on=["DocumentId"],
how="left",
)
records_count_compare["records_count_diff"] = records_count_compare["records_count_x"] - records_count_compare["records_count_y"]
records_count_compare = records_count_compare.sort_values(by="records_count_diff", ascending=False)
# rename records_count_x to records_count_document, records_count_y to records_count_database
records_count_compare.rename(
columns={"records_count_x": "records_count_document",
"records_count_y": "records_count_database"}, inplace=True
)
records_count_compare.reset_index(drop=True, inplace=True)
records_count_compare_file = (
r"/data/emea_ar/basic_information/English/records_count_compare_between_document_database.xlsx"
)
with pd.ExcelWriter(records_count_compare_file) as writer:
records_count_compare.to_excel(
writer, sheet_name="records_count_compare", index=False
)
if __name__ == "__main__":
doc_provider_file_path = (
r"/data/emea_ar/basic_information/English/latest_provider_ar_document.xlsx"
@ -845,22 +894,35 @@ if __name__ == "__main__":
output_folder = r"/data/emea_ar/output/"
# get_unique_docids_from_doc_provider_data(doc_provider_file_path)
# download_pdf(doc_provider_file_path, 'doc_provider_count', pdf_folder)
pdf_folder = r"/data/emea_ar/small_pdf/"
# pdf_folder = r"/data/emea_ar/small_pdf/"
output_folder = r"/data/emea_ar/small_pdf_txt/"
random_small_document_data_file = (
r"/data/emea_ar/basic_information/English/lux_english_ar_top_100_provider_random_small_document.xlsx"
)
# download_pdf(random_small_document_data_file, 'random_small_document', pdf_folder)
# output_pdf_page_text(pdf_folder, output_folder)
# extract_pdf_table(pdf_folder, output_folder)
# analyze_json_error()
# statistics_document(pdf_folder, doc_mapping_file_path, basic_info_folder)
latest_top_100_provider_ar_data_file = r"/data/emea_ar/basic_information/English/top_100_provider_latest_document_most_mapping/lux_english_ar_from_top_100_provider_latest_document_with_most_mappings.xlsx"
# download_pdf(latest_top_100_provider_ar_data_file,
# 'latest_ar_document_most_mapping',
# pdf_folder)
output_data_folder = r"/data/emea_ar/basic_information/English/top_100_provider_latest_document_most_mapping/"
statistics_document(pdf_folder=pdf_folder,
doc_mapping_file_path=latest_top_100_provider_ar_data_file,
sheet_name="latest_doc_ar_data",
output_folder=output_data_folder,
output_file="latest_doc_ar_mapping_statistics.xlsx")
# statistics_provider_mapping(
# provider_mapping_data_file=provider_mapping_data_file,
# output_folder=basic_info_folder,
# )
# statistics_document_fund_share_count(doc_mapping_from_top_100_provider_file)
pickup_document_from_top_100_providers()
# pickup_document_from_top_100_providers()
# compare_records_count_by_document_id()

View File

@ -165,6 +165,16 @@ def remove_special_characters(text):
text = text.strip()
return text
def get_unique_words_text(text):
text = remove_special_characters(text)
text = text.lower()
text_split = text.split()
text_split = list(set(text_split))
# sort the list
text_split.sort()
return_text = ' '.join(text_split)
return return_text
def remove_numeric_characters(text):
# remove numeric characters