dc-ml-emea-ar/calc_metrics.py

1376 lines
89 KiB
Python
Raw Normal View History

2025-03-05 15:57:02 +00:00
import os
from time import sleep
import pandas as pd
from glob import glob
from tqdm import tqdm
import numpy as np
from datetime import datetime
import re
import json
import traceback
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import requests
import fitz
from utils.similarity import Similarity
def calc_metrics(ground_truth_file: str, prediction_file: str):
"""
Calculate metrics by comparing ground truth and prediction files
"""
if not os.path.exists(ground_truth_file):
raise FileNotFoundError(f"File not found: {ground_truth_file}")
if not os.path.exists(prediction_file):
raise FileNotFoundError(f"File not found: {prediction_file}")
ground_truth_df = pd.read_excel(ground_truth_file)
prediction_df = pd.read_excel(prediction_file)
gt_auum_list = []
pred_auum_list = []
gt_tor_list = []
pred_tor_list = []
columns = ["fund_name", "auum", "tor"]
# Check whether the ground truth file contains the same values as the prediction file
# The purpose is to calculate Recall
for gt_index, gt_row in ground_truth_df.iterrows():
gt_fund_name = gt_row["fund_name"]
gt_auum = gt_row["auum"]
gt_tor = gt_row["tor"]
find_auum_flag = False
find_tor_flag = False
for pred_index, pred_row in prediction_df.iterrows():
pred_fund_name = pred_row["fund_name"]
pred_auum = pred_row["auum"]
pred_tor = pred_row["tor"]
if gt_fund_name == pred_fund_name:
if gt_auum == pred_auum:
find_auum_flag = True
if gt_tor == pred_tor:
find_tor_flag = True
break
if find_auum_flag:
gt_auum_list.append(1)
pred_auum_list.append(1)
else:
gt_auum_list.append(1)
pred_auum_list.append(0)
if find_tor_flag:
gt_tor_list.append(1)
pred_tor_list.append(1)
else:
gt_tor_list.append(1)
pred_tor_list.append(0)
# Check whether the prediction file contains the same values as the ground truth file
# The purpose is to calculate Precision
for pred_index, pred_row in prediction_df.iterrows():
pred_fund_name = pred_row["fund_name"]
pred_auum = pred_row["auum"]
pred_tor = pred_row["tor"]
find_auum_flag = False
find_tor_flag = False
for gt_index, gt_row in ground_truth_df.iterrows():
gt_fund_name = gt_row["fund_name"]
gt_auum = gt_row["auum"]
gt_tor = gt_row["tor"]
if pred_fund_name == gt_fund_name:
if pred_auum == gt_auum:
find_auum_flag = True
if pred_tor == gt_tor:
find_tor_flag = True
break
if not find_auum_flag:
gt_auum_list.append(0)
pred_auum_list.append(1)
if not find_tor_flag:
gt_tor_list.append(0)
pred_tor_list.append(1)
precision_auum = precision_score(gt_auum_list, pred_auum_list)
recall_auum = recall_score(gt_auum_list, pred_auum_list)
f1_auum = f1_score(gt_auum_list, pred_auum_list)
accuracy_auum = accuracy_score(gt_auum_list, pred_auum_list)
precision_tor = precision_score(gt_tor_list, pred_tor_list)
recall_tor = recall_score(gt_tor_list, pred_tor_list)
f1_tor = f1_score(gt_tor_list, pred_tor_list)
accuracy_tor = accuracy_score(gt_tor_list, pred_tor_list)
print(f"AUUM Support: {sum(gt_auum_list)}")
print(f"F1 AUUM: {f1_auum}")
print(f"Precision AUUM: {precision_auum}")
print(f"Recall AUUM: {recall_auum}")
print(f"Accuracy AUUM: {accuracy_auum}\n")
print(f"TOR Support: {sum(gt_tor_list)}")
print(f"F1 TOR: {f1_tor}")
print(f"Precision TOR: {precision_tor}")
print(f"Recall TOR: {recall_tor}")
print(f"Accuracy TOR: {accuracy_tor}")
def transform_pdf_2_image():
"""
Transform pdf to image.
"""
import fitz
folder = r"/Users/bhe/OneDrive - MORNINGSTAR INC/Personal Document/US_Life/pay/"
pdf_file = r"Pay_Date_2025-02-14.pdf"
pdf_path = os.path.join(folder, pdf_file)
pdf_doc = fitz.open(pdf_path)
pdf_file_pure_name = pdf_file.replace(".pdf", "")
for page_num in range(pdf_doc.page_count):
page = pdf_doc.load_page(page_num)
image = page.get_pixmap(dpi=300)
image_path = os.path.join(folder, f"{pdf_file_pure_name}_{page_num}.png")
image.save(image_path)
def invoke_api_demo(doc_id: str = "407881493"):
headers = {"connection": "keep-alive", "content-type": "application/json"}
data = {
"doc_id": doc_id,
}
print(f"Start to invoke API for document: {doc_id}")
# url = 'https://internal-ts00006-stg-dcms-gpt-765982576.us-east-1.elb.amazonaws.com/automation/api/model/us_ar'
url = "http://127.0.0.1:8080/automation/api/model/emea_ar"
try:
response = requests.post(url, json=data, headers=headers)
print("API response status code: {0}".format(response.status_code))
json_data = json.loads(response.text)
print(json_data)
data_folder = r"/data/emea_ar/output/extract_data_by_api/"
os.makedirs(data_folder, exist_ok=True)
json_file = os.path.join(data_folder, f"{doc_id}.json")
with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4)
except Exception as e:
print("Meet exception: {0}".format(e))
def batch_run_documents():
document_id_list = [
"292989214",
"316237292",
"321733631",
"323390570",
"327956364",
"333207452",
"334718372",
"344636875",
"362246081",
"366179419",
"380945052",
"382366116",
"387202452",
"389171486",
"391456740",
"391736837",
"394778487",
"401684600",
"402113224",
"402181770",
"402397014",
"405803396",
"445102363",
"445256897",
"448265376",
"449555622",
"449623976",
"458291624",
"458359181",
"463081566",
"469138353",
"471641628",
"476492237",
"478585901",
"478586066",
"479042264",
"479793787",
"481475385",
"483617247",
"486378555",
"486383912",
"492121213",
"497497599",
"502693599",
"502821436",
"503194284",
"506559375",
"507967525",
"508854243",
"509845549",
"520879048",
"529925114",
]
for doc_id in document_id_list:
invoke_api_demo(doc_id)
def remove_ter_ogc_performance_fee_annotation():
data_folder = r"/data/emea_ar/output/extract_data_by_api/"
os.makedirs(data_folder, exist_ok=True)
# get all of json files from the folder
json_files = glob(os.path.join(data_folder, "*.json"))
remove_dp_list = ["ter", "ogc", "performance_fee"]
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
json_data = json.load(f)
annotation_data_list = json_data["annotation_data"]
remove_data_list = []
for annotation_data in annotation_data_list:
if annotation_data["data_point"] in remove_dp_list:
remove_data_list.append(annotation_data)
if len(remove_data_list) > 0:
for remove_data in remove_data_list:
if remove_data in annotation_data_list:
annotation_data_list.remove(remove_data)
with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4)
def output_part_of_pages(pdf_file: str, page_list: list, output_folder: str):
"""
Output part of pages from a pdf file to new pdf file.
:param pdf_file: str, the path of the pdf file.
:param page_list: list, the page number list.
:param output_folder: str, the output folder.
"""
pdf_doc = fitz.open(pdf_file)
pdf_file_pure_name = os.path.basename(pdf_file).replace(".pdf", "")
new_pdf = fitz.open()
print(f"output pages: {page_list} for {pdf_file_pure_name}")
for page_index in page_list:
new_pdf.insert_pdf(pdf_doc, from_page=page_index, to_page=page_index)
if output_folder is None or len(output_folder) == 0:
output_folder = r"./data/emea_ar/output/pdf_part/"
os.makedirs(output_folder, exist_ok=True)
new_pdf.save(os.path.join(output_folder, f"{pdf_file_pure_name}_part.pdf"))
def calculate_metrics_based_audit_file(is_strict: bool = False):
print("Start to calculate metrics based on audit file and verify file...")
audit_file_path = (
r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/Audited file_phase2.xlsx"
)
audit_data_sheets = ["Mayank - revised ", "Prathamesh - Revised"]
audit_fields = [
"doc_id",
"fund_name",
"management_fee_and_costs",
"management_fee",
"performance_fee",
"performance_fee_costs",
"buy_spread",
"sell_spread",
"minimum_initial_investment",
"recoverable_expenses",
"indirect_costs"
]
audit_data_list = []
for audit_data_sheet in audit_data_sheets:
sub_audit_data_df = pd.read_excel(audit_file_path, sheet_name=audit_data_sheet)
sub_audit_data_df = sub_audit_data_df[audit_fields]
audit_data_list.append(sub_audit_data_df)
audit_data_df = pd.concat(audit_data_list, ignore_index=True)
audit_data_df = audit_data_df.drop_duplicates()
audit_data_df.fillna("", inplace=True)
audit_data_df.reset_index(drop=True, inplace=True)
verify_file_path = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250205134704.xlsx"
verify_data_sheet = "total_data"
verify_fields = [
"DocumentId",
"raw_fund_name",
"management_fee_and_costs",
"management_fee",
"performance_fee",
"performance_fee_costs",
"buy_spread",
"sell_spread",
"minimum_initial_investment",
"recoverable_expenses",
"indirect_costs"
]
verify_data_df = pd.read_excel(verify_file_path, sheet_name=verify_data_sheet)
verify_data_df = verify_data_df[verify_fields]
verify_data_df = verify_data_df.drop_duplicates()
verify_data_df = verify_data_df.rename(columns={"DocumentId": "doc_id", "raw_fund_name": "fund_name"})
verify_data_df.fillna("", inplace=True)
verify_data_df.reset_index(drop=True, inplace=True)
if len(audit_data_df) == 0 or len(verify_data_df) == 0:
print("No data to calculate metrics.")
return
# Calculate metrics
gt_management_fee_and_costs_list = []
pred_management_fee_and_costs_list = []
gt_management_fee_list = []
pred_management_fee_list = []
gt_performance_fee_list = []
pred_performance_fee_list = []
gt_performance_fee_costs_list = []
pred_performance_fee_costs_list = []
gt_buy_spread_list = []
pred_buy_spread_list = []
gt_sell_spread_list = []
pred_sell_spread_list = []
gt_minimum_initial_investment_list = []
pred_minimum_initial_investment_list = []
gt_recoverable_expenses_list = []
pred_recoverable_expenses_list = []
gt_indirect_costs_list = []
pred_indirect_costs_list = []
document_id_list = audit_data_df["doc_id"].unique().tolist()
print(f"Total document count: {len(document_id_list)}")
print("Construct ground truth and prediction data...")
similarity = Similarity()
for document_id in document_id_list:
doc_audit_data = audit_data_df[audit_data_df["doc_id"] == document_id]
doc_verify_data = verify_data_df[verify_data_df["doc_id"] == document_id]
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
fund_name_split = fund_name.lower().split()
management_fee_and_costs = str(row["management_fee_and_costs"])
management_fee = str(row["management_fee"])
performance_fee = str(row["performance_fee"])
performance_fee_costs = str(row["performance_fee_costs"])
buy_spread = str(row["buy_spread"])
sell_spread = str(row["sell_spread"])
minimum_initial_investment = str(row["minimum_initial_investment"])
recoverable_expenses = str(row["recoverable_expenses"])
indirect_costs = str(row["indirect_costs"])
find_flag = False
for idx, r in doc_verify_data.iterrows():
v_fund_name = r["fund_name"]
if fund_name == v_fund_name:
find_flag = True
else:
v_fund_name_split = v_fund_name.lower().split()
name_similarity = similarity.jaccard_similarity(fund_name_split, v_fund_name_split)
if name_similarity > 0.8:
find_flag = True
if find_flag:
v_management_fee_and_costs = str(r["management_fee_and_costs"])
v_management_fee = str(r["management_fee"])
v_performance_fee = str(r["performance_fee"])
v_performance_fee_costs = str(r["performance_fee_costs"])
v_buy_spread = str(r["buy_spread"])
v_sell_spread = str(r["sell_spread"])
v_minimum_initial_investment = str(r["minimum_initial_investment"])
v_recoverable_expenses = str(r["recoverable_expenses"])
v_indirect_costs = str(r["indirect_costs"])
get_gt_pred_by_compare_values(management_fee_and_costs, v_management_fee_and_costs, gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
get_gt_pred_by_compare_values(management_fee, v_management_fee, gt_management_fee_list, pred_management_fee_list)
get_gt_pred_by_compare_values(performance_fee, v_performance_fee, gt_performance_fee_list, pred_performance_fee_list)
get_gt_pred_by_compare_values(performance_fee_costs, v_performance_fee_costs, gt_performance_fee_costs_list, pred_performance_fee_costs_list)
get_gt_pred_by_compare_values(buy_spread, v_buy_spread, gt_buy_spread_list, pred_buy_spread_list)
get_gt_pred_by_compare_values(sell_spread, v_sell_spread, gt_sell_spread_list, pred_sell_spread_list)
get_gt_pred_by_compare_values(minimum_initial_investment, v_minimum_initial_investment, gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
get_gt_pred_by_compare_values(recoverable_expenses, v_recoverable_expenses, gt_recoverable_expenses_list, pred_recoverable_expenses_list)
get_gt_pred_by_compare_values(indirect_costs, v_indirect_costs, gt_indirect_costs_list, pred_indirect_costs_list)
break
if not find_flag:
if management_fee_and_costs is not None and len(management_fee_and_costs) > 0:
gt_management_fee_and_costs_list.append(1)
pred_management_fee_and_costs_list.append(0)
if management_fee is not None and len(management_fee) > 0:
gt_management_fee_list.append(1)
pred_management_fee_list.append(0)
if performance_fee is not None and len(performance_fee) > 0:
gt_performance_fee_list.append(1)
pred_performance_fee_list.append(0)
if performance_fee_costs is not None and len(performance_fee_costs) > 0:
gt_performance_fee_costs_list.append(1)
pred_performance_fee_costs_list.append(0)
if buy_spread is not None and len(buy_spread) > 0:
gt_buy_spread_list.append(1)
pred_buy_spread_list.append(0)
if sell_spread is not None and len(sell_spread) > 0:
gt_sell_spread_list.append(1)
pred_sell_spread_list.append(0)
if minimum_initial_investment is not None and len(minimum_initial_investment) > 0:
gt_minimum_initial_investment_list.append(1)
pred_minimum_initial_investment_list.append(0)
if recoverable_expenses is not None and len(recoverable_expenses) > 0:
gt_recoverable_expenses_list.append(1)
pred_recoverable_expenses_list.append(0)
if indirect_costs is not None and len(indirect_costs) > 0:
gt_indirect_costs_list.append(1)
pred_indirect_costs_list.append(0)
if is_strict:
for idx, r in doc_verify_data.iterrows():
v_fund_name = r["fund_name"]
find_flag = False
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
if fund_name == v_fund_name:
find_flag = True
else:
v_fund_name_split = v_fund_name.lower().split()
fund_name_split = fund_name.lower().split()
name_similarity = similarity.jaccard_similarity(fund_name_split, v_fund_name_split)
if name_similarity > 0.8:
find_flag = True
if find_flag:
break
if not find_flag:
v_management_fee_and_costs = str(r["management_fee_and_costs"])
v_management_fee = str(r["management_fee"])
v_performance_fee = str(r["performance_fee"])
v_performance_fee_costs = str(r["performance_fee_costs"])
v_buy_spread = str(r["buy_spread"])
v_sell_spread = str(r["sell_spread"])
v_minimum_initial_investment = str(r["minimum_initial_investment"])
v_recoverable_expenses = str(r["recoverable_expenses"])
v_indirect_costs = str(r["indirect_costs"])
if v_management_fee_and_costs is not None and len(v_management_fee_and_costs) > 0:
gt_management_fee_and_costs_list.append(0)
pred_management_fee_and_costs_list.append(1)
if v_management_fee is not None and len(v_management_fee) > 0:
gt_management_fee_list.append(0)
pred_management_fee_list.append(1)
if v_performance_fee is not None and len(v_performance_fee) > 0:
gt_performance_fee_list.append(0)
pred_performance_fee_list.append(1)
if v_performance_fee_costs is not None and len(v_performance_fee_costs) > 0:
gt_performance_fee_costs_list.append(0)
pred_performance_fee_costs_list.append(1)
if v_buy_spread is not None and len(v_buy_spread) > 0:
gt_buy_spread_list.append(0)
pred_buy_spread_list.append(1)
if v_sell_spread is not None and len(v_sell_spread) > 0:
gt_sell_spread_list.append(0)
pred_sell_spread_list.append(1)
if v_minimum_initial_investment is not None and len(v_minimum_initial_investment) > 0:
gt_minimum_initial_investment_list.append(0)
pred_minimum_initial_investment_list.append(1)
if v_recoverable_expenses is not None and len(v_recoverable_expenses) > 0:
gt_recoverable_expenses_list.append(0)
pred_recoverable_expenses_list.append(1)
if v_indirect_costs is not None and len(v_indirect_costs) > 0:
gt_indirect_costs_list.append(0)
pred_indirect_costs_list.append(1)
# calculate metrics
print("Calculate metrics...")
precision_management_fee_and_costs = precision_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
recall_management_fee_and_costs = recall_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
f1_management_fee_and_costs = f1_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
accuracy_management_fee_and_costs = accuracy_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
support_management_fee_and_costs = sum(gt_management_fee_and_costs_list)
precision_management_fee = precision_score(gt_management_fee_list, pred_management_fee_list)
recall_management_fee = recall_score(gt_management_fee_list, pred_management_fee_list)
f1_management_fee = f1_score(gt_management_fee_list, pred_management_fee_list)
accuracy_management_fee = accuracy_score(gt_management_fee_list, pred_management_fee_list)
support_management_fee = sum(gt_management_fee_list)
precision_performance_fee = precision_score(gt_performance_fee_list, pred_performance_fee_list)
recall_performance_fee = recall_score(gt_performance_fee_list, pred_performance_fee_list)
f1_performance_fee = f1_score(gt_performance_fee_list, pred_performance_fee_list)
accuracy_performance_fee = accuracy_score(gt_performance_fee_list, pred_performance_fee_list)
support_performance_fee = sum(gt_performance_fee_list)
precision_performance_fee_costs = precision_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
recall_performance_fee_costs = recall_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
f1_performance_fee_costs = f1_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
accuracy_performance_fee_costs = accuracy_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
support_performance_fee_costs = sum(gt_performance_fee_costs_list)
precision_buy_spread = precision_score(gt_buy_spread_list, pred_buy_spread_list)
recall_buy_spread = recall_score(gt_buy_spread_list, pred_buy_spread_list)
f1_buy_spread = f1_score(gt_buy_spread_list, pred_buy_spread_list)
accuracy_buy_spread = accuracy_score(gt_buy_spread_list, pred_buy_spread_list)
support_buy_spread = sum(gt_buy_spread_list)
precision_sell_spread = precision_score(gt_sell_spread_list, pred_sell_spread_list)
recall_sell_spread = recall_score(gt_sell_spread_list, pred_sell_spread_list)
f1_sell_spread = f1_score(gt_sell_spread_list, pred_sell_spread_list)
accuracy_sell_spread = accuracy_score(gt_sell_spread_list, pred_sell_spread_list)
support_buy_spread = sum(gt_sell_spread_list)
precision_minimum_initial_investment = precision_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
recall_minimum_initial_investment = recall_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
f1_minimum_initial_investment = f1_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
accuracy_minimum_initial_investment = accuracy_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
support_minimum_initial_investment = sum(gt_minimum_initial_investment_list)
precision_recoverable_expenses = precision_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
recall_recoverable_expenses = recall_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
f1_recoverable_expenses = f1_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
accuracy_recoverable_expenses = accuracy_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
support_recoverable_expenses = sum(gt_recoverable_expenses_list)
precision_indirect_costs = precision_score(gt_indirect_costs_list, pred_indirect_costs_list)
recall_indirect_costs = recall_score(gt_indirect_costs_list, pred_indirect_costs_list)
f1_indirect_costs = f1_score(gt_indirect_costs_list, pred_indirect_costs_list)
accuracy_indirect_costs = accuracy_score(gt_indirect_costs_list, pred_indirect_costs_list)
support_indirect_costs = sum(gt_indirect_costs_list)
metrics_data = [{"item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"item": "performance_fee", "precision": precision_performance_fee, "recall": recall_performance_fee, "f1": f1_performance_fee, "accuracy": accuracy_performance_fee, "support": support_performance_fee},
{"item": "performance_fee_costs", "precision": precision_performance_fee_costs, "recall": recall_performance_fee_costs, "f1": f1_performance_fee_costs, "accuracy": accuracy_performance_fee_costs, "support": support_performance_fee_costs},
{"item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
{"item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
{"item": "minimum_initial_investment", "precision": precision_minimum_initial_investment, "recall": recall_minimum_initial_investment, "f1": f1_minimum_initial_investment, "accuracy": accuracy_minimum_initial_investment, "support": support_minimum_initial_investment},
{"item": "recoverable_expenses", "precision": precision_recoverable_expenses, "recall": recall_recoverable_expenses, "f1": f1_recoverable_expenses, "accuracy": accuracy_recoverable_expenses, "support": support_recoverable_expenses},
{"item": "indirect_costs", "precision": precision_indirect_costs, "recall": recall_indirect_costs, "f1": f1_indirect_costs, "accuracy": accuracy_indirect_costs, "support": support_indirect_costs}]
metrics_data_df = pd.DataFrame(metrics_data)
averate_precision = metrics_data_df["precision"].mean()
average_recall = metrics_data_df["recall"].mean()
average_f1 = metrics_data_df["f1"].mean()
average_accuracy = metrics_data_df["accuracy"].mean()
sum_support = metrics_data_df["support"].sum()
metrics_data.append({"item": "average_score", "precision": averate_precision, "recall": average_recall, "f1": average_f1, "accuracy": average_accuracy, "support": sum_support})
metrics_data_df = pd.DataFrame(metrics_data)
metrics_data_df = metrics_data_df[['item', 'f1', 'precision', 'recall', 'accuracy', 'support']]
# output metrics data to Excel file
print("Output metrics data to Excel file...")
output_folder = r"/data/aus_prospectus/output/metrics_data/"
os.makedirs(output_folder, exist_ok=True)
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
if is_strict:
verify_file_name = f"metrics_{verify_file_name}_revised_strict.xlsx"
else:
verify_file_name = f"metrics_{verify_file_name}_revised_not_strict.xlsx"
output_file = os.path.join(output_folder, verify_file_name)
with pd.ExcelWriter(output_file) as writer:
metrics_data_df.to_excel(writer, index=False)
def calculate_metrics_based_db_data_file(audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx",
audit_data_sheet: str = "Sheet1",
verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250303171140.xlsx",
verify_data_sheet: str = "total_data",
verify_document_list_file: str = None,
is_for_all: bool = False
2025-03-05 15:57:02 +00:00
):
print("Start to calculate metrics based on DB data file and extracted file...")
audit_data_df = pd.DataFrame()
verify_data_df = pd.DataFrame()
audit_fields = [
"DocumentId",
"FundLegalName",
"FundId",
"FundClassLegalName",
"FundClassId",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"interposed_vehicle_performance_fee_cost",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges"
# "withdrawal_fee",
# "switching_fee",
# "activity_fee",
]
audit_data_df = pd.read_excel(audit_file_path, sheet_name=audit_data_sheet)
audit_data_df = audit_data_df[audit_fields]
audit_data_df = audit_data_df.drop_duplicates()
audit_data_df = audit_data_df.rename(columns={"DocumentId": "doc_id",
"FundLegalName": "fund_name",
"FundId": "fund_id",
"FundClassLegalName": "sec_name",
"FundClassId": "sec_id"})
audit_data_df.fillna("", inplace=True)
audit_data_df.reset_index(drop=True, inplace=True)
# verify_file_path = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250205134704.xlsx"
# ravi_verify_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"
# verify_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"
verify_fields = [
"DocumentId",
"raw_fund_name",
"fund_id",
"fund_name",
"raw_share_name",
"sec_id",
"sec_name",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"interposed_vehicle_performance_fee_cost",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges"
# "withdrawal_fee",
# "switching_fee",
# "activity_fee"
]
verify_data_df = pd.read_excel(verify_file_path, sheet_name=verify_data_sheet)
# ravi_verify_data_df = pd.read_excel(ravi_verify_file_path, sheet_name=verify_data_sheet)
# only get raw_verify_data_df data which sec_id is equal with sec_id in ravi_verify_data_df
# verify_data_df = raw_verify_data_df[raw_verify_data_df["sec_id"].isin(ravi_verify_data_df["sec_id"])]
verify_data_df = verify_data_df[verify_fields]
verify_data_df = verify_data_df.drop_duplicates()
verify_data_df = verify_data_df.rename(columns={"DocumentId": "doc_id"})
verify_data_df.fillna("", inplace=True)
verify_data_df.reset_index(drop=True, inplace=True)
if len(audit_data_df) == 0 or len(verify_data_df) == 0:
print("No data to calculate metrics.")
return
# Calculate metrics
gt_management_fee_and_costs_list = []
pred_management_fee_and_costs_list = []
gt_management_fee_list = []
pred_management_fee_list = []
gt_administration_fees_list = []
pred_administration_fees_list = []
gt_minimum_initial_investment_list = []
pred_minimum_initial_investment_list = []
gt_benchmark_name_list = []
pred_benchmark_name_list = []
if is_for_all:
gt_performance_fee_list = []
pred_performance_fee_list = []
gt_interposed_vehicle_performance_fee_cost_list = []
pred_interposed_vehicle_performance_fee_cost_list = []
gt_buy_spread_list = []
pred_buy_spread_list = []
gt_sell_spread_list = []
pred_sell_spread_list = []
gt_total_annual_dollar_based_charges_list = []
pred_total_annual_dollar_based_charges_list = []
2025-03-05 15:57:02 +00:00
# gt_performance_fee_costs_list = []
# pred_performance_fee_costs_list = []
# gt_buy_spread_list = []
# pred_buy_spread_list = []
# gt_sell_spread_list = []
# pred_sell_spread_list = []
# gt_withdrawal_fee_list = []
# pred_withdrawal_fee_list = []
# gt_switching_fee_list = []
# pred_switching_fee_list = []
# gt_activity_fee_list = []
# pred_activity_fee_list = []
if verify_document_list_file is not None:
with open(verify_document_list_file, "r", encoding="utf-8") as f:
verify_document_list = f.readlines()
verify_document_list = [int(doc_id.strip()) for doc_id in verify_document_list]
if len(verify_document_list) > 0:
verify_data_df = verify_data_df[verify_data_df["doc_id"].isin(verify_document_list)]
2025-03-05 15:57:02 +00:00
document_id_list = verify_data_df["doc_id"].unique().tolist()
print(f"Total document count: {len(document_id_list)}")
print("Construct ground truth and prediction data...")
# similarity = Similarity()
message_list = []
for document_id in document_id_list:
doc_audit_data = audit_data_df[audit_data_df["doc_id"] == document_id]
audit_sec_id_list = [doc_sec_id for doc_sec_id
in doc_audit_data["sec_id"].unique().tolist()
if len(doc_sec_id) > 0]
# get doc_verify_data which doc_id is same as document_id and sec_id in audit_sec_id_list
doc_verify_data = verify_data_df[(verify_data_df["doc_id"] == document_id) & (verify_data_df["sec_id"].isin(audit_sec_id_list))]
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
sec_id = row["sec_id"]
management_fee_and_costs = str(row["management_fee_and_costs"])
management_fee = str(row["management_fee"])
administration_fees = str(row["administration_fees"])
minimum_initial_investment = str(row["minimum_initial_investment"])
benchmark_name = str(row["benchmark_name"])
if is_for_all:
performance_fee = str(row["performance_fee"])
interposed_vehicle_performance_fee_cost = str(row["interposed_vehicle_performance_fee_cost"])
buy_spread = str(row["buy_spread"])
sell_spread = str(row["sell_spread"])
total_annual_dollar_based_charges = str(row["total_annual_dollar_based_charges"])
2025-03-05 15:57:02 +00:00
# get the first row which sec_id in doc_verify_data is same as sec_id
doc_verify_sec_data = doc_verify_data[doc_verify_data["sec_id"] == sec_id]
if len(doc_verify_sec_data) == 0:
continue
doc_verify_sec_row = doc_verify_sec_data.iloc[0]
raw_fund_name = doc_verify_sec_row["raw_fund_name"]
v_management_fee_and_costs = str(doc_verify_sec_row["management_fee_and_costs"])
v_management_fee = str(doc_verify_sec_row["management_fee"])
v_administration_fees = str(doc_verify_sec_row["administration_fees"])
v_minimum_initial_investment = str(doc_verify_sec_row["minimum_initial_investment"])
v_benchmark_name = str(doc_verify_sec_row["benchmark_name"])
if is_for_all:
v_performance_fee = str(doc_verify_sec_row["performance_fee"])
v_interposed_vehicle_performance_fee_cost = str(doc_verify_sec_row["interposed_vehicle_performance_fee_cost"])
v_buy_spread = str(doc_verify_sec_row["buy_spread"])
v_sell_spread = str(doc_verify_sec_row["sell_spread"])
v_total_annual_dollar_based_charges = str(doc_verify_sec_row["total_annual_dollar_based_charges"])
2025-03-05 15:57:02 +00:00
# v_performance_fee_costs = str(doc_verify_sec_row["performance_fee_costs"])
# v_buy_spread = str(doc_verify_sec_row["buy_spread"])
# v_sell_spread = str(doc_verify_sec_row["sell_spread"])
# v_withdrawal_fee = str(doc_verify_sec_row["withdrawal_fee"])
# v_switching_fee = str(doc_verify_sec_row["switching_fee"])
# v_activity_fee = str(doc_verify_sec_row["activity_fee"])
message = get_gt_pred_by_compare_values(management_fee_and_costs, v_management_fee_and_costs, gt_management_fee_and_costs_list, pred_management_fee_and_costs_list, data_point="management_fee_and_costs")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee_and_costs"))
message = get_gt_pred_by_compare_values(management_fee, v_management_fee, gt_management_fee_list, pred_management_fee_list, data_point="management_fee")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee"))
message = get_gt_pred_by_compare_values(administration_fees, v_administration_fees, gt_administration_fees_list, pred_administration_fees_list, data_point="administration_fees")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "administration_fees"))
message = get_gt_pred_by_compare_values(minimum_initial_investment, v_minimum_initial_investment, gt_minimum_initial_investment_list, pred_minimum_initial_investment_list, data_point="minimum_initial_investment")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "minimum_initial_investment"))
message = get_gt_pred_by_compare_values(benchmark_name, v_benchmark_name, gt_benchmark_name_list, pred_benchmark_name_list, data_point="benchmark_name")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "benchmark_name"))
if is_for_all:
message = get_gt_pred_by_compare_values(performance_fee, v_performance_fee, gt_performance_fee_list, pred_performance_fee_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "performance_fee"))
message = get_gt_pred_by_compare_values(interposed_vehicle_performance_fee_cost, v_interposed_vehicle_performance_fee_cost,
gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "interposed_vehicle_performance_fee_cost"))
message = get_gt_pred_by_compare_values(buy_spread, v_buy_spread, gt_buy_spread_list, pred_buy_spread_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "buy_spread"))
message = get_gt_pred_by_compare_values(sell_spread, v_sell_spread, gt_sell_spread_list, pred_sell_spread_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "sell_spread"))
message = get_gt_pred_by_compare_values(total_annual_dollar_based_charges, v_total_annual_dollar_based_charges,
gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "total_annual_dollar_based_charges"))
2025-03-05 15:57:02 +00:00
# message = get_gt_pred_by_compare_values(withdrawal_fee, v_withdrawal_fee, gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "withdrawal_fee"))
# message = get_gt_pred_by_compare_values(switching_fee, v_switching_fee, gt_switching_fee_list, pred_switching_fee_list)
# message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "switching_fee"))
# message = get_gt_pred_by_compare_values(activity_fee, v_activity_fee, gt_activity_fee_list, pred_activity_fee_list)
# message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "activity_fee"))
message_data_df = pd.DataFrame(message_list)
message_data_df = message_data_df[['doc_id', 'sec_id', 'raw_fund_name', 'fund_legal_name', 'data_point', 'gt_value', 'pred_value', 'error']]
# order by doc_id, raw_fund_name, data_point
message_data_df = message_data_df.sort_values(by=['doc_id', 'raw_fund_name', 'data_point'])
message_data_df.reset_index(drop=True, inplace=True)
# calculate metrics
print("Calculate metrics...")
precision_management_fee_and_costs = precision_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
recall_management_fee_and_costs = recall_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
f1_management_fee_and_costs = f1_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
accuracy_management_fee_and_costs = accuracy_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
support_management_fee_and_costs = sum(gt_management_fee_and_costs_list)
precision_management_fee = precision_score(gt_management_fee_list, pred_management_fee_list)
recall_management_fee = recall_score(gt_management_fee_list, pred_management_fee_list)
f1_management_fee = f1_score(gt_management_fee_list, pred_management_fee_list)
accuracy_management_fee = accuracy_score(gt_management_fee_list, pred_management_fee_list)
support_management_fee = sum(gt_management_fee_list)
precision_administration_fees = precision_score(gt_administration_fees_list, pred_administration_fees_list)
recall_administration_fees = recall_score(gt_administration_fees_list, pred_administration_fees_list)
f1_administration_fees = f1_score(gt_administration_fees_list, pred_administration_fees_list)
accuracy_administration_fees = accuracy_score(gt_administration_fees_list, pred_administration_fees_list)
support_administration_fees = sum(gt_administration_fees_list)
precision_miminimum_initial_investment = precision_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
recall_miminimum_initial_investment = recall_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
f1_miminimum_initial_investment = f1_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
accuracy_miminimum_initial_investment = accuracy_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
support_miminimum_initial_investment = sum(gt_minimum_initial_investment_list)
precision_benchmark_name = precision_score(gt_benchmark_name_list, pred_benchmark_name_list)
recall_benchmark_name = recall_score(gt_benchmark_name_list, pred_benchmark_name_list)
f1_benchmark_name = f1_score(gt_benchmark_name_list, pred_benchmark_name_list)
accuracy_benchmark_name = accuracy_score(gt_benchmark_name_list, pred_benchmark_name_list)
support_benchmark_name = sum(gt_benchmark_name_list)
if is_for_all:
precision_performance_fee = precision_score(gt_performance_fee_list, pred_performance_fee_list)
recall_performance_fee = recall_score(gt_performance_fee_list, pred_performance_fee_list)
f1_performance_fee = f1_score(gt_performance_fee_list, pred_performance_fee_list)
accuracy_performance_fee = accuracy_score(gt_performance_fee_list, pred_performance_fee_list)
support_performance_fee = sum(gt_performance_fee_list)
precision_interposed_vehicle_performance_fee_cost = precision_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
recall_interposed_vehicle_performance_fee_cost = recall_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
f1_interposed_vehicle_performance_fee_cost = f1_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
accuracy_interposed_vehicle_performance_fee_cost = accuracy_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
support_interposed_vehicle_performance_fee_cost = sum(gt_interposed_vehicle_performance_fee_cost_list)
precision_buy_spread = precision_score(gt_buy_spread_list, pred_buy_spread_list)
recall_buy_spread = recall_score(gt_buy_spread_list, pred_buy_spread_list)
f1_buy_spread = f1_score(gt_buy_spread_list, pred_buy_spread_list)
accuracy_buy_spread = accuracy_score(gt_buy_spread_list, pred_buy_spread_list)
support_buy_spread = sum(gt_buy_spread_list)
precision_sell_spread = precision_score(gt_sell_spread_list, pred_sell_spread_list)
recall_sell_spread = recall_score(gt_sell_spread_list, pred_sell_spread_list)
f1_sell_spread = f1_score(gt_sell_spread_list, pred_sell_spread_list)
accuracy_sell_spread = accuracy_score(gt_sell_spread_list, pred_sell_spread_list)
support_buy_spread = sum(gt_sell_spread_list)
precision_total_annual_dollar_based_charges = precision_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
recall_total_annual_dollar_based_charges = recall_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
f1_total_annual_dollar_based_charges = f1_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
accuracy_total_annual_dollar_based_charges = accuracy_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
support_total_annual_dollar_based_charges = sum(gt_total_annual_dollar_based_charges_list)
2025-03-05 15:57:02 +00:00
# precision_withdrawal_fee = precision_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# recall_withdrawal_fee = recall_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# f1_withdrawal_fee = f1_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# accuracy_withdrawal_fee = accuracy_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# support_withdrawal_fee = sum(gt_withdrawal_fee_list)
# precision_switching_fee = precision_score(gt_switching_fee_list, pred_switching_fee_list)
# recall_switching_fee = recall_score(gt_switching_fee_list, pred_switching_fee_list)
# f1_switching_fee = f1_score(gt_switching_fee_list, pred_switching_fee_list)
# accuracy_switching_fee = accuracy_score(gt_switching_fee_list, pred_switching_fee_list)
# support_switching_fee = sum(gt_switching_fee_list)
# precision_activity_fee = precision_score(gt_activity_fee_list, pred_activity_fee_list)
# recall_activity_fee = recall_score(gt_activity_fee_list, pred_activity_fee_list)
# f1_activity_fee = f1_score(gt_activity_fee_list, pred_activity_fee_list)
# accuracy_activity_fee = accuracy_score(gt_activity_fee_list, pred_activity_fee_list)
# support_activity_fee = sum(gt_activity_fee_list)
if is_for_all:
metrics_data = [{"item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees},
{"item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment},
{"item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name},
{"item": "performance_fee", "precision": precision_performance_fee, "recall": recall_performance_fee, "f1": f1_performance_fee, "accuracy": accuracy_performance_fee, "support": support_performance_fee},
{"item": "interposed_vehicle_performance_fee_cost", "precision": precision_interposed_vehicle_performance_fee_cost, "recall": recall_interposed_vehicle_performance_fee_cost,
"f1": f1_interposed_vehicle_performance_fee_cost, "accuracy": accuracy_interposed_vehicle_performance_fee_cost, "support": support_interposed_vehicle_performance_fee_cost},
{"item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
{"item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
{"item": "total_annual_dollar_based_charges", "precision": precision_total_annual_dollar_based_charges, "recall": recall_total_annual_dollar_based_charges,
"f1": f1_total_annual_dollar_based_charges, "accuracy": accuracy_total_annual_dollar_based_charges, "support": support_total_annual_dollar_based_charges}
# {"item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
# {"item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
# {"item": "withdrawal_fee", "precision": precision_withdrawal_fee, "recall": recall_withdrawal_fee, "f1": f1_withdrawal_fee, "accuracy": accuracy_withdrawal_fee, "support": support_withdrawal_fee},
# {"item": "switching_fee", "precision": precision_switching_fee, "recall": recall_switching_fee, "f1": f1_switching_fee, "accuracy": accuracy_switching_fee, "support": support_switching_fee},
# {"item": "activity_fee", "precision": precision_activity_fee, "recall": recall_activity_fee, "f1": f1_activity_fee, "accuracy": accuracy_activity_fee, "support": support_activity_fee}
]
else:
metrics_data = [{"item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees},
{"item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment},
{"item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name}
]
2025-03-05 15:57:02 +00:00
metrics_data_df = pd.DataFrame(metrics_data)
averate_precision = metrics_data_df["precision"].mean()
average_recall = metrics_data_df["recall"].mean()
average_f1 = metrics_data_df["f1"].mean()
average_accuracy = metrics_data_df["accuracy"].mean()
sum_support = metrics_data_df["support"].sum()
metrics_data.append({"item": "average_score", "precision": averate_precision, "recall": average_recall, "f1": average_f1, "accuracy": average_accuracy, "support": sum_support})
metrics_data_df = pd.DataFrame(metrics_data)
metrics_data_df = metrics_data_df[['item', 'f1', 'precision', 'recall', 'accuracy', 'support']]
# output metrics data to Excel file
print("Output metrics data to Excel file...")
output_folder = r"/data/aus_prospectus/output/metrics_data/"
os.makedirs(output_folder, exist_ok=True)
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
if is_for_all:
verify_file_name = f"metrics_{verify_file_name}_all"
metrics_file_name = f"metrics_{verify_file_name}_{len(document_id_list)}_documents_4_dps_not_strict.xlsx"
2025-03-05 15:57:02 +00:00
output_file = os.path.join(output_folder, metrics_file_name)
with pd.ExcelWriter(output_file) as writer:
metrics_data_df.to_excel(writer, index=False, sheet_name="metrics_data")
message_data_df.to_excel(writer, index=False, sheet_name="message_data")
def calculate_metrics_by_provider(audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx",
audit_data_sheet: str = "Sheet1",
verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250303171140.xlsx",
verify_data_sheet: str = "total_data",
verify_document_list_file: str = None,
is_for_all: bool = False
):
print("Start to calculate metrics based on DB data file and extracted file...")
audit_data_df = pd.DataFrame()
verify_data_df = pd.DataFrame()
audit_fields = [
"DocumentId",
"provider_id",
"provider_name",
"FundLegalName",
"FundId",
"FundClassLegalName",
"FundClassId",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"interposed_vehicle_performance_fee_cost",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges"
]
audit_data_df = pd.read_excel(audit_file_path, sheet_name=audit_data_sheet)
audit_data_df = audit_data_df[audit_fields]
audit_data_df = audit_data_df.drop_duplicates()
audit_data_df = audit_data_df.rename(columns={"DocumentId": "doc_id",
"FundLegalName": "fund_name",
"FundId": "fund_id",
"FundClassLegalName": "sec_name",
"FundClassId": "sec_id"})
audit_data_df.fillna("", inplace=True)
audit_data_df.reset_index(drop=True, inplace=True)
verify_fields = [
"DocumentId",
"raw_fund_name",
"fund_id",
"fund_name",
"raw_share_name",
"sec_id",
"sec_name",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"interposed_vehicle_performance_fee_cost",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges"
]
verify_data_df = pd.read_excel(verify_file_path, sheet_name=verify_data_sheet)
verify_data_df = verify_data_df[verify_fields]
verify_data_df = verify_data_df.drop_duplicates()
verify_data_df = verify_data_df.rename(columns={"DocumentId": "doc_id"})
verify_data_df.fillna("", inplace=True)
verify_data_df.reset_index(drop=True, inplace=True)
if len(audit_data_df) == 0 or len(verify_data_df) == 0:
print("No data to calculate metrics.")
return
# Calculate metrics
if verify_document_list_file is not None:
with open(verify_document_list_file, "r", encoding="utf-8") as f:
verify_document_list = f.readlines()
verify_document_list = [int(doc_id.strip()) for doc_id in verify_document_list]
if len(verify_document_list) > 0:
verify_data_df = verify_data_df[verify_data_df["doc_id"].isin(verify_document_list)]
document_id_list = verify_data_df["doc_id"].unique().tolist()
print(f"Total document count: {len(document_id_list)}")
print("Construct ground truth and prediction data...")
# similarity = Similarity()
message_list = []
provider_gt_pred_data = {}
for document_id in document_id_list:
doc_audit_data = audit_data_df[audit_data_df["doc_id"] == document_id]
provider_id = doc_audit_data["provider_id"].iloc[0]
provider_name = doc_audit_data["provider_name"].iloc[0]
if provider_id not in list(provider_gt_pred_data.keys()):
provider_gt_pred_data[provider_id] = {"provider_name": provider_name,
"gt_management_fee_and_costs_list": [],
"pred_management_fee_and_costs_list": [],
"gt_management_fee_list": [],
"pred_management_fee_list": [],
"gt_administration_fees_list": [],
"pred_administration_fees_list": [],
"gt_minimum_initial_investment_list": [],
"pred_minimum_initial_investment_list": [],
"gt_benchmark_name_list": [],
"pred_benchmark_name_list": []}
if is_for_all:
provider_gt_pred_data[provider_id].update({"gt_performance_fee_list": [],
"pred_performance_fee_list": [],
"gt_interposed_vehicle_performance_fee_cost_list": [],
"pred_interposed_vehicle_performance_fee_cost_list": [],
"gt_buy_spread_list": [],
"pred_buy_spread_list": [],
"gt_sell_spread_list": [],
"pred_sell_spread_list": [],
"gt_total_annual_dollar_based_charges_list": [],
"pred_total_annual_dollar_based_charges_list": []})
audit_sec_id_list = [doc_sec_id for doc_sec_id
in doc_audit_data["sec_id"].unique().tolist()
if len(doc_sec_id) > 0]
# get doc_verify_data which doc_id is same as document_id and sec_id in audit_sec_id_list
doc_verify_data = verify_data_df[(verify_data_df["doc_id"] == document_id) & (verify_data_df["sec_id"].isin(audit_sec_id_list))]
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
sec_id = row["sec_id"]
management_fee_and_costs = str(row["management_fee_and_costs"])
management_fee = str(row["management_fee"])
administration_fees = str(row["administration_fees"])
minimum_initial_investment = str(row["minimum_initial_investment"])
benchmark_name = str(row["benchmark_name"])
if is_for_all:
performance_fee = str(row["performance_fee"])
interposed_vehicle_performance_fee_cost = str(row["interposed_vehicle_performance_fee_cost"])
buy_spread = str(row["buy_spread"])
sell_spread = str(row["sell_spread"])
total_annual_dollar_based_charges = str(row["total_annual_dollar_based_charges"])
# get the first row which sec_id in doc_verify_data is same as sec_id
doc_verify_sec_data = doc_verify_data[doc_verify_data["sec_id"] == sec_id]
if len(doc_verify_sec_data) == 0:
continue
doc_verify_sec_row = doc_verify_sec_data.iloc[0]
raw_fund_name = doc_verify_sec_row["raw_fund_name"]
v_management_fee_and_costs = str(doc_verify_sec_row["management_fee_and_costs"])
v_management_fee = str(doc_verify_sec_row["management_fee"])
v_administration_fees = str(doc_verify_sec_row["administration_fees"])
v_minimum_initial_investment = str(doc_verify_sec_row["minimum_initial_investment"])
v_benchmark_name = str(doc_verify_sec_row["benchmark_name"])
if is_for_all:
v_performance_fee = str(doc_verify_sec_row["performance_fee"])
v_interposed_vehicle_performance_fee_cost = str(doc_verify_sec_row["interposed_vehicle_performance_fee_cost"])
v_buy_spread = str(doc_verify_sec_row["buy_spread"])
v_sell_spread = str(doc_verify_sec_row["sell_spread"])
v_total_annual_dollar_based_charges = str(doc_verify_sec_row["total_annual_dollar_based_charges"])
message = get_gt_pred_by_compare_values(management_fee_and_costs,
v_management_fee_and_costs,
provider_gt_pred_data[provider_id]["gt_management_fee_and_costs_list"],
provider_gt_pred_data[provider_id]["pred_management_fee_and_costs_list"],
data_point="management_fee_and_costs")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee_and_costs"))
message = get_gt_pred_by_compare_values(management_fee,
v_management_fee,
provider_gt_pred_data[provider_id]["gt_management_fee_list"],
provider_gt_pred_data[provider_id]["pred_management_fee_list"],
data_point="management_fee")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee"))
message = get_gt_pred_by_compare_values(administration_fees,
v_administration_fees,
provider_gt_pred_data[provider_id]["gt_administration_fees_list"],
provider_gt_pred_data[provider_id]["pred_administration_fees_list"],
data_point="administration_fees")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "administration_fees"))
message = get_gt_pred_by_compare_values(minimum_initial_investment,
v_minimum_initial_investment,
provider_gt_pred_data[provider_id]["gt_minimum_initial_investment_list"],
provider_gt_pred_data[provider_id]["pred_minimum_initial_investment_list"],
data_point="minimum_initial_investment")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "minimum_initial_investment"))
message = get_gt_pred_by_compare_values(benchmark_name,
v_benchmark_name,
provider_gt_pred_data[provider_id]["gt_benchmark_name_list"],
provider_gt_pred_data[provider_id]["pred_benchmark_name_list"],
data_point="benchmark_name")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "benchmark_name"))
if is_for_all:
message = get_gt_pred_by_compare_values(performance_fee,
v_performance_fee,
provider_gt_pred_data[provider_id]["gt_performance_fee_list"],
provider_gt_pred_data[provider_id]["pred_performance_fee_list"],
data_point="performance_fee")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "performance_fee"))
message = get_gt_pred_by_compare_values(interposed_vehicle_performance_fee_cost,
v_interposed_vehicle_performance_fee_cost,
provider_gt_pred_data[provider_id]["gt_interposed_vehicle_performance_fee_cost_list"],
provider_gt_pred_data[provider_id]["pred_interposed_vehicle_performance_fee_cost_list"],
data_point="interposed_vehicle_performance_fee_cost")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "interposed_vehicle_performance_fee_cost"))
message = get_gt_pred_by_compare_values(buy_spread,
v_buy_spread,
provider_gt_pred_data[provider_id]["gt_buy_spread_list"],
provider_gt_pred_data[provider_id]["pred_buy_spread_list"],
data_point="buy_spread")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "buy_spread"))
message = get_gt_pred_by_compare_values(sell_spread,
v_sell_spread,
provider_gt_pred_data[provider_id]["gt_sell_spread_list"],
provider_gt_pred_data[provider_id]["pred_sell_spread_list"],
data_point="sell_spread")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "sell_spread"))
message = get_gt_pred_by_compare_values(total_annual_dollar_based_charges,
v_total_annual_dollar_based_charges,
provider_gt_pred_data[provider_id]["gt_total_annual_dollar_based_charges_list"],
provider_gt_pred_data[provider_id]["pred_total_annual_dollar_based_charges_list"],
data_point="total_annual_dollar_based_charges")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "total_annual_dollar_based_charges"))
message_data_df = pd.DataFrame(message_list)
message_data_df = message_data_df[['doc_id', 'sec_id', 'raw_fund_name', 'fund_legal_name', 'data_point', 'gt_value', 'pred_value', 'error']]
# order by doc_id, raw_fund_name, data_point
message_data_df = message_data_df.sort_values(by=['doc_id', 'raw_fund_name', 'data_point'])
message_data_df.reset_index(drop=True, inplace=True)
# calculate metrics
print("Calculate metrics...")
provider_metrics_list = []
for provider_id, gt_pred_data in provider_gt_pred_data.items():
provider_name = gt_pred_data["provider_name"]
precision_management_fee_and_costs = precision_score(gt_pred_data["gt_management_fee_and_costs_list"],
gt_pred_data["pred_management_fee_and_costs_list"])
recall_management_fee_and_costs = recall_score(gt_pred_data["gt_management_fee_and_costs_list"], gt_pred_data["pred_management_fee_and_costs_list"])
f1_management_fee_and_costs = f1_score(gt_pred_data["gt_management_fee_and_costs_list"], gt_pred_data["pred_management_fee_and_costs_list"])
accuracy_management_fee_and_costs = accuracy_score(gt_pred_data["gt_management_fee_and_costs_list"], gt_pred_data["pred_management_fee_and_costs_list"])
support_management_fee_and_costs = sum(gt_pred_data["gt_management_fee_and_costs_list"])
precision_management_fee = precision_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"])
recall_management_fee = recall_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"])
f1_management_fee = f1_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"])
accuracy_management_fee = accuracy_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"])
support_management_fee = sum(gt_pred_data["gt_management_fee_list"])
precision_administration_fees = precision_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"])
recall_administration_fees = recall_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"])
f1_administration_fees = f1_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"])
accuracy_administration_fees = accuracy_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"])
support_administration_fees = sum(gt_pred_data["gt_administration_fees_list"])
precision_miminimum_initial_investment = precision_score(gt_pred_data["gt_minimum_initial_investment_list"],
gt_pred_data["pred_minimum_initial_investment_list"])
recall_miminimum_initial_investment = recall_score(gt_pred_data["gt_minimum_initial_investment_list"],
gt_pred_data["pred_minimum_initial_investment_list"])
f1_miminimum_initial_investment = f1_score(gt_pred_data["gt_minimum_initial_investment_list"],
gt_pred_data["pred_minimum_initial_investment_list"])
accuracy_miminimum_initial_investment = accuracy_score(gt_pred_data["gt_minimum_initial_investment_list"],
gt_pred_data["pred_minimum_initial_investment_list"])
support_miminimum_initial_investment = sum(gt_pred_data["gt_minimum_initial_investment_list"])
precision_benchmark_name = precision_score(gt_pred_data["gt_benchmark_name_list"],
gt_pred_data["pred_benchmark_name_list"])
recall_benchmark_name = recall_score(gt_pred_data["gt_benchmark_name_list"],
gt_pred_data["pred_benchmark_name_list"])
f1_benchmark_name = f1_score(gt_pred_data["gt_benchmark_name_list"],
gt_pred_data["pred_benchmark_name_list"])
accuracy_benchmark_name = accuracy_score(gt_pred_data["gt_benchmark_name_list"],
gt_pred_data["pred_benchmark_name_list"])
support_benchmark_name = sum(gt_pred_data["gt_benchmark_name_list"])
if is_for_all:
precision_performance_fee = precision_score(gt_pred_data["gt_performance_fee_list"],
gt_pred_data["pred_performance_fee_list"])
recall_performance_fee = recall_score(gt_pred_data["gt_performance_fee_list"],
gt_pred_data["pred_performance_fee_list"])
f1_performance_fee = f1_score(gt_pred_data["gt_performance_fee_list"],
gt_pred_data["pred_performance_fee_list"])
accuracy_performance_fee = accuracy_score(gt_pred_data["gt_performance_fee_list"],
gt_pred_data["pred_performance_fee_list"])
support_performance_fee = sum(gt_pred_data["gt_performance_fee_list"])
precision_interposed_vehicle_performance_fee_cost = precision_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"],
gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"])
recall_interposed_vehicle_performance_fee_cost = recall_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"],
gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"])
f1_interposed_vehicle_performance_fee_cost = f1_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"],
gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"])
accuracy_interposed_vehicle_performance_fee_cost = accuracy_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"],
gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"])
support_interposed_vehicle_performance_fee_cost = sum(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"])
precision_buy_spread = precision_score(gt_pred_data["gt_buy_spread_list"],
gt_pred_data["pred_buy_spread_list"])
recall_buy_spread = recall_score(gt_pred_data["gt_buy_spread_list"],
gt_pred_data["pred_buy_spread_list"])
f1_buy_spread = f1_score(gt_pred_data["gt_buy_spread_list"],
gt_pred_data["pred_buy_spread_list"])
accuracy_buy_spread = accuracy_score(gt_pred_data["gt_buy_spread_list"],
gt_pred_data["pred_buy_spread_list"])
support_buy_spread = sum(gt_pred_data["gt_buy_spread_list"])
precision_sell_spread = precision_score(gt_pred_data["gt_sell_spread_list"],
gt_pred_data["pred_sell_spread_list"])
recall_sell_spread = recall_score(gt_pred_data["gt_sell_spread_list"],
gt_pred_data["pred_sell_spread_list"])
f1_sell_spread = f1_score(gt_pred_data["gt_sell_spread_list"],
gt_pred_data["pred_sell_spread_list"])
accuracy_sell_spread = accuracy_score(gt_pred_data["gt_sell_spread_list"],
gt_pred_data["pred_sell_spread_list"])
support_buy_spread = sum(gt_pred_data["gt_sell_spread_list"])
precision_total_annual_dollar_based_charges = precision_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"],
gt_pred_data["pred_total_annual_dollar_based_charges_list"])
recall_total_annual_dollar_based_charges = recall_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"],
gt_pred_data["pred_total_annual_dollar_based_charges_list"])
f1_total_annual_dollar_based_charges = f1_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"],
gt_pred_data["pred_total_annual_dollar_based_charges_list"])
accuracy_total_annual_dollar_based_charges = accuracy_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"],
gt_pred_data["pred_total_annual_dollar_based_charges_list"])
support_total_annual_dollar_based_charges = sum(gt_pred_data["gt_total_annual_dollar_based_charges_list"])
if is_for_all:
metrics_data = [{"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"provider_id": provider_id, "provider_name": provider_name, "item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees},
{"provider_id": provider_id, "provider_name": provider_name, "item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment},
{"provider_id": provider_id, "provider_name": provider_name, "item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name},
{"provider_id": provider_id, "provider_name": provider_name, "item": "performance_fee", "precision": precision_performance_fee, "recall": recall_performance_fee, "f1": f1_performance_fee, "accuracy": accuracy_performance_fee, "support": support_performance_fee},
{"provider_id": provider_id, "provider_name": provider_name, "item": "interposed_vehicle_performance_fee_cost", "precision": precision_interposed_vehicle_performance_fee_cost, "recall": recall_interposed_vehicle_performance_fee_cost,
"f1": f1_interposed_vehicle_performance_fee_cost, "accuracy": accuracy_interposed_vehicle_performance_fee_cost, "support": support_interposed_vehicle_performance_fee_cost},
{"provider_id": provider_id, "provider_name": provider_name, "item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
{"provider_id": provider_id, "provider_name": provider_name, "item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
{"provider_id": provider_id, "provider_name": provider_name, "item": "total_annual_dollar_based_charges", "precision": precision_total_annual_dollar_based_charges, "recall": recall_total_annual_dollar_based_charges,
"f1": f1_total_annual_dollar_based_charges, "accuracy": accuracy_total_annual_dollar_based_charges, "support": support_total_annual_dollar_based_charges}
]
else:
metrics_data = [{"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"provider_id": provider_id, "provider_name": provider_name, "item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees},
{"provider_id": provider_id, "provider_name": provider_name, "item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment},
{"provider_id": provider_id, "provider_name": provider_name, "item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name}
]
metrics_data_df = pd.DataFrame(metrics_data)
averate_precision = metrics_data_df["precision"].mean()
average_recall = metrics_data_df["recall"].mean()
average_f1 = metrics_data_df["f1"].mean()
average_accuracy = metrics_data_df["accuracy"].mean()
sum_support = metrics_data_df["support"].sum()
metrics_data.append({"provider_id": provider_id, "provider_name": provider_name, "item": "average_score", "precision": averate_precision, "recall": average_recall, "f1": average_f1, "accuracy": average_accuracy, "support": sum_support})
metrics_data_df = pd.DataFrame(metrics_data)
metrics_data_df = metrics_data_df[["provider_id", "provider_name", "item", "f1", "precision", "recall", "accuracy", "support"]]
provider_metrics_list.append(metrics_data_df)
all_provider_metrics_df = pd.concat(provider_metrics_list)
all_provider_metrics_df.reset_index(drop=True, inplace=True)
# output metrics data to Excel file
print("Output metrics data to Excel file...")
output_folder = r"/data/aus_prospectus/output/metrics_data/"
os.makedirs(output_folder, exist_ok=True)
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
if is_for_all:
verify_file_name = f"{verify_file_name}_all"
metrics_file_name = f"metrics_{verify_file_name}_{len(document_id_list)}_documents_for_providers.xlsx"
output_file = os.path.join(output_folder, metrics_file_name)
with pd.ExcelWriter(output_file) as writer:
all_provider_metrics_df.to_excel(writer, index=False, sheet_name="metrics_data")
message_data_df.to_excel(writer, index=False, sheet_name="message_data")
2025-03-05 15:57:02 +00:00
def generate_message(message: dict, doc_id: str, sec_id: str, fund_legal_name: str, raw_fund_name: str, datapoint: str):
message["data_point"] = datapoint
message["fund_legal_name"] = fund_legal_name
message["raw_fund_name"] = raw_fund_name
message["sec_id"] = sec_id
message["doc_id"] = str(doc_id)
return message
def get_gt_pred_by_compare_values(gt_value, pred_value, gt_list, pred_list, data_point: str = ""):
message = {"gt_value": gt_value, "pred_value": pred_value, "error": ""}
if gt_value is not None and len(str(gt_value)) > 0:
gt_list.append(1)
gt_equal_pred = is_equal(gt_value, pred_value, data_point)
if gt_equal_pred:
pred_list.append(1)
else:
pred_list.append(0)
message["error"] = "pred_value is not equal to gt_value"
if pred_value is not None and len(str(pred_value)) > 0:
pred_list.append(1)
gt_list.append(0)
else:
if pred_value is not None and len(str(pred_value)) > 0:
gt_list.append(0)
pred_list.append(1)
message["error"] = "gt_value is empty, but pred_value is not empty"
# else:
# gt_list.append(1)
# pred_list.append(1)
return message
def is_equal(gt_value, pred_value, data_point: str = ""):
if gt_value is not None and len(str(gt_value)) > 0 and \
pred_value is not None and len(str(pred_value)) > 0:
if gt_value == pred_value:
return True
if data_point == "benchmark_name":
gt_value = clean_text(gt_value)
pred_value = clean_text(pred_value)
if gt_value == pred_value or gt_value in pred_value or pred_value in gt_value:
return True
similarity = Similarity()
jacard_score = similarity.jaccard_similarity(gt_value.lower().split(), pred_value.lower().split())
if jacard_score > 0.8:
return True
return False
def clean_text(text: str):
if text is None or len(text) == 0:
return text
text = re.sub(r"\W", " ", text)
text = re.sub(r"\s+", " ", text)
return text
if __name__ == "__main__":
# adjust_column_order()
# set_mapping_to_data_side_documents_data()
# source_file = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx"
# target_file = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx"
# adjust_data_file(source_file=source_file, targe_file=target_file)
# audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx"
# audit_data_sheet: str = "Sheet1"
# verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250303171140.xlsx"
# verify_data_sheet: str = "total_data"
audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx"
audit_data_sheet: str = "Sheet1"
verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_46_documents_by_text_20250306171226.xlsx"
2025-03-05 15:57:02 +00:00
verify_data_sheet: str = "total_data"
# verify_document_list_file: str = "./sample_documents/aus_prospectus_29_documents_sample.txt"
verify_document_list_file_list = [None, "./sample_documents/aus_prospectus_29_documents_sample.txt", "./sample_documents/aus_prospectus_17_documents_sample.txt"]
is_for_all = False
# for verify_document_list_file in verify_document_list_file_list:
# calculate_metrics_based_db_data_file(audit_file_path=audit_file_path,
# audit_data_sheet=audit_data_sheet,
# verify_file_path=verify_file_path,
# verify_data_sheet=verify_data_sheet,
# verify_document_list_file = verify_document_list_file,
# is_for_all=is_for_all)
for verify_document_list_file in verify_document_list_file_list:
calculate_metrics_by_provider(audit_file_path=audit_file_path,
audit_data_sheet=audit_data_sheet,
verify_file_path=verify_file_path,
verify_data_sheet=verify_data_sheet,
verify_document_list_file = verify_document_list_file,
is_for_all=is_for_all)
2025-03-05 15:57:02 +00:00
# set_mapping_to_17_documents_data()
# set_mapping_to_ravi_data()
# calculate_metrics_based_audit_file(is_strict=True)
# calculate_metrics_based_audit_file(is_strict=False)
# remove_ter_ogc_performance_fee_annotation()
# batch_run_documents()
# transform_pdf_2_image()
# ground_truth_file = "./test_metrics/ground_truth.xlsx"
# prediction_file = "./test_metrics/prediction.xlsx"
# calc_metrics(ground_truth_file, prediction_file)
# pdf_file = r"./data/emea_ar/pdf/532438210.pdf"
# page_list = [25, 26, 27, 28, 29]
# output_folder = r"./data/emea_ar/output/pdf_part/"
# output_part_of_pages(pdf_file, page_list, output_folder)