From dd6701f18c3f38ba2da3293433e2dd1692a8a089 Mon Sep 17 00:00:00 2001 From: Blade He Date: Wed, 25 Sep 2024 15:15:38 -0500 Subject: [PATCH] 1. optimize investment mapping algorithm 2. realize investment mapping metrics --- core/data_mapping.py | 16 ++- core/metrics.py | 333 +++++++++++++++++++++++++++++++++---------- main.py | 161 ++++++++++++++------- utils/biz_utils.py | 17 ++- 4 files changed, 389 insertions(+), 138 deletions(-) diff --git a/core/data_mapping.py b/core/data_mapping.py index 4d0b55e..870bc64 100644 --- a/core/data_mapping.py +++ b/core/data_mapping.py @@ -1,7 +1,7 @@ import os import json import pandas as pd -from utils.biz_utils import get_most_similar_name +from utils.biz_utils import get_most_similar_name, remove_common_word from utils.sql_query_util import ( query_document_fund_mapping, query_investment_by_provider, @@ -270,8 +270,12 @@ class DataMapping: else: if parent_id is not None and len(parent_id) > 0: # filter self.doc_fund_class_mapping by parent_id as FundId - doc_compare_mapping = None - doc_compare_name_list = None + doc_compare_mapping = self.doc_fund_class_mapping[ + self.doc_fund_class_mapping["FundId"] == parent_id + ] + doc_compare_name_list = ( + doc_compare_mapping["ShareClassName"].unique().tolist() + ) provider_compare_mapping = self.provider_fund_class_mapping[ self.provider_fund_class_mapping["FundId"] == parent_id @@ -290,7 +294,9 @@ class DataMapping: data_info = {"name": raw_name} if len(provider_compare_name_list) > 0: + pre_common_word_list = [] if doc_compare_name_list is not None and len(doc_compare_name_list) > 0: + _, pre_common_word_list = remove_common_word(doc_compare_name_list) max_similarity_name, max_similarity = get_most_similar_name( raw_name, doc_compare_name_list) if max_similarity is not None and max_similarity >= 0.9: @@ -301,8 +307,10 @@ class DataMapping: data_info["similarity"] = max_similarity if data_info.get("id", None) is None or len(data_info.get("id", "")) == 0: + # set pre_common_word_list, reason: the document mapping for same fund maybe different with provider mapping + # the purpose is to get the most common word list, to improve the similarity. max_similarity_name, max_similarity = get_most_similar_name( - raw_name, provider_compare_name_list + raw_name, provider_compare_name_list, pre_common_word_list=pre_common_word_list ) if max_similarity is not None and max_similarity >= 0.5: data_info["id"] = provider_compare_mapping[ diff --git a/core/metrics.py b/core/metrics.py index f49e40a..a6a3146 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -94,6 +94,9 @@ class Metrics: performance_fee_true = [] performance_fee_pred = [] + + investment_mapping_true = [] + investment_mapping_pred = [] missing_error_list = [] data_point_list = ["tor", "ter", "ogc", "performance_fee"] @@ -157,86 +160,123 @@ class Metrics: performance_fee_pred.extend(pred_data) missing_error_list.extend(missing_error_data) elif self.data_type == "investment_mapping": - pass + prediction_doc_id_list = prediction_df["doc_id"].unique().tolist() + ground_truth_doc_id_list = ground_truth_df["doc_id"].unique().tolist() + # get intersection of doc_id_list + doc_id_list = list( + set(prediction_doc_id_list) & set(ground_truth_doc_id_list) + ) + # order by doc_id + doc_id_list.sort() + + for doc_id in doc_id_list: + prediction_data = prediction_df[prediction_df["doc_id"] == doc_id] + ground_truth_data = ground_truth_df[ground_truth_df["doc_id"] == doc_id] + for data_point in data_point_list: + true_data, pred_data, missing_error_data = self.get_investment_mapping_true_pred_data( + doc_id, ground_truth_data, prediction_data, data_point + ) + investment_mapping_true.extend(true_data) + investment_mapping_pred.extend(pred_data) + missing_error_list.extend(missing_error_data) metrics_list = [] - for data_point in data_point_list: - if data_point == "tor": - precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred) - tor_support = self.get_support_number(tor_true) - metrics_list.append( - { - "Data_Point": data_point, - "Precision": precision, - "Recall": recall, - "F1": f1, - "Support": tor_support, - } - ) - logger.info( - f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {tor_support}" - ) - elif data_point == "ter": - precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred) - ter_support = self.get_support_number(ter_true) - metrics_list.append( - { - "Data_Point": data_point, - "Precision": precision, - "Recall": recall, - "F1": f1, - "Support": ter_support, - } - ) - logger.info( - f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ter_support}" - ) - elif data_point == "ogc": - precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred) - ogc_support = self.get_support_number(ogc_true) - metrics_list.append( - { - "Data_Point": data_point, - "Precision": precision, - "Recall": recall, - "F1": f1, - "Support": ogc_support, - } - ) - logger.info( - f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ogc_support}" - ) - elif data_point == "performance_fee": - precision, recall, f1 = self.get_specific_metrics( - performance_fee_true, performance_fee_pred - ) - performance_fee_support = self.get_support_number(performance_fee_true) - metrics_list.append( - { - "Data_Point": data_point, - "Precision": precision, - "Recall": recall, - "F1": f1, - "Support": performance_fee_support, - } - ) - logger.info( - f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {performance_fee_support}" - ) + if self.data_type == "investment_mapping": + if len(investment_mapping_true) == 0 and len(investment_mapping_pred) == 0: + investment_mapping_true.append(1) + investment_mapping_pred.append(1) + precision, recall, f1 = self.get_specific_metrics(investment_mapping_true, investment_mapping_pred) + investment_mapping_support = self.get_support_number(investment_mapping_true) + metrics_list.append( + { + "Data_Point": "Investment Mapping", + "Precision": precision, + "Recall": recall, + "F1": f1, + "Support": investment_mapping_support, + } + ) + logger.info( + f"Investment mapping Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {investment_mapping_support}" + ) + else: + for data_point in data_point_list: + if data_point == "tor": + precision, recall, f1 = self.get_specific_metrics(tor_true, tor_pred) + tor_support = self.get_support_number(tor_true) + metrics_list.append( + { + "Data_Point": data_point, + "Precision": precision, + "Recall": recall, + "F1": f1, + "Support": tor_support, + } + ) + logger.info( + f"TOR Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {tor_support}" + ) + elif data_point == "ter": + precision, recall, f1 = self.get_specific_metrics(ter_true, ter_pred) + ter_support = self.get_support_number(ter_true) + metrics_list.append( + { + "Data_Point": data_point, + "Precision": precision, + "Recall": recall, + "F1": f1, + "Support": ter_support, + } + ) + logger.info( + f"TER Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ter_support}" + ) + elif data_point == "ogc": + precision, recall, f1 = self.get_specific_metrics(ogc_true, ogc_pred) + ogc_support = self.get_support_number(ogc_true) + metrics_list.append( + { + "Data_Point": data_point, + "Precision": precision, + "Recall": recall, + "F1": f1, + "Support": ogc_support, + } + ) + logger.info( + f"OGC Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {ogc_support}" + ) + elif data_point == "performance_fee": + precision, recall, f1 = self.get_specific_metrics( + performance_fee_true, performance_fee_pred + ) + performance_fee_support = self.get_support_number(performance_fee_true) + metrics_list.append( + { + "Data_Point": data_point, + "Precision": precision, + "Recall": recall, + "F1": f1, + "Support": performance_fee_support, + } + ) + logger.info( + f"Performance Fee Precision: {precision}, Recall: {recall}, F1: {f1}, Support: {performance_fee_support}" + ) - # get average metrics - precision_list = [metric["Precision"] for metric in metrics_list] - recall_list = [metric["Recall"] for metric in metrics_list] - f1_list = [metric["F1"] for metric in metrics_list] - metrics_list.append( - { - "Data_Point": "Average", - "Precision": sum(precision_list) / len(precision_list), - "Recall": sum(recall_list) / len(recall_list), - "F1": sum(f1_list) / len(f1_list), - "Support": sum([metric["Support"] for metric in metrics_list]), - } - ) + # get average metrics + precision_list = [metric["Precision"] for metric in metrics_list] + recall_list = [metric["Recall"] for metric in metrics_list] + f1_list = [metric["F1"] for metric in metrics_list] + metrics_list.append( + { + "Data_Point": "Average", + "Precision": sum(precision_list) / len(precision_list), + "Recall": sum(recall_list) / len(recall_list), + "F1": sum(f1_list) / len(f1_list), + "Support": sum([metric["Support"] for metric in metrics_list]), + } + ) return missing_error_list, metrics_list def get_support_number(self, true_data: list): @@ -490,6 +530,145 @@ class Metrics: return true_data, pred_data, missing_error_data + def get_investment_mapping_true_pred_data( + self, + doc_id, + ground_truth_data: pd.DataFrame, + prediction_data: pd.DataFrame, + data_point: str, + ): + dp_prediction = prediction_data[prediction_data["datapoint"] == data_point] + dp_prediction = self.modify_data(dp_prediction) + pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist() + pred_simple_name_unique_words_list = ( + dp_prediction["simple_name_unique_words"].unique().tolist() + ) + + dp_ground_truth = ground_truth_data[ + ground_truth_data["datapoint"] == data_point + ] + dp_ground_truth = self.modify_data(dp_ground_truth) + gt_simple_raw_names = dp_ground_truth["simple_raw_name"].unique().tolist() + gt_simple_name_unique_words_list = ( + dp_ground_truth["simple_name_unique_words"].unique().tolist() + ) + + compare_data_list = [] + for index, ground_truth in dp_ground_truth.iterrows(): + gt_page_index = ground_truth["page_index"] + gt_raw_name = ground_truth["raw_name"] + gt_simple_raw_name = ground_truth["simple_raw_name"] + gt_simple_name_unique_words = ground_truth["simple_name_unique_words"] + gt_investment_type = ground_truth["investment_type"] + + find_raw_name_in_pred = [ + pred_raw_name + for pred_raw_name in pred_simple_raw_names + if ( + gt_simple_raw_name in pred_raw_name + or pred_raw_name in gt_simple_raw_name + ) + and pred_raw_name.endswith(gt_simple_raw_name.split()[-1]) + ] + + if ( + gt_simple_name_unique_words in pred_simple_name_unique_words_list + or len(find_raw_name_in_pred) > 0 + ): + # get the ground truth data with the same unique words + if gt_simple_name_unique_words in pred_simple_name_unique_words_list: + pred_data_df = dp_prediction[ + dp_prediction["simple_name_unique_words"] + == gt_simple_name_unique_words + ] + if len(pred_data_df) > 1: + if ( + len(pred_data_df[pred_data_df["page_index"] == gt_page_index]) + == 0 + ): + pred_data = pred_data_df.iloc[0] + else: + pred_data = pred_data_df[ + pred_data_df["page_index"] == gt_page_index + ].iloc[0] + elif len(pred_data_df) == 1: + pred_data = pred_data_df.iloc[0] + else: + pred_data = None + else: + pred_data_df = dp_prediction[ + dp_prediction["simple_raw_name"] == find_raw_name_in_pred[0] + ] + if len(pred_data_df) > 1: + if ( + len(pred_data_df[pred_data_df["page_index"] == gt_page_index]) + == 0 + ): + pred_data = pred_data_df.iloc[0] + else: + pred_data = pred_data_df[ + pred_data_df["page_index"] == gt_page_index + ].iloc[0] + elif len(pred_data_df) == 1: + pred_data = pred_data_df.iloc[0] + else: + pred_data = None + if pred_data is not None: + compare_data = {"raw_name": gt_raw_name, + "investment_type": gt_investment_type, + "gt_investment_id": ground_truth["investment_id"], + "gt_investment_name": ground_truth["investment_name"], + "pred_investment_id": pred_data["investment_id"], + "pred_investment_name": pred_data["investment_name"]} + compare_data_list.append(compare_data) + + true_data = [] + pred_data = [] + missing_error_data = [] + + for compare_data in compare_data_list: + if compare_data["gt_investment_id"] == compare_data["pred_investment_id"]: + true_data.append(1) + pred_data.append(1) + else: + true_data.append(1) + pred_data.append(0) + error_data = { + "doc_id": doc_id, + "data_point": data_point, + "raw_name": compare_data["raw_name"], + "investment_type": compare_data["investment_type"], + "error_type": "mapping missing", + "error_id": compare_data["pred_investment_id"], + "error_name": compare_data["pred_investment_name"], + "correct_id": compare_data["gt_investment_id"], + "correct_name": compare_data["gt_investment_name"] + } + missing_error_data.append(error_data) + + for index, prediction in dp_prediction.iterrows(): + pred_raw_name = prediction["raw_name"] + pred_investment_id = prediction["investment_id"] + pred_investment_name = prediction["investment_name"] + pred_investment_type = prediction["investment_type"] + gt_data_df = dp_ground_truth[dp_ground_truth["investment_id"] == pred_investment_id] + if len(gt_data_df) == 0: + true_data.append(0) + pred_data.append(1) + error_data = { + "doc_id": doc_id, + "data_point": data_point, + "raw_name": pred_raw_name, + "investment_type": pred_investment_type, + "error_type": "mapping incorrect", + "error_id": pred_investment_id, + "error_name": pred_investment_name, + "correct_id": "", + "correct_name": "" + } + missing_error_data.append(error_data) + return true_data, pred_data, missing_error_data + def modify_data(self, data: pd.DataFrame): data["simple_raw_name"] = "" data["simple_name_unique_words"] = "" diff --git a/main.py b/main.py index d299a96..15527af 100644 --- a/main.py +++ b/main.py @@ -27,13 +27,15 @@ class EMEA_AR_Parsing: os.makedirs(self.pdf_folder, exist_ok=True) self.pdf_file = self.download_pdf() self.document_mapping_info_df = query_document_fund_mapping(doc_id, rerun=False) - + if extract_way is None or len(extract_way) == 0: extract_way = "text" self.extract_way = extract_way self.output_extract_image_folder = None if self.extract_way == "image": - self.output_extract_image_folder = r"/data/emea_ar/output/extract_data/images/" + self.output_extract_image_folder = ( + r"/data/emea_ar/output/extract_data/images/" + ) os.makedirs(self.output_extract_image_folder, exist_ok=True) if output_extract_data_folder is None or len(output_extract_data_folder) == 0: @@ -41,7 +43,9 @@ class EMEA_AR_Parsing: if not output_extract_data_folder.endswith("/"): output_extract_data_folder = f"{output_extract_data_folder}/" if extract_way is not None and len(extract_way) > 0: - output_extract_data_folder = f"{output_extract_data_folder}by_{extract_way}/" + output_extract_data_folder = ( + f"{output_extract_data_folder}by_{extract_way}/" + ) self.output_extract_data_folder = output_extract_data_folder os.makedirs(self.output_extract_data_folder, exist_ok=True) @@ -50,7 +54,9 @@ class EMEA_AR_Parsing: if not output_mapping_data_folder.endswith("/"): output_mapping_data_folder = f"{output_mapping_data_folder}/" if extract_way is not None and len(extract_way) > 0: - output_mapping_data_folder = f"{output_mapping_data_folder}by_{extract_way}/" + output_mapping_data_folder = ( + f"{output_mapping_data_folder}by_{extract_way}/" + ) self.output_mapping_data_folder = output_mapping_data_folder os.makedirs(self.output_mapping_data_folder, exist_ok=True) @@ -75,8 +81,10 @@ class EMEA_AR_Parsing: datapoints.remove("doc_id") return datapoints - def extract_data(self, - re_run: bool = False,) -> list: + def extract_data( + self, + re_run: bool = False, + ) -> list: if not re_run: output_data_json_folder = os.path.join( self.output_extract_data_folder, "json/" @@ -100,7 +108,7 @@ class EMEA_AR_Parsing: self.datapoints, self.document_mapping_info_df, extract_way=self.extract_way, - output_image_folder=self.output_extract_image_folder + output_image_folder=self.output_extract_image_folder, ) data_from_gpt = data_extraction.extract_data() return data_from_gpt @@ -144,18 +152,18 @@ def filter_pages(doc_id: str, pdf_folder: str) -> None: def extract_data( - doc_id: str, - pdf_folder: str, + doc_id: str, + pdf_folder: str, output_data_folder: str, extract_way: str = "text", - re_run: bool = False + re_run: bool = False, ) -> None: logger.info(f"Extract EMEA AR data for doc_id: {doc_id}") emea_ar_parsing = EMEA_AR_Parsing( - doc_id, - pdf_folder, + doc_id, + pdf_folder, output_extract_data_folder=output_data_folder, - extract_way=extract_way + extract_way=extract_way, ) data_from_gpt = emea_ar_parsing.extract_data(re_run) return data_from_gpt @@ -284,19 +292,22 @@ def batch_start_job( result_extract_data_list.extend(doc_data_from_gpt) result_mapping_data_list.extend(doc_mapping_data_list) - if force_save_total_data or (special_doc_id_list is None or len(special_doc_id_list) == 0): + if force_save_total_data or ( + special_doc_id_list is None or len(special_doc_id_list) == 0 + ): result_extract_data_df = pd.DataFrame(result_extract_data_list) result_extract_data_df.reset_index(drop=True, inplace=True) result_mappingdata_df = pd.DataFrame(result_mapping_data_list) result_mappingdata_df.reset_index(drop=True, inplace=True) - + logger.info(f"Saving extract data to {output_extract_data_total_folder}") + unique_doc_ids = result_extract_data_df["doc_id"].unique().tolist() os.makedirs(output_extract_data_total_folder, exist_ok=True) time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) output_file = os.path.join( output_extract_data_total_folder, - f"extract_data_info_{len(pdf_files)}_documents_by_{extract_way}_{time_stamp}.xlsx", + f"extract_data_info_{len(unique_doc_ids)}_documents_by_{extract_way}_{time_stamp}.xlsx", ) with pd.ExcelWriter(output_file) as writer: result_extract_data_df.to_excel( @@ -304,11 +315,12 @@ def batch_start_job( ) logger.info(f"Saving mapping data to {output_mapping_total_folder}") + unique_doc_ids = result_mappingdata_df["doc_id"].unique().tolist() os.makedirs(output_mapping_total_folder, exist_ok=True) time_stamp = time.strftime("%Y%m%d%H%M%S", time.localtime()) output_file = os.path.join( output_mapping_total_folder, - f"mapping_data_info_{len(pdf_files)}_documents_by_{extract_way}_{time_stamp}.xlsx", + f"mapping_data_info_{len(unique_doc_ids)}_documents_by_{extract_way}_{time_stamp}.xlsx", ) with pd.ExcelWriter(output_file) as writer: result_mappingdata_df.to_excel( @@ -317,18 +329,30 @@ def batch_start_job( result_extract_data_df.to_excel( writer, index=False, sheet_name="extract_data" ) - + prediction_sheet_name = "mapping_data" ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" ground_truth_sheet_name = "mapping_data" metrics_output_folder = r"/data/emea_ar/output/metrics/" + + logger.info(f"Calculating metrics for data extraction") missing_error_list, metrics_list, metrics_file = get_metrics( - "data_extraction", - output_file, + "data_extraction", + output_file, prediction_sheet_name, ground_truth_file, ground_truth_sheet_name, - metrics_output_folder + metrics_output_folder, + ) + + logger.info(f"Calculating metrics for investment mapping") + missing_error_list, metrics_list, metrics_file = get_metrics( + "investment_mapping", + output_file, + prediction_sheet_name, + ground_truth_file, + ground_truth_sheet_name, + metrics_output_folder, ) @@ -530,14 +554,32 @@ def test_data_extraction_metrics(): ground_truth_sheet_name = "mapping_data" metrics_output_folder = r"/data/emea_ar/output/metrics/" missing_error_list, metrics_list, metrics_file = get_metrics( - data_type, - prediction_file, + data_type, + prediction_file, prediction_sheet_name, ground_truth_file, ground_truth_sheet_name, - metrics_output_folder + metrics_output_folder, ) + +def test_mapping_raw_name(): + doc_id = "292989214" + raw_name = "ENBD Saudi Arabia Equity Fund Class A USD Accumulation" + output_folder = r"/data/emea_ar/output/mapping_data/docs/by_text/" + data_mapping = DataMapping( + doc_id, + datapoints=None, + raw_document_data_list=None, + document_mapping_info_df=None, + output_data_folder=output_folder, + ) + mapping_info = data_mapping.matching_with_database( + raw_name=raw_name, parent_id="FS0000B4A7", matching_type="share" + ) + print(mapping_info) + + if __name__ == "__main__": pdf_folder = r"/data/emea_ar/small_pdf/" page_filter_ground_truth_file = ( @@ -560,7 +602,7 @@ if __name__ == "__main__": output_extract_data_child_folder = r"/data/emea_ar/output/extract_data/docs/" output_extract_data_total_folder = r"/data/emea_ar/output/extract_data/total/" - + # batch_extract_data( # pdf_folder, # page_filter_ground_truth_file, @@ -572,34 +614,57 @@ if __name__ == "__main__": # doc_id = "476492237" # extract_way = "image" - # extract_data(doc_id, - # pdf_folder, + # extract_data(doc_id, + # pdf_folder, # output_extract_data_child_folder, # extract_way, # re_run_extract_data) - + # special_doc_id_list = ["505174428", "510326848", "349679479"] - special_doc_id_list = [] + check_mapping_doc_id_list = [ + "458359181", + "486383912", + "529925114", + "391456740", + "391736837", + "497497599", + "327956364", + "479793787", + "334718372", + "321733631", + "507967525", + "478585901", + "366179419", + "509845549", + "323390570", + "344636875", + "445256897", + "508854243", + "520879048", + "463081566", + ] + special_doc_id_list = check_mapping_doc_id_list output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_extract_data = False - re_run_mapping_data = False - force_save_total_data = False - + re_run_mapping_data = True + force_save_total_data = True + extract_ways = ["text"] - # for extract_way in extract_ways: - # batch_start_job( - # pdf_folder, - # page_filter_ground_truth_file, - # output_extract_data_child_folder, - # output_mapping_child_folder, - # output_extract_data_total_folder, - # output_mapping_total_folder, - # extract_way, - # special_doc_id_list, - # re_run_extract_data, - # re_run_mapping_data, - # force_save_total_data=force_save_total_data, - # ) - - test_data_extraction_metrics() + for extract_way in extract_ways: + batch_start_job( + pdf_folder, + page_filter_ground_truth_file, + output_extract_data_child_folder, + output_mapping_child_folder, + output_extract_data_total_folder, + output_mapping_total_folder, + extract_way, + special_doc_id_list, + re_run_extract_data, + re_run_mapping_data, + force_save_total_data=force_save_total_data, + ) + + # test_data_extraction_metrics() + # test_mapping_raw_name() diff --git a/utils/biz_utils.py b/utils/biz_utils.py index 9b3a782..285ff9d 100644 --- a/utils/biz_utils.py +++ b/utils/biz_utils.py @@ -23,7 +23,7 @@ def clean_text(text: str) -> str: return text -def get_most_similar_name(text: str, name_list: list): +def get_most_similar_name(text: str, name_list: list, pre_common_word_list: list = None) -> str: """ Get the most similar fund name from fund_name_list by jacard similarity """ @@ -40,6 +40,9 @@ def get_most_similar_name(text: str, name_list: list): common_word_list = [] if len(name_list) > 1: _, common_word_list = remove_common_word(copy_fund_name_list) + if pre_common_word_list is not None and len(pre_common_word_list) > 0: + common_word_list.extend([word for word in pre_common_word_list + if word not in common_word_list]) text = text.strip() text = remove_special_characters(text) @@ -61,17 +64,13 @@ def get_most_similar_name(text: str, name_list: list): # remove word in fund_name_list for i in range(len(copy_fund_name_list)): temp_splits = copy_fund_name_list[i].split() - for temp in temp_splits: - if remove_special_characters(temp).lower() == word: - copy_fund_name_list[i] = re.sub(r'\s+', ' ', - copy_fund_name_list[i].replace(temp, ' ')) + copy_fund_name_list[i] = ' '.join([split for split in temp_splits + if remove_special_characters(split).lower() != word]) for i in range(len(copy_fund_name_list)): temp_splits = copy_fund_name_list[i].split() - for temp in temp_splits: - if remove_special_characters(temp).lower() in ['fund', 'portfolio', 'class', 'share', 'shares']: - copy_fund_name_list[i] = \ - re.sub(r'\s+', ' ', copy_fund_name_list[i].replace(temp, ' ')) + copy_fund_name_list[i] = ' '.join([split for split in temp_splits + if remove_special_characters(split).lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']]) final_splits = [] for split in new_splits: if split.lower() not in ['fund', 'portfolio', 'class', 'share', 'shares']: