Optimize investment mapping algorithm.
1. Get proper currency if exist multiple currencies in share name, e.g. CHF EUR 2. Default currency should be based on scenario: USD or EUR. 3. Remove special chars should be based on \W, instead of [^a-zA-Z0-9\s]
This commit is contained in:
parent
5b9f9416de
commit
f1c0290588
93
main.py
93
main.py
|
|
@ -717,7 +717,7 @@ def test_replace_abbrevation():
|
|||
def test_calculate_metrics():
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||
data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
||||
mapping_file = r"/data/emea_ar/basic_information/English\sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx"
|
||||
mapping_file = r"/data/emea_ar/basic_information/English/sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx"
|
||||
|
||||
data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping")
|
||||
data_df = data_df[data_df["check"].isin([0, 1])]
|
||||
|
|
@ -736,22 +736,22 @@ def test_calculate_metrics():
|
|||
logger.info(f"Investment mapping metrics: {mapping_metrics}")
|
||||
|
||||
# tor data
|
||||
tor_data_df = data_df[data_df["datapoint"] == "tor"]
|
||||
tor_data_df = filter_data_df[filter_data_df["datapoint"] == "tor"]
|
||||
tor_metrics = get_sub_metrics(tor_data_df, "tor")
|
||||
logger.info(f"TOR metrics: {tor_metrics}")
|
||||
|
||||
# ter data
|
||||
ter_data_df = data_df[data_df["datapoint"] == "ter"]
|
||||
ter_data_df = filter_data_df[filter_data_df["datapoint"] == "ter"]
|
||||
ter_metrics = get_sub_metrics(ter_data_df, "ter")
|
||||
logger.info(f"TER metrics: {ter_metrics}")
|
||||
|
||||
# ogc data
|
||||
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
|
||||
ogc_data_df = filter_data_df[filter_data_df["datapoint"] == "ogc"]
|
||||
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
|
||||
logger.info(f"OGC metrics: {ogc_metrics}")
|
||||
|
||||
# performance_fee data
|
||||
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
|
||||
performance_fee_data_df = filter_data_df[filter_data_df["datapoint"] == "performance_fee"]
|
||||
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
|
||||
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
|
||||
|
||||
|
|
@ -770,9 +770,11 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
|
|||
pre_list = data_df["check"].tolist()
|
||||
# convert pre_list member to be integer
|
||||
pre_list = [int(pre) for pre in pre_list]
|
||||
zero_pre_list = [int(pre) for pre in pre_list if int(pre) == 0]
|
||||
gt_list += zero_pre_list
|
||||
pre_list += [1] * len(zero_pre_list)
|
||||
|
||||
for index, row in data_df.iterrows():
|
||||
if row["check"] == 0 and len(row["investment_id"].strip()) > 0:
|
||||
pre_list.append(1)
|
||||
gt_list.append(0)
|
||||
# calculate metrics
|
||||
accuracy = accuracy_score(gt_list, pre_list)
|
||||
precision = precision_score(gt_list, pre_list)
|
||||
|
|
@ -790,8 +792,45 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
|
|||
}
|
||||
return metrics
|
||||
|
||||
def replace_rerun_data(new_data_file: str, original_data_file: str):
|
||||
data_in_doc_mapping_sheet = "data_in_doc_mapping"
|
||||
total_mapping_data_sheet = "total_mapping_data"
|
||||
extract_data_sheet = "extract_data"
|
||||
|
||||
new_data_in_doc_mapping = pd.read_excel(new_data_file, sheet_name=data_in_doc_mapping_sheet)
|
||||
new_total_mapping_data = pd.read_excel(new_data_file, sheet_name=total_mapping_data_sheet)
|
||||
new_extract_data = pd.read_excel(new_data_file, sheet_name=extract_data_sheet)
|
||||
|
||||
document_list = new_data_in_doc_mapping["doc_id"].unique().tolist()
|
||||
|
||||
original_data_in_doc_mapping = pd.read_excel(original_data_file, sheet_name=data_in_doc_mapping_sheet)
|
||||
original_total_mapping_data = pd.read_excel(original_data_file, sheet_name=total_mapping_data_sheet)
|
||||
original_extract_data = pd.read_excel(original_data_file, sheet_name=extract_data_sheet)
|
||||
|
||||
# remove data in original data by document_list
|
||||
original_data_in_doc_mapping = original_data_in_doc_mapping[~original_data_in_doc_mapping["doc_id"].isin(document_list)]
|
||||
original_total_mapping_data = original_total_mapping_data[~original_total_mapping_data["doc_id"].isin(document_list)]
|
||||
original_extract_data = original_extract_data[~original_extract_data["doc_id"].isin(document_list)]
|
||||
|
||||
# merge new data to original data
|
||||
new_data_in_doc_mapping = pd.concat([original_data_in_doc_mapping, new_data_in_doc_mapping])
|
||||
new_data_in_doc_mapping.reset_index(drop=True, inplace=True)
|
||||
new_total_mapping_data = pd.concat([original_total_mapping_data, new_total_mapping_data])
|
||||
new_total_mapping_data.reset_index(drop=True, inplace=True)
|
||||
new_extract_data = pd.concat([original_extract_data, new_extract_data])
|
||||
new_extract_data.reset_index(drop=True, inplace=True)
|
||||
|
||||
with pd.ExcelWriter(original_data_file) as writer:
|
||||
new_data_in_doc_mapping.to_excel(writer, index=False, sheet_name=data_in_doc_mapping_sheet)
|
||||
new_total_mapping_data.to_excel(writer, index=False, sheet_name=total_mapping_data_sheet)
|
||||
new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_calculate_metrics()
|
||||
# new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
|
||||
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
||||
# replace_rerun_data(new_data_file, original_data_file)
|
||||
test_calculate_metrics()
|
||||
# test_replace_abbrevation()
|
||||
# test_translate_pdf()
|
||||
pdf_folder = r"/data/emea_ar/pdf/"
|
||||
|
|
@ -1203,32 +1242,32 @@ if __name__ == "__main__":
|
|||
"501380497",
|
||||
"514636959",
|
||||
"508981020"]
|
||||
special_doc_id_list = ["514636953"]
|
||||
# special_doc_id_list = ["514636952"]
|
||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||
re_run_extract_data = False
|
||||
re_run_mapping_data = True
|
||||
force_save_total_data = False
|
||||
re_run_mapping_data = False
|
||||
force_save_total_data = True
|
||||
calculate_metrics = False
|
||||
|
||||
extract_ways = ["text"]
|
||||
# pdf_folder = r"/data/emea_ar/small_pdf/"
|
||||
pdf_folder = r"/data/emea_ar/pdf/"
|
||||
for extract_way in extract_ways:
|
||||
batch_start_job(
|
||||
pdf_folder,
|
||||
page_filter_ground_truth_file,
|
||||
output_extract_data_child_folder,
|
||||
output_mapping_child_folder,
|
||||
output_extract_data_total_folder,
|
||||
output_mapping_total_folder,
|
||||
extract_way,
|
||||
special_doc_id_list,
|
||||
re_run_extract_data,
|
||||
re_run_mapping_data,
|
||||
force_save_total_data=force_save_total_data,
|
||||
calculate_metrics=calculate_metrics,
|
||||
)
|
||||
# for extract_way in extract_ways:
|
||||
# batch_start_job(
|
||||
# pdf_folder,
|
||||
# page_filter_ground_truth_file,
|
||||
# output_extract_data_child_folder,
|
||||
# output_mapping_child_folder,
|
||||
# output_extract_data_total_folder,
|
||||
# output_mapping_total_folder,
|
||||
# extract_way,
|
||||
# special_doc_id_list,
|
||||
# re_run_extract_data,
|
||||
# re_run_mapping_data,
|
||||
# force_save_total_data=force_save_total_data,
|
||||
# calculate_metrics=calculate_metrics,
|
||||
# )
|
||||
|
||||
# test_data_extraction_metrics()
|
||||
# test_mapping_raw_name()
|
||||
|
|
|
|||
|
|
@ -464,6 +464,7 @@ def get_share_feature_from_text(text: str):
|
|||
count += 1
|
||||
return None
|
||||
|
||||
|
||||
def get_currency_from_text(text: str):
|
||||
if text is None or len(text.strip()) == 0:
|
||||
return None
|
||||
|
|
@ -479,7 +480,16 @@ def get_currency_from_text(text: str):
|
|||
count += 1
|
||||
if len(currency_list) > 1:
|
||||
# remove the first currency from currency list
|
||||
currency_list.pop(0)
|
||||
if currency_list[0] in ['USD', 'EUR']:
|
||||
currency_list.pop(0)
|
||||
else:
|
||||
remove_currency = None
|
||||
for currency in currency_list:
|
||||
if currency in ['USD', 'EUR']:
|
||||
remove_currency = currency
|
||||
break
|
||||
if remove_currency is not None:
|
||||
currency_list.remove(remove_currency)
|
||||
return currency_list[0]
|
||||
elif len(currency_list) == 1:
|
||||
return currency_list[0]
|
||||
|
|
@ -563,19 +573,26 @@ def update_for_currency(text: str, share_name: str, compare_list: list):
|
|||
else:
|
||||
# return text, share_name, compare_list
|
||||
pass
|
||||
default_currency = 'USD'
|
||||
if with_currency:
|
||||
share_name_split = share_name.split()
|
||||
share_name_currency_list = []
|
||||
for split in share_name_split:
|
||||
if split.upper() in total_currency_list and split.upper() not in share_name_currency_list:
|
||||
share_name_currency_list.append(split)
|
||||
if len(share_name_currency_list) > 1 and 'USD' in share_name_currency_list:
|
||||
new_share_name = ' '.join([split for split in share_name_split if split.upper() != 'USD'])
|
||||
share_name_currency = get_currency_from_text(share_name)
|
||||
if share_name_currency is not None and share_name_currency in total_currency_list:
|
||||
for split in share_name_split:
|
||||
if split in total_currency_list and split != share_name_currency:
|
||||
default_currency = split
|
||||
break
|
||||
new_share_name = ' '.join([split for split in share_name_split
|
||||
if split not in total_currency_list
|
||||
or (split == share_name_currency)])
|
||||
if share_name in text:
|
||||
text = text.replace(share_name, new_share_name)
|
||||
else:
|
||||
text = ' '.join([split for split in text.split() if split.upper() != 'USD'])
|
||||
text = ' '.join([split for split in text.split()
|
||||
if split not in total_currency_list
|
||||
or (split == share_name_currency)])
|
||||
share_name = new_share_name
|
||||
|
||||
for c_i in range(len(compare_list)):
|
||||
compare = compare_list[c_i]
|
||||
compare_share_part = get_share_part_list([compare])[0]
|
||||
|
|
@ -584,8 +601,8 @@ def update_for_currency(text: str, share_name: str, compare_list: list):
|
|||
for split in compare_share_part_split:
|
||||
if split.upper() in total_currency_list and split.upper() not in compare_share_part_currency_list:
|
||||
compare_share_part_currency_list.append(split)
|
||||
if len(compare_share_part_currency_list) > 1 and 'USD' in compare_share_part_currency_list:
|
||||
compare_share_part_split = [split for split in compare_share_part_split if split.upper() != 'USD']
|
||||
if len(compare_share_part_currency_list) > 1 and default_currency in compare_share_part_currency_list:
|
||||
compare_share_part_split = [split for split in compare_share_part_split if split.upper() != default_currency]
|
||||
new_compare_share_part = ' '.join(compare_share_part_split)
|
||||
compare_list[c_i] = compare.replace(compare_share_part, new_compare_share_part)
|
||||
return text, share_name, compare_list
|
||||
|
|
@ -672,7 +689,7 @@ def split_words_without_space(text: str):
|
|||
|
||||
|
||||
def remove_special_characters(text):
|
||||
text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
|
||||
text = re.sub(r'\W', ' ', text)
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
|
|
|||
Loading…
Reference in New Issue