Optimize investment mapping algorithm.

1. Get proper currency if exist multiple currencies in share name, e.g. CHF EUR
2. Default currency should be based on scenario: USD or EUR.
3. Remove special chars should be based on \W, instead of [^a-zA-Z0-9\s]
This commit is contained in:
Blade He 2024-11-21 16:36:58 -06:00
parent 5b9f9416de
commit f1c0290588
2 changed files with 94 additions and 38 deletions

93
main.py
View File

@ -717,7 +717,7 @@ def test_replace_abbrevation():
def test_calculate_metrics(): def test_calculate_metrics():
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx" data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
mapping_file = r"/data/emea_ar/basic_information/English\sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx" mapping_file = r"/data/emea_ar/basic_information/English/sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx"
data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping") data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping")
data_df = data_df[data_df["check"].isin([0, 1])] data_df = data_df[data_df["check"].isin([0, 1])]
@ -736,22 +736,22 @@ def test_calculate_metrics():
logger.info(f"Investment mapping metrics: {mapping_metrics}") logger.info(f"Investment mapping metrics: {mapping_metrics}")
# tor data # tor data
tor_data_df = data_df[data_df["datapoint"] == "tor"] tor_data_df = filter_data_df[filter_data_df["datapoint"] == "tor"]
tor_metrics = get_sub_metrics(tor_data_df, "tor") tor_metrics = get_sub_metrics(tor_data_df, "tor")
logger.info(f"TOR metrics: {tor_metrics}") logger.info(f"TOR metrics: {tor_metrics}")
# ter data # ter data
ter_data_df = data_df[data_df["datapoint"] == "ter"] ter_data_df = filter_data_df[filter_data_df["datapoint"] == "ter"]
ter_metrics = get_sub_metrics(ter_data_df, "ter") ter_metrics = get_sub_metrics(ter_data_df, "ter")
logger.info(f"TER metrics: {ter_metrics}") logger.info(f"TER metrics: {ter_metrics}")
# ogc data # ogc data
ogc_data_df = data_df[data_df["datapoint"] == "ogc"] ogc_data_df = filter_data_df[filter_data_df["datapoint"] == "ogc"]
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc") ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
logger.info(f"OGC metrics: {ogc_metrics}") logger.info(f"OGC metrics: {ogc_metrics}")
# performance_fee data # performance_fee data
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"] performance_fee_data_df = filter_data_df[filter_data_df["datapoint"] == "performance_fee"]
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee") performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
logger.info(f"Performance fee metrics: {performance_fee_metrics}") logger.info(f"Performance fee metrics: {performance_fee_metrics}")
@ -770,9 +770,11 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
pre_list = data_df["check"].tolist() pre_list = data_df["check"].tolist()
# convert pre_list member to be integer # convert pre_list member to be integer
pre_list = [int(pre) for pre in pre_list] pre_list = [int(pre) for pre in pre_list]
zero_pre_list = [int(pre) for pre in pre_list if int(pre) == 0]
gt_list += zero_pre_list for index, row in data_df.iterrows():
pre_list += [1] * len(zero_pre_list) if row["check"] == 0 and len(row["investment_id"].strip()) > 0:
pre_list.append(1)
gt_list.append(0)
# calculate metrics # calculate metrics
accuracy = accuracy_score(gt_list, pre_list) accuracy = accuracy_score(gt_list, pre_list)
precision = precision_score(gt_list, pre_list) precision = precision_score(gt_list, pre_list)
@ -790,8 +792,45 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
} }
return metrics return metrics
def replace_rerun_data(new_data_file: str, original_data_file: str):
data_in_doc_mapping_sheet = "data_in_doc_mapping"
total_mapping_data_sheet = "total_mapping_data"
extract_data_sheet = "extract_data"
new_data_in_doc_mapping = pd.read_excel(new_data_file, sheet_name=data_in_doc_mapping_sheet)
new_total_mapping_data = pd.read_excel(new_data_file, sheet_name=total_mapping_data_sheet)
new_extract_data = pd.read_excel(new_data_file, sheet_name=extract_data_sheet)
document_list = new_data_in_doc_mapping["doc_id"].unique().tolist()
original_data_in_doc_mapping = pd.read_excel(original_data_file, sheet_name=data_in_doc_mapping_sheet)
original_total_mapping_data = pd.read_excel(original_data_file, sheet_name=total_mapping_data_sheet)
original_extract_data = pd.read_excel(original_data_file, sheet_name=extract_data_sheet)
# remove data in original data by document_list
original_data_in_doc_mapping = original_data_in_doc_mapping[~original_data_in_doc_mapping["doc_id"].isin(document_list)]
original_total_mapping_data = original_total_mapping_data[~original_total_mapping_data["doc_id"].isin(document_list)]
original_extract_data = original_extract_data[~original_extract_data["doc_id"].isin(document_list)]
# merge new data to original data
new_data_in_doc_mapping = pd.concat([original_data_in_doc_mapping, new_data_in_doc_mapping])
new_data_in_doc_mapping.reset_index(drop=True, inplace=True)
new_total_mapping_data = pd.concat([original_total_mapping_data, new_total_mapping_data])
new_total_mapping_data.reset_index(drop=True, inplace=True)
new_extract_data = pd.concat([original_extract_data, new_extract_data])
new_extract_data.reset_index(drop=True, inplace=True)
with pd.ExcelWriter(original_data_file) as writer:
new_data_in_doc_mapping.to_excel(writer, index=False, sheet_name=data_in_doc_mapping_sheet)
new_total_mapping_data.to_excel(writer, index=False, sheet_name=total_mapping_data_sheet)
new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet)
if __name__ == "__main__": if __name__ == "__main__":
# test_calculate_metrics() # new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
# replace_rerun_data(new_data_file, original_data_file)
test_calculate_metrics()
# test_replace_abbrevation() # test_replace_abbrevation()
# test_translate_pdf() # test_translate_pdf()
pdf_folder = r"/data/emea_ar/pdf/" pdf_folder = r"/data/emea_ar/pdf/"
@ -1203,32 +1242,32 @@ if __name__ == "__main__":
"501380497", "501380497",
"514636959", "514636959",
"508981020"] "508981020"]
special_doc_id_list = ["514636953"] # special_doc_id_list = ["514636952"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_extract_data = False re_run_extract_data = False
re_run_mapping_data = True re_run_mapping_data = False
force_save_total_data = False force_save_total_data = True
calculate_metrics = False calculate_metrics = False
extract_ways = ["text"] extract_ways = ["text"]
# pdf_folder = r"/data/emea_ar/small_pdf/" # pdf_folder = r"/data/emea_ar/small_pdf/"
pdf_folder = r"/data/emea_ar/pdf/" pdf_folder = r"/data/emea_ar/pdf/"
for extract_way in extract_ways: # for extract_way in extract_ways:
batch_start_job( # batch_start_job(
pdf_folder, # pdf_folder,
page_filter_ground_truth_file, # page_filter_ground_truth_file,
output_extract_data_child_folder, # output_extract_data_child_folder,
output_mapping_child_folder, # output_mapping_child_folder,
output_extract_data_total_folder, # output_extract_data_total_folder,
output_mapping_total_folder, # output_mapping_total_folder,
extract_way, # extract_way,
special_doc_id_list, # special_doc_id_list,
re_run_extract_data, # re_run_extract_data,
re_run_mapping_data, # re_run_mapping_data,
force_save_total_data=force_save_total_data, # force_save_total_data=force_save_total_data,
calculate_metrics=calculate_metrics, # calculate_metrics=calculate_metrics,
) # )
# test_data_extraction_metrics() # test_data_extraction_metrics()
# test_mapping_raw_name() # test_mapping_raw_name()

View File

@ -464,6 +464,7 @@ def get_share_feature_from_text(text: str):
count += 1 count += 1
return None return None
def get_currency_from_text(text: str): def get_currency_from_text(text: str):
if text is None or len(text.strip()) == 0: if text is None or len(text.strip()) == 0:
return None return None
@ -479,7 +480,16 @@ def get_currency_from_text(text: str):
count += 1 count += 1
if len(currency_list) > 1: if len(currency_list) > 1:
# remove the first currency from currency list # remove the first currency from currency list
currency_list.pop(0) if currency_list[0] in ['USD', 'EUR']:
currency_list.pop(0)
else:
remove_currency = None
for currency in currency_list:
if currency in ['USD', 'EUR']:
remove_currency = currency
break
if remove_currency is not None:
currency_list.remove(remove_currency)
return currency_list[0] return currency_list[0]
elif len(currency_list) == 1: elif len(currency_list) == 1:
return currency_list[0] return currency_list[0]
@ -563,19 +573,26 @@ def update_for_currency(text: str, share_name: str, compare_list: list):
else: else:
# return text, share_name, compare_list # return text, share_name, compare_list
pass pass
default_currency = 'USD'
if with_currency: if with_currency:
share_name_split = share_name.split() share_name_split = share_name.split()
share_name_currency_list = [] share_name_currency = get_currency_from_text(share_name)
for split in share_name_split: if share_name_currency is not None and share_name_currency in total_currency_list:
if split.upper() in total_currency_list and split.upper() not in share_name_currency_list: for split in share_name_split:
share_name_currency_list.append(split) if split in total_currency_list and split != share_name_currency:
if len(share_name_currency_list) > 1 and 'USD' in share_name_currency_list: default_currency = split
new_share_name = ' '.join([split for split in share_name_split if split.upper() != 'USD']) break
new_share_name = ' '.join([split for split in share_name_split
if split not in total_currency_list
or (split == share_name_currency)])
if share_name in text: if share_name in text:
text = text.replace(share_name, new_share_name) text = text.replace(share_name, new_share_name)
else: else:
text = ' '.join([split for split in text.split() if split.upper() != 'USD']) text = ' '.join([split for split in text.split()
if split not in total_currency_list
or (split == share_name_currency)])
share_name = new_share_name share_name = new_share_name
for c_i in range(len(compare_list)): for c_i in range(len(compare_list)):
compare = compare_list[c_i] compare = compare_list[c_i]
compare_share_part = get_share_part_list([compare])[0] compare_share_part = get_share_part_list([compare])[0]
@ -584,8 +601,8 @@ def update_for_currency(text: str, share_name: str, compare_list: list):
for split in compare_share_part_split: for split in compare_share_part_split:
if split.upper() in total_currency_list and split.upper() not in compare_share_part_currency_list: if split.upper() in total_currency_list and split.upper() not in compare_share_part_currency_list:
compare_share_part_currency_list.append(split) compare_share_part_currency_list.append(split)
if len(compare_share_part_currency_list) > 1 and 'USD' in compare_share_part_currency_list: if len(compare_share_part_currency_list) > 1 and default_currency in compare_share_part_currency_list:
compare_share_part_split = [split for split in compare_share_part_split if split.upper() != 'USD'] compare_share_part_split = [split for split in compare_share_part_split if split.upper() != default_currency]
new_compare_share_part = ' '.join(compare_share_part_split) new_compare_share_part = ' '.join(compare_share_part_split)
compare_list[c_i] = compare.replace(compare_share_part, new_compare_share_part) compare_list[c_i] = compare.replace(compare_share_part, new_compare_share_part)
return text, share_name, compare_list return text, share_name, compare_list
@ -672,7 +689,7 @@ def split_words_without_space(text: str):
def remove_special_characters(text): def remove_special_characters(text):
text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text) text = re.sub(r'\W', ' ', text)
text = re.sub(r'\s+', ' ', text) text = re.sub(r'\s+', ' ', text)
text = text.strip() text = text.strip()
return text return text