dc-ml-emea-ar/calc_metrics.py

1198 lines
69 KiB
Python

import os
from time import sleep
import pandas as pd
from glob import glob
from tqdm import tqdm
import numpy as np
from datetime import datetime
import re
import json
import traceback
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import requests
import fitz
from copy import deepcopy
from utils.similarity import Similarity
from core.auz_nz.hybrid_solution_script import final_function_to_match
def calc_metrics(ground_truth_file: str, prediction_file: str):
"""
Calculate metrics by comparing ground truth and prediction files
"""
if not os.path.exists(ground_truth_file):
raise FileNotFoundError(f"File not found: {ground_truth_file}")
if not os.path.exists(prediction_file):
raise FileNotFoundError(f"File not found: {prediction_file}")
ground_truth_df = pd.read_excel(ground_truth_file)
prediction_df = pd.read_excel(prediction_file)
gt_auum_list = []
pred_auum_list = []
gt_tor_list = []
pred_tor_list = []
columns = ["fund_name", "auum", "tor"]
# Check whether the ground truth file contains the same values as the prediction file
# The purpose is to calculate Recall
for gt_index, gt_row in ground_truth_df.iterrows():
gt_fund_name = gt_row["fund_name"]
gt_auum = gt_row["auum"]
gt_tor = gt_row["tor"]
find_auum_flag = False
find_tor_flag = False
for pred_index, pred_row in prediction_df.iterrows():
pred_fund_name = pred_row["fund_name"]
pred_auum = pred_row["auum"]
pred_tor = pred_row["tor"]
if gt_fund_name == pred_fund_name:
if gt_auum == pred_auum:
find_auum_flag = True
if gt_tor == pred_tor:
find_tor_flag = True
break
if find_auum_flag:
gt_auum_list.append(1)
pred_auum_list.append(1)
else:
gt_auum_list.append(1)
pred_auum_list.append(0)
if find_tor_flag:
gt_tor_list.append(1)
pred_tor_list.append(1)
else:
gt_tor_list.append(1)
pred_tor_list.append(0)
# Check whether the prediction file contains the same values as the ground truth file
# The purpose is to calculate Precision
for pred_index, pred_row in prediction_df.iterrows():
pred_fund_name = pred_row["fund_name"]
pred_auum = pred_row["auum"]
pred_tor = pred_row["tor"]
find_auum_flag = False
find_tor_flag = False
for gt_index, gt_row in ground_truth_df.iterrows():
gt_fund_name = gt_row["fund_name"]
gt_auum = gt_row["auum"]
gt_tor = gt_row["tor"]
if pred_fund_name == gt_fund_name:
if pred_auum == gt_auum:
find_auum_flag = True
if pred_tor == gt_tor:
find_tor_flag = True
break
if not find_auum_flag:
gt_auum_list.append(0)
pred_auum_list.append(1)
if not find_tor_flag:
gt_tor_list.append(0)
pred_tor_list.append(1)
precision_auum = precision_score(gt_auum_list, pred_auum_list)
recall_auum = recall_score(gt_auum_list, pred_auum_list)
f1_auum = f1_score(gt_auum_list, pred_auum_list)
accuracy_auum = accuracy_score(gt_auum_list, pred_auum_list)
precision_tor = precision_score(gt_tor_list, pred_tor_list)
recall_tor = recall_score(gt_tor_list, pred_tor_list)
f1_tor = f1_score(gt_tor_list, pred_tor_list)
accuracy_tor = accuracy_score(gt_tor_list, pred_tor_list)
print(f"AUUM Support: {sum(gt_auum_list)}")
print(f"F1 AUUM: {f1_auum}")
print(f"Precision AUUM: {precision_auum}")
print(f"Recall AUUM: {recall_auum}")
print(f"Accuracy AUUM: {accuracy_auum}\n")
print(f"TOR Support: {sum(gt_tor_list)}")
print(f"F1 TOR: {f1_tor}")
print(f"Precision TOR: {precision_tor}")
print(f"Recall TOR: {recall_tor}")
print(f"Accuracy TOR: {accuracy_tor}")
def transform_pdf_2_image():
"""
Transform pdf to image.
"""
import fitz
folder = r"/Users/bhe/OneDrive - MORNINGSTAR INC/Personal Document/US_Life/pay/"
pdf_file = r"Pay_Date_2025-02-14.pdf"
pdf_path = os.path.join(folder, pdf_file)
pdf_doc = fitz.open(pdf_path)
pdf_file_pure_name = pdf_file.replace(".pdf", "")
for page_num in range(pdf_doc.page_count):
page = pdf_doc.load_page(page_num)
image = page.get_pixmap(dpi=300)
image_path = os.path.join(folder, f"{pdf_file_pure_name}_{page_num}.png")
image.save(image_path)
def invoke_api_demo(doc_id: str = "407881493"):
headers = {"connection": "keep-alive", "content-type": "application/json"}
data = {
"doc_id": doc_id,
}
print(f"Start to invoke API for document: {doc_id}")
# url = 'https://internal-ts00006-stg-dcms-gpt-765982576.us-east-1.elb.amazonaws.com/automation/api/model/us_ar'
url = "http://127.0.0.1:8080/automation/api/model/emea_ar"
try:
response = requests.post(url, json=data, headers=headers)
print("API response status code: {0}".format(response.status_code))
json_data = json.loads(response.text)
print(json_data)
data_folder = r"/data/emea_ar/output/extract_data_by_api/"
os.makedirs(data_folder, exist_ok=True)
json_file = os.path.join(data_folder, f"{doc_id}.json")
with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4)
except Exception as e:
print("Meet exception: {0}".format(e))
def batch_run_documents():
document_id_list = [
"292989214",
"316237292",
"321733631",
"323390570",
"327956364",
"333207452",
"334718372",
"344636875",
"362246081",
"366179419",
"380945052",
"382366116",
"387202452",
"389171486",
"391456740",
"391736837",
"394778487",
"401684600",
"402113224",
"402181770",
"402397014",
"405803396",
"445102363",
"445256897",
"448265376",
"449555622",
"449623976",
"458291624",
"458359181",
"463081566",
"469138353",
"471641628",
"476492237",
"478585901",
"478586066",
"479042264",
"479793787",
"481475385",
"483617247",
"486378555",
"486383912",
"492121213",
"497497599",
"502693599",
"502821436",
"503194284",
"506559375",
"507967525",
"508854243",
"509845549",
"520879048",
"529925114",
]
for doc_id in document_id_list:
invoke_api_demo(doc_id)
def remove_ter_ogc_performance_fee_annotation():
data_folder = r"/data/emea_ar/output/extract_data_by_api/"
os.makedirs(data_folder, exist_ok=True)
# get all of json files from the folder
json_files = glob(os.path.join(data_folder, "*.json"))
remove_dp_list = ["ter", "ogc", "performance_fee"]
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
json_data = json.load(f)
annotation_data_list = json_data["annotation_data"]
remove_data_list = []
for annotation_data in annotation_data_list:
if annotation_data["data_point"] in remove_dp_list:
remove_data_list.append(annotation_data)
if len(remove_data_list) > 0:
for remove_data in remove_data_list:
if remove_data in annotation_data_list:
annotation_data_list.remove(remove_data)
with open(json_file, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4)
def output_part_of_pages(pdf_file: str, page_list: list, output_folder: str):
"""
Output part of pages from a pdf file to new pdf file.
:param pdf_file: str, the path of the pdf file.
:param page_list: list, the page number list.
:param output_folder: str, the output folder.
"""
pdf_doc = fitz.open(pdf_file)
pdf_file_pure_name = os.path.basename(pdf_file).replace(".pdf", "")
new_pdf = fitz.open()
print(f"output pages: {page_list} for {pdf_file_pure_name}")
for page_index in page_list:
new_pdf.insert_pdf(pdf_doc, from_page=page_index, to_page=page_index)
if output_folder is None or len(output_folder) == 0:
output_folder = r"./data/emea_ar/output/pdf_part/"
os.makedirs(output_folder, exist_ok=True)
new_pdf.save(os.path.join(output_folder, f"{pdf_file_pure_name}_part.pdf"))
def calculate_metrics_based_audit_file(is_strict: bool = False):
print("Start to calculate metrics based on audit file and verify file...")
audit_file_path = (
r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/Audited file_phase2.xlsx"
)
audit_data_sheets = ["Mayank - revised ", "Prathamesh - Revised"]
audit_fields = [
"doc_id",
"fund_name",
"management_fee_and_costs",
"management_fee",
"performance_fee",
"performance_fee_costs",
"buy_spread",
"sell_spread",
"minimum_initial_investment",
"recoverable_expenses",
"indirect_costs"
]
audit_data_list = []
for audit_data_sheet in audit_data_sheets:
sub_audit_data_df = pd.read_excel(audit_file_path, sheet_name=audit_data_sheet)
sub_audit_data_df = sub_audit_data_df[audit_fields]
audit_data_list.append(sub_audit_data_df)
audit_data_df = pd.concat(audit_data_list, ignore_index=True)
audit_data_df = audit_data_df.drop_duplicates()
audit_data_df.fillna("", inplace=True)
audit_data_df.reset_index(drop=True, inplace=True)
verify_file_path = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250205134704.xlsx"
verify_data_sheet = "total_data"
verify_fields = [
"DocumentId",
"raw_fund_name",
"management_fee_and_costs",
"management_fee",
"performance_fee",
"performance_fee_costs",
"buy_spread",
"sell_spread",
"minimum_initial_investment",
"recoverable_expenses",
"indirect_costs"
]
verify_data_df = pd.read_excel(verify_file_path, sheet_name=verify_data_sheet)
verify_data_df = verify_data_df[verify_fields]
verify_data_df = verify_data_df.drop_duplicates()
verify_data_df = verify_data_df.rename(columns={"DocumentId": "doc_id", "raw_fund_name": "fund_name"})
verify_data_df.fillna("", inplace=True)
verify_data_df.reset_index(drop=True, inplace=True)
if len(audit_data_df) == 0 or len(verify_data_df) == 0:
print("No data to calculate metrics.")
return
# Calculate metrics
gt_management_fee_and_costs_list = []
pred_management_fee_and_costs_list = []
gt_management_fee_list = []
pred_management_fee_list = []
gt_performance_fee_list = []
pred_performance_fee_list = []
gt_performance_fee_costs_list = []
pred_performance_fee_costs_list = []
gt_buy_spread_list = []
pred_buy_spread_list = []
gt_sell_spread_list = []
pred_sell_spread_list = []
gt_minimum_initial_investment_list = []
pred_minimum_initial_investment_list = []
gt_recoverable_expenses_list = []
pred_recoverable_expenses_list = []
gt_indirect_costs_list = []
pred_indirect_costs_list = []
document_id_list = audit_data_df["doc_id"].unique().tolist()
print(f"Total document count: {len(document_id_list)}")
print("Construct ground truth and prediction data...")
similarity = Similarity()
for document_id in document_id_list:
doc_audit_data = audit_data_df[audit_data_df["doc_id"] == document_id]
doc_verify_data = verify_data_df[verify_data_df["doc_id"] == document_id]
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
fund_name_split = fund_name.lower().split()
management_fee_and_costs = str(row["management_fee_and_costs"])
management_fee = str(row["management_fee"])
performance_fee = str(row["performance_fee"])
performance_fee_costs = str(row["performance_fee_costs"])
buy_spread = str(row["buy_spread"])
sell_spread = str(row["sell_spread"])
minimum_initial_investment = str(row["minimum_initial_investment"])
recoverable_expenses = str(row["recoverable_expenses"])
indirect_costs = str(row["indirect_costs"])
find_flag = False
for idx, r in doc_verify_data.iterrows():
v_fund_name = r["fund_name"]
if fund_name == v_fund_name:
find_flag = True
else:
v_fund_name_split = v_fund_name.lower().split()
name_similarity = similarity.jaccard_similarity(fund_name_split, v_fund_name_split)
if name_similarity > 0.8:
find_flag = True
if find_flag:
v_management_fee_and_costs = str(r["management_fee_and_costs"])
v_management_fee = str(r["management_fee"])
v_performance_fee = str(r["performance_fee"])
v_performance_fee_costs = str(r["performance_fee_costs"])
v_buy_spread = str(r["buy_spread"])
v_sell_spread = str(r["sell_spread"])
v_minimum_initial_investment = str(r["minimum_initial_investment"])
v_recoverable_expenses = str(r["recoverable_expenses"])
v_indirect_costs = str(r["indirect_costs"])
get_gt_pred_by_compare_values(management_fee_and_costs, v_management_fee_and_costs, gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
get_gt_pred_by_compare_values(management_fee, v_management_fee, gt_management_fee_list, pred_management_fee_list)
get_gt_pred_by_compare_values(performance_fee, v_performance_fee, gt_performance_fee_list, pred_performance_fee_list)
get_gt_pred_by_compare_values(performance_fee_costs, v_performance_fee_costs, gt_performance_fee_costs_list, pred_performance_fee_costs_list)
get_gt_pred_by_compare_values(buy_spread, v_buy_spread, gt_buy_spread_list, pred_buy_spread_list)
get_gt_pred_by_compare_values(sell_spread, v_sell_spread, gt_sell_spread_list, pred_sell_spread_list)
get_gt_pred_by_compare_values(minimum_initial_investment, v_minimum_initial_investment, gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
get_gt_pred_by_compare_values(recoverable_expenses, v_recoverable_expenses, gt_recoverable_expenses_list, pred_recoverable_expenses_list)
get_gt_pred_by_compare_values(indirect_costs, v_indirect_costs, gt_indirect_costs_list, pred_indirect_costs_list)
break
if not find_flag:
if management_fee_and_costs is not None and len(management_fee_and_costs) > 0:
gt_management_fee_and_costs_list.append(1)
pred_management_fee_and_costs_list.append(0)
if management_fee is not None and len(management_fee) > 0:
gt_management_fee_list.append(1)
pred_management_fee_list.append(0)
if performance_fee is not None and len(performance_fee) > 0:
gt_performance_fee_list.append(1)
pred_performance_fee_list.append(0)
if performance_fee_costs is not None and len(performance_fee_costs) > 0:
gt_performance_fee_costs_list.append(1)
pred_performance_fee_costs_list.append(0)
if buy_spread is not None and len(buy_spread) > 0:
gt_buy_spread_list.append(1)
pred_buy_spread_list.append(0)
if sell_spread is not None and len(sell_spread) > 0:
gt_sell_spread_list.append(1)
pred_sell_spread_list.append(0)
if minimum_initial_investment is not None and len(minimum_initial_investment) > 0:
gt_minimum_initial_investment_list.append(1)
pred_minimum_initial_investment_list.append(0)
if recoverable_expenses is not None and len(recoverable_expenses) > 0:
gt_recoverable_expenses_list.append(1)
pred_recoverable_expenses_list.append(0)
if indirect_costs is not None and len(indirect_costs) > 0:
gt_indirect_costs_list.append(1)
pred_indirect_costs_list.append(0)
if is_strict:
for idx, r in doc_verify_data.iterrows():
v_fund_name = r["fund_name"]
find_flag = False
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
if fund_name == v_fund_name:
find_flag = True
else:
v_fund_name_split = v_fund_name.lower().split()
fund_name_split = fund_name.lower().split()
name_similarity = similarity.jaccard_similarity(fund_name_split, v_fund_name_split)
if name_similarity > 0.8:
find_flag = True
if find_flag:
break
if not find_flag:
v_management_fee_and_costs = str(r["management_fee_and_costs"])
v_management_fee = str(r["management_fee"])
v_performance_fee = str(r["performance_fee"])
v_performance_fee_costs = str(r["performance_fee_costs"])
v_buy_spread = str(r["buy_spread"])
v_sell_spread = str(r["sell_spread"])
v_minimum_initial_investment = str(r["minimum_initial_investment"])
v_recoverable_expenses = str(r["recoverable_expenses"])
v_indirect_costs = str(r["indirect_costs"])
if v_management_fee_and_costs is not None and len(v_management_fee_and_costs) > 0:
gt_management_fee_and_costs_list.append(0)
pred_management_fee_and_costs_list.append(1)
if v_management_fee is not None and len(v_management_fee) > 0:
gt_management_fee_list.append(0)
pred_management_fee_list.append(1)
if v_performance_fee is not None and len(v_performance_fee) > 0:
gt_performance_fee_list.append(0)
pred_performance_fee_list.append(1)
if v_performance_fee_costs is not None and len(v_performance_fee_costs) > 0:
gt_performance_fee_costs_list.append(0)
pred_performance_fee_costs_list.append(1)
if v_buy_spread is not None and len(v_buy_spread) > 0:
gt_buy_spread_list.append(0)
pred_buy_spread_list.append(1)
if v_sell_spread is not None and len(v_sell_spread) > 0:
gt_sell_spread_list.append(0)
pred_sell_spread_list.append(1)
if v_minimum_initial_investment is not None and len(v_minimum_initial_investment) > 0:
gt_minimum_initial_investment_list.append(0)
pred_minimum_initial_investment_list.append(1)
if v_recoverable_expenses is not None and len(v_recoverable_expenses) > 0:
gt_recoverable_expenses_list.append(0)
pred_recoverable_expenses_list.append(1)
if v_indirect_costs is not None and len(v_indirect_costs) > 0:
gt_indirect_costs_list.append(0)
pred_indirect_costs_list.append(1)
# calculate metrics
print("Calculate metrics...")
precision_management_fee_and_costs = precision_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
recall_management_fee_and_costs = recall_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
f1_management_fee_and_costs = f1_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
accuracy_management_fee_and_costs = accuracy_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
support_management_fee_and_costs = sum(gt_management_fee_and_costs_list)
precision_management_fee = precision_score(gt_management_fee_list, pred_management_fee_list)
recall_management_fee = recall_score(gt_management_fee_list, pred_management_fee_list)
f1_management_fee = f1_score(gt_management_fee_list, pred_management_fee_list)
accuracy_management_fee = accuracy_score(gt_management_fee_list, pred_management_fee_list)
support_management_fee = sum(gt_management_fee_list)
precision_performance_fee = precision_score(gt_performance_fee_list, pred_performance_fee_list)
recall_performance_fee = recall_score(gt_performance_fee_list, pred_performance_fee_list)
f1_performance_fee = f1_score(gt_performance_fee_list, pred_performance_fee_list)
accuracy_performance_fee = accuracy_score(gt_performance_fee_list, pred_performance_fee_list)
support_performance_fee = sum(gt_performance_fee_list)
precision_performance_fee_costs = precision_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
recall_performance_fee_costs = recall_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
f1_performance_fee_costs = f1_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
accuracy_performance_fee_costs = accuracy_score(gt_performance_fee_costs_list, pred_performance_fee_costs_list)
support_performance_fee_costs = sum(gt_performance_fee_costs_list)
precision_buy_spread = precision_score(gt_buy_spread_list, pred_buy_spread_list)
recall_buy_spread = recall_score(gt_buy_spread_list, pred_buy_spread_list)
f1_buy_spread = f1_score(gt_buy_spread_list, pred_buy_spread_list)
accuracy_buy_spread = accuracy_score(gt_buy_spread_list, pred_buy_spread_list)
support_buy_spread = sum(gt_buy_spread_list)
precision_sell_spread = precision_score(gt_sell_spread_list, pred_sell_spread_list)
recall_sell_spread = recall_score(gt_sell_spread_list, pred_sell_spread_list)
f1_sell_spread = f1_score(gt_sell_spread_list, pred_sell_spread_list)
accuracy_sell_spread = accuracy_score(gt_sell_spread_list, pred_sell_spread_list)
support_buy_spread = sum(gt_sell_spread_list)
precision_minimum_initial_investment = precision_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
recall_minimum_initial_investment = recall_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
f1_minimum_initial_investment = f1_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
accuracy_minimum_initial_investment = accuracy_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
support_minimum_initial_investment = sum(gt_minimum_initial_investment_list)
precision_recoverable_expenses = precision_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
recall_recoverable_expenses = recall_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
f1_recoverable_expenses = f1_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
accuracy_recoverable_expenses = accuracy_score(gt_recoverable_expenses_list, pred_recoverable_expenses_list)
support_recoverable_expenses = sum(gt_recoverable_expenses_list)
precision_indirect_costs = precision_score(gt_indirect_costs_list, pred_indirect_costs_list)
recall_indirect_costs = recall_score(gt_indirect_costs_list, pred_indirect_costs_list)
f1_indirect_costs = f1_score(gt_indirect_costs_list, pred_indirect_costs_list)
accuracy_indirect_costs = accuracy_score(gt_indirect_costs_list, pred_indirect_costs_list)
support_indirect_costs = sum(gt_indirect_costs_list)
metrics_data = [{"item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"item": "performance_fee", "precision": precision_performance_fee, "recall": recall_performance_fee, "f1": f1_performance_fee, "accuracy": accuracy_performance_fee, "support": support_performance_fee},
{"item": "performance_fee_costs", "precision": precision_performance_fee_costs, "recall": recall_performance_fee_costs, "f1": f1_performance_fee_costs, "accuracy": accuracy_performance_fee_costs, "support": support_performance_fee_costs},
{"item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
{"item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
{"item": "minimum_initial_investment", "precision": precision_minimum_initial_investment, "recall": recall_minimum_initial_investment, "f1": f1_minimum_initial_investment, "accuracy": accuracy_minimum_initial_investment, "support": support_minimum_initial_investment},
{"item": "recoverable_expenses", "precision": precision_recoverable_expenses, "recall": recall_recoverable_expenses, "f1": f1_recoverable_expenses, "accuracy": accuracy_recoverable_expenses, "support": support_recoverable_expenses},
{"item": "indirect_costs", "precision": precision_indirect_costs, "recall": recall_indirect_costs, "f1": f1_indirect_costs, "accuracy": accuracy_indirect_costs, "support": support_indirect_costs}]
metrics_data_df = pd.DataFrame(metrics_data)
averate_precision = metrics_data_df["precision"].mean()
average_recall = metrics_data_df["recall"].mean()
average_f1 = metrics_data_df["f1"].mean()
average_accuracy = metrics_data_df["accuracy"].mean()
sum_support = metrics_data_df["support"].sum()
metrics_data.append({"item": "average_score", "precision": averate_precision, "recall": average_recall, "f1": average_f1, "accuracy": average_accuracy, "support": sum_support})
metrics_data_df = pd.DataFrame(metrics_data)
metrics_data_df = metrics_data_df[['item', 'f1', 'precision', 'recall', 'accuracy', 'support']]
# output metrics data to Excel file
print("Output metrics data to Excel file...")
output_folder = r"/data/aus_prospectus/output/metrics_data/"
os.makedirs(output_folder, exist_ok=True)
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
if is_strict:
verify_file_name = f"metrics_{verify_file_name}_revised_strict.xlsx"
else:
verify_file_name = f"metrics_{verify_file_name}_revised_not_strict.xlsx"
output_file = os.path.join(output_folder, verify_file_name)
with pd.ExcelWriter(output_file) as writer:
metrics_data_df.to_excel(writer, index=False)
def calculate_metrics_based_db_data_file(audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx",
audit_data_sheet: str = "Sheet1",
verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250303171140.xlsx",
verify_data_sheet: str = "total_data",
verify_document_list_file: str = None,
is_for_all: bool = False
):
print("Start to calculate metrics based on DB data file and extracted file...")
audit_data_df = pd.DataFrame()
verify_data_df = pd.DataFrame()
audit_fields = [
"DocumentId",
"FundLegalName",
"FundId",
"FundClassLegalName",
"FundClassId",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"interposed_vehicle_performance_fee_cost",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges"
# "withdrawal_fee",
# "switching_fee",
# "activity_fee",
]
audit_data_df = pd.read_excel(audit_file_path, sheet_name=audit_data_sheet)
audit_data_df = audit_data_df[audit_fields]
audit_data_df = audit_data_df.drop_duplicates()
audit_data_df = audit_data_df.rename(columns={"DocumentId": "doc_id",
"FundLegalName": "fund_name",
"FundId": "fund_id",
"FundClassLegalName": "sec_name",
"FundClassId": "sec_id"})
audit_data_df.fillna("", inplace=True)
audit_data_df.reset_index(drop=True, inplace=True)
# verify_file_path = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250205134704.xlsx"
# ravi_verify_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"
# verify_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"
verify_fields = [
"DocumentId",
"raw_fund_name",
"fund_id",
"fund_name",
"raw_share_name",
"sec_id",
"sec_name",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"interposed_vehicle_performance_fee_cost",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges"
# "withdrawal_fee",
# "switching_fee",
# "activity_fee"
]
verify_data_df = pd.read_excel(verify_file_path, sheet_name=verify_data_sheet)
# ravi_verify_data_df = pd.read_excel(ravi_verify_file_path, sheet_name=verify_data_sheet)
# only get raw_verify_data_df data which sec_id is equal with sec_id in ravi_verify_data_df
# verify_data_df = raw_verify_data_df[raw_verify_data_df["sec_id"].isin(ravi_verify_data_df["sec_id"])]
verify_data_df = verify_data_df[verify_fields]
verify_data_df = verify_data_df.drop_duplicates()
verify_data_df = verify_data_df.rename(columns={"DocumentId": "doc_id"})
verify_data_df.fillna("", inplace=True)
verify_data_df.reset_index(drop=True, inplace=True)
if len(audit_data_df) == 0 or len(verify_data_df) == 0:
print("No data to calculate metrics.")
return
# Calculate metrics
gt_management_fee_and_costs_list = []
pred_management_fee_and_costs_list = []
gt_management_fee_list = []
pred_management_fee_list = []
gt_administration_fees_list = []
pred_administration_fees_list = []
gt_minimum_initial_investment_list = []
pred_minimum_initial_investment_list = []
gt_benchmark_name_list = []
pred_benchmark_name_list = []
if is_for_all:
gt_performance_fee_list = []
pred_performance_fee_list = []
gt_interposed_vehicle_performance_fee_cost_list = []
pred_interposed_vehicle_performance_fee_cost_list = []
gt_buy_spread_list = []
pred_buy_spread_list = []
gt_sell_spread_list = []
pred_sell_spread_list = []
gt_total_annual_dollar_based_charges_list = []
pred_total_annual_dollar_based_charges_list = []
# gt_performance_fee_costs_list = []
# pred_performance_fee_costs_list = []
# gt_buy_spread_list = []
# pred_buy_spread_list = []
# gt_sell_spread_list = []
# pred_sell_spread_list = []
# gt_withdrawal_fee_list = []
# pred_withdrawal_fee_list = []
# gt_switching_fee_list = []
# pred_switching_fee_list = []
# gt_activity_fee_list = []
# pred_activity_fee_list = []
if verify_document_list_file is not None:
with open(verify_document_list_file, "r", encoding="utf-8") as f:
verify_document_list = f.readlines()
verify_document_list = [int(doc_id.strip()) for doc_id in verify_document_list]
if len(verify_document_list) > 0:
verify_data_df = verify_data_df[verify_data_df["doc_id"].isin(verify_document_list)]
document_id_list = verify_data_df["doc_id"].unique().tolist()
print(f"Total document count: {len(document_id_list)}")
print("Construct ground truth and prediction data...")
# similarity = Similarity()
message_list = []
for document_id in document_id_list:
doc_audit_data = audit_data_df[audit_data_df["doc_id"] == document_id]
audit_sec_id_list = [doc_sec_id for doc_sec_id
in doc_audit_data["sec_id"].unique().tolist()
if len(doc_sec_id) > 0]
# get doc_verify_data which doc_id is same as document_id and sec_id in audit_sec_id_list
doc_verify_data = verify_data_df[(verify_data_df["doc_id"] == document_id) & (verify_data_df["sec_id"].isin(audit_sec_id_list))]
for index, row in doc_audit_data.iterrows():
fund_name = row["fund_name"]
sec_id = row["sec_id"]
management_fee_and_costs = str(row["management_fee_and_costs"])
management_fee = str(row["management_fee"])
administration_fees = str(row["administration_fees"])
minimum_initial_investment = str(row["minimum_initial_investment"])
benchmark_name = str(row["benchmark_name"])
if is_for_all:
performance_fee = str(row["performance_fee"])
interposed_vehicle_performance_fee_cost = str(row["interposed_vehicle_performance_fee_cost"])
buy_spread = str(row["buy_spread"])
sell_spread = str(row["sell_spread"])
total_annual_dollar_based_charges = str(row["total_annual_dollar_based_charges"])
# get the first row which sec_id in doc_verify_data is same as sec_id
doc_verify_sec_data = doc_verify_data[doc_verify_data["sec_id"] == sec_id]
if len(doc_verify_sec_data) == 0:
continue
doc_verify_sec_row = doc_verify_sec_data.iloc[0]
raw_fund_name = doc_verify_sec_row["raw_fund_name"]
v_management_fee_and_costs = str(doc_verify_sec_row["management_fee_and_costs"])
v_management_fee = str(doc_verify_sec_row["management_fee"])
v_administration_fees = str(doc_verify_sec_row["administration_fees"])
v_minimum_initial_investment = str(doc_verify_sec_row["minimum_initial_investment"])
v_benchmark_name = str(doc_verify_sec_row["benchmark_name"])
if is_for_all:
v_performance_fee = str(doc_verify_sec_row["performance_fee"])
v_interposed_vehicle_performance_fee_cost = str(doc_verify_sec_row["interposed_vehicle_performance_fee_cost"])
v_buy_spread = str(doc_verify_sec_row["buy_spread"])
v_sell_spread = str(doc_verify_sec_row["sell_spread"])
v_total_annual_dollar_based_charges = str(doc_verify_sec_row["total_annual_dollar_based_charges"])
# v_performance_fee_costs = str(doc_verify_sec_row["performance_fee_costs"])
# v_buy_spread = str(doc_verify_sec_row["buy_spread"])
# v_sell_spread = str(doc_verify_sec_row["sell_spread"])
# v_withdrawal_fee = str(doc_verify_sec_row["withdrawal_fee"])
# v_switching_fee = str(doc_verify_sec_row["switching_fee"])
# v_activity_fee = str(doc_verify_sec_row["activity_fee"])
message = get_gt_pred_by_compare_values(management_fee_and_costs, v_management_fee_and_costs, gt_management_fee_and_costs_list, pred_management_fee_and_costs_list, data_point="management_fee_and_costs")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee_and_costs"))
message = get_gt_pred_by_compare_values(management_fee, v_management_fee, gt_management_fee_list, pred_management_fee_list, data_point="management_fee")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee"))
message = get_gt_pred_by_compare_values(administration_fees, v_administration_fees, gt_administration_fees_list, pred_administration_fees_list, data_point="administration_fees")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "administration_fees"))
message = get_gt_pred_by_compare_values(minimum_initial_investment, v_minimum_initial_investment, gt_minimum_initial_investment_list, pred_minimum_initial_investment_list, data_point="minimum_initial_investment")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "minimum_initial_investment"))
message = get_gt_pred_by_compare_values(benchmark_name, v_benchmark_name, gt_benchmark_name_list, pred_benchmark_name_list, data_point="benchmark_name")
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "benchmark_name"))
if is_for_all:
message = get_gt_pred_by_compare_values(performance_fee, v_performance_fee, gt_performance_fee_list, pred_performance_fee_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "performance_fee"))
message = get_gt_pred_by_compare_values(interposed_vehicle_performance_fee_cost, v_interposed_vehicle_performance_fee_cost,
gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "interposed_vehicle_performance_fee_cost"))
message = get_gt_pred_by_compare_values(buy_spread, v_buy_spread, gt_buy_spread_list, pred_buy_spread_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "buy_spread"))
message = get_gt_pred_by_compare_values(sell_spread, v_sell_spread, gt_sell_spread_list, pred_sell_spread_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "sell_spread"))
message = get_gt_pred_by_compare_values(total_annual_dollar_based_charges, v_total_annual_dollar_based_charges,
gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "total_annual_dollar_based_charges"))
# message = get_gt_pred_by_compare_values(withdrawal_fee, v_withdrawal_fee, gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "withdrawal_fee"))
# message = get_gt_pred_by_compare_values(switching_fee, v_switching_fee, gt_switching_fee_list, pred_switching_fee_list)
# message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "switching_fee"))
# message = get_gt_pred_by_compare_values(activity_fee, v_activity_fee, gt_activity_fee_list, pred_activity_fee_list)
# message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "activity_fee"))
message_data_df = pd.DataFrame(message_list)
message_data_df = message_data_df[['doc_id', 'sec_id', 'raw_fund_name', 'fund_legal_name', 'data_point', 'gt_value', 'pred_value', 'error']]
# order by doc_id, raw_fund_name, data_point
message_data_df = message_data_df.sort_values(by=['doc_id', 'raw_fund_name', 'data_point'])
message_data_df.reset_index(drop=True, inplace=True)
# calculate metrics
print("Calculate metrics...")
precision_management_fee_and_costs = precision_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
recall_management_fee_and_costs = recall_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
f1_management_fee_and_costs = f1_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
accuracy_management_fee_and_costs = accuracy_score(gt_management_fee_and_costs_list, pred_management_fee_and_costs_list)
support_management_fee_and_costs = sum(gt_management_fee_and_costs_list)
precision_management_fee = precision_score(gt_management_fee_list, pred_management_fee_list)
recall_management_fee = recall_score(gt_management_fee_list, pred_management_fee_list)
f1_management_fee = f1_score(gt_management_fee_list, pred_management_fee_list)
accuracy_management_fee = accuracy_score(gt_management_fee_list, pred_management_fee_list)
support_management_fee = sum(gt_management_fee_list)
precision_administration_fees = precision_score(gt_administration_fees_list, pred_administration_fees_list)
recall_administration_fees = recall_score(gt_administration_fees_list, pred_administration_fees_list)
f1_administration_fees = f1_score(gt_administration_fees_list, pred_administration_fees_list)
accuracy_administration_fees = accuracy_score(gt_administration_fees_list, pred_administration_fees_list)
support_administration_fees = sum(gt_administration_fees_list)
precision_miminimum_initial_investment = precision_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
recall_miminimum_initial_investment = recall_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
f1_miminimum_initial_investment = f1_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
accuracy_miminimum_initial_investment = accuracy_score(gt_minimum_initial_investment_list, pred_minimum_initial_investment_list)
support_miminimum_initial_investment = sum(gt_minimum_initial_investment_list)
precision_benchmark_name = precision_score(gt_benchmark_name_list, pred_benchmark_name_list)
recall_benchmark_name = recall_score(gt_benchmark_name_list, pred_benchmark_name_list)
f1_benchmark_name = f1_score(gt_benchmark_name_list, pred_benchmark_name_list)
accuracy_benchmark_name = accuracy_score(gt_benchmark_name_list, pred_benchmark_name_list)
support_benchmark_name = sum(gt_benchmark_name_list)
if is_for_all:
precision_performance_fee = precision_score(gt_performance_fee_list, pred_performance_fee_list)
recall_performance_fee = recall_score(gt_performance_fee_list, pred_performance_fee_list)
f1_performance_fee = f1_score(gt_performance_fee_list, pred_performance_fee_list)
accuracy_performance_fee = accuracy_score(gt_performance_fee_list, pred_performance_fee_list)
support_performance_fee = sum(gt_performance_fee_list)
precision_interposed_vehicle_performance_fee_cost = precision_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
recall_interposed_vehicle_performance_fee_cost = recall_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
f1_interposed_vehicle_performance_fee_cost = f1_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
accuracy_interposed_vehicle_performance_fee_cost = accuracy_score(gt_interposed_vehicle_performance_fee_cost_list, pred_interposed_vehicle_performance_fee_cost_list)
support_interposed_vehicle_performance_fee_cost = sum(gt_interposed_vehicle_performance_fee_cost_list)
precision_buy_spread = precision_score(gt_buy_spread_list, pred_buy_spread_list)
recall_buy_spread = recall_score(gt_buy_spread_list, pred_buy_spread_list)
f1_buy_spread = f1_score(gt_buy_spread_list, pred_buy_spread_list)
accuracy_buy_spread = accuracy_score(gt_buy_spread_list, pred_buy_spread_list)
support_buy_spread = sum(gt_buy_spread_list)
precision_sell_spread = precision_score(gt_sell_spread_list, pred_sell_spread_list)
recall_sell_spread = recall_score(gt_sell_spread_list, pred_sell_spread_list)
f1_sell_spread = f1_score(gt_sell_spread_list, pred_sell_spread_list)
accuracy_sell_spread = accuracy_score(gt_sell_spread_list, pred_sell_spread_list)
support_buy_spread = sum(gt_sell_spread_list)
precision_total_annual_dollar_based_charges = precision_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
recall_total_annual_dollar_based_charges = recall_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
f1_total_annual_dollar_based_charges = f1_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
accuracy_total_annual_dollar_based_charges = accuracy_score(gt_total_annual_dollar_based_charges_list, pred_total_annual_dollar_based_charges_list)
support_total_annual_dollar_based_charges = sum(gt_total_annual_dollar_based_charges_list)
# precision_withdrawal_fee = precision_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# recall_withdrawal_fee = recall_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# f1_withdrawal_fee = f1_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# accuracy_withdrawal_fee = accuracy_score(gt_withdrawal_fee_list, pred_withdrawal_fee_list)
# support_withdrawal_fee = sum(gt_withdrawal_fee_list)
# precision_switching_fee = precision_score(gt_switching_fee_list, pred_switching_fee_list)
# recall_switching_fee = recall_score(gt_switching_fee_list, pred_switching_fee_list)
# f1_switching_fee = f1_score(gt_switching_fee_list, pred_switching_fee_list)
# accuracy_switching_fee = accuracy_score(gt_switching_fee_list, pred_switching_fee_list)
# support_switching_fee = sum(gt_switching_fee_list)
# precision_activity_fee = precision_score(gt_activity_fee_list, pred_activity_fee_list)
# recall_activity_fee = recall_score(gt_activity_fee_list, pred_activity_fee_list)
# f1_activity_fee = f1_score(gt_activity_fee_list, pred_activity_fee_list)
# accuracy_activity_fee = accuracy_score(gt_activity_fee_list, pred_activity_fee_list)
# support_activity_fee = sum(gt_activity_fee_list)
if is_for_all:
metrics_data = [{"item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees},
{"item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment},
{"item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name},
{"item": "performance_fee", "precision": precision_performance_fee, "recall": recall_performance_fee, "f1": f1_performance_fee, "accuracy": accuracy_performance_fee, "support": support_performance_fee},
{"item": "interposed_vehicle_performance_fee_cost", "precision": precision_interposed_vehicle_performance_fee_cost, "recall": recall_interposed_vehicle_performance_fee_cost,
"f1": f1_interposed_vehicle_performance_fee_cost, "accuracy": accuracy_interposed_vehicle_performance_fee_cost, "support": support_interposed_vehicle_performance_fee_cost},
{"item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
{"item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
{"item": "total_annual_dollar_based_charges", "precision": precision_total_annual_dollar_based_charges, "recall": recall_total_annual_dollar_based_charges,
"f1": f1_total_annual_dollar_based_charges, "accuracy": accuracy_total_annual_dollar_based_charges, "support": support_total_annual_dollar_based_charges}
# {"item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread},
# {"item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread},
# {"item": "withdrawal_fee", "precision": precision_withdrawal_fee, "recall": recall_withdrawal_fee, "f1": f1_withdrawal_fee, "accuracy": accuracy_withdrawal_fee, "support": support_withdrawal_fee},
# {"item": "switching_fee", "precision": precision_switching_fee, "recall": recall_switching_fee, "f1": f1_switching_fee, "accuracy": accuracy_switching_fee, "support": support_switching_fee},
# {"item": "activity_fee", "precision": precision_activity_fee, "recall": recall_activity_fee, "f1": f1_activity_fee, "accuracy": accuracy_activity_fee, "support": support_activity_fee}
]
else:
metrics_data = [{"item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs},
{"item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee},
{"item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees},
{"item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment},
{"item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name}
]
metrics_data_df = pd.DataFrame(metrics_data)
averate_precision = metrics_data_df["precision"].mean()
average_recall = metrics_data_df["recall"].mean()
average_f1 = metrics_data_df["f1"].mean()
average_accuracy = metrics_data_df["accuracy"].mean()
sum_support = metrics_data_df["support"].sum()
metrics_data.append({"item": "average_score", "precision": averate_precision, "recall": average_recall, "f1": average_f1, "accuracy": average_accuracy, "support": sum_support})
metrics_data_df = pd.DataFrame(metrics_data)
metrics_data_df = metrics_data_df[['item', 'f1', 'precision', 'recall', 'accuracy', 'support']]
# output metrics data to Excel file
print("Output metrics data to Excel file...")
output_folder = r"/data/aus_prospectus/output/metrics_data/"
os.makedirs(output_folder, exist_ok=True)
verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "")
metrics_file_name = f"metrics_{verify_file_name}_{len(document_id_list)}_documents_4_dps_not_strict.xlsx"
output_file = os.path.join(output_folder, metrics_file_name)
with pd.ExcelWriter(output_file) as writer:
metrics_data_df.to_excel(writer, index=False, sheet_name="metrics_data")
message_data_df.to_excel(writer, index=False, sheet_name="message_data")
def generate_message(message: dict, doc_id: str, sec_id: str, fund_legal_name: str, raw_fund_name: str, datapoint: str):
message["data_point"] = datapoint
message["fund_legal_name"] = fund_legal_name
message["raw_fund_name"] = raw_fund_name
message["sec_id"] = sec_id
message["doc_id"] = str(doc_id)
return message
def get_gt_pred_by_compare_values(gt_value, pred_value, gt_list, pred_list, data_point: str = ""):
message = {"gt_value": gt_value, "pred_value": pred_value, "error": ""}
if gt_value is not None and len(str(gt_value)) > 0:
gt_list.append(1)
gt_equal_pred = is_equal(gt_value, pred_value, data_point)
if gt_equal_pred:
pred_list.append(1)
else:
pred_list.append(0)
message["error"] = "pred_value is not equal to gt_value"
if pred_value is not None and len(str(pred_value)) > 0:
pred_list.append(1)
gt_list.append(0)
else:
if pred_value is not None and len(str(pred_value)) > 0:
gt_list.append(0)
pred_list.append(1)
message["error"] = "gt_value is empty, but pred_value is not empty"
# else:
# gt_list.append(1)
# pred_list.append(1)
return message
def is_equal(gt_value, pred_value, data_point: str = ""):
if gt_value is not None and len(str(gt_value)) > 0 and \
pred_value is not None and len(str(pred_value)) > 0:
if gt_value == pred_value:
return True
if data_point == "benchmark_name":
gt_value = clean_text(gt_value)
pred_value = clean_text(pred_value)
if gt_value == pred_value or gt_value in pred_value or pred_value in gt_value:
return True
similarity = Similarity()
jacard_score = similarity.jaccard_similarity(gt_value.lower().split(), pred_value.lower().split())
if jacard_score > 0.8:
return True
return False
def clean_text(text: str):
if text is None or len(text) == 0:
return text
text = re.sub(r"\W", " ", text)
text = re.sub(r"\s+", " ", text)
return text
def set_mapping_to_raw_name_data(data_file_path: str = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees.xlsx",
data_sheet: str = "Sheet1",
raw_name_column: str = "raw_share_name",
mapping_file_path: str = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx",
mapping_sheet: str = "document_mapping",
raw_name_mapping_column: str = None,
output_file_path: str = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"):
data_df = pd.read_excel(data_file_path, sheet_name=data_sheet)
data_df["fund_id"] = ""
data_df["fund_name"] = ""
data_df["sec_id"] = ""
data_df["sec_name"] = ""
mapping_data = pd.read_excel(mapping_file_path, sheet_name=mapping_sheet)
doc_id_list = data_df["doc_id"].unique().tolist()
for doc_id in doc_id_list:
doc_data = data_df[data_df["doc_id"] == doc_id]
raw_name_list = doc_data[raw_name_column].unique().tolist()
doc_mapping_data = mapping_data[mapping_data["DocumentId"] == doc_id]
if len(doc_mapping_data) == 0:
continue
provider_name = doc_mapping_data["CompanyName"].values[0]
if raw_name_mapping_column is not None and raw_name_mapping_column == "FundLegalName":
doc_db_name_list = doc_mapping_data[raw_name_mapping_column].unique().tolist()
for raw_name in raw_name_list:
find_df = doc_mapping_data[doc_mapping_data[raw_name_mapping_column] == raw_name]
if find_df is not None and len(find_df) == 1:
sec_id = find_df["FundClassId"].values[0]
sec_name = find_df["FundClassLegalName"].values[0]
fund_id = find_df["FundId"].values[0]
fund_name = find_df["FundLegalName"].values[0]
# update doc_data which raw_share_name is same as raw_share_name
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "sec_id"] = sec_id
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "sec_name"] = sec_name
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "fund_id"] = fund_id
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "fund_name"] = fund_name
else:
doc_db_name_list = doc_mapping_data["FundClassLegalName"].unique().tolist()
all_match_result = get_raw_name_db_match_result(doc_id,
provider_name,
raw_name_list,
doc_db_name_list,
iter_count=60)
for raw_share_name in raw_name_list:
if all_match_result.get(raw_share_name) is not None:
matched_db_share_name = all_match_result[raw_share_name]
if (
matched_db_share_name is not None
and len(matched_db_share_name) > 0
):
# get SecId from self.doc_fund_class_mapping
find_share_df = doc_mapping_data[doc_mapping_data["FundClassLegalName"] == matched_db_share_name]
if find_share_df is not None and len(find_share_df) > 0:
sec_id = find_share_df["FundClassId"].values[0]
fund_id = find_share_df["FundId"].values[0]
fund_name = find_share_df["FundLegalName"].values[0]
# update doc_data which raw_share_name is same as raw_share_name
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "sec_id"] = sec_id
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "sec_name"] = matched_db_share_name
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "fund_id"] = fund_id
data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "fund_name"] = fund_name
try:
data_df = data_df[["doc_id",
"raw_fund_name",
"fund_id",
"fund_name",
"raw_share_name",
"sec_id",
"sec_name",
"management_fee_and_costs",
"management_fee",
"administration_fees",
"minimum_initial_investment",
"benchmark_name",
"performance_fee",
"performance_fee_charged",
"buy_spread",
"sell_spread",
"total_annual_dollar_based_charges",
"interposed_vehicle_performance_fee_cost",
"establishment_fee",
"contribution_fee",
"withdrawal_fee",
"exit_fee",
"switching_fee",
"activity_fee",
"hurdle_rate",
"analyst_name"
]]
except Exception as e:
print(e)
with open(output_file_path, "wb") as file:
data_df.to_excel(file, index=False)
def get_raw_name_db_match_result(
doc_id: str, provider_name: str, raw_name_list: list, doc_share_name_list: list, iter_count: int = 30
):
# split raw_name_list into several parts which each part is with 30 elements
# The reason to split is to avoid invoke token limitation issues from CahtGPT
raw_name_list_parts = [
raw_name_list[i : i + iter_count]
for i in range(0, len(raw_name_list), iter_count)
]
all_match_result = {}
doc_share_name_list = deepcopy(doc_share_name_list)
for raw_name_list in raw_name_list_parts:
match_result, doc_share_name_list = get_final_function_to_match(
doc_id, provider_name, raw_name_list, doc_share_name_list
)
all_match_result.update(match_result)
return all_match_result
def get_final_function_to_match(doc_id, provider_name, raw_name_list, db_name_list):
if len(db_name_list) == 0:
match_result = {}
for raw_name in raw_name_list:
match_result[raw_name] = ""
else:
match_result = final_function_to_match(
doc_id=doc_id,
pred_list=raw_name_list,
db_list=db_name_list,
provider_name=provider_name,
doc_source="aus_prospectus"
)
matched_name_list = list(match_result.values())
db_name_list = remove_matched_names(db_name_list, matched_name_list)
return match_result, db_name_list
def remove_matched_names(target_name_list: list, matched_name_list: list):
if len(matched_name_list) == 0:
return target_name_list
matched_name_list = list(set(matched_name_list))
matched_name_list = [
value for value in matched_name_list if value is not None and len(value) > 0
]
for matched_name in matched_name_list:
if (
matched_name is not None
and len(matched_name) > 0
and matched_name in target_name_list
):
target_name_list.remove(matched_name)
return target_name_list
def set_mapping_to_ravi_data():
data_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees.xlsx"
data_sheet = "Sheet1"
mapping_file_path = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx"
mapping_sheet = "document_mapping"
output_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"
set_mapping_to_raw_name_data(data_file_path, data_sheet, mapping_file_path, mapping_sheet, output_file_path)
def set_mapping_to_data_side_documents_data():
# data_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/Audited file_phase2.xlsx"
# data_sheet = "all"
# mapping_file_path = r"/data/aus_prospectus/basic_information/17_documents/aus_prospectus_17_documents_mapping.xlsx"
# mapping_sheet = "document_mapping"
# output_file_path = r"/data/aus_prospectus/output/ravi_100_documents/audited_file_phase2_with_mapping.xlsx"
data_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth.xlsx"
data_sheet = "ground_truth"
raw_name_column = "raw_share_name"
mapping_file_path = r"/data/aus_prospectus/basic_information/46_documents/aus_prospectus_46_documents_mapping.xlsx"
mapping_sheet = "document_mapping"
raw_name_mapping_column = None
output_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx"
set_mapping_to_raw_name_data(data_file_path=data_file_path,
data_sheet=data_sheet,
raw_name_column=raw_name_column,
mapping_file_path=mapping_file_path,
mapping_sheet=mapping_sheet,
raw_name_mapping_column=raw_name_mapping_column,
output_file_path=output_file_path)
def adjust_data_file(source_file: str,
targe_file: str):
source_data = pd.read_excel(source_file, sheet_name="Sheet1")
source_doc_id_list = source_data["DocumentId"].unique().tolist()
target_data = pd.read_excel(targe_file, sheet_name="Sheet1")
#remove target_data which doc_id is in source_doc_id_list
target_data = target_data[~target_data["DocumentId"].isin(source_doc_id_list)]
# concat source_data and target_data
target_data = pd.concat([source_data, target_data], ignore_index=True)
with open(targe_file, "wb") as file:
target_data.to_excel(file, index=False)
if __name__ == "__main__":
# adjust_column_order()
# set_mapping_to_data_side_documents_data()
# source_file = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx"
# target_file = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx"
# adjust_data_file(source_file=source_file, targe_file=target_file)
# audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx"
# audit_data_sheet: str = "Sheet1"
# verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250303171140.xlsx"
# verify_data_sheet: str = "total_data"
audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx"
audit_data_sheet: str = "Sheet1"
verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_46_documents_by_text_20250306171226.xlsx"
verify_data_sheet: str = "total_data"
# verify_document_list_file: str = "./sample_documents/aus_prospectus_29_documents_sample.txt"
verify_document_list_file_list = [None, "./sample_documents/aus_prospectus_29_documents_sample.txt", "./sample_documents/aus_prospectus_17_documents_sample.txt"]
for verify_document_list_file in verify_document_list_file_list:
calculate_metrics_based_db_data_file(audit_file_path=audit_file_path,
audit_data_sheet=audit_data_sheet,
verify_file_path=verify_file_path,
verify_data_sheet=verify_data_sheet,
verify_document_list_file = verify_document_list_file)
# set_mapping_to_17_documents_data()
# set_mapping_to_ravi_data()
# calculate_metrics_based_audit_file(is_strict=True)
# calculate_metrics_based_audit_file(is_strict=False)
# remove_ter_ogc_performance_fee_annotation()
# batch_run_documents()
# transform_pdf_2_image()
# ground_truth_file = "./test_metrics/ground_truth.xlsx"
# prediction_file = "./test_metrics/prediction.xlsx"
# calc_metrics(ground_truth_file, prediction_file)
# pdf_file = r"./data/emea_ar/pdf/532438210.pdf"
# page_list = [25, 26, 27, 28, 29]
# output_folder = r"./data/emea_ar/output/pdf_part/"
# output_part_of_pages(pdf_file, page_list, output_folder)