optimize mapping logic

This commit is contained in:
Blade He 2024-09-27 16:39:56 -05:00
parent 39cd53dc33
commit 3aa596ea33
4 changed files with 150 additions and 41 deletions

View File

@ -441,7 +441,7 @@ class DataExtraction:
fund_name_line = "" fund_name_line = ""
half_line = rest_context[half_len].strip() half_line = rest_context[half_len].strip()
max_similarity_fund_name, max_similarity = get_most_similar_name( max_similarity_fund_name, max_similarity = get_most_similar_name(
half_line, self.provider_fund_name_list half_line, self.provider_fund_name_list, matching_type="fund"
) )
if max_similarity < 0.2: if max_similarity < 0.2:
# get the fund name line text from the first half # get the fund name line text from the first half
@ -457,7 +457,7 @@ class DataExtraction:
continue continue
max_similarity_fund_name, max_similarity = get_most_similar_name( max_similarity_fund_name, max_similarity = get_most_similar_name(
line_text, self.provider_fund_name_list line_text, self.provider_fund_name_list, matching_type="fund"
) )
if max_similarity >= 0.2: if max_similarity >= 0.2:
fund_name_line = line_text fund_name_line = line_text

View File

@ -108,6 +108,7 @@ class DataMapping:
mapped_data_list = [] mapped_data_list = []
mapped_fund_cache = {} mapped_fund_cache = {}
mapped_share_cache = {} mapped_share_cache = {}
process_cache = {}
for page_data in self.raw_document_data_list: for page_data in self.raw_document_data_list:
doc_id = page_data.get("doc_id", "") doc_id = page_data.get("doc_id", "")
page_index = page_data.get("page_index", "") page_index = page_data.get("page_index", "")
@ -166,12 +167,16 @@ class DataMapping:
fund_id = fund_info["id"] fund_id = fund_info["id"]
else: else:
fund_info = self.matching_with_database( fund_info = self.matching_with_database(
raw_fund_name, "fund" raw_name=raw_fund_name, matching_type="fund"
) )
fund_id = fund_info["id"] fund_id = fund_info["id"]
mapped_fund_cache[raw_fund_name] = fund_info mapped_fund_cache[raw_fund_name] = fund_info
investment_info = self.matching_with_database( investment_info = self.matching_with_database(
raw_name, fund_id, "share" raw_name=raw_name,
raw_share_name=raw_share_name,
parent_id=fund_id,
matching_type="share",
process_cache=process_cache
) )
mapped_share_cache[raw_name] = investment_info mapped_share_cache[raw_name] = investment_info
elif raw_fund_name is not None and len(raw_fund_name) > 0: elif raw_fund_name is not None and len(raw_fund_name) > 0:
@ -180,7 +185,7 @@ class DataMapping:
investment_info = mapped_fund_cache[raw_fund_name] investment_info = mapped_fund_cache[raw_fund_name]
else: else:
investment_info = self.matching_with_database( investment_info = self.matching_with_database(
raw_name, "fund" raw_name=raw_fund_name, matching_type="fund"
) )
mapped_fund_cache[raw_fund_name] = investment_info mapped_fund_cache[raw_fund_name] = investment_info
else: else:
@ -246,7 +251,12 @@ class DataMapping:
return raw_name return raw_name
def matching_with_database( def matching_with_database(
self, raw_name: str, parent_id: str = None, matching_type: str = "fund" self,
raw_name: str,
raw_share_name: str = None,
parent_id: str = None,
matching_type: str = "fund",
process_cache: dict = {}
): ):
if len(self.doc_fund_name_list) == 0 and len(self.provider_fund_name_list) == 0: if len(self.doc_fund_name_list) == 0 and len(self.provider_fund_name_list) == 0:
data_info["id"] = "" data_info["id"] = ""
@ -298,7 +308,11 @@ class DataMapping:
if doc_compare_name_list is not None and len(doc_compare_name_list) > 0: if doc_compare_name_list is not None and len(doc_compare_name_list) > 0:
_, pre_common_word_list = remove_common_word(doc_compare_name_list) _, pre_common_word_list = remove_common_word(doc_compare_name_list)
max_similarity_name, max_similarity = get_most_similar_name( max_similarity_name, max_similarity = get_most_similar_name(
raw_name, doc_compare_name_list) raw_name,
doc_compare_name_list,
share_name=raw_share_name,
matching_type=matching_type,
process_cache=process_cache)
if max_similarity is not None and max_similarity >= 0.9: if max_similarity is not None and max_similarity >= 0.9:
data_info["id"] = doc_compare_mapping[ data_info["id"] = doc_compare_mapping[
doc_compare_mapping[compare_name_dp] == max_similarity_name doc_compare_mapping[compare_name_dp] == max_similarity_name
@ -310,12 +324,20 @@ class DataMapping:
# set pre_common_word_list, reason: the document mapping for same fund maybe different with provider mapping # set pre_common_word_list, reason: the document mapping for same fund maybe different with provider mapping
# the purpose is to get the most common word list, to improve the similarity. # the purpose is to get the most common word list, to improve the similarity.
max_similarity_name, max_similarity = get_most_similar_name( max_similarity_name, max_similarity = get_most_similar_name(
raw_name, provider_compare_name_list, pre_common_word_list=pre_common_word_list raw_name,
provider_compare_name_list,
share_name=raw_share_name,
matching_type=matching_type,
pre_common_word_list=pre_common_word_list,
process_cache=process_cache
) )
threshold = 0.7 threshold = 0.7
if matching_type == "share": if matching_type == "share":
threshold = 0.5 threshold = 0.5
if max_similarity is not None and max_similarity >= threshold: round_similarity = 0
if max_similarity is not None and isinstance(max_similarity, float):
round_similarity = round(max_similarity, 1)
if round_similarity is not None and round_similarity >= threshold:
data_info["id"] = provider_compare_mapping[ data_info["id"] = provider_compare_mapping[
provider_compare_mapping[compare_name_dp] == max_similarity_name provider_compare_mapping[compare_name_dp] == max_similarity_name
][compare_id_dp].values[0] ][compare_id_dp].values[0]

37
main.py
View File

@ -335,15 +335,15 @@ def batch_start_job(
ground_truth_sheet_name = "mapping_data" ground_truth_sheet_name = "mapping_data"
metrics_output_folder = r"/data/emea_ar/output/metrics/" metrics_output_folder = r"/data/emea_ar/output/metrics/"
logger.info(f"Calculating metrics for data extraction") # logger.info(f"Calculating metrics for data extraction")
missing_error_list, metrics_list, metrics_file = get_metrics( # missing_error_list, metrics_list, metrics_file = get_metrics(
"data_extraction", # "data_extraction",
output_file, # output_file,
prediction_sheet_name, # prediction_sheet_name,
ground_truth_file, # ground_truth_file,
ground_truth_sheet_name, # ground_truth_sheet_name,
metrics_output_folder, # metrics_output_folder,
) # )
# logger.info(f"Calculating metrics for investment mapping by actual document mapping") # logger.info(f"Calculating metrics for investment mapping by actual document mapping")
# missing_error_list, metrics_list, metrics_file = get_metrics( # missing_error_list, metrics_list, metrics_file = get_metrics(
@ -446,7 +446,7 @@ def get_metrics(
ground_truth_sheet_name=ground_truth_sheet_name, ground_truth_sheet_name=ground_truth_sheet_name,
output_folder=output_folder, output_folder=output_folder,
) )
missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=True) missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=False)
return missing_error_list, metrics_list, metrics_file return missing_error_list, metrics_list, metrics_file
@ -574,8 +574,8 @@ def test_data_extraction_metrics():
def test_mapping_raw_name(): def test_mapping_raw_name():
doc_id = "481475385" doc_id = "382366116"
raw_name = "Emerging Markets Fund Y-DIST Shares (USD)" raw_name = "SPARINVEST SICAV - ETHICAL EMERGING MARKETS VALUE EUR I"
output_folder = r"/data/emea_ar/output/mapping_data/docs/by_text/" output_folder = r"/data/emea_ar/output/mapping_data/docs/by_text/"
data_mapping = DataMapping( data_mapping = DataMapping(
doc_id, doc_id,
@ -584,10 +584,13 @@ def test_mapping_raw_name():
document_mapping_info_df=None, document_mapping_info_df=None,
output_data_folder=output_folder, output_data_folder=output_folder,
) )
process_cache = {}
mapping_info = data_mapping.matching_with_database( mapping_info = data_mapping.matching_with_database(
raw_name=raw_name, raw_name=raw_name,
raw_share_name=None,
parent_id=None, parent_id=None,
matching_type="share" matching_type="share",
process_cache=process_cache
) )
print(mapping_info) print(mapping_info)
@ -677,7 +680,7 @@ if __name__ == "__main__":
"333207452", "333207452",
"334718372", "334718372",
"344636875", "344636875",
"349679479", # "349679479",
"362246081", "362246081",
"366179419", "366179419",
"380945052", "380945052",
@ -693,12 +696,12 @@ if __name__ == "__main__":
] ]
# special_doc_id_list = check_mapping_doc_id_list # special_doc_id_list = check_mapping_doc_id_list
special_doc_id_list = check_db_mapping_doc_id_list special_doc_id_list = check_db_mapping_doc_id_list
# special_doc_id_list = ["382366116"] special_doc_id_list = ["402397014"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_extract_data = False re_run_extract_data = False
re_run_mapping_data = False re_run_mapping_data = True
force_save_total_data = True force_save_total_data = False
extract_ways = ["text"] extract_ways = ["text"]
for extract_way in extract_ways: for extract_way in extract_ways:

View File

@ -1,4 +1,5 @@
import re import re
from utils.logger import logger
from copy import deepcopy from copy import deepcopy
from traceback import print_exc from traceback import print_exc
@ -48,7 +49,9 @@ total_currency_list = [
"XFO", "XFO",
] ]
share_features = ['Accumulation', 'Income', 'Distribution', 'Investor', 'Institutional', 'Capitalisation', 'Admin', 'Advantage'] share_features_full_name = ['Accumulation', 'Income', 'Distribution', 'Dividend', 'Investor', 'Institutional', 'Admin', 'Advantage']
share_features_abbrevation = ['Acc', 'Inc', 'Dist', 'Div', 'Inv', 'Inst', 'Adm', 'Adv']
def add_slash_to_text_as_regex(text: str): def add_slash_to_text_as_regex(text: str):
if text is None or len(text) == 0: if text is None or len(text) == 0:
@ -72,7 +75,12 @@ def clean_text(text: str) -> str:
return text return text
def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list = None) -> str: def get_most_similar_name(text: str,
name_list: list,
share_name: str = None,
matching_type="share",
pre_common_word_list: list = None,
process_cache: dict = None) -> str:
""" """
Get the most similar fund name from fund_name_list by jacard similarity Get the most similar fund name from fund_name_list by jacard similarity
""" """
@ -134,9 +142,33 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list
max_similarity = 0 max_similarity = 0
max_similarity_full_name = None max_similarity_full_name = None
text = remove_special_characters(text) text = remove_special_characters(text)
if matching_type == "share":
text, copy_name_list = update_for_currency(text, copy_name_list) text, copy_name_list = update_for_currency(text, copy_name_list)
text_currencty = get_currency_from_text(text) text_currency = None
text_feature = None
text_share_short_name = None
if matching_type == "share" and text is not None and len(text.strip()) > 0:
if process_cache is not None and isinstance(process_cache, dict):
if process_cache.get(text, None) is not None:
cache = process_cache.get(text)
text_share_short_name = cache.get("share_short_name")
text_feature = cache.get("share_feature")
text_currency = cache.get("share_currency")
else:
text_share_short_name = get_share_short_name_from_text(text)
text_feature = get_share_feature_from_text(text) text_feature = get_share_feature_from_text(text)
text_currency = get_currency_from_text(text)
process_cache[text] = {
"share_short_name": text_share_short_name,
"share_feature": text_feature,
"share_currency": text_currency
}
else:
text_share_short_name = get_share_short_name_from_text(share_name)
text_feature = get_share_feature_from_text(share_name)
text_currency = get_currency_from_text(share_name)
# logger.info(f"Source text: {text}, candidate names count: {len(copy_name_list)}")
for full_name, copy_name in zip(name_list , copy_name_list): for full_name, copy_name in zip(name_list , copy_name_list):
copy_name = remove_special_characters(copy_name) copy_name = remove_special_characters(copy_name)
copy_name = split_words_without_space(copy_name) copy_name = split_words_without_space(copy_name)
@ -151,14 +183,40 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list
if similarity_2 > similarity: if similarity_2 > similarity:
similarity = similarity_2 similarity = similarity_2
if similarity > max_similarity: if similarity > max_similarity:
copy_name_currency = get_currency_from_text(copy_name) if matching_type == "share":
if text_currencty is not None and copy_name_currency is not None: if process_cache is not None and isinstance(process_cache, dict):
if text_currencty != copy_name_currency: if process_cache.get(copy_name, None) is not None:
continue cache = process_cache.get(copy_name)
copy_name_short_name = cache.get("share_short_name")
copy_name_feature = cache.get("share_feature")
copy_name_currency = cache.get("share_currency")
else:
copy_name_short_name = get_share_short_name_from_text(copy_name)
copy_name_feature = get_share_feature_from_text(copy_name) copy_name_feature = get_share_feature_from_text(copy_name)
if text_feature is not None and copy_name_feature is not None: copy_name_currency = get_currency_from_text(copy_name)
process_cache[copy_name] = {
"share_short_name": copy_name_short_name,
"share_feature": copy_name_feature,
"share_currency": copy_name_currency
}
else:
copy_name_short_name = get_share_short_name_from_text(copy_name)
copy_name_feature = get_share_feature_from_text(copy_name)
copy_name_currency = get_currency_from_text(copy_name)
if text_currency is not None and len(text_currency) > 0 and \
copy_name_currency is not None and len(copy_name_currency) > 0:
if text_currency != copy_name_currency:
continue
if text_feature is not None and len(text_feature) > 0 and \
copy_name_feature is not None and len(copy_name_feature) > 0:
if text_feature != copy_name_feature: if text_feature != copy_name_feature:
continue continue
if matching_type == "share":
if text_share_short_name is not None and len(text_share_short_name) > 0 and \
copy_name_short_name is not None and len(copy_name_short_name) > 0:
if text_share_short_name != copy_name_short_name:
continue
max_similarity = similarity max_similarity = similarity
max_similarity_full_name = full_name max_similarity_full_name = full_name
if max_similarity == 1: if max_similarity == 1:
@ -171,16 +229,38 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list
print_exc() print_exc()
return None, 0.0 return None, 0.0
def get_share_short_name_from_text(text: str):
if text is None or len(text.strip()) == 0:
return None
text = text.strip()
text_split = text.split()
temp_share_features = [feature.lower() for feature in share_features_full_name]
count = 0
for split in text_split[::-1]:
if count == 4:
break
if split.lower() not in temp_share_features and \
split not in total_currency_list:
if len(split) <= 3 and split.upper() == split:
return split.upper()
count += 1
return None
def get_share_feature_from_text(text: str): def get_share_feature_from_text(text: str):
if text is None or len(text.strip()) == 0: if text is None or len(text.strip()) == 0:
return None return None
text = text.strip() text = text.strip()
text = text.lower() text = text.lower()
text_split = text.split() text_split = text.split()
temp_share_features = [feature.lower() for feature in share_features] temp_share_features = [feature.lower() for feature in share_features_full_name]
count = 0
for split in text_split[::-1]: for split in text_split[::-1]:
if split in temp_share_features: if count == 4:
break
if split.lower() in temp_share_features:
return split return split
count += 1
return None return None
def get_currency_from_text(text: str): def get_currency_from_text(text: str):
@ -189,9 +269,13 @@ def get_currency_from_text(text: str):
text = text.strip() text = text.strip()
text = text.lower() text = text.lower()
text_split = text.split() text_split = text.split()
count = 0
for split in text_split[::-1]: for split in text_split[::-1]:
if count == 4:
break
if split.upper() in total_currency_list: if split.upper() in total_currency_list:
return split return split
count += 1
return None return None