From f1c0290588fc202402f1d73ab923693884650474 Mon Sep 17 00:00:00 2001 From: Blade He Date: Thu, 21 Nov 2024 16:36:58 -0600 Subject: [PATCH] Optimize investment mapping algorithm. 1. Get proper currency if exist multiple currencies in share name, e.g. CHF EUR 2. Default currency should be based on scenario: USD or EUR. 3. Remove special chars should be based on \W, instead of [^a-zA-Z0-9\s] --- main.py | 93 ++++++++++++++++++++++++++++++++-------------- utils/biz_utils.py | 39 +++++++++++++------ 2 files changed, 94 insertions(+), 38 deletions(-) diff --git a/main.py b/main.py index 6061731..8d0a21c 100644 --- a/main.py +++ b/main.py @@ -717,7 +717,7 @@ def test_replace_abbrevation(): def test_calculate_metrics(): from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx" - mapping_file = r"/data/emea_ar/basic_information/English\sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx" + mapping_file = r"/data/emea_ar/basic_information/English/sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx" data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping") data_df = data_df[data_df["check"].isin([0, 1])] @@ -736,22 +736,22 @@ def test_calculate_metrics(): logger.info(f"Investment mapping metrics: {mapping_metrics}") # tor data - tor_data_df = data_df[data_df["datapoint"] == "tor"] + tor_data_df = filter_data_df[filter_data_df["datapoint"] == "tor"] tor_metrics = get_sub_metrics(tor_data_df, "tor") logger.info(f"TOR metrics: {tor_metrics}") # ter data - ter_data_df = data_df[data_df["datapoint"] == "ter"] + ter_data_df = filter_data_df[filter_data_df["datapoint"] == "ter"] ter_metrics = get_sub_metrics(ter_data_df, "ter") logger.info(f"TER metrics: {ter_metrics}") # ogc data - ogc_data_df = data_df[data_df["datapoint"] == "ogc"] + ogc_data_df = filter_data_df[filter_data_df["datapoint"] == "ogc"] ogc_metrics = get_sub_metrics(ogc_data_df, "ogc") logger.info(f"OGC metrics: {ogc_metrics}") # performance_fee data - performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"] + performance_fee_data_df = filter_data_df[filter_data_df["datapoint"] == "performance_fee"] performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee") logger.info(f"Performance fee metrics: {performance_fee_metrics}") @@ -770,9 +770,11 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict: pre_list = data_df["check"].tolist() # convert pre_list member to be integer pre_list = [int(pre) for pre in pre_list] - zero_pre_list = [int(pre) for pre in pre_list if int(pre) == 0] - gt_list += zero_pre_list - pre_list += [1] * len(zero_pre_list) + + for index, row in data_df.iterrows(): + if row["check"] == 0 and len(row["investment_id"].strip()) > 0: + pre_list.append(1) + gt_list.append(0) # calculate metrics accuracy = accuracy_score(gt_list, pre_list) precision = precision_score(gt_list, pre_list) @@ -790,8 +792,45 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict: } return metrics +def replace_rerun_data(new_data_file: str, original_data_file: str): + data_in_doc_mapping_sheet = "data_in_doc_mapping" + total_mapping_data_sheet = "total_mapping_data" + extract_data_sheet = "extract_data" + + new_data_in_doc_mapping = pd.read_excel(new_data_file, sheet_name=data_in_doc_mapping_sheet) + new_total_mapping_data = pd.read_excel(new_data_file, sheet_name=total_mapping_data_sheet) + new_extract_data = pd.read_excel(new_data_file, sheet_name=extract_data_sheet) + + document_list = new_data_in_doc_mapping["doc_id"].unique().tolist() + + original_data_in_doc_mapping = pd.read_excel(original_data_file, sheet_name=data_in_doc_mapping_sheet) + original_total_mapping_data = pd.read_excel(original_data_file, sheet_name=total_mapping_data_sheet) + original_extract_data = pd.read_excel(original_data_file, sheet_name=extract_data_sheet) + + # remove data in original data by document_list + original_data_in_doc_mapping = original_data_in_doc_mapping[~original_data_in_doc_mapping["doc_id"].isin(document_list)] + original_total_mapping_data = original_total_mapping_data[~original_total_mapping_data["doc_id"].isin(document_list)] + original_extract_data = original_extract_data[~original_extract_data["doc_id"].isin(document_list)] + + # merge new data to original data + new_data_in_doc_mapping = pd.concat([original_data_in_doc_mapping, new_data_in_doc_mapping]) + new_data_in_doc_mapping.reset_index(drop=True, inplace=True) + new_total_mapping_data = pd.concat([original_total_mapping_data, new_total_mapping_data]) + new_total_mapping_data.reset_index(drop=True, inplace=True) + new_extract_data = pd.concat([original_extract_data, new_extract_data]) + new_extract_data.reset_index(drop=True, inplace=True) + + with pd.ExcelWriter(original_data_file) as writer: + new_data_in_doc_mapping.to_excel(writer, index=False, sheet_name=data_in_doc_mapping_sheet) + new_total_mapping_data.to_excel(writer, index=False, sheet_name=total_mapping_data_sheet) + new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet) + + if __name__ == "__main__": - # test_calculate_metrics() + # new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx" + # original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx" + # replace_rerun_data(new_data_file, original_data_file) + test_calculate_metrics() # test_replace_abbrevation() # test_translate_pdf() pdf_folder = r"/data/emea_ar/pdf/" @@ -1203,32 +1242,32 @@ if __name__ == "__main__": "501380497", "514636959", "508981020"] - special_doc_id_list = ["514636953"] + # special_doc_id_list = ["514636952"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_extract_data = False - re_run_mapping_data = True - force_save_total_data = False + re_run_mapping_data = False + force_save_total_data = True calculate_metrics = False extract_ways = ["text"] # pdf_folder = r"/data/emea_ar/small_pdf/" pdf_folder = r"/data/emea_ar/pdf/" - for extract_way in extract_ways: - batch_start_job( - pdf_folder, - page_filter_ground_truth_file, - output_extract_data_child_folder, - output_mapping_child_folder, - output_extract_data_total_folder, - output_mapping_total_folder, - extract_way, - special_doc_id_list, - re_run_extract_data, - re_run_mapping_data, - force_save_total_data=force_save_total_data, - calculate_metrics=calculate_metrics, - ) + # for extract_way in extract_ways: + # batch_start_job( + # pdf_folder, + # page_filter_ground_truth_file, + # output_extract_data_child_folder, + # output_mapping_child_folder, + # output_extract_data_total_folder, + # output_mapping_total_folder, + # extract_way, + # special_doc_id_list, + # re_run_extract_data, + # re_run_mapping_data, + # force_save_total_data=force_save_total_data, + # calculate_metrics=calculate_metrics, + # ) # test_data_extraction_metrics() # test_mapping_raw_name() diff --git a/utils/biz_utils.py b/utils/biz_utils.py index a1d8b7e..c56a668 100644 --- a/utils/biz_utils.py +++ b/utils/biz_utils.py @@ -464,6 +464,7 @@ def get_share_feature_from_text(text: str): count += 1 return None + def get_currency_from_text(text: str): if text is None or len(text.strip()) == 0: return None @@ -479,7 +480,16 @@ def get_currency_from_text(text: str): count += 1 if len(currency_list) > 1: # remove the first currency from currency list - currency_list.pop(0) + if currency_list[0] in ['USD', 'EUR']: + currency_list.pop(0) + else: + remove_currency = None + for currency in currency_list: + if currency in ['USD', 'EUR']: + remove_currency = currency + break + if remove_currency is not None: + currency_list.remove(remove_currency) return currency_list[0] elif len(currency_list) == 1: return currency_list[0] @@ -563,19 +573,26 @@ def update_for_currency(text: str, share_name: str, compare_list: list): else: # return text, share_name, compare_list pass + default_currency = 'USD' if with_currency: share_name_split = share_name.split() - share_name_currency_list = [] - for split in share_name_split: - if split.upper() in total_currency_list and split.upper() not in share_name_currency_list: - share_name_currency_list.append(split) - if len(share_name_currency_list) > 1 and 'USD' in share_name_currency_list: - new_share_name = ' '.join([split for split in share_name_split if split.upper() != 'USD']) + share_name_currency = get_currency_from_text(share_name) + if share_name_currency is not None and share_name_currency in total_currency_list: + for split in share_name_split: + if split in total_currency_list and split != share_name_currency: + default_currency = split + break + new_share_name = ' '.join([split for split in share_name_split + if split not in total_currency_list + or (split == share_name_currency)]) if share_name in text: text = text.replace(share_name, new_share_name) else: - text = ' '.join([split for split in text.split() if split.upper() != 'USD']) + text = ' '.join([split for split in text.split() + if split not in total_currency_list + or (split == share_name_currency)]) share_name = new_share_name + for c_i in range(len(compare_list)): compare = compare_list[c_i] compare_share_part = get_share_part_list([compare])[0] @@ -584,8 +601,8 @@ def update_for_currency(text: str, share_name: str, compare_list: list): for split in compare_share_part_split: if split.upper() in total_currency_list and split.upper() not in compare_share_part_currency_list: compare_share_part_currency_list.append(split) - if len(compare_share_part_currency_list) > 1 and 'USD' in compare_share_part_currency_list: - compare_share_part_split = [split for split in compare_share_part_split if split.upper() != 'USD'] + if len(compare_share_part_currency_list) > 1 and default_currency in compare_share_part_currency_list: + compare_share_part_split = [split for split in compare_share_part_split if split.upper() != default_currency] new_compare_share_part = ' '.join(compare_share_part_split) compare_list[c_i] = compare.replace(compare_share_part, new_compare_share_part) return text, share_name, compare_list @@ -672,7 +689,7 @@ def split_words_without_space(text: str): def remove_special_characters(text): - text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text) + text = re.sub(r'\W', ' ', text) text = re.sub(r'\s+', ' ', text) text = text.strip() return text