From 2cd4f5f787e45e932b2d5e294dfc0b09696c7153 Mon Sep 17 00:00:00 2001 From: Blade He Date: Fri, 7 Mar 2025 15:02:12 -0600 Subject: [PATCH] Supplement provider information to ground truth data Calculate metrics based on providers Integrate "merge" data algorithm for AUS Prospectus final outputs --- calc_metrics.py | 586 ++++++++++++++++++++++++++-------------- core/data_extraction.py | 2 +- core/data_mapping.py | 175 +++++++++++- main.py | 43 ++- prepare_data.py | 288 +++++++++++++++++++- 5 files changed, 854 insertions(+), 240 deletions(-) diff --git a/calc_metrics.py b/calc_metrics.py index 41de9e1..b136fb8 100644 --- a/calc_metrics.py +++ b/calc_metrics.py @@ -11,9 +11,7 @@ import traceback from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score import requests import fitz -from copy import deepcopy from utils.similarity import Similarity -from core.auz_nz.hybrid_solution_script import final_function_to_match def calc_metrics(ground_truth_file: str, prediction_file: str): @@ -891,6 +889,8 @@ def calculate_metrics_based_db_data_file(audit_file_path: str = r"/data/aus_pros output_folder = r"/data/aus_prospectus/output/metrics_data/" os.makedirs(output_folder, exist_ok=True) verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "") + if is_for_all: + verify_file_name = f"metrics_{verify_file_name}_all" metrics_file_name = f"metrics_{verify_file_name}_{len(document_id_list)}_documents_4_dps_not_strict.xlsx" output_file = os.path.join(output_folder, metrics_file_name) with pd.ExcelWriter(output_file) as writer: @@ -898,6 +898,369 @@ def calculate_metrics_based_db_data_file(audit_file_path: str = r"/data/aus_pros message_data_df.to_excel(writer, index=False, sheet_name="message_data") +def calculate_metrics_by_provider(audit_file_path: str = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx", + audit_data_sheet: str = "Sheet1", + verify_file_path: str = r"/data/aus_prospectus/output/mapping_data/total/merged/merged_mapping_data_info_17_documents_by_text_20250303171140.xlsx", + verify_data_sheet: str = "total_data", + verify_document_list_file: str = None, + is_for_all: bool = False + ): + print("Start to calculate metrics based on DB data file and extracted file...") + audit_data_df = pd.DataFrame() + verify_data_df = pd.DataFrame() + + audit_fields = [ + "DocumentId", + "provider_id", + "provider_name", + "FundLegalName", + "FundId", + "FundClassLegalName", + "FundClassId", + "management_fee_and_costs", + "management_fee", + "administration_fees", + "minimum_initial_investment", + "benchmark_name", + "performance_fee", + "interposed_vehicle_performance_fee_cost", + "buy_spread", + "sell_spread", + "total_annual_dollar_based_charges" + ] + audit_data_df = pd.read_excel(audit_file_path, sheet_name=audit_data_sheet) + audit_data_df = audit_data_df[audit_fields] + audit_data_df = audit_data_df.drop_duplicates() + audit_data_df = audit_data_df.rename(columns={"DocumentId": "doc_id", + "FundLegalName": "fund_name", + "FundId": "fund_id", + "FundClassLegalName": "sec_name", + "FundClassId": "sec_id"}) + audit_data_df.fillna("", inplace=True) + audit_data_df.reset_index(drop=True, inplace=True) + + verify_fields = [ + "DocumentId", + "raw_fund_name", + "fund_id", + "fund_name", + "raw_share_name", + "sec_id", + "sec_name", + "management_fee_and_costs", + "management_fee", + "administration_fees", + "minimum_initial_investment", + "benchmark_name", + "performance_fee", + "interposed_vehicle_performance_fee_cost", + "buy_spread", + "sell_spread", + "total_annual_dollar_based_charges" + ] + verify_data_df = pd.read_excel(verify_file_path, sheet_name=verify_data_sheet) + verify_data_df = verify_data_df[verify_fields] + verify_data_df = verify_data_df.drop_duplicates() + verify_data_df = verify_data_df.rename(columns={"DocumentId": "doc_id"}) + verify_data_df.fillna("", inplace=True) + verify_data_df.reset_index(drop=True, inplace=True) + + if len(audit_data_df) == 0 or len(verify_data_df) == 0: + print("No data to calculate metrics.") + return + + # Calculate metrics + if verify_document_list_file is not None: + with open(verify_document_list_file, "r", encoding="utf-8") as f: + verify_document_list = f.readlines() + verify_document_list = [int(doc_id.strip()) for doc_id in verify_document_list] + if len(verify_document_list) > 0: + verify_data_df = verify_data_df[verify_data_df["doc_id"].isin(verify_document_list)] + document_id_list = verify_data_df["doc_id"].unique().tolist() + + print(f"Total document count: {len(document_id_list)}") + print("Construct ground truth and prediction data...") + # similarity = Similarity() + message_list = [] + provider_gt_pred_data = {} + for document_id in document_id_list: + doc_audit_data = audit_data_df[audit_data_df["doc_id"] == document_id] + provider_id = doc_audit_data["provider_id"].iloc[0] + provider_name = doc_audit_data["provider_name"].iloc[0] + if provider_id not in list(provider_gt_pred_data.keys()): + provider_gt_pred_data[provider_id] = {"provider_name": provider_name, + "gt_management_fee_and_costs_list": [], + "pred_management_fee_and_costs_list": [], + "gt_management_fee_list": [], + "pred_management_fee_list": [], + "gt_administration_fees_list": [], + "pred_administration_fees_list": [], + "gt_minimum_initial_investment_list": [], + "pred_minimum_initial_investment_list": [], + "gt_benchmark_name_list": [], + "pred_benchmark_name_list": []} + if is_for_all: + provider_gt_pred_data[provider_id].update({"gt_performance_fee_list": [], + "pred_performance_fee_list": [], + "gt_interposed_vehicle_performance_fee_cost_list": [], + "pred_interposed_vehicle_performance_fee_cost_list": [], + "gt_buy_spread_list": [], + "pred_buy_spread_list": [], + "gt_sell_spread_list": [], + "pred_sell_spread_list": [], + "gt_total_annual_dollar_based_charges_list": [], + "pred_total_annual_dollar_based_charges_list": []}) + audit_sec_id_list = [doc_sec_id for doc_sec_id + in doc_audit_data["sec_id"].unique().tolist() + if len(doc_sec_id) > 0] + # get doc_verify_data which doc_id is same as document_id and sec_id in audit_sec_id_list + doc_verify_data = verify_data_df[(verify_data_df["doc_id"] == document_id) & (verify_data_df["sec_id"].isin(audit_sec_id_list))] + for index, row in doc_audit_data.iterrows(): + fund_name = row["fund_name"] + sec_id = row["sec_id"] + management_fee_and_costs = str(row["management_fee_and_costs"]) + management_fee = str(row["management_fee"]) + administration_fees = str(row["administration_fees"]) + minimum_initial_investment = str(row["minimum_initial_investment"]) + benchmark_name = str(row["benchmark_name"]) + if is_for_all: + performance_fee = str(row["performance_fee"]) + interposed_vehicle_performance_fee_cost = str(row["interposed_vehicle_performance_fee_cost"]) + buy_spread = str(row["buy_spread"]) + sell_spread = str(row["sell_spread"]) + total_annual_dollar_based_charges = str(row["total_annual_dollar_based_charges"]) + + # get the first row which sec_id in doc_verify_data is same as sec_id + doc_verify_sec_data = doc_verify_data[doc_verify_data["sec_id"] == sec_id] + if len(doc_verify_sec_data) == 0: + continue + doc_verify_sec_row = doc_verify_sec_data.iloc[0] + raw_fund_name = doc_verify_sec_row["raw_fund_name"] + v_management_fee_and_costs = str(doc_verify_sec_row["management_fee_and_costs"]) + v_management_fee = str(doc_verify_sec_row["management_fee"]) + v_administration_fees = str(doc_verify_sec_row["administration_fees"]) + v_minimum_initial_investment = str(doc_verify_sec_row["minimum_initial_investment"]) + v_benchmark_name = str(doc_verify_sec_row["benchmark_name"]) + if is_for_all: + v_performance_fee = str(doc_verify_sec_row["performance_fee"]) + v_interposed_vehicle_performance_fee_cost = str(doc_verify_sec_row["interposed_vehicle_performance_fee_cost"]) + v_buy_spread = str(doc_verify_sec_row["buy_spread"]) + v_sell_spread = str(doc_verify_sec_row["sell_spread"]) + v_total_annual_dollar_based_charges = str(doc_verify_sec_row["total_annual_dollar_based_charges"]) + + message = get_gt_pred_by_compare_values(management_fee_and_costs, + v_management_fee_and_costs, + provider_gt_pred_data[provider_id]["gt_management_fee_and_costs_list"], + provider_gt_pred_data[provider_id]["pred_management_fee_and_costs_list"], + data_point="management_fee_and_costs") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee_and_costs")) + message = get_gt_pred_by_compare_values(management_fee, + v_management_fee, + provider_gt_pred_data[provider_id]["gt_management_fee_list"], + provider_gt_pred_data[provider_id]["pred_management_fee_list"], + data_point="management_fee") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "management_fee")) + message = get_gt_pred_by_compare_values(administration_fees, + v_administration_fees, + provider_gt_pred_data[provider_id]["gt_administration_fees_list"], + provider_gt_pred_data[provider_id]["pred_administration_fees_list"], + data_point="administration_fees") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "administration_fees")) + message = get_gt_pred_by_compare_values(minimum_initial_investment, + v_minimum_initial_investment, + provider_gt_pred_data[provider_id]["gt_minimum_initial_investment_list"], + provider_gt_pred_data[provider_id]["pred_minimum_initial_investment_list"], + data_point="minimum_initial_investment") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "minimum_initial_investment")) + message = get_gt_pred_by_compare_values(benchmark_name, + v_benchmark_name, + provider_gt_pred_data[provider_id]["gt_benchmark_name_list"], + provider_gt_pred_data[provider_id]["pred_benchmark_name_list"], + data_point="benchmark_name") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "benchmark_name")) + if is_for_all: + message = get_gt_pred_by_compare_values(performance_fee, + v_performance_fee, + provider_gt_pred_data[provider_id]["gt_performance_fee_list"], + provider_gt_pred_data[provider_id]["pred_performance_fee_list"], + data_point="performance_fee") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "performance_fee")) + message = get_gt_pred_by_compare_values(interposed_vehicle_performance_fee_cost, + v_interposed_vehicle_performance_fee_cost, + provider_gt_pred_data[provider_id]["gt_interposed_vehicle_performance_fee_cost_list"], + provider_gt_pred_data[provider_id]["pred_interposed_vehicle_performance_fee_cost_list"], + data_point="interposed_vehicle_performance_fee_cost") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "interposed_vehicle_performance_fee_cost")) + message = get_gt_pred_by_compare_values(buy_spread, + v_buy_spread, + provider_gt_pred_data[provider_id]["gt_buy_spread_list"], + provider_gt_pred_data[provider_id]["pred_buy_spread_list"], + data_point="buy_spread") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "buy_spread")) + message = get_gt_pred_by_compare_values(sell_spread, + v_sell_spread, + provider_gt_pred_data[provider_id]["gt_sell_spread_list"], + provider_gt_pred_data[provider_id]["pred_sell_spread_list"], + data_point="sell_spread") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "sell_spread")) + message = get_gt_pred_by_compare_values(total_annual_dollar_based_charges, + v_total_annual_dollar_based_charges, + provider_gt_pred_data[provider_id]["gt_total_annual_dollar_based_charges_list"], + provider_gt_pred_data[provider_id]["pred_total_annual_dollar_based_charges_list"], + data_point="total_annual_dollar_based_charges") + message_list.append(generate_message(message, document_id, sec_id, fund_name, raw_fund_name, "total_annual_dollar_based_charges")) + + message_data_df = pd.DataFrame(message_list) + message_data_df = message_data_df[['doc_id', 'sec_id', 'raw_fund_name', 'fund_legal_name', 'data_point', 'gt_value', 'pred_value', 'error']] + # order by doc_id, raw_fund_name, data_point + message_data_df = message_data_df.sort_values(by=['doc_id', 'raw_fund_name', 'data_point']) + message_data_df.reset_index(drop=True, inplace=True) + + # calculate metrics + print("Calculate metrics...") + provider_metrics_list = [] + for provider_id, gt_pred_data in provider_gt_pred_data.items(): + provider_name = gt_pred_data["provider_name"] + precision_management_fee_and_costs = precision_score(gt_pred_data["gt_management_fee_and_costs_list"], + gt_pred_data["pred_management_fee_and_costs_list"]) + recall_management_fee_and_costs = recall_score(gt_pred_data["gt_management_fee_and_costs_list"], gt_pred_data["pred_management_fee_and_costs_list"]) + f1_management_fee_and_costs = f1_score(gt_pred_data["gt_management_fee_and_costs_list"], gt_pred_data["pred_management_fee_and_costs_list"]) + accuracy_management_fee_and_costs = accuracy_score(gt_pred_data["gt_management_fee_and_costs_list"], gt_pred_data["pred_management_fee_and_costs_list"]) + support_management_fee_and_costs = sum(gt_pred_data["gt_management_fee_and_costs_list"]) + + precision_management_fee = precision_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"]) + recall_management_fee = recall_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"]) + f1_management_fee = f1_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"]) + accuracy_management_fee = accuracy_score(gt_pred_data["gt_management_fee_list"], gt_pred_data["pred_management_fee_list"]) + support_management_fee = sum(gt_pred_data["gt_management_fee_list"]) + + precision_administration_fees = precision_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"]) + recall_administration_fees = recall_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"]) + f1_administration_fees = f1_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"]) + accuracy_administration_fees = accuracy_score(gt_pred_data["gt_administration_fees_list"], gt_pred_data["pred_administration_fees_list"]) + support_administration_fees = sum(gt_pred_data["gt_administration_fees_list"]) + + precision_miminimum_initial_investment = precision_score(gt_pred_data["gt_minimum_initial_investment_list"], + gt_pred_data["pred_minimum_initial_investment_list"]) + recall_miminimum_initial_investment = recall_score(gt_pred_data["gt_minimum_initial_investment_list"], + gt_pred_data["pred_minimum_initial_investment_list"]) + f1_miminimum_initial_investment = f1_score(gt_pred_data["gt_minimum_initial_investment_list"], + gt_pred_data["pred_minimum_initial_investment_list"]) + accuracy_miminimum_initial_investment = accuracy_score(gt_pred_data["gt_minimum_initial_investment_list"], + gt_pred_data["pred_minimum_initial_investment_list"]) + support_miminimum_initial_investment = sum(gt_pred_data["gt_minimum_initial_investment_list"]) + + precision_benchmark_name = precision_score(gt_pred_data["gt_benchmark_name_list"], + gt_pred_data["pred_benchmark_name_list"]) + recall_benchmark_name = recall_score(gt_pred_data["gt_benchmark_name_list"], + gt_pred_data["pred_benchmark_name_list"]) + f1_benchmark_name = f1_score(gt_pred_data["gt_benchmark_name_list"], + gt_pred_data["pred_benchmark_name_list"]) + accuracy_benchmark_name = accuracy_score(gt_pred_data["gt_benchmark_name_list"], + gt_pred_data["pred_benchmark_name_list"]) + support_benchmark_name = sum(gt_pred_data["gt_benchmark_name_list"]) + + if is_for_all: + precision_performance_fee = precision_score(gt_pred_data["gt_performance_fee_list"], + gt_pred_data["pred_performance_fee_list"]) + recall_performance_fee = recall_score(gt_pred_data["gt_performance_fee_list"], + gt_pred_data["pred_performance_fee_list"]) + f1_performance_fee = f1_score(gt_pred_data["gt_performance_fee_list"], + gt_pred_data["pred_performance_fee_list"]) + accuracy_performance_fee = accuracy_score(gt_pred_data["gt_performance_fee_list"], + gt_pred_data["pred_performance_fee_list"]) + support_performance_fee = sum(gt_pred_data["gt_performance_fee_list"]) + + precision_interposed_vehicle_performance_fee_cost = precision_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"], + gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"]) + recall_interposed_vehicle_performance_fee_cost = recall_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"], + gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"]) + f1_interposed_vehicle_performance_fee_cost = f1_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"], + gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"]) + accuracy_interposed_vehicle_performance_fee_cost = accuracy_score(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"], + gt_pred_data["pred_interposed_vehicle_performance_fee_cost_list"]) + support_interposed_vehicle_performance_fee_cost = sum(gt_pred_data["gt_interposed_vehicle_performance_fee_cost_list"]) + + precision_buy_spread = precision_score(gt_pred_data["gt_buy_spread_list"], + gt_pred_data["pred_buy_spread_list"]) + recall_buy_spread = recall_score(gt_pred_data["gt_buy_spread_list"], + gt_pred_data["pred_buy_spread_list"]) + f1_buy_spread = f1_score(gt_pred_data["gt_buy_spread_list"], + gt_pred_data["pred_buy_spread_list"]) + accuracy_buy_spread = accuracy_score(gt_pred_data["gt_buy_spread_list"], + gt_pred_data["pred_buy_spread_list"]) + support_buy_spread = sum(gt_pred_data["gt_buy_spread_list"]) + + precision_sell_spread = precision_score(gt_pred_data["gt_sell_spread_list"], + gt_pred_data["pred_sell_spread_list"]) + recall_sell_spread = recall_score(gt_pred_data["gt_sell_spread_list"], + gt_pred_data["pred_sell_spread_list"]) + f1_sell_spread = f1_score(gt_pred_data["gt_sell_spread_list"], + gt_pred_data["pred_sell_spread_list"]) + accuracy_sell_spread = accuracy_score(gt_pred_data["gt_sell_spread_list"], + gt_pred_data["pred_sell_spread_list"]) + support_buy_spread = sum(gt_pred_data["gt_sell_spread_list"]) + + precision_total_annual_dollar_based_charges = precision_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"], + gt_pred_data["pred_total_annual_dollar_based_charges_list"]) + recall_total_annual_dollar_based_charges = recall_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"], + gt_pred_data["pred_total_annual_dollar_based_charges_list"]) + f1_total_annual_dollar_based_charges = f1_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"], + gt_pred_data["pred_total_annual_dollar_based_charges_list"]) + accuracy_total_annual_dollar_based_charges = accuracy_score(gt_pred_data["gt_total_annual_dollar_based_charges_list"], + gt_pred_data["pred_total_annual_dollar_based_charges_list"]) + support_total_annual_dollar_based_charges = sum(gt_pred_data["gt_total_annual_dollar_based_charges_list"]) + + if is_for_all: + metrics_data = [{"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "performance_fee", "precision": precision_performance_fee, "recall": recall_performance_fee, "f1": f1_performance_fee, "accuracy": accuracy_performance_fee, "support": support_performance_fee}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "interposed_vehicle_performance_fee_cost", "precision": precision_interposed_vehicle_performance_fee_cost, "recall": recall_interposed_vehicle_performance_fee_cost, + "f1": f1_interposed_vehicle_performance_fee_cost, "accuracy": accuracy_interposed_vehicle_performance_fee_cost, "support": support_interposed_vehicle_performance_fee_cost}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "buy_spread", "precision": precision_buy_spread, "recall": recall_buy_spread, "f1": f1_buy_spread, "accuracy": accuracy_buy_spread, "support": support_buy_spread}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "sell_spread", "precision": precision_sell_spread, "recall": recall_sell_spread, "f1": f1_sell_spread, "accuracy": accuracy_sell_spread, "support": support_buy_spread}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "total_annual_dollar_based_charges", "precision": precision_total_annual_dollar_based_charges, "recall": recall_total_annual_dollar_based_charges, + "f1": f1_total_annual_dollar_based_charges, "accuracy": accuracy_total_annual_dollar_based_charges, "support": support_total_annual_dollar_based_charges} + ] + else: + metrics_data = [{"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee_and_costs", "precision": precision_management_fee_and_costs, "recall": recall_management_fee_and_costs, "f1": f1_management_fee_and_costs, "accuracy": accuracy_management_fee_and_costs, "support": support_management_fee_and_costs}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "management_fee", "precision": precision_management_fee, "recall": recall_management_fee, "f1": f1_management_fee, "accuracy": accuracy_management_fee, "support": support_management_fee}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "administration_fees", "precision": precision_administration_fees, "recall": recall_administration_fees, "f1": f1_administration_fees, "accuracy": accuracy_administration_fees, "support": support_administration_fees}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "minimum_initial_investment", "precision": precision_miminimum_initial_investment, "recall": recall_miminimum_initial_investment, "f1": f1_miminimum_initial_investment, "accuracy": accuracy_miminimum_initial_investment, "support": support_miminimum_initial_investment}, + {"provider_id": provider_id, "provider_name": provider_name, "item": "benchmark_name", "precision": precision_benchmark_name, "recall": recall_benchmark_name, "f1": f1_benchmark_name, "accuracy": accuracy_benchmark_name, "support": support_benchmark_name} + ] + metrics_data_df = pd.DataFrame(metrics_data) + averate_precision = metrics_data_df["precision"].mean() + average_recall = metrics_data_df["recall"].mean() + average_f1 = metrics_data_df["f1"].mean() + average_accuracy = metrics_data_df["accuracy"].mean() + sum_support = metrics_data_df["support"].sum() + metrics_data.append({"provider_id": provider_id, "provider_name": provider_name, "item": "average_score", "precision": averate_precision, "recall": average_recall, "f1": average_f1, "accuracy": average_accuracy, "support": sum_support}) + metrics_data_df = pd.DataFrame(metrics_data) + metrics_data_df = metrics_data_df[["provider_id", "provider_name", "item", "f1", "precision", "recall", "accuracy", "support"]] + provider_metrics_list.append(metrics_data_df) + + all_provider_metrics_df = pd.concat(provider_metrics_list) + all_provider_metrics_df.reset_index(drop=True, inplace=True) + + # output metrics data to Excel file + print("Output metrics data to Excel file...") + output_folder = r"/data/aus_prospectus/output/metrics_data/" + os.makedirs(output_folder, exist_ok=True) + verify_file_name = os.path.basename(verify_file_path).replace(".xlsx", "") + if is_for_all: + verify_file_name = f"{verify_file_name}_all" + metrics_file_name = f"metrics_{verify_file_name}_{len(document_id_list)}_documents_for_providers.xlsx" + output_file = os.path.join(output_folder, metrics_file_name) + with pd.ExcelWriter(output_file) as writer: + all_provider_metrics_df.to_excel(writer, index=False, sheet_name="metrics_data") + message_data_df.to_excel(writer, index=False, sheet_name="message_data") + + + def generate_message(message: dict, doc_id: str, sec_id: str, fund_legal_name: str, raw_fund_name: str, datapoint: str): message["data_point"] = datapoint message["fund_legal_name"] = fund_legal_name @@ -954,203 +1317,6 @@ def clean_text(text: str): text = re.sub(r"\W", " ", text) text = re.sub(r"\s+", " ", text) return text - - -def set_mapping_to_raw_name_data(data_file_path: str = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees.xlsx", - data_sheet: str = "Sheet1", - raw_name_column: str = "raw_share_name", - mapping_file_path: str = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx", - mapping_sheet: str = "document_mapping", - raw_name_mapping_column: str = None, - output_file_path: str = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"): - data_df = pd.read_excel(data_file_path, sheet_name=data_sheet) - data_df["fund_id"] = "" - data_df["fund_name"] = "" - data_df["sec_id"] = "" - data_df["sec_name"] = "" - - mapping_data = pd.read_excel(mapping_file_path, sheet_name=mapping_sheet) - - doc_id_list = data_df["doc_id"].unique().tolist() - for doc_id in doc_id_list: - doc_data = data_df[data_df["doc_id"] == doc_id] - raw_name_list = doc_data[raw_name_column].unique().tolist() - - doc_mapping_data = mapping_data[mapping_data["DocumentId"] == doc_id] - if len(doc_mapping_data) == 0: - continue - provider_name = doc_mapping_data["CompanyName"].values[0] - if raw_name_mapping_column is not None and raw_name_mapping_column == "FundLegalName": - doc_db_name_list = doc_mapping_data[raw_name_mapping_column].unique().tolist() - for raw_name in raw_name_list: - find_df = doc_mapping_data[doc_mapping_data[raw_name_mapping_column] == raw_name] - if find_df is not None and len(find_df) == 1: - sec_id = find_df["FundClassId"].values[0] - sec_name = find_df["FundClassLegalName"].values[0] - fund_id = find_df["FundId"].values[0] - fund_name = find_df["FundLegalName"].values[0] - # update doc_data which raw_share_name is same as raw_share_name - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "sec_id"] = sec_id - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "sec_name"] = sec_name - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "fund_id"] = fund_id - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "fund_name"] = fund_name - else: - doc_db_name_list = doc_mapping_data["FundClassLegalName"].unique().tolist() - all_match_result = get_raw_name_db_match_result(doc_id, - provider_name, - raw_name_list, - doc_db_name_list, - iter_count=60) - for raw_share_name in raw_name_list: - if all_match_result.get(raw_share_name) is not None: - matched_db_share_name = all_match_result[raw_share_name] - if ( - matched_db_share_name is not None - and len(matched_db_share_name) > 0 - ): - # get SecId from self.doc_fund_class_mapping - find_share_df = doc_mapping_data[doc_mapping_data["FundClassLegalName"] == matched_db_share_name] - if find_share_df is not None and len(find_share_df) > 0: - sec_id = find_share_df["FundClassId"].values[0] - fund_id = find_share_df["FundId"].values[0] - fund_name = find_share_df["FundLegalName"].values[0] - # update doc_data which raw_share_name is same as raw_share_name - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "sec_id"] = sec_id - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "sec_name"] = matched_db_share_name - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "fund_id"] = fund_id - data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "fund_name"] = fund_name - try: - data_df = data_df[["doc_id", - "raw_fund_name", - "fund_id", - "fund_name", - "raw_share_name", - "sec_id", - "sec_name", - "management_fee_and_costs", - "management_fee", - "administration_fees", - "minimum_initial_investment", - "benchmark_name", - "performance_fee", - "performance_fee_charged", - "buy_spread", - "sell_spread", - "total_annual_dollar_based_charges", - "interposed_vehicle_performance_fee_cost", - "establishment_fee", - "contribution_fee", - "withdrawal_fee", - "exit_fee", - "switching_fee", - "activity_fee", - "hurdle_rate", - "analyst_name" - ]] - except Exception as e: - print(e) - - with open(output_file_path, "wb") as file: - data_df.to_excel(file, index=False) - - -def get_raw_name_db_match_result( - doc_id: str, provider_name: str, raw_name_list: list, doc_share_name_list: list, iter_count: int = 30 - ): - # split raw_name_list into several parts which each part is with 30 elements - # The reason to split is to avoid invoke token limitation issues from CahtGPT - raw_name_list_parts = [ - raw_name_list[i : i + iter_count] - for i in range(0, len(raw_name_list), iter_count) - ] - all_match_result = {} - doc_share_name_list = deepcopy(doc_share_name_list) - for raw_name_list in raw_name_list_parts: - match_result, doc_share_name_list = get_final_function_to_match( - doc_id, provider_name, raw_name_list, doc_share_name_list - ) - all_match_result.update(match_result) - return all_match_result - -def get_final_function_to_match(doc_id, provider_name, raw_name_list, db_name_list): - if len(db_name_list) == 0: - match_result = {} - for raw_name in raw_name_list: - match_result[raw_name] = "" - else: - match_result = final_function_to_match( - doc_id=doc_id, - pred_list=raw_name_list, - db_list=db_name_list, - provider_name=provider_name, - doc_source="aus_prospectus" - ) - matched_name_list = list(match_result.values()) - db_name_list = remove_matched_names(db_name_list, matched_name_list) - return match_result, db_name_list - -def remove_matched_names(target_name_list: list, matched_name_list: list): - if len(matched_name_list) == 0: - return target_name_list - - matched_name_list = list(set(matched_name_list)) - matched_name_list = [ - value for value in matched_name_list if value is not None and len(value) > 0 - ] - for matched_name in matched_name_list: - if ( - matched_name is not None - and len(matched_name) > 0 - and matched_name in target_name_list - ): - target_name_list.remove(matched_name) - return target_name_list - - -def set_mapping_to_ravi_data(): - data_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees.xlsx" - data_sheet = "Sheet1" - mapping_file_path = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx" - mapping_sheet = "document_mapping" - output_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx" - set_mapping_to_raw_name_data(data_file_path, data_sheet, mapping_file_path, mapping_sheet, output_file_path) - - -def set_mapping_to_data_side_documents_data(): - # data_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/Audited file_phase2.xlsx" - # data_sheet = "all" - # mapping_file_path = r"/data/aus_prospectus/basic_information/17_documents/aus_prospectus_17_documents_mapping.xlsx" - # mapping_sheet = "document_mapping" - # output_file_path = r"/data/aus_prospectus/output/ravi_100_documents/audited_file_phase2_with_mapping.xlsx" - - data_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth.xlsx" - data_sheet = "ground_truth" - raw_name_column = "raw_share_name" - mapping_file_path = r"/data/aus_prospectus/basic_information/46_documents/aus_prospectus_46_documents_mapping.xlsx" - mapping_sheet = "document_mapping" - raw_name_mapping_column = None - output_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx" - set_mapping_to_raw_name_data(data_file_path=data_file_path, - data_sheet=data_sheet, - raw_name_column=raw_name_column, - mapping_file_path=mapping_file_path, - mapping_sheet=mapping_sheet, - raw_name_mapping_column=raw_name_mapping_column, - output_file_path=output_file_path) - - -def adjust_data_file(source_file: str, - targe_file: str): - source_data = pd.read_excel(source_file, sheet_name="Sheet1") - source_doc_id_list = source_data["DocumentId"].unique().tolist() - - target_data = pd.read_excel(targe_file, sheet_name="Sheet1") - #remove target_data which doc_id is in source_doc_id_list - target_data = target_data[~target_data["DocumentId"].isin(source_doc_id_list)] - # concat source_data and target_data - target_data = pd.concat([source_data, target_data], ignore_index=True) - with open(targe_file, "wb") as file: - target_data.to_excel(file, index=False) if __name__ == "__main__": @@ -1172,12 +1338,24 @@ if __name__ == "__main__": verify_data_sheet: str = "total_data" # verify_document_list_file: str = "./sample_documents/aus_prospectus_29_documents_sample.txt" verify_document_list_file_list = [None, "./sample_documents/aus_prospectus_29_documents_sample.txt", "./sample_documents/aus_prospectus_17_documents_sample.txt"] + is_for_all = False + # for verify_document_list_file in verify_document_list_file_list: + # calculate_metrics_based_db_data_file(audit_file_path=audit_file_path, + # audit_data_sheet=audit_data_sheet, + # verify_file_path=verify_file_path, + # verify_data_sheet=verify_data_sheet, + # verify_document_list_file = verify_document_list_file, + # is_for_all=is_for_all) + for verify_document_list_file in verify_document_list_file_list: - calculate_metrics_based_db_data_file(audit_file_path=audit_file_path, - audit_data_sheet=audit_data_sheet, - verify_file_path=verify_file_path, - verify_data_sheet=verify_data_sheet, - verify_document_list_file = verify_document_list_file) + calculate_metrics_by_provider(audit_file_path=audit_file_path, + audit_data_sheet=audit_data_sheet, + verify_file_path=verify_file_path, + verify_data_sheet=verify_data_sheet, + verify_document_list_file = verify_document_list_file, + is_for_all=is_for_all) + + # set_mapping_to_17_documents_data() # set_mapping_to_ravi_data() diff --git a/core/data_extraction.py b/core/data_extraction.py index 715546b..1998d61 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -576,7 +576,7 @@ class DataExtraction: previous_page_datapoints = [] previous_page_fund_name = None for page_num, page_text in self.page_text_dict.items(): - # if page_num != 24: + # if page_num != 21: # continue if page_num in handled_page_num_list: continue diff --git a/core/data_mapping.py b/core/data_mapping.py index 8578c1c..e50798e 100644 --- a/core/data_mapping.py +++ b/core/data_mapping.py @@ -228,7 +228,180 @@ class DataMapping: mapped_data["similarity"] = 1 self.output_mapping_file(mapped_data_list) - return mapped_data_list + + if self.doc_source == "aus_prospectus": + output_data_folder_splits = self.output_data_excel_folder.split("output") + if len(output_data_folder_splits) == 2: + merged_data_folder = f'{output_data_folder_splits[0]}output/merged_data/docs/' + os.makedirs(merged_data_folder, exist_ok=True) + + merged_data_json_folder = os.path.join(merged_data_folder, "json/") + os.makedirs(merged_data_json_folder, exist_ok=True) + + merged_data_excel_folder = os.path.join(merged_data_folder, "excel/") + os.makedirs(merged_data_excel_folder, exist_ok=True) + merged_data_list = self.merge_output_data_aus_prospectus(mapped_data_list, + merged_data_json_folder, + merged_data_excel_folder) + return merged_data_list + else: + return mapped_data_list + + def merge_output_data_aus_prospectus(self, + mapped_data_list: list, + merged_data_json_folder: str, + merged_data_excel_folder: str): + # TODO: merge output data for aus prospectus, plan to realize it on 2025-01-16 + if mapped_data_list is None or len(mapped_data_list) == 0: + return + if merged_data_json_folder is None or len(merged_data_json_folder) == 0: + return + if merged_data_excel_folder is None or len(merged_data_excel_folder) == 0: + return + mapping_data_df = pd.DataFrame(mapped_data_list) + mapping_data_df.reset_index(drop=True, inplace=True) + mapping_data_df.fillna("", inplace=True) + + document_mapping_df = self.document_mapping_info_df + document_mapping_df.fillna("", inplace=True) + + datapoint_keyword_config_file = ( + f"./configuration/{self.doc_source}/datapoint_name.json" + ) + with open(datapoint_keyword_config_file, "r", encoding="utf-8") as f: + datapoint_keyword_config = json.load(f) + datapoint_name_list = list(datapoint_keyword_config.keys()) + total_data_list = [] + + doc_date = str(document_mapping_df["EffectiveDate"].values[0])[0:10] + share_doc_data_df = mapping_data_df[(mapping_data_df["investment_type"] == 1)] + exist_raw_name_list = [] + for index, row in share_doc_data_df.iterrows(): + doc_id = str(row["doc_id"]) + page_index = int(row["page_index"]) + raw_fund_name = str(row["raw_fund_name"]) + raw_share_name = str(row["raw_share_name"]) + raw_name = str(row["raw_name"]) + datapoint = str(row["datapoint"]) + value = row["value"] + investment_type = row["investment_type"] + share_class_id = row["investment_id"] + share_class_legal_name = row["investment_name"] + fund_id = "" + fund_legal_name = "" + if share_class_id != "": + record_row = document_mapping_df[document_mapping_df["SecId"] == share_class_id] + if len(record_row) > 0: + fund_id = record_row["FundId"].values[0] + fund_legal_name = record_row["FundName"].values[0] + + exist = False + for exist_raw_name_info in exist_raw_name_list: + exist_raw_name = exist_raw_name_info["raw_name"] + exist_investment_type = exist_raw_name_info["investment_type"] + exist_investment_id = exist_raw_name_info["investment_id"] + if ( + exist_raw_name == raw_name + and exist_investment_type == investment_type + ) or (len(exist_investment_id) > 0 and exist_investment_id == share_class_id): + exist = True + break + if not exist: + data = { + "DocumentId": doc_id, + "raw_fund_name": raw_fund_name, + "raw_share_name": raw_share_name, + "raw_name": raw_name, + "fund_id": fund_id, + "fund_name": fund_legal_name, + "sec_id": share_class_id, + "sec_name": share_class_legal_name, + "EffectiveDate": doc_date, + "page_index": [], + "RawName": raw_name, + } + for datapoint_name in datapoint_name_list: + data[datapoint_name] = "" + exist_raw_name_list.append( + {"raw_name": raw_name, "investment_type": investment_type, "investment_id": share_class_id} + ) + total_data_list.append(data) + # find data from total_data_list by raw_name + for data in total_data_list: + if data["raw_name"] == raw_name: + update_key = datapoint + data[update_key] = value + if page_index not in data["page_index"]: + data["page_index"].append(page_index) + break + if len(share_class_id) > 0 and data["sec_id"] == share_class_id: + update_key = datapoint + if len(str(data[update_key])) == 0: + data[update_key] = value + if page_index not in data["page_index"]: + data["page_index"].append(page_index) + break + + fund_doc_data_df = mapping_data_df[(mapping_data_df["investment_type"] == 33)] + fund_doc_data_df.fillna("", inplace=True) + for index, row in fund_doc_data_df.iterrows(): + doc_id = str(row["doc_id"]) + page_index = int(row["page_index"]) + raw_fund_name = str(row["raw_fund_name"]) + raw_share_name = "" + raw_name = str(row["raw_name"]) + datapoint = str(row["datapoint"]) + value = row["value"] + fund_id = row["investment_id"] + fund_legal_name = row["investment_name"] + exist = False + if fund_id != "": + for data in total_data_list: + if (fund_id != "" and data["fund_id"] == fund_id) or ( + data["raw_fund_name"] == raw_fund_name + ): + update_key = datapoint + data[update_key] = value + if page_index not in data["page_index"]: + data["page_index"].append(page_index) + exist = True + else: + for data in total_data_list: + if data["raw_name"] == raw_name: + update_key = datapoint + data[update_key] = value + if page_index not in data["page_index"]: + data["page_index"].append(page_index) + exist = True + if not exist: + data = { + "DocumentId": doc_id, + "raw_fund_name": raw_fund_name, + "raw_share_name": "", + "raw_name": raw_name, + "fund_id": fund_id, + "fund_name": fund_legal_name, + "sec_id": "", + "sec_name": "", + "EffectiveDate": doc_date, + "page_index": [page_index], + "RawName": raw_name, + } + for datapoint_name in datapoint_name_list: + data[datapoint_name] = "" + data[datapoint] = value + total_data_list.append(data) + total_data_df = pd.DataFrame(total_data_list) + total_data_df.fillna("", inplace=True) + + merged_data_excel_file = os.path.join(merged_data_excel_folder, f"merged_{self.doc_id}.xlsx") + with pd.ExcelWriter(merged_data_excel_file) as writer: + total_data_df.to_excel(writer, index=False, sheet_name="merged_data") + + merged_data_json_file = os.path.join(merged_data_json_folder, f"merged_{self.doc_id}.json") + with open(merged_data_json_file, "w", encoding="utf-8") as f: + json.dump(total_data_list, f, ensure_ascii=False, indent=4) + return total_data_list def get_raw_name_db_match_result( self, raw_name_list, investment_type: str, iter_count: int = 30 diff --git a/main.py b/main.py index 64fe6e0..2a145aa 100644 --- a/main.py +++ b/main.py @@ -499,7 +499,17 @@ def batch_start_job( ) logger.info(f"Saving mapping data to {output_mapping_total_folder}") - unique_doc_ids = result_mappingdata_df["doc_id"].unique().tolist() + result_mappingdata_df_columns = list(result_mappingdata_df.columns) + doc_id_column = "" + if "doc_id" in result_mappingdata_df_columns: + doc_id_column = "doc_id" + if "DocumentId" in result_mappingdata_df_columns: + doc_id_column = "DocumentId" + + if doc_id_column == "": + logger.error(f"Cannot find doc_id column in mapping data") + return + unique_doc_ids = result_mappingdata_df[doc_id_column].unique().tolist() os.makedirs(output_mapping_total_folder, exist_ok=True) time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) file_name = f"mapping_data_info_{len(unique_doc_ids)}_documents_by_{extract_way}_{time_stamp}.xlsx" @@ -507,11 +517,11 @@ def batch_start_job( file_name = f"{total_data_prefix}_{file_name}" output_file = os.path.join(output_mapping_total_folder, file_name) - doc_mapping_data_in_db = only_output_mapping_data_in_db(result_mappingdata_df) + # doc_mapping_data_in_db = only_output_mapping_data_in_db(result_mappingdata_df) with pd.ExcelWriter(output_file) as writer: - doc_mapping_data_in_db.to_excel( - writer, index=False, sheet_name="data_in_doc_mapping" - ) + # doc_mapping_data_in_db.to_excel( + # writer, index=False, sheet_name="data_in_doc_mapping" + # ) result_mappingdata_df.to_excel( writer, index=False, sheet_name="total_mapping_data" ) @@ -519,27 +529,6 @@ def batch_start_job( writer, index=False, sheet_name="extract_data" ) - if ( - doc_source == "aus_prospectus" - and document_mapping_file is not None - and len(document_mapping_file) > 0 - and os.path.exists(document_mapping_file) - ): - try: - merged_total_data_folder = os.path.join( - output_mapping_total_folder, "merged/" - ) - os.makedirs(merged_total_data_folder, exist_ok=True) - data_file_base_name = os.path.basename(output_file) - output_merged_data_file_path = os.path.join( - merged_total_data_folder, "merged_" + data_file_base_name - ) - merge_output_data_aus_prospectus( - output_file, document_mapping_file, output_merged_data_file_path - ) - except Exception as e: - logger.error(f"Error: {e}") - if calculate_metrics: prediction_sheet_name = "data_in_doc_mapping" ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" @@ -1527,7 +1516,7 @@ if __name__ == "__main__": document_mapping_file = r"/data/aus_prospectus/basic_information/46_documents/aus_prospectus_46_documents_mapping.xlsx" # special_doc_id_list: list = ["410899007", "539266880", "539266817", # "539261734", "539266893"] - # special_doc_id_list: list = ["401212184"] + # special_doc_id_list: list = ["539266880"] pdf_folder: str = r"/data/aus_prospectus/pdf/" output_pdf_text_folder: str = r"/data/aus_prospectus/output/pdf_text/" output_extract_data_child_folder: str = ( diff --git a/prepare_data.py b/prepare_data.py index cf4de36..72b9b1b 100644 --- a/prepare_data.py +++ b/prepare_data.py @@ -8,10 +8,12 @@ import re import time import traceback import json_repair +from copy import deepcopy from utils.logger import logger from utils.pdf_download import download_pdf_from_documents_warehouse from utils.pdf_util import PDFUtil +from core.auz_nz.hybrid_solution_script import final_function_to_match def get_unique_docids_from_doc_provider_data(doc_provider_file_path: str): @@ -1463,18 +1465,290 @@ def prepare_multi_fund_aus_prospectus_document(data_folder: str = r"/data/aus_pr with open(output_sample_document_file, "w") as f: for doc_id in document_id_list: f.write(f"{doc_id}\n") + + +def set_mapping_to_ravi_data(): + data_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees.xlsx" + data_sheet = "Sheet1" + mapping_file_path = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx" + mapping_sheet = "document_mapping" + output_file_path = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx" + set_mapping_to_raw_name_data(data_file_path, data_sheet, mapping_file_path, mapping_sheet, output_file_path) + + +def set_mapping_to_data_side_documents_data(): + # data_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/Audited file_phase2.xlsx" + # data_sheet = "all" + # mapping_file_path = r"/data/aus_prospectus/basic_information/17_documents/aus_prospectus_17_documents_mapping.xlsx" + # mapping_sheet = "document_mapping" + # output_file_path = r"/data/aus_prospectus/output/ravi_100_documents/audited_file_phase2_with_mapping.xlsx" + + data_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth.xlsx" + data_sheet = "ground_truth" + raw_name_column = "raw_share_name" + mapping_file_path = r"/data/aus_prospectus/basic_information/46_documents/aus_prospectus_46_documents_mapping.xlsx" + mapping_sheet = "document_mapping" + raw_name_mapping_column = None + output_file_path = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx" + set_mapping_to_raw_name_data(data_file_path=data_file_path, + data_sheet=data_sheet, + raw_name_column=raw_name_column, + mapping_file_path=mapping_file_path, + mapping_sheet=mapping_sheet, + raw_name_mapping_column=raw_name_mapping_column, + output_file_path=output_file_path) + + +def set_mapping_to_raw_name_data(data_file_path: str = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees.xlsx", + data_sheet: str = "Sheet1", + raw_name_column: str = "raw_share_name", + mapping_file_path: str = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx", + mapping_sheet: str = "document_mapping", + raw_name_mapping_column: str = None, + output_file_path: str = r"/data/aus_prospectus/output/ravi_100_documents/AUS_Extracted_Fees_with_mapping.xlsx"): + data_df = pd.read_excel(data_file_path, sheet_name=data_sheet) + data_df["provider_id"] = "" + data_df["provider_name"] = "" + data_df["fund_id"] = "" + data_df["fund_name"] = "" + data_df["sec_id"] = "" + data_df["sec_name"] = "" + + mapping_data = pd.read_excel(mapping_file_path, sheet_name=mapping_sheet) + + doc_id_list = data_df["doc_id"].unique().tolist() + for doc_id in doc_id_list: + doc_data = data_df[data_df["doc_id"] == doc_id] + raw_name_list = doc_data[raw_name_column].unique().tolist() + + doc_mapping_data = mapping_data[mapping_data["DocumentId"] == doc_id] + if len(doc_mapping_data) == 0: + continue + provider_id = doc_mapping_data["CompanyId"].values[0] + provider_name = doc_mapping_data["CompanyName"].values[0] + data_df.loc[(data_df["doc_id"] == doc_id), "provider_id"] = provider_id + data_df.loc[(data_df["doc_id"] == doc_id), "provider_name"] = provider_name + if raw_name_mapping_column is not None and raw_name_mapping_column == "FundLegalName": + doc_db_name_list = doc_mapping_data[raw_name_mapping_column].unique().tolist() + for raw_name in raw_name_list: + find_df = doc_mapping_data[doc_mapping_data[raw_name_mapping_column] == raw_name] + if find_df is not None and len(find_df) == 1: + sec_id = find_df["FundClassId"].values[0] + sec_name = find_df["FundClassLegalName"].values[0] + fund_id = find_df["FundId"].values[0] + fund_name = find_df["FundLegalName"].values[0] + # update doc_data which raw_share_name is same as raw_share_name + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "sec_id"] = sec_id + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "sec_name"] = sec_name + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "fund_id"] = fund_id + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_name), "fund_name"] = fund_name + else: + doc_db_name_list = doc_mapping_data["FundClassLegalName"].unique().tolist() + all_match_result = get_raw_name_db_match_result(doc_id, + provider_name, + raw_name_list, + doc_db_name_list, + iter_count=60) + for raw_share_name in raw_name_list: + if all_match_result.get(raw_share_name) is not None: + matched_db_share_name = all_match_result[raw_share_name] + if ( + matched_db_share_name is not None + and len(matched_db_share_name) > 0 + ): + # get SecId from self.doc_fund_class_mapping + find_share_df = doc_mapping_data[doc_mapping_data["FundClassLegalName"] == matched_db_share_name] + if find_share_df is not None and len(find_share_df) > 0: + sec_id = find_share_df["FundClassId"].values[0] + fund_id = find_share_df["FundId"].values[0] + fund_name = find_share_df["FundLegalName"].values[0] + # update doc_data which raw_share_name is same as raw_share_name + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "sec_id"] = sec_id + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "sec_name"] = matched_db_share_name + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "fund_id"] = fund_id + data_df.loc[(data_df["doc_id"] == doc_id) & (data_df[raw_name_column] == raw_share_name), "fund_name"] = fund_name + try: + data_df = data_df[["doc_id", + "provider_id", + "provider_name", + "raw_fund_name", + "fund_id", + "fund_name", + "raw_share_name", + "sec_id", + "sec_name", + "management_fee_and_costs", + "management_fee", + "administration_fees", + "minimum_initial_investment", + "benchmark_name", + "performance_fee", + "performance_fee_charged", + "buy_spread", + "sell_spread", + "total_annual_dollar_based_charges", + "interposed_vehicle_performance_fee_cost", + "establishment_fee", + "contribution_fee", + "withdrawal_fee", + "exit_fee", + "switching_fee", + "activity_fee", + "hurdle_rate", + "analyst_name" + ]] + except Exception as e: + print(e) + + with open(output_file_path, "wb") as file: + data_df.to_excel(file, index=False) + + +def get_raw_name_db_match_result( + doc_id: str, provider_name: str, raw_name_list: list, doc_share_name_list: list, iter_count: int = 30 + ): + # split raw_name_list into several parts which each part is with 30 elements + # The reason to split is to avoid invoke token limitation issues from CahtGPT + raw_name_list_parts = [ + raw_name_list[i : i + iter_count] + for i in range(0, len(raw_name_list), iter_count) + ] + all_match_result = {} + doc_share_name_list = deepcopy(doc_share_name_list) + for raw_name_list in raw_name_list_parts: + match_result, doc_share_name_list = get_final_function_to_match( + doc_id, provider_name, raw_name_list, doc_share_name_list + ) + all_match_result.update(match_result) + return all_match_result + +def get_final_function_to_match(doc_id, provider_name, raw_name_list, db_name_list): + if len(db_name_list) == 0: + match_result = {} + for raw_name in raw_name_list: + match_result[raw_name] = "" + else: + match_result = final_function_to_match( + doc_id=doc_id, + pred_list=raw_name_list, + db_list=db_name_list, + provider_name=provider_name, + doc_source="aus_prospectus" + ) + matched_name_list = list(match_result.values()) + db_name_list = remove_matched_names(db_name_list, matched_name_list) + return match_result, db_name_list + +def remove_matched_names(target_name_list: list, matched_name_list: list): + if len(matched_name_list) == 0: + return target_name_list + + matched_name_list = list(set(matched_name_list)) + matched_name_list = [ + value for value in matched_name_list if value is not None and len(value) > 0 + ] + for matched_name in matched_name_list: + if ( + matched_name is not None + and len(matched_name) > 0 + and matched_name in target_name_list + ): + target_name_list.remove(matched_name) + return target_name_list + + +def adjust_data_file(source_file: str, + targe_file: str): + source_data = pd.read_excel(source_file, sheet_name="Sheet1") + source_doc_id_list = source_data["DocumentId"].unique().tolist() + + target_data = pd.read_excel(targe_file, sheet_name="Sheet1") + #remove target_data which doc_id is in source_doc_id_list + target_data = target_data[~target_data["DocumentId"].isin(source_doc_id_list)] + # concat source_data and target_data + target_data = pd.concat([source_data, target_data], ignore_index=True) + with open(targe_file, "wb") as file: + target_data.to_excel(file, index=False) + + +def set_provider_to_ground_truth(groud_truth_file: str, + ground_truth_sheet: str, + document_mapping_file: str, + document_mapping_sheet: str): + ground_truth_df = pd.read_excel(groud_truth_file, sheet_name=ground_truth_sheet) + ground_truth_df["provider_id"] = "" + ground_truth_df["provider_name"] = "" + + mapping_data = pd.read_excel(document_mapping_file, sheet_name=document_mapping_sheet) + + doc_id_list = ground_truth_df["DocumentId"].unique().tolist() + for doc_id in doc_id_list: + doc_mapping_data = mapping_data[mapping_data["DocumentId"] == doc_id] + if len(doc_mapping_data) == 0: + continue + provider_id = doc_mapping_data["CompanyId"].values[0] + provider_name = doc_mapping_data["CompanyName"].values[0] + ground_truth_df.loc[(ground_truth_df["DocumentId"] == doc_id), "provider_id"] = provider_id + ground_truth_df.loc[(ground_truth_df["DocumentId"] == doc_id), "provider_name"] = provider_name + try: + ground_truth_df = ground_truth_df[["DocumentId", + "provider_id", + "provider_name", + "raw_fund_name", + "FundId", + "FundLegalName", + "raw_share_name", + "FundClassId", + "FundClassLegalName", + "management_fee_and_costs", + "management_fee", + "administration_fees", + "minimum_initial_investment", + "benchmark_name", + "performance_fee", + "performance_fee_charged", + "buy_spread", + "sell_spread", + "total_annual_dollar_based_charges", + "interposed_vehicle_performance_fee_cost", + "establishment_fee", + "contribution_fee", + "withdrawal_fee", + "exit_fee", + "switching_fee", + "activity_fee", + "hurdle_rate", + "analyst_name" + ]] + except Exception as e: + print(e) + + with open(groud_truth_file, "wb") as file: + ground_truth_df.to_excel(file, index=False) - if __name__ == "__main__": + set_provider_to_ground_truth( + groud_truth_file=r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx", + ground_truth_sheet="Sheet1", + document_mapping_file=r"/data/aus_prospectus/basic_information/46_documents/aus_prospectus_46_documents_mapping.xlsx", + document_mapping_sheet="document_mapping" + ) + + # set_mapping_to_data_side_documents_data() + + # source_file = r"/data/aus_prospectus/ground_truth/phase2_file/17_documents/audited_file_phase2_with_mapping.xlsx" + # target_file = r"/data/aus_prospectus/ground_truth/phase2_file/46_documents/46_documents_ground_truth_with_mapping.xlsx" + # adjust_data_file(source_file=source_file, targe_file=target_file) + # pdf_exist() # prepare_multi_fund_aus_prospectus_document() - merge_aus_document_prospectus_data(aus_data_folder=r"/data/aus_prospectus/basic_information/17_documents/", - aus_document_mapping_file="aus_prospectus_17_documents_mapping.xlsx", - aus_prospectus_data_file="aus_prospectus_data_17_documents_secid.xlsx", - document_mapping_sheet="document_mapping", - output_file="aus_prospectus_17_documents_data.xlsx", - output_sheet="aus_document_prospectus") + # merge_aus_document_prospectus_data(aus_data_folder=r"/data/aus_prospectus/basic_information/17_documents/", + # aus_document_mapping_file="aus_prospectus_17_documents_mapping.xlsx", + # aus_prospectus_data_file="aus_prospectus_data_17_documents_secid.xlsx", + # document_mapping_sheet="document_mapping", + # output_file="aus_prospectus_17_documents_data.xlsx", + # output_sheet="aus_document_prospectus") folder = r"/data/emea_ar/basic_information/English/sample_doc/emea_11_06_case/" file_name = "doc_ar_data_for_emea_11_06.xlsx" # get_document_with_all_4_data_points(folder, file_name, None)