1. only get name mapping data from document mapping

2. Compare name mapping metrics between Ravi's and mine.
This commit is contained in:
Blade He 2025-01-27 12:29:49 -06:00
parent 350550d1b0
commit 47c41e492f
5 changed files with 1409 additions and 149 deletions

View File

@ -69,7 +69,8 @@ def emea_ar_data_extract():
output_extract_data_folder=output_extract_data_folder, output_extract_data_folder=output_extract_data_folder,
output_mapping_data_folder=output_mapping_data_folder, output_mapping_data_folder=output_mapping_data_folder,
extract_way=extract_way, extract_way=extract_way,
drilldown_folder=drilldown_folder) drilldown_folder=drilldown_folder,
compare_with_provider=False)
doc_data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run=re_run_extract_data) doc_data_from_gpt, annotation_list = emea_ar_parsing.extract_data(re_run=re_run_extract_data)
doc_mapping_data = emea_ar_parsing.mapping_data( doc_mapping_data = emea_ar_parsing.mapping_data(
data_from_gpt=doc_data_from_gpt, re_run=re_run_mapping_data data_from_gpt=doc_data_from_gpt, re_run=re_run_mapping_data

File diff suppressed because it is too large Load Diff

View File

@ -32,24 +32,24 @@ from openai import AzureOpenAI
ABB_JSON = dict() ABB_JSON = dict()
def get_abb_json(): def get_abb_json(doc_source: str = "aus_prospectus"):
global ABB_JSON global ABB_JSON
if len(ABB_JSON.keys()) == 0: if len(ABB_JSON.keys()) == 0:
with open("./configuration/aus_prospectus/abbreviation_records.json", "r") as file: with open(f"./configuration/{doc_source}/abbreviation_records.json", "r") as file:
# Load the JSON and convert keys to lowercase # Load the JSON and convert keys to lowercase
ABB_JSON = {key.lower(): value for key, value in json.load(file).items()} ABB_JSON = {key.lower(): value for key, value in json.load(file).items()}
def get_abbre_format_str(fundname): def get_abbre_format_str(fundname, doc_source: str = "aus_prospectus"):
"""Replaces abbreviations in a fund name with their expanded forms.""" """Replaces abbreviations in a fund name with their expanded forms."""
# Convert fund name to lowercase while matching # Convert fund name to lowercase while matching
f_list = fundname.lower().split() f_list = fundname.lower().split()
get_abb_json() get_abb_json(doc_source)
updated_doc_fname_words = [ABB_JSON.get(word, word).lower() for word in f_list] updated_doc_fname_words = [ABB_JSON.get(word, word).lower() for word in f_list]
return " ".join(updated_doc_fname_words) return " ".join(updated_doc_fname_words)
def replace_abbrevs_in_fundnames(fund_names_list): def replace_abbrevs_in_fundnames(fund_names_list, doc_source: str = "aus_prospectus"):
"""Replaces abbreviations in a list of fund names.""" """Replaces abbreviations in a list of fund names."""
return [get_abbre_format_str(fund_name) for fund_name in fund_names_list] return [get_abbre_format_str(fund_name, doc_source) for fund_name in fund_names_list]
### STEP 2 - Remove Stopwords ### STEP 2 - Remove Stopwords
@ -440,7 +440,7 @@ def format_response(doc_id, pred_fund, db_fund, clean_pred_name, clean_db_name,
return dt return dt
def final_function_to_match(doc_id, pred_list, db_list, provider_name): def final_function_to_match(doc_id, pred_list, db_list, provider_name, doc_source: str = "aus_prospectus"):
final_result = {} final_result = {}
df_data = [] df_data = []
unmatched_pred_list = pred_list.copy() unmatched_pred_list = pred_list.copy()
@ -466,8 +466,8 @@ def final_function_to_match(doc_id, pred_list, db_list, provider_name):
# unmatched_pred_list.remove(pred_list[index]) # unmatched_pred_list.remove(pred_list[index])
else: else:
### STEP-1 Abbreviation replacement ### STEP-1 Abbreviation replacement
cleaned_pred_name1 = replace_abbrevs_in_fundnames([pred_fund])[0] cleaned_pred_name1 = replace_abbrevs_in_fundnames([pred_fund], doc_source)[0]
cleaned_db_list1 = replace_abbrevs_in_fundnames(db_list) cleaned_db_list1 = replace_abbrevs_in_fundnames(db_list, doc_source)
# print("--> ",cleaned_db_list1, cleaned_pred_name1) # print("--> ",cleaned_db_list1, cleaned_pred_name1)
step1_result, matched_index, all_scores1_, all_matched_fund_names1_ = get_fund_match_final_score(cleaned_db_list1, cleaned_pred_name1) step1_result, matched_index, all_scores1_, all_matched_fund_names1_ = get_fund_match_final_score(cleaned_db_list1, cleaned_pred_name1)
# print(f"\nStep 1 - Abbreviation Replacement Result: {step1_result}") # print(f"\nStep 1 - Abbreviation Replacement Result: {step1_result}")
@ -617,11 +617,11 @@ def final_function_to_match(doc_id, pred_list, db_list, provider_name):
# print("==>>> DB LIST: ",unmatched_db_list) # print("==>>> DB LIST: ",unmatched_db_list)
# print("==>>> PRED LIST: ",unmatched_pred_list) # print("==>>> PRED LIST: ",unmatched_pred_list)
if len(unmatched_pred_list)!=0: if len(unmatched_pred_list)!=0:
cleaned_unmatched_pred_list = replace_abbrevs_in_fundnames(unmatched_pred_list) cleaned_unmatched_pred_list = replace_abbrevs_in_fundnames(unmatched_pred_list, doc_source)
cleaned_unmatched_pred_list = remove_stopwords_nltk(cleaned_unmatched_pred_list) cleaned_unmatched_pred_list = remove_stopwords_nltk(cleaned_unmatched_pred_list)
cleaned_unmatched_pred_list = remove_special_characters(cleaned_unmatched_pred_list) cleaned_unmatched_pred_list = remove_special_characters(cleaned_unmatched_pred_list)
cleaned_unmatched_db_list = replace_abbrevs_in_fundnames(unmatched_db_list) cleaned_unmatched_db_list = replace_abbrevs_in_fundnames(unmatched_db_list, doc_source)
cleaned_unmatched_db_list = remove_stopwords_nltk(cleaned_unmatched_db_list) cleaned_unmatched_db_list = remove_stopwords_nltk(cleaned_unmatched_db_list)
cleaned_unmatched_db_list = remove_special_characters(cleaned_unmatched_db_list) cleaned_unmatched_db_list = remove_special_characters(cleaned_unmatched_db_list)
prompt_context = f""" prompt_context = f"""

View File

@ -1,6 +1,7 @@
import os import os
import json import json
import pandas as pd import pandas as pd
from copy import deepcopy
from utils.biz_utils import get_most_similar_name, remove_common_word from utils.biz_utils import get_most_similar_name, remove_common_word
from utils.sql_query_util import ( from utils.sql_query_util import (
query_document_fund_mapping, query_document_fund_mapping,
@ -18,14 +19,18 @@ class DataMapping:
raw_document_data_list: list, raw_document_data_list: list,
document_mapping_info_df: pd.DataFrame, document_mapping_info_df: pd.DataFrame,
output_data_folder: str, output_data_folder: str,
doc_source: str = "emea_ar" doc_source: str = "emea_ar",
compare_with_provider: bool = True
): ):
self.doc_id = doc_id self.doc_id = doc_id
self.datapoints = datapoints self.datapoints = datapoints
self.doc_source = doc_source self.doc_source = doc_source
self.compare_with_provider = compare_with_provider
self.raw_document_data_list = raw_document_data_list self.raw_document_data_list = raw_document_data_list
if document_mapping_info_df is None or len(document_mapping_info_df) == 0: if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False) self.document_mapping_info_df = query_document_fund_mapping(
doc_id, rerun=False
)
else: else:
self.document_mapping_info_df = document_mapping_info_df self.document_mapping_info_df = document_mapping_info_df
@ -44,7 +49,9 @@ class DataMapping:
def set_mapping_data_by_db(self, document_mapping_info_df: pd.DataFrame): def set_mapping_data_by_db(self, document_mapping_info_df: pd.DataFrame):
logger.info("Setting document mapping data") logger.info("Setting document mapping data")
if document_mapping_info_df is None or len(document_mapping_info_df) == 0: if document_mapping_info_df is None or len(document_mapping_info_df) == 0:
self.document_mapping_info_df = query_document_fund_mapping(self.doc_id, rerun=False) self.document_mapping_info_df = query_document_fund_mapping(
self.doc_id, rerun=False
)
else: else:
self.document_mapping_info_df = document_mapping_info_df self.document_mapping_info_df = document_mapping_info_df
if len(self.document_mapping_info_df) == 0: if len(self.document_mapping_info_df) == 0:
@ -92,26 +99,27 @@ class DataMapping:
def get_provider_mapping(self): def get_provider_mapping(self):
if len(self.document_mapping_info_df) == 0: if len(self.document_mapping_info_df) == 0:
return pd.DataFrame() return pd.DataFrame()
provider_id_list = ( provider_id_list = self.document_mapping_info_df["ProviderId"].unique().tolist()
self.document_mapping_info_df["ProviderId"].unique().tolist()
)
provider_mapping_list = [] provider_mapping_list = []
for provider_id in provider_id_list: for provider_id in provider_id_list:
provider_mapping_list.append(query_investment_by_provider(provider_id, rerun=False)) provider_mapping_list.append(
query_investment_by_provider(provider_id, rerun=False)
)
provider_mapping_df = pd.concat(provider_mapping_list) provider_mapping_df = pd.concat(provider_mapping_list)
provider_mapping_df = provider_mapping_df.drop_duplicates() provider_mapping_df = provider_mapping_df.drop_duplicates()
provider_mapping_df.reset_index(drop=True, inplace=True) provider_mapping_df.reset_index(drop=True, inplace=True)
return provider_mapping_df return provider_mapping_df
def mapping_raw_data_entrance(self): def mapping_raw_data_entrance(self):
if self.doc_source == "emear_ar": if self.doc_source == "emea_ar":
return self.mapping_raw_data() return self.mapping_raw_data()
elif self.doc_source == "aus_prospectus": elif self.doc_source == "aus_prospectus":
return self.mapping_raw_data_aus() return self.mapping_raw_data_generic()
else: else:
return self.mapping_raw_data() return self.mapping_raw_data()
# return self.mapping_raw_data_generic()
def mapping_raw_data_aus(self): def mapping_raw_data_generic(self):
logger.info(f"Mapping raw data for AUS Prospectus document {self.doc_id}") logger.info(f"Mapping raw data for AUS Prospectus document {self.doc_id}")
mapped_data_list = [] mapped_data_list = []
# Generate raw name based on fund name and share name by integrate_share_name # Generate raw name based on fund name and share name by integrate_share_name
@ -128,7 +136,9 @@ class DataMapping:
raw_share_name = raw_data.get("share_name", "") raw_share_name = raw_data.get("share_name", "")
raw_data_keys = list(raw_data.keys()) raw_data_keys = list(raw_data.keys())
if len(raw_share_name) > 0: if len(raw_share_name) > 0:
integrated_share_name = self.integrate_share_name(raw_fund_name, raw_share_name) integrated_share_name = self.integrate_share_name(
raw_fund_name, raw_share_name
)
if integrated_share_name not in share_raw_name_list: if integrated_share_name not in share_raw_name_list:
share_raw_name_list.append(integrated_share_name) share_raw_name_list.append(integrated_share_name)
for datapoint in self.datapoints: for datapoint in self.datapoints:
@ -144,7 +154,7 @@ class DataMapping:
"investment_type": 1, "investment_type": 1,
"investment_id": "", "investment_id": "",
"investment_name": "", "investment_name": "",
"similarity": 0 "similarity": 0,
} }
mapped_data_list.append(mapped_data) mapped_data_list.append(mapped_data)
else: else:
@ -162,19 +172,23 @@ class DataMapping:
"value": raw_data[datapoint], "value": raw_data[datapoint],
"investment_type": 33, "investment_type": 33,
"investment_id": "", "investment_id": "",
"investment_name": "" "investment_name": "",
} }
mapped_data_list.append(mapped_data) mapped_data_list.append(mapped_data)
# Mapping raw data with database # Mapping raw data with database
iter_count = 30 iter_count = 60
fund_match_result = {} fund_match_result = {}
if len(fund_raw_name_list) > 0: if len(fund_raw_name_list) > 0:
fund_match_result = self.get_raw_name_db_match_result(fund_raw_name_list, "fund", iter_count) fund_match_result = self.get_raw_name_db_match_result(
logger.info(f"Fund match result: \n{fund_match_result}") fund_raw_name_list, "fund", iter_count
)
# logger.info(f"Fund match result: \n{fund_match_result}")
share_match_result = {} share_match_result = {}
if len(share_raw_name_list) > 0: if len(share_raw_name_list) > 0:
share_match_result = self.get_raw_name_db_match_result(share_raw_name_list, "share", iter_count) share_match_result = self.get_raw_name_db_match_result(
logger.info(f"Share match result: \n{share_match_result}") share_raw_name_list, "share", iter_count
)
# logger.info(f"Share match result: \n{share_match_result}")
for mapped_data in mapped_data_list: for mapped_data in mapped_data_list:
investment_type = mapped_data["investment_type"] investment_type = mapped_data["investment_type"]
@ -182,9 +196,14 @@ class DataMapping:
if investment_type == 33: if investment_type == 33:
if fund_match_result.get(raw_name) is not None: if fund_match_result.get(raw_name) is not None:
matched_db_fund_name = fund_match_result[raw_name] matched_db_fund_name = fund_match_result[raw_name]
if matched_db_fund_name is not None and len(matched_db_fund_name) > 0: if (
matched_db_fund_name is not None
and len(matched_db_fund_name) > 0
):
# get FundId from self.doc_fund_mapping # get FundId from self.doc_fund_mapping
find_fund_df = self.doc_fund_mapping[self.doc_fund_mapping["FundName"] == matched_db_fund_name] find_fund_df = self.doc_fund_mapping[
self.doc_fund_mapping["FundName"] == matched_db_fund_name
]
if find_fund_df is not None and len(find_fund_df) > 0: if find_fund_df is not None and len(find_fund_df) > 0:
fund_id = find_fund_df["FundId"].values[0] fund_id = find_fund_df["FundId"].values[0]
mapped_data["investment_id"] = fund_id mapped_data["investment_id"] = fund_id
@ -193,9 +212,15 @@ class DataMapping:
if investment_type == 1: if investment_type == 1:
if share_match_result.get(raw_name) is not None: if share_match_result.get(raw_name) is not None:
matched_db_share_name = share_match_result[raw_name] matched_db_share_name = share_match_result[raw_name]
if matched_db_share_name is not None and len(matched_db_share_name) > 0: if (
matched_db_share_name is not None
and len(matched_db_share_name) > 0
):
# get SecId from self.doc_fund_class_mapping # get SecId from self.doc_fund_class_mapping
find_share_df = self.doc_fund_class_mapping[self.doc_fund_class_mapping["ShareClassName"] == matched_db_share_name] find_share_df = self.doc_fund_class_mapping[
self.doc_fund_class_mapping["ShareClassName"]
== matched_db_share_name
]
if find_share_df is not None and len(find_share_df) > 0: if find_share_df is not None and len(find_share_df) > 0:
share_id = find_share_df["SecId"].values[0] share_id = find_share_df["SecId"].values[0]
mapped_data["investment_id"] = share_id mapped_data["investment_id"] = share_id
@ -205,26 +230,64 @@ class DataMapping:
self.output_mapping_file(mapped_data_list) self.output_mapping_file(mapped_data_list)
return mapped_data_list return mapped_data_list
def get_raw_name_db_match_result(self, raw_name_list, investment_type: str, iter_count: int = 30): def get_raw_name_db_match_result(
self, raw_name_list, investment_type: str, iter_count: int = 30
):
# split raw_name_list into several parts which each part is with 30 elements # split raw_name_list into several parts which each part is with 30 elements
# The reason to split is to avoid invoke token limitation issues from CahtGPT # The reason to split is to avoid invoke token limitation issues from CahtGPT
raw_name_list_parts = [raw_name_list[i:i + iter_count] raw_name_list_parts = [
for i in range(0, len(raw_name_list), iter_count)] raw_name_list[i : i + iter_count]
for i in range(0, len(raw_name_list), iter_count)
]
all_match_result = {} all_match_result = {}
doc_fund_name_list = deepcopy(self.doc_fund_name_list)
doc_share_name_list = deepcopy(self.doc_share_name_list)
for raw_name_list in raw_name_list_parts: for raw_name_list in raw_name_list_parts:
if investment_type == "fund": if investment_type == "fund":
match_result = final_function_to_match(doc_id=self.doc_id, match_result, doc_fund_name_list = self.get_final_function_to_match(
pred_list=raw_name_list, raw_name_list, doc_fund_name_list
db_list=self.doc_fund_name_list, )
provider_name=self.provider_name)
else: else:
match_result = final_function_to_match(doc_id=self.doc_id, match_result, doc_share_name_list = self.get_final_function_to_match(
pred_list=raw_name_list, raw_name_list, doc_share_name_list
db_list=self.doc_share_name_list, )
provider_name=self.provider_name)
all_match_result.update(match_result) all_match_result.update(match_result)
return all_match_result return all_match_result
def get_final_function_to_match(self, raw_name_list, db_name_list):
if len(db_name_list) == 0:
match_result = {}
for raw_name in raw_name_list:
match_result[raw_name] = ""
else:
match_result = final_function_to_match(
doc_id=self.doc_id,
pred_list=raw_name_list,
db_list=db_name_list,
provider_name=self.provider_name,
doc_source=self.doc_source
)
matched_name_list = list(match_result.values())
db_name_list = self.remove_matched_names(db_name_list, matched_name_list)
return match_result, db_name_list
def remove_matched_names(self, target_name_list: list, matched_name_list: list):
if len(matched_name_list) == 0:
return target_name_list
matched_name_list = list(set(matched_name_list))
matched_name_list = [
value for value in matched_name_list if value is not None and len(value) > 0
]
for matched_name in matched_name_list:
if (
matched_name is not None
and len(matched_name) > 0
and matched_name in target_name_list
):
target_name_list.remove(matched_name)
return target_name_list
def mapping_raw_data(self): def mapping_raw_data(self):
""" """
doc_id, page_index, datapoint, value, doc_id, page_index, datapoint, value,
@ -245,9 +308,14 @@ class DataMapping:
if raw_fund_name is None or len(raw_fund_name) == 0: if raw_fund_name is None or len(raw_fund_name) == 0:
continue continue
raw_share_name = raw_data.get("share_name", "") raw_share_name = raw_data.get("share_name", "")
if len(self.doc_fund_name_list) == 0 and len(self.provider_fund_name_list) == 0: if (
len(self.doc_fund_name_list) == 0
and len(self.provider_fund_name_list) == 0
):
if len(raw_share_name) > 0: if len(raw_share_name) > 0:
integrated_share_name = self.integrate_share_name(raw_fund_name, raw_share_name) integrated_share_name = self.integrate_share_name(
raw_fund_name, raw_share_name
)
raw_data_keys = list(raw_data.keys()) raw_data_keys = list(raw_data.keys())
for datapoint in self.datapoints: for datapoint in self.datapoints:
if datapoint in raw_data_keys: if datapoint in raw_data_keys:
@ -262,7 +330,7 @@ class DataMapping:
"investment_type": 1, "investment_type": 1,
"investment_id": "", "investment_id": "",
"investment_name": "", "investment_name": "",
"similarity": 0 "similarity": 0,
} }
mapped_data_list.append(mapped_data) mapped_data_list.append(mapped_data)
else: else:
@ -279,13 +347,15 @@ class DataMapping:
"value": raw_data[datapoint], "value": raw_data[datapoint],
"investment_type": 33, "investment_type": 33,
"investment_id": "", "investment_id": "",
"investment_name": "" "investment_name": "",
} }
mapped_data_list.append(mapped_data) mapped_data_list.append(mapped_data)
else: else:
raw_name = "" raw_name = ""
if raw_share_name is not None and len(raw_share_name) > 0: if raw_share_name is not None and len(raw_share_name) > 0:
raw_name = self.integrate_share_name(raw_fund_name, raw_share_name) raw_name = self.integrate_share_name(
raw_fund_name, raw_share_name
)
if mapped_share_cache.get(raw_name) is not None: if mapped_share_cache.get(raw_name) is not None:
investment_info = mapped_share_cache[raw_name] investment_info = mapped_share_cache[raw_name]
else: else:
@ -298,13 +368,19 @@ class DataMapping:
) )
fund_id = fund_info["id"] fund_id = fund_info["id"]
mapped_fund_cache[raw_fund_name] = fund_info mapped_fund_cache[raw_fund_name] = fund_info
investment_info = {}
if len(fund_id) > 0:
investment_info = self.mapping_unique_raw_data(fund_id=fund_id,
raw_fund_name=raw_fund_name,
raw_data_list=raw_data_list)
if investment_info.get("id", None) is None or len(investment_info.get("id", "")) == 0:
investment_info = self.matching_with_database( investment_info = self.matching_with_database(
raw_name=raw_name, raw_name=raw_name,
raw_share_name=raw_share_name, raw_share_name=raw_share_name,
raw_fund_name=raw_fund_name, raw_fund_name=raw_fund_name,
parent_id=fund_id, parent_id=fund_id,
matching_type="share", matching_type="share",
process_cache=process_cache process_cache=process_cache,
) )
mapped_share_cache[raw_name] = investment_info mapped_share_cache[raw_name] = investment_info
elif raw_fund_name is not None and len(raw_fund_name) > 0: elif raw_fund_name is not None and len(raw_fund_name) > 0:
@ -322,7 +398,7 @@ class DataMapping:
"id": "", "id": "",
"legal_name": "", "legal_name": "",
"investment_type": -1, "investment_type": -1,
"similarity": 0 "similarity": 0,
} }
raw_data_keys = list(raw_data.keys()) raw_data_keys = list(raw_data.keys())
@ -339,13 +415,35 @@ class DataMapping:
"investment_type": investment_info["investment_type"], "investment_type": investment_info["investment_type"],
"investment_id": investment_info["id"], "investment_id": investment_info["id"],
"investment_name": investment_info["legal_name"], "investment_name": investment_info["legal_name"],
"similarity": investment_info["similarity"] "similarity": investment_info["similarity"],
} }
mapped_data_list.append(mapped_data) mapped_data_list.append(mapped_data)
self.output_mapping_file(mapped_data_list) self.output_mapping_file(mapped_data_list)
return mapped_data_list return mapped_data_list
def mapping_unique_raw_data(self, fund_id: str, raw_fund_name: str, raw_data_list: list):
share_count = 0
for raw_data in raw_data_list:
fund_name = raw_data.get("fund_name", "")
share_name = raw_data.get("share_name", "")
if fund_name == raw_fund_name and share_name is not None and len(share_name) > 0:
share_count += 1
if share_count > 1:
break
data_info = {}
if share_count == 1:
doc_compare_mapping = self.doc_fund_class_mapping[
self.doc_fund_class_mapping["FundId"] == fund_id
]
if len(doc_compare_mapping) == 1:
data_info["id"] = doc_compare_mapping["SecId"].values[0]
data_info["legal_name"] = doc_compare_mapping["ShareClassName"].values[0]
data_info["investment_type"] = 1
data_info["similarity"] = 1
return data_info
def output_mapping_file(self, mapped_data_list: list): def output_mapping_file(self, mapped_data_list: list):
json_data_file = os.path.join( json_data_file = os.path.join(
self.output_data_json_folder, f"{self.doc_id}.json" self.output_data_json_folder, f"{self.doc_id}.json"
@ -390,7 +488,7 @@ class DataMapping:
raw_fund_name: str = None, raw_fund_name: str = None,
parent_id: str = None, parent_id: str = None,
matching_type: str = "fund", matching_type: str = "fund",
process_cache: dict = {} process_cache: dict = {},
): ):
if len(self.doc_fund_name_list) == 0 and len(self.provider_fund_name_list) == 0: if len(self.doc_fund_name_list) == 0 and len(self.provider_fund_name_list) == 0:
data_info["id"] = "" data_info["id"] = ""
@ -417,8 +515,9 @@ class DataMapping:
doc_compare_mapping = self.doc_fund_class_mapping[ doc_compare_mapping = self.doc_fund_class_mapping[
self.doc_fund_class_mapping["FundId"] == parent_id self.doc_fund_class_mapping["FundId"] == parent_id
] ]
provider_compare_mapping = self.provider_fund_class_mapping\ provider_compare_mapping = self.provider_fund_class_mapping[
[self.provider_fund_class_mapping["FundId"] == parent_id] self.provider_fund_class_mapping["FundId"] == parent_id
]
if len(doc_compare_mapping) == 0: if len(doc_compare_mapping) == 0:
if len(provider_compare_mapping) == 0: if len(provider_compare_mapping) == 0:
doc_compare_name_list = self.doc_share_name_list doc_compare_name_list = self.doc_share_name_list
@ -436,8 +535,9 @@ class DataMapping:
doc_compare_mapping["ShareClassName"].unique().tolist() doc_compare_mapping["ShareClassName"].unique().tolist()
) )
if len(provider_compare_mapping) == 0 or \ if len(provider_compare_mapping) == 0 or len(
len(provider_compare_mapping) < len(doc_compare_mapping): provider_compare_mapping
) < len(doc_compare_mapping):
provider_compare_name_list = doc_compare_name_list provider_compare_name_list = doc_compare_name_list
provider_compare_mapping = doc_compare_mapping provider_compare_mapping = doc_compare_mapping
else: else:
@ -464,11 +564,15 @@ class DataMapping:
share_name=raw_share_name, share_name=raw_share_name,
fund_name=raw_fund_name, fund_name=raw_fund_name,
matching_type=matching_type, matching_type=matching_type,
process_cache=process_cache) process_cache=process_cache,
)
if matching_type == "fund": if matching_type == "fund":
threshold = 0.7 threshold = 0.7
else: else:
if self.compare_with_provider:
threshold = 0.9 threshold = 0.9
else:
threshold = 0.6
if max_similarity is not None and max_similarity >= threshold: if max_similarity is not None and max_similarity >= threshold:
data_info["id"] = doc_compare_mapping[ data_info["id"] = doc_compare_mapping[
doc_compare_mapping[compare_name_dp] == max_similarity_name doc_compare_mapping[compare_name_dp] == max_similarity_name
@ -479,6 +583,7 @@ class DataMapping:
if data_info.get("id", None) is None or len(data_info.get("id", "")) == 0: if data_info.get("id", None) is None or len(data_info.get("id", "")) == 0:
# set pre_common_word_list, reason: the document mapping for same fund maybe different with provider mapping # set pre_common_word_list, reason: the document mapping for same fund maybe different with provider mapping
# the purpose is to get the most common word list, to improve the similarity. # the purpose is to get the most common word list, to improve the similarity.
if self.compare_with_provider:
max_similarity_name, max_similarity = get_most_similar_name( max_similarity_name, max_similarity = get_most_similar_name(
raw_name, raw_name,
provider_compare_name_list, provider_compare_name_list,
@ -486,7 +591,7 @@ class DataMapping:
fund_name=raw_fund_name, fund_name=raw_fund_name,
matching_type=matching_type, matching_type=matching_type,
pre_common_word_list=pre_common_word_list, pre_common_word_list=pre_common_word_list,
process_cache=process_cache process_cache=process_cache,
) )
threshold = 0.7 threshold = 0.7
if matching_type == "share": if matching_type == "share":
@ -503,7 +608,8 @@ class DataMapping:
else: else:
if len(doc_compare_name_list) == 1: if len(doc_compare_name_list) == 1:
data_info["id"] = doc_compare_mapping[ data_info["id"] = doc_compare_mapping[
doc_compare_mapping[compare_name_dp] == doc_compare_name_list[0] doc_compare_mapping[compare_name_dp]
== doc_compare_name_list[0]
][compare_id_dp].values[0] ][compare_id_dp].values[0]
data_info["legal_name"] = doc_compare_name_list[0] data_info["legal_name"] = doc_compare_name_list[0]
data_info["similarity"] = 1 data_info["similarity"] = 1
@ -511,6 +617,10 @@ class DataMapping:
data_info["id"] = "" data_info["id"] = ""
data_info["legal_name"] = "" data_info["legal_name"] = ""
data_info["similarity"] = 0 data_info["similarity"] = 0
else:
data_info["id"] = ""
data_info["legal_name"] = ""
data_info["similarity"] = 0
data_info["investment_type"] = investment_type data_info["investment_type"] = investment_type
else: else:
data_info["id"] = "" data_info["id"] = ""

133
main.py
View File

@ -31,11 +31,14 @@ class EMEA_AR_Parsing:
output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/", output_mapping_data_folder: str = r"/data/emea_ar/output/mapping_data/docs/",
extract_way: str = "text", extract_way: str = "text",
drilldown_folder: str = r"/data/emea_ar/output/drilldown/", drilldown_folder: str = r"/data/emea_ar/output/drilldown/",
compare_with_provider: bool = True
) -> None: ) -> None:
self.doc_id = doc_id self.doc_id = doc_id
self.doc_source = doc_source self.doc_source = doc_source
self.pdf_folder = pdf_folder self.pdf_folder = pdf_folder
os.makedirs(self.pdf_folder, exist_ok=True) os.makedirs(self.pdf_folder, exist_ok=True)
self.compare_with_provider = compare_with_provider
self.pdf_file = self.download_pdf() self.pdf_file = self.download_pdf()
self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False) self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False)
@ -76,7 +79,7 @@ class EMEA_AR_Parsing:
self.pdf_file, self.pdf_file,
self.document_mapping_info_df, self.document_mapping_info_df,
self.doc_source, self.doc_source,
output_pdf_text_folder output_pdf_text_folder,
) )
self.page_text_dict = self.filter_pages.page_text_dict self.page_text_dict = self.filter_pages.page_text_dict
@ -87,7 +90,9 @@ class EMEA_AR_Parsing:
drilldown_folder = r"/data/emea_ar/output/drilldown/" drilldown_folder = r"/data/emea_ar/output/drilldown/"
os.makedirs(drilldown_folder, exist_ok=True) os.makedirs(drilldown_folder, exist_ok=True)
self.drilldown_folder = drilldown_folder self.drilldown_folder = drilldown_folder
misc_config_file = os.path.join(f"./configuration/{doc_source}/", "misc_config.json") misc_config_file = os.path.join(
f"./configuration/{doc_source}/", "misc_config.json"
)
if os.path.exists(misc_config_file): if os.path.exists(misc_config_file):
with open(misc_config_file, "r", encoding="utf-8") as f: with open(misc_config_file, "r", encoding="utf-8") as f:
misc_config = json.load(f) misc_config = json.load(f)
@ -278,7 +283,8 @@ class EMEA_AR_Parsing:
data_from_gpt, data_from_gpt,
self.document_mapping_info_df, self.document_mapping_info_df,
self.output_mapping_data_folder, self.output_mapping_data_folder,
self.doc_source self.doc_source,
compare_with_provider=self.compare_with_provider
) )
return data_mapping.mapping_raw_data_entrance() return data_mapping.mapping_raw_data_entrance()
@ -334,6 +340,7 @@ def mapping_data(
output_mapping_data_folder=output_mapping_folder, output_mapping_data_folder=output_mapping_folder,
extract_way=extract_way, extract_way=extract_way,
drilldown_folder=drilldown_folder, drilldown_folder=drilldown_folder,
compare_with_provider=False
) )
doc_data_from_gpt, annotation_list = emea_ar_parsing.extract_data( doc_data_from_gpt, annotation_list = emea_ar_parsing.extract_data(
re_run=re_run_extract_data re_run=re_run_extract_data
@ -502,18 +509,28 @@ def batch_start_job(
writer, index=False, sheet_name="extract_data" writer, index=False, sheet_name="extract_data"
) )
if document_mapping_file is not None and len(document_mapping_file) > 0 and os.path.exists(document_mapping_file): if (
document_mapping_file is not None
and len(document_mapping_file) > 0
and os.path.exists(document_mapping_file)
):
try: try:
merged_total_data_folder = os.path.join(output_mapping_total_folder, "merged/") merged_total_data_folder = os.path.join(
output_mapping_total_folder, "merged/"
)
os.makedirs(merged_total_data_folder, exist_ok=True) os.makedirs(merged_total_data_folder, exist_ok=True)
data_file_base_name = os.path.basename(output_file) data_file_base_name = os.path.basename(output_file)
output_merged_data_file_path = os.path.join(merged_total_data_folder, "merged_" + data_file_base_name) output_merged_data_file_path = os.path.join(
merge_output_data_aus_prospectus(output_file, document_mapping_file, output_merged_data_file_path) merged_total_data_folder, "merged_" + data_file_base_name
)
merge_output_data_aus_prospectus(
output_file, document_mapping_file, output_merged_data_file_path
)
except Exception as e: except Exception as e:
logger.error(f"Error: {e}") logger.error(f"Error: {e}")
if calculate_metrics: if calculate_metrics:
prediction_sheet_name = "total_mapping_data" prediction_sheet_name = "data_in_doc_mapping"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
ground_truth_sheet_name = "mapping_data" ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/" metrics_output_folder = r"/data/emea_ar/output/metrics/"
@ -770,11 +787,11 @@ def test_auto_generate_instructions():
def test_data_extraction_metrics(): def test_data_extraction_metrics():
data_type = "data_extraction" data_type = "document_mapping_in_db"
# prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_by_image_20240920033929.xlsx" # prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_by_image_20240920033929.xlsx"
prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_by_text_20240922152517.xlsx" prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_51_documents_by_text_20250127104008.xlsx"
# prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/481475385.xlsx" # prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/481475385.xlsx"
prediction_sheet_name = "mapping_data" prediction_sheet_name = "data_in_doc_mapping"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
ground_truth_sheet_name = "mapping_data" ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/" metrics_output_folder = r"/data/emea_ar/output/metrics/"
@ -1017,7 +1034,7 @@ def batch_run_documents(
) )
re_run_extract_data = False re_run_extract_data = False
re_run_mapping_data = True re_run_mapping_data = True
force_save_total_data = True force_save_total_data = False
calculate_metrics = False calculate_metrics = False
extract_way = "text" extract_way = "text"
@ -1194,13 +1211,17 @@ def merge_output_data_aus_prospectus(
): ):
# TODO: merge output data for aus prospectus, plan to realize it on 2025-01-16 # TODO: merge output data for aus prospectus, plan to realize it on 2025-01-16
data_df = pd.read_excel(data_file_path, sheet_name="total_mapping_data") data_df = pd.read_excel(data_file_path, sheet_name="total_mapping_data")
document_mapping_df = pd.read_excel(document_mapping_file, sheet_name="document_mapping") document_mapping_df = pd.read_excel(
document_mapping_file, sheet_name="document_mapping"
)
# set doc_id to be string type # set doc_id to be string type
data_df["doc_id"] = data_df["doc_id"].astype(str) data_df["doc_id"] = data_df["doc_id"].astype(str)
document_mapping_df["DocumentId"] = document_mapping_df["DocumentId"].astype(str) document_mapping_df["DocumentId"] = document_mapping_df["DocumentId"].astype(str)
doc_id_list = data_df["doc_id"].unique().tolist() doc_id_list = data_df["doc_id"].unique().tolist()
datapoint_keyword_config_file = r"./configuration/aus_prospectus/datapoint_name.json" datapoint_keyword_config_file = (
r"./configuration/aus_prospectus/datapoint_name.json"
)
with open(datapoint_keyword_config_file, "r", encoding="utf-8") as f: with open(datapoint_keyword_config_file, "r", encoding="utf-8") as f:
datapoint_keyword_config = json.load(f) datapoint_keyword_config = json.load(f)
datapoint_name_list = list(datapoint_keyword_config.keys()) datapoint_name_list = list(datapoint_keyword_config.keys())
@ -1212,7 +1233,9 @@ def merge_output_data_aus_prospectus(
"EffectiveDate" "EffectiveDate"
].values[0] ].values[0]
)[0:10] )[0:10]
share_doc_data_df = data_df[(data_df["doc_id"] == doc_id) & (data_df["investment_type"] == 1)] share_doc_data_df = data_df[
(data_df["doc_id"] == doc_id) & (data_df["investment_type"] == 1)
]
exist_raw_name_list = [] exist_raw_name_list = []
for index, row in share_doc_data_df.iterrows(): for index, row in share_doc_data_df.iterrows():
doc_id = str(row["doc_id"]) doc_id = str(row["doc_id"])
@ -1228,7 +1251,9 @@ def merge_output_data_aus_prospectus(
fund_id = "" fund_id = ""
fund_legal_name = "" fund_legal_name = ""
if share_class_id != "": if share_class_id != "":
record_row = document_mapping_df[document_mapping_df["FundClassId"] == share_class_id] record_row = document_mapping_df[
document_mapping_df["FundClassId"] == share_class_id
]
if len(record_row) > 0: if len(record_row) > 0:
fund_id = record_row["FundId"].values[0] fund_id = record_row["FundId"].values[0]
fund_legal_name = record_row["FundLegalName"].values[0] fund_legal_name = record_row["FundLegalName"].values[0]
@ -1265,16 +1290,16 @@ def merge_output_data_aus_prospectus(
doc_data_list.append(data) doc_data_list.append(data)
# find data from total_data_list by raw_name # find data from total_data_list by raw_name
for data in doc_data_list: for data in doc_data_list:
if ( if data["raw_name"] == raw_name:
data["raw_name"] == raw_name
):
update_key = datapoint update_key = datapoint
data[update_key] = value data[update_key] = value
if page_index not in data["page_index"]: if page_index not in data["page_index"]:
data["page_index"].append(page_index) data["page_index"].append(page_index)
break break
fund_doc_data_df = data_df[(data_df["doc_id"] == doc_id) & (data_df["investment_type"] == 33)] fund_doc_data_df = data_df[
(data_df["doc_id"] == doc_id) & (data_df["investment_type"] == 33)
]
for index, row in fund_doc_data_df.iterrows(): for index, row in fund_doc_data_df.iterrows():
doc_id = str(row["doc_id"]) doc_id = str(row["doc_id"])
page_index = int(row["page_index"]) page_index = int(row["page_index"])
@ -1289,8 +1314,9 @@ def merge_output_data_aus_prospectus(
exist = False exist = False
if fund_id != "": if fund_id != "":
for data in doc_data_list: for data in doc_data_list:
if (fund_id != "" and data["fund_id"] == fund_id) or \ if (fund_id != "" and data["fund_id"] == fund_id) or (
(data["raw_fund_name"] == raw_fund_name): data["raw_fund_name"] == raw_fund_name
):
update_key = datapoint update_key = datapoint
data[update_key] = value data[update_key] = value
if page_index not in data["page_index"]: if page_index not in data["page_index"]:
@ -1323,6 +1349,7 @@ def merge_output_data_aus_prospectus(
if __name__ == "__main__": if __name__ == "__main__":
# test_data_extraction_metrics()
# data_file_path = r"/data/aus_prospectus/output/mapping_data/total/mapping_data_info_11_documents_by_text_20250116220811.xlsx" # data_file_path = r"/data/aus_prospectus/output/mapping_data/total/mapping_data_info_11_documents_by_text_20250116220811.xlsx"
# document_mapping_file_path = r"/data/aus_prospectus/basic_information/11_documents/document_mapping.xlsx" # document_mapping_file_path = r"/data/aus_prospectus/basic_information/11_documents/document_mapping.xlsx"
# merged_total_data_folder = r'/data/aus_prospectus/output/mapping_data/total/merged/' # merged_total_data_folder = r'/data/aus_prospectus/output/mapping_data/total/merged/'
@ -1348,9 +1375,11 @@ if __name__ == "__main__":
# special_doc_id_list = ["553242411"] # special_doc_id_list = ["553242411"]
doc_source = "aus_prospectus" doc_source = "emea_ar"
if doc_source == "aus_prospectus": if doc_source == "aus_prospectus":
document_sample_file = r"./sample_documents/aus_prospectus_100_documents_multi_fund_sample.txt" document_sample_file = (
r"./sample_documents/aus_prospectus_100_documents_multi_fund_sample.txt"
)
with open(document_sample_file, "r", encoding="utf-8") as f: with open(document_sample_file, "r", encoding="utf-8") as f:
special_doc_id_list = [doc_id.strip() for doc_id in f.readlines()] special_doc_id_list = [doc_id.strip() for doc_id in f.readlines()]
document_mapping_file = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx" document_mapping_file = r"/data/aus_prospectus/basic_information/from_2024_documents/aus_100_document_prospectus_multi_fund.xlsx"
@ -1397,7 +1426,61 @@ if __name__ == "__main__":
drilldown_folder=drilldown_folder, drilldown_folder=drilldown_folder,
) )
elif doc_source == "emea_ar": elif doc_source == "emea_ar":
special_doc_id_list = ["553242408"] special_doc_id_list = [
"292989214",
"316237292",
"321733631",
"323390570",
"327956364",
"333207452",
"334718372",
"344636875",
"362246081",
"366179419",
"380945052",
"382366116",
"387202452",
"389171486",
"391456740",
"391736837",
"394778487",
"401684600",
"402113224",
"402181770",
"402397014",
"405803396",
"445102363",
"445256897",
"448265376",
"449555622",
"449623976",
"458291624",
"458359181",
"463081566",
"469138353",
"471641628",
"476492237",
"478585901",
"478586066",
"479042264",
"479793787",
"481475385",
"483617247",
"486378555",
"486383912",
"492121213",
"497497599",
"502693599",
"502821436",
"503194284",
"506559375",
"507967525",
"508854243",
"509845549",
"520879048",
"529925114",
]
special_doc_id_list = ["471641628"]
batch_run_documents( batch_run_documents(
doc_source=doc_source, special_doc_id_list=special_doc_id_list doc_source=doc_source, special_doc_id_list=special_doc_id_list
) )