Optimize investment mapping algorithm.
1. Get proper currency if exist multiple currencies in share name, e.g. CHF EUR 2. Default currency should be based on scenario: USD or EUR. 3. Remove special chars should be based on \W, instead of [^a-zA-Z0-9\s]
This commit is contained in:
parent
5b9f9416de
commit
f1c0290588
93
main.py
93
main.py
|
|
@ -717,7 +717,7 @@ def test_replace_abbrevation():
|
||||||
def test_calculate_metrics():
|
def test_calculate_metrics():
|
||||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
||||||
data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
||||||
mapping_file = r"/data/emea_ar/basic_information/English\sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx"
|
mapping_file = r"/data/emea_ar/basic_information/English/sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx"
|
||||||
|
|
||||||
data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping")
|
data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping")
|
||||||
data_df = data_df[data_df["check"].isin([0, 1])]
|
data_df = data_df[data_df["check"].isin([0, 1])]
|
||||||
|
|
@ -736,22 +736,22 @@ def test_calculate_metrics():
|
||||||
logger.info(f"Investment mapping metrics: {mapping_metrics}")
|
logger.info(f"Investment mapping metrics: {mapping_metrics}")
|
||||||
|
|
||||||
# tor data
|
# tor data
|
||||||
tor_data_df = data_df[data_df["datapoint"] == "tor"]
|
tor_data_df = filter_data_df[filter_data_df["datapoint"] == "tor"]
|
||||||
tor_metrics = get_sub_metrics(tor_data_df, "tor")
|
tor_metrics = get_sub_metrics(tor_data_df, "tor")
|
||||||
logger.info(f"TOR metrics: {tor_metrics}")
|
logger.info(f"TOR metrics: {tor_metrics}")
|
||||||
|
|
||||||
# ter data
|
# ter data
|
||||||
ter_data_df = data_df[data_df["datapoint"] == "ter"]
|
ter_data_df = filter_data_df[filter_data_df["datapoint"] == "ter"]
|
||||||
ter_metrics = get_sub_metrics(ter_data_df, "ter")
|
ter_metrics = get_sub_metrics(ter_data_df, "ter")
|
||||||
logger.info(f"TER metrics: {ter_metrics}")
|
logger.info(f"TER metrics: {ter_metrics}")
|
||||||
|
|
||||||
# ogc data
|
# ogc data
|
||||||
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
|
ogc_data_df = filter_data_df[filter_data_df["datapoint"] == "ogc"]
|
||||||
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
|
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
|
||||||
logger.info(f"OGC metrics: {ogc_metrics}")
|
logger.info(f"OGC metrics: {ogc_metrics}")
|
||||||
|
|
||||||
# performance_fee data
|
# performance_fee data
|
||||||
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
|
performance_fee_data_df = filter_data_df[filter_data_df["datapoint"] == "performance_fee"]
|
||||||
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
|
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
|
||||||
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
|
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
|
||||||
|
|
||||||
|
|
@ -770,9 +770,11 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
|
||||||
pre_list = data_df["check"].tolist()
|
pre_list = data_df["check"].tolist()
|
||||||
# convert pre_list member to be integer
|
# convert pre_list member to be integer
|
||||||
pre_list = [int(pre) for pre in pre_list]
|
pre_list = [int(pre) for pre in pre_list]
|
||||||
zero_pre_list = [int(pre) for pre in pre_list if int(pre) == 0]
|
|
||||||
gt_list += zero_pre_list
|
for index, row in data_df.iterrows():
|
||||||
pre_list += [1] * len(zero_pre_list)
|
if row["check"] == 0 and len(row["investment_id"].strip()) > 0:
|
||||||
|
pre_list.append(1)
|
||||||
|
gt_list.append(0)
|
||||||
# calculate metrics
|
# calculate metrics
|
||||||
accuracy = accuracy_score(gt_list, pre_list)
|
accuracy = accuracy_score(gt_list, pre_list)
|
||||||
precision = precision_score(gt_list, pre_list)
|
precision = precision_score(gt_list, pre_list)
|
||||||
|
|
@ -790,8 +792,45 @@ def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
|
||||||
}
|
}
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
def replace_rerun_data(new_data_file: str, original_data_file: str):
|
||||||
|
data_in_doc_mapping_sheet = "data_in_doc_mapping"
|
||||||
|
total_mapping_data_sheet = "total_mapping_data"
|
||||||
|
extract_data_sheet = "extract_data"
|
||||||
|
|
||||||
|
new_data_in_doc_mapping = pd.read_excel(new_data_file, sheet_name=data_in_doc_mapping_sheet)
|
||||||
|
new_total_mapping_data = pd.read_excel(new_data_file, sheet_name=total_mapping_data_sheet)
|
||||||
|
new_extract_data = pd.read_excel(new_data_file, sheet_name=extract_data_sheet)
|
||||||
|
|
||||||
|
document_list = new_data_in_doc_mapping["doc_id"].unique().tolist()
|
||||||
|
|
||||||
|
original_data_in_doc_mapping = pd.read_excel(original_data_file, sheet_name=data_in_doc_mapping_sheet)
|
||||||
|
original_total_mapping_data = pd.read_excel(original_data_file, sheet_name=total_mapping_data_sheet)
|
||||||
|
original_extract_data = pd.read_excel(original_data_file, sheet_name=extract_data_sheet)
|
||||||
|
|
||||||
|
# remove data in original data by document_list
|
||||||
|
original_data_in_doc_mapping = original_data_in_doc_mapping[~original_data_in_doc_mapping["doc_id"].isin(document_list)]
|
||||||
|
original_total_mapping_data = original_total_mapping_data[~original_total_mapping_data["doc_id"].isin(document_list)]
|
||||||
|
original_extract_data = original_extract_data[~original_extract_data["doc_id"].isin(document_list)]
|
||||||
|
|
||||||
|
# merge new data to original data
|
||||||
|
new_data_in_doc_mapping = pd.concat([original_data_in_doc_mapping, new_data_in_doc_mapping])
|
||||||
|
new_data_in_doc_mapping.reset_index(drop=True, inplace=True)
|
||||||
|
new_total_mapping_data = pd.concat([original_total_mapping_data, new_total_mapping_data])
|
||||||
|
new_total_mapping_data.reset_index(drop=True, inplace=True)
|
||||||
|
new_extract_data = pd.concat([original_extract_data, new_extract_data])
|
||||||
|
new_extract_data.reset_index(drop=True, inplace=True)
|
||||||
|
|
||||||
|
with pd.ExcelWriter(original_data_file) as writer:
|
||||||
|
new_data_in_doc_mapping.to_excel(writer, index=False, sheet_name=data_in_doc_mapping_sheet)
|
||||||
|
new_total_mapping_data.to_excel(writer, index=False, sheet_name=total_mapping_data_sheet)
|
||||||
|
new_extract_data.to_excel(writer, index=False, sheet_name=extract_data_sheet)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test_calculate_metrics()
|
# new_data_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_15_documents_by_text_20241121154243.xlsx"
|
||||||
|
# original_data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
|
||||||
|
# replace_rerun_data(new_data_file, original_data_file)
|
||||||
|
test_calculate_metrics()
|
||||||
# test_replace_abbrevation()
|
# test_replace_abbrevation()
|
||||||
# test_translate_pdf()
|
# test_translate_pdf()
|
||||||
pdf_folder = r"/data/emea_ar/pdf/"
|
pdf_folder = r"/data/emea_ar/pdf/"
|
||||||
|
|
@ -1203,32 +1242,32 @@ if __name__ == "__main__":
|
||||||
"501380497",
|
"501380497",
|
||||||
"514636959",
|
"514636959",
|
||||||
"508981020"]
|
"508981020"]
|
||||||
special_doc_id_list = ["514636953"]
|
# special_doc_id_list = ["514636952"]
|
||||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||||
re_run_extract_data = False
|
re_run_extract_data = False
|
||||||
re_run_mapping_data = True
|
re_run_mapping_data = False
|
||||||
force_save_total_data = False
|
force_save_total_data = True
|
||||||
calculate_metrics = False
|
calculate_metrics = False
|
||||||
|
|
||||||
extract_ways = ["text"]
|
extract_ways = ["text"]
|
||||||
# pdf_folder = r"/data/emea_ar/small_pdf/"
|
# pdf_folder = r"/data/emea_ar/small_pdf/"
|
||||||
pdf_folder = r"/data/emea_ar/pdf/"
|
pdf_folder = r"/data/emea_ar/pdf/"
|
||||||
for extract_way in extract_ways:
|
# for extract_way in extract_ways:
|
||||||
batch_start_job(
|
# batch_start_job(
|
||||||
pdf_folder,
|
# pdf_folder,
|
||||||
page_filter_ground_truth_file,
|
# page_filter_ground_truth_file,
|
||||||
output_extract_data_child_folder,
|
# output_extract_data_child_folder,
|
||||||
output_mapping_child_folder,
|
# output_mapping_child_folder,
|
||||||
output_extract_data_total_folder,
|
# output_extract_data_total_folder,
|
||||||
output_mapping_total_folder,
|
# output_mapping_total_folder,
|
||||||
extract_way,
|
# extract_way,
|
||||||
special_doc_id_list,
|
# special_doc_id_list,
|
||||||
re_run_extract_data,
|
# re_run_extract_data,
|
||||||
re_run_mapping_data,
|
# re_run_mapping_data,
|
||||||
force_save_total_data=force_save_total_data,
|
# force_save_total_data=force_save_total_data,
|
||||||
calculate_metrics=calculate_metrics,
|
# calculate_metrics=calculate_metrics,
|
||||||
)
|
# )
|
||||||
|
|
||||||
# test_data_extraction_metrics()
|
# test_data_extraction_metrics()
|
||||||
# test_mapping_raw_name()
|
# test_mapping_raw_name()
|
||||||
|
|
|
||||||
|
|
@ -464,6 +464,7 @@ def get_share_feature_from_text(text: str):
|
||||||
count += 1
|
count += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_currency_from_text(text: str):
|
def get_currency_from_text(text: str):
|
||||||
if text is None or len(text.strip()) == 0:
|
if text is None or len(text.strip()) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
@ -479,7 +480,16 @@ def get_currency_from_text(text: str):
|
||||||
count += 1
|
count += 1
|
||||||
if len(currency_list) > 1:
|
if len(currency_list) > 1:
|
||||||
# remove the first currency from currency list
|
# remove the first currency from currency list
|
||||||
currency_list.pop(0)
|
if currency_list[0] in ['USD', 'EUR']:
|
||||||
|
currency_list.pop(0)
|
||||||
|
else:
|
||||||
|
remove_currency = None
|
||||||
|
for currency in currency_list:
|
||||||
|
if currency in ['USD', 'EUR']:
|
||||||
|
remove_currency = currency
|
||||||
|
break
|
||||||
|
if remove_currency is not None:
|
||||||
|
currency_list.remove(remove_currency)
|
||||||
return currency_list[0]
|
return currency_list[0]
|
||||||
elif len(currency_list) == 1:
|
elif len(currency_list) == 1:
|
||||||
return currency_list[0]
|
return currency_list[0]
|
||||||
|
|
@ -563,19 +573,26 @@ def update_for_currency(text: str, share_name: str, compare_list: list):
|
||||||
else:
|
else:
|
||||||
# return text, share_name, compare_list
|
# return text, share_name, compare_list
|
||||||
pass
|
pass
|
||||||
|
default_currency = 'USD'
|
||||||
if with_currency:
|
if with_currency:
|
||||||
share_name_split = share_name.split()
|
share_name_split = share_name.split()
|
||||||
share_name_currency_list = []
|
share_name_currency = get_currency_from_text(share_name)
|
||||||
for split in share_name_split:
|
if share_name_currency is not None and share_name_currency in total_currency_list:
|
||||||
if split.upper() in total_currency_list and split.upper() not in share_name_currency_list:
|
for split in share_name_split:
|
||||||
share_name_currency_list.append(split)
|
if split in total_currency_list and split != share_name_currency:
|
||||||
if len(share_name_currency_list) > 1 and 'USD' in share_name_currency_list:
|
default_currency = split
|
||||||
new_share_name = ' '.join([split for split in share_name_split if split.upper() != 'USD'])
|
break
|
||||||
|
new_share_name = ' '.join([split for split in share_name_split
|
||||||
|
if split not in total_currency_list
|
||||||
|
or (split == share_name_currency)])
|
||||||
if share_name in text:
|
if share_name in text:
|
||||||
text = text.replace(share_name, new_share_name)
|
text = text.replace(share_name, new_share_name)
|
||||||
else:
|
else:
|
||||||
text = ' '.join([split for split in text.split() if split.upper() != 'USD'])
|
text = ' '.join([split for split in text.split()
|
||||||
|
if split not in total_currency_list
|
||||||
|
or (split == share_name_currency)])
|
||||||
share_name = new_share_name
|
share_name = new_share_name
|
||||||
|
|
||||||
for c_i in range(len(compare_list)):
|
for c_i in range(len(compare_list)):
|
||||||
compare = compare_list[c_i]
|
compare = compare_list[c_i]
|
||||||
compare_share_part = get_share_part_list([compare])[0]
|
compare_share_part = get_share_part_list([compare])[0]
|
||||||
|
|
@ -584,8 +601,8 @@ def update_for_currency(text: str, share_name: str, compare_list: list):
|
||||||
for split in compare_share_part_split:
|
for split in compare_share_part_split:
|
||||||
if split.upper() in total_currency_list and split.upper() not in compare_share_part_currency_list:
|
if split.upper() in total_currency_list and split.upper() not in compare_share_part_currency_list:
|
||||||
compare_share_part_currency_list.append(split)
|
compare_share_part_currency_list.append(split)
|
||||||
if len(compare_share_part_currency_list) > 1 and 'USD' in compare_share_part_currency_list:
|
if len(compare_share_part_currency_list) > 1 and default_currency in compare_share_part_currency_list:
|
||||||
compare_share_part_split = [split for split in compare_share_part_split if split.upper() != 'USD']
|
compare_share_part_split = [split for split in compare_share_part_split if split.upper() != default_currency]
|
||||||
new_compare_share_part = ' '.join(compare_share_part_split)
|
new_compare_share_part = ' '.join(compare_share_part_split)
|
||||||
compare_list[c_i] = compare.replace(compare_share_part, new_compare_share_part)
|
compare_list[c_i] = compare.replace(compare_share_part, new_compare_share_part)
|
||||||
return text, share_name, compare_list
|
return text, share_name, compare_list
|
||||||
|
|
@ -672,7 +689,7 @@ def split_words_without_space(text: str):
|
||||||
|
|
||||||
|
|
||||||
def remove_special_characters(text):
|
def remove_special_characters(text):
|
||||||
text = re.sub(r'[^a-zA-Z0-9\s]', ' ', text)
|
text = re.sub(r'\W', ' ', text)
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r'\s+', ' ', text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
return text
|
return text
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue