From 3aa596ea330d509c46cc373429ac01a8f15223b7 Mon Sep 17 00:00:00 2001 From: Blade He Date: Fri, 27 Sep 2024 16:39:56 -0500 Subject: [PATCH] optimize mapping logic --- core/data_extraction.py | 4 +- core/data_mapping.py | 36 ++++++++++--- main.py | 37 +++++++------ utils/biz_utils.py | 114 ++++++++++++++++++++++++++++++++++------ 4 files changed, 150 insertions(+), 41 deletions(-) diff --git a/core/data_extraction.py b/core/data_extraction.py index 2679d20..d9990bf 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -441,7 +441,7 @@ class DataExtraction: fund_name_line = "" half_line = rest_context[half_len].strip() max_similarity_fund_name, max_similarity = get_most_similar_name( - half_line, self.provider_fund_name_list + half_line, self.provider_fund_name_list, matching_type="fund" ) if max_similarity < 0.2: # get the fund name line text from the first half @@ -457,7 +457,7 @@ class DataExtraction: continue max_similarity_fund_name, max_similarity = get_most_similar_name( - line_text, self.provider_fund_name_list + line_text, self.provider_fund_name_list, matching_type="fund" ) if max_similarity >= 0.2: fund_name_line = line_text diff --git a/core/data_mapping.py b/core/data_mapping.py index eaa5cdc..c385158 100644 --- a/core/data_mapping.py +++ b/core/data_mapping.py @@ -108,6 +108,7 @@ class DataMapping: mapped_data_list = [] mapped_fund_cache = {} mapped_share_cache = {} + process_cache = {} for page_data in self.raw_document_data_list: doc_id = page_data.get("doc_id", "") page_index = page_data.get("page_index", "") @@ -166,12 +167,16 @@ class DataMapping: fund_id = fund_info["id"] else: fund_info = self.matching_with_database( - raw_fund_name, "fund" + raw_name=raw_fund_name, matching_type="fund" ) fund_id = fund_info["id"] mapped_fund_cache[raw_fund_name] = fund_info investment_info = self.matching_with_database( - raw_name, fund_id, "share" + raw_name=raw_name, + raw_share_name=raw_share_name, + parent_id=fund_id, + matching_type="share", + process_cache=process_cache ) mapped_share_cache[raw_name] = investment_info elif raw_fund_name is not None and len(raw_fund_name) > 0: @@ -180,7 +185,7 @@ class DataMapping: investment_info = mapped_fund_cache[raw_fund_name] else: investment_info = self.matching_with_database( - raw_name, "fund" + raw_name=raw_fund_name, matching_type="fund" ) mapped_fund_cache[raw_fund_name] = investment_info else: @@ -246,7 +251,12 @@ class DataMapping: return raw_name def matching_with_database( - self, raw_name: str, parent_id: str = None, matching_type: str = "fund" + self, + raw_name: str, + raw_share_name: str = None, + parent_id: str = None, + matching_type: str = "fund", + process_cache: dict = {} ): if len(self.doc_fund_name_list) == 0 and len(self.provider_fund_name_list) == 0: data_info["id"] = "" @@ -298,7 +308,11 @@ class DataMapping: if doc_compare_name_list is not None and len(doc_compare_name_list) > 0: _, pre_common_word_list = remove_common_word(doc_compare_name_list) max_similarity_name, max_similarity = get_most_similar_name( - raw_name, doc_compare_name_list) + raw_name, + doc_compare_name_list, + share_name=raw_share_name, + matching_type=matching_type, + process_cache=process_cache) if max_similarity is not None and max_similarity >= 0.9: data_info["id"] = doc_compare_mapping[ doc_compare_mapping[compare_name_dp] == max_similarity_name @@ -310,12 +324,20 @@ class DataMapping: # set pre_common_word_list, reason: the document mapping for same fund maybe different with provider mapping # the purpose is to get the most common word list, to improve the similarity. max_similarity_name, max_similarity = get_most_similar_name( - raw_name, provider_compare_name_list, pre_common_word_list=pre_common_word_list + raw_name, + provider_compare_name_list, + share_name=raw_share_name, + matching_type=matching_type, + pre_common_word_list=pre_common_word_list, + process_cache=process_cache ) threshold = 0.7 if matching_type == "share": threshold = 0.5 - if max_similarity is not None and max_similarity >= threshold: + round_similarity = 0 + if max_similarity is not None and isinstance(max_similarity, float): + round_similarity = round(max_similarity, 1) + if round_similarity is not None and round_similarity >= threshold: data_info["id"] = provider_compare_mapping[ provider_compare_mapping[compare_name_dp] == max_similarity_name ][compare_id_dp].values[0] diff --git a/main.py b/main.py index 7e0caaf..4a52cfe 100644 --- a/main.py +++ b/main.py @@ -335,15 +335,15 @@ def batch_start_job( ground_truth_sheet_name = "mapping_data" metrics_output_folder = r"/data/emea_ar/output/metrics/" - logger.info(f"Calculating metrics for data extraction") - missing_error_list, metrics_list, metrics_file = get_metrics( - "data_extraction", - output_file, - prediction_sheet_name, - ground_truth_file, - ground_truth_sheet_name, - metrics_output_folder, - ) + # logger.info(f"Calculating metrics for data extraction") + # missing_error_list, metrics_list, metrics_file = get_metrics( + # "data_extraction", + # output_file, + # prediction_sheet_name, + # ground_truth_file, + # ground_truth_sheet_name, + # metrics_output_folder, + # ) # logger.info(f"Calculating metrics for investment mapping by actual document mapping") # missing_error_list, metrics_list, metrics_file = get_metrics( @@ -446,7 +446,7 @@ def get_metrics( ground_truth_sheet_name=ground_truth_sheet_name, output_folder=output_folder, ) - missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=True) + missing_error_list, metrics_list, metrics_file = metrics.get_metrics(strict_model=False) return missing_error_list, metrics_list, metrics_file @@ -574,8 +574,8 @@ def test_data_extraction_metrics(): def test_mapping_raw_name(): - doc_id = "481475385" - raw_name = "Emerging Markets Fund Y-DIST Shares (USD)" + doc_id = "382366116" + raw_name = "SPARINVEST SICAV - ETHICAL EMERGING MARKETS VALUE EUR I" output_folder = r"/data/emea_ar/output/mapping_data/docs/by_text/" data_mapping = DataMapping( doc_id, @@ -584,10 +584,13 @@ def test_mapping_raw_name(): document_mapping_info_df=None, output_data_folder=output_folder, ) + process_cache = {} mapping_info = data_mapping.matching_with_database( raw_name=raw_name, + raw_share_name=None, parent_id=None, - matching_type="share" + matching_type="share", + process_cache=process_cache ) print(mapping_info) @@ -677,7 +680,7 @@ if __name__ == "__main__": "333207452", "334718372", "344636875", - "349679479", + # "349679479", "362246081", "366179419", "380945052", @@ -693,12 +696,12 @@ if __name__ == "__main__": ] # special_doc_id_list = check_mapping_doc_id_list special_doc_id_list = check_db_mapping_doc_id_list - # special_doc_id_list = ["382366116"] + special_doc_id_list = ["402397014"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_extract_data = False - re_run_mapping_data = False - force_save_total_data = True + re_run_mapping_data = True + force_save_total_data = False extract_ways = ["text"] for extract_way in extract_ways: diff --git a/utils/biz_utils.py b/utils/biz_utils.py index 789cbb2..a6e3487 100644 --- a/utils/biz_utils.py +++ b/utils/biz_utils.py @@ -1,4 +1,5 @@ import re +from utils.logger import logger from copy import deepcopy from traceback import print_exc @@ -48,7 +49,9 @@ total_currency_list = [ "XFO", ] -share_features = ['Accumulation', 'Income', 'Distribution', 'Investor', 'Institutional', 'Capitalisation', 'Admin', 'Advantage'] +share_features_full_name = ['Accumulation', 'Income', 'Distribution', 'Dividend', 'Investor', 'Institutional', 'Admin', 'Advantage'] +share_features_abbrevation = ['Acc', 'Inc', 'Dist', 'Div', 'Inv', 'Inst', 'Adm', 'Adv'] + def add_slash_to_text_as_regex(text: str): if text is None or len(text) == 0: @@ -72,7 +75,12 @@ def clean_text(text: str) -> str: return text -def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list = None) -> str: +def get_most_similar_name(text: str, + name_list: list, + share_name: str = None, + matching_type="share", + pre_common_word_list: list = None, + process_cache: dict = None) -> str: """ Get the most similar fund name from fund_name_list by jacard similarity """ @@ -134,9 +142,33 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list max_similarity = 0 max_similarity_full_name = None text = remove_special_characters(text) - text, copy_name_list = update_for_currency(text, copy_name_list) - text_currencty = get_currency_from_text(text) - text_feature = get_share_feature_from_text(text) + if matching_type == "share": + text, copy_name_list = update_for_currency(text, copy_name_list) + text_currency = None + text_feature = None + text_share_short_name = None + if matching_type == "share" and text is not None and len(text.strip()) > 0: + if process_cache is not None and isinstance(process_cache, dict): + if process_cache.get(text, None) is not None: + cache = process_cache.get(text) + text_share_short_name = cache.get("share_short_name") + text_feature = cache.get("share_feature") + text_currency = cache.get("share_currency") + else: + text_share_short_name = get_share_short_name_from_text(text) + text_feature = get_share_feature_from_text(text) + text_currency = get_currency_from_text(text) + process_cache[text] = { + "share_short_name": text_share_short_name, + "share_feature": text_feature, + "share_currency": text_currency + } + else: + text_share_short_name = get_share_short_name_from_text(share_name) + text_feature = get_share_feature_from_text(share_name) + text_currency = get_currency_from_text(share_name) + + # logger.info(f"Source text: {text}, candidate names count: {len(copy_name_list)}") for full_name, copy_name in zip(name_list , copy_name_list): copy_name = remove_special_characters(copy_name) copy_name = split_words_without_space(copy_name) @@ -151,14 +183,40 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list if similarity_2 > similarity: similarity = similarity_2 if similarity > max_similarity: - copy_name_currency = get_currency_from_text(copy_name) - if text_currencty is not None and copy_name_currency is not None: - if text_currencty != copy_name_currency: - continue - copy_name_feature = get_share_feature_from_text(copy_name) - if text_feature is not None and copy_name_feature is not None: - if text_feature != copy_name_feature: - continue + if matching_type == "share": + if process_cache is not None and isinstance(process_cache, dict): + if process_cache.get(copy_name, None) is not None: + cache = process_cache.get(copy_name) + copy_name_short_name = cache.get("share_short_name") + copy_name_feature = cache.get("share_feature") + copy_name_currency = cache.get("share_currency") + else: + copy_name_short_name = get_share_short_name_from_text(copy_name) + copy_name_feature = get_share_feature_from_text(copy_name) + copy_name_currency = get_currency_from_text(copy_name) + process_cache[copy_name] = { + "share_short_name": copy_name_short_name, + "share_feature": copy_name_feature, + "share_currency": copy_name_currency + } + else: + copy_name_short_name = get_share_short_name_from_text(copy_name) + copy_name_feature = get_share_feature_from_text(copy_name) + copy_name_currency = get_currency_from_text(copy_name) + + if text_currency is not None and len(text_currency) > 0 and \ + copy_name_currency is not None and len(copy_name_currency) > 0: + if text_currency != copy_name_currency: + continue + if text_feature is not None and len(text_feature) > 0 and \ + copy_name_feature is not None and len(copy_name_feature) > 0: + if text_feature != copy_name_feature: + continue + if matching_type == "share": + if text_share_short_name is not None and len(text_share_short_name) > 0 and \ + copy_name_short_name is not None and len(copy_name_short_name) > 0: + if text_share_short_name != copy_name_short_name: + continue max_similarity = similarity max_similarity_full_name = full_name if max_similarity == 1: @@ -171,16 +229,38 @@ def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list print_exc() return None, 0.0 +def get_share_short_name_from_text(text: str): + if text is None or len(text.strip()) == 0: + return None + text = text.strip() + text_split = text.split() + temp_share_features = [feature.lower() for feature in share_features_full_name] + + count = 0 + for split in text_split[::-1]: + if count == 4: + break + if split.lower() not in temp_share_features and \ + split not in total_currency_list: + if len(split) <= 3 and split.upper() == split: + return split.upper() + count += 1 + return None + def get_share_feature_from_text(text: str): if text is None or len(text.strip()) == 0: return None text = text.strip() text = text.lower() text_split = text.split() - temp_share_features = [feature.lower() for feature in share_features] + temp_share_features = [feature.lower() for feature in share_features_full_name] + count = 0 for split in text_split[::-1]: - if split in temp_share_features: + if count == 4: + break + if split.lower() in temp_share_features: return split + count += 1 return None def get_currency_from_text(text: str): @@ -189,9 +269,13 @@ def get_currency_from_text(text: str): text = text.strip() text = text.lower() text_split = text.split() + count = 0 for split in text_split[::-1]: + if count == 4: + break if split.upper() in total_currency_list: return split + count += 1 return None