a little change

This commit is contained in:
Blade He 2024-11-18 16:13:24 -06:00
parent a42c0b5c2b
commit 8223ca9a5c
2 changed files with 96 additions and 17 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@
/utils/__pycache__ /utils/__pycache__
/__pycache__/*.pyc /__pycache__/*.pyc
/core/__pycache__/*.pyc /core/__pycache__/*.pyc
/test_calc_metrics.py
/test_metrics

111
main.py
View File

@ -714,7 +714,84 @@ def test_replace_abbrevation():
logger.info(f"Original text: {text}, replaced text: {result}") logger.info(f"Original text: {text}, replaced text: {result}")
def test_calculate_metrics():
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx"
mapping_file = r"/data/emea_ar/basic_information/English\sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx"
data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping")
data_df = data_df[data_df["check"].isin([0, 1])]
data_df.fillna("", inplace=True)
data_df.reset_index(drop=True, inplace=True)
mapping_df = pd.read_excel(mapping_file, sheet_name="doc_ar_data_in_db")
mapping_fund_id = mapping_df["FundId"].unique().tolist()
mapping_share_id = mapping_df["FundClassId"].unique().tolist()
mapping_id_list = mapping_fund_id + mapping_share_id
# filter data_df whether investment_id in mapping_id_list
filter_data_df = data_df[(data_df["investment_id"].isin(mapping_id_list)) | (data_df["investment_id"] == "")]
# Investment mapping data
mapping_metrics = get_sub_metrics(filter_data_df, "investment_mapping")
logger.info(f"Investment mapping metrics: {mapping_metrics}")
# tor data
tor_data_df = data_df[data_df["datapoint"] == "tor"]
tor_metrics = get_sub_metrics(tor_data_df, "tor")
logger.info(f"TOR metrics: {tor_metrics}")
# ter data
ter_data_df = data_df[data_df["datapoint"] == "ter"]
ter_metrics = get_sub_metrics(ter_data_df, "ter")
logger.info(f"TER metrics: {ter_metrics}")
# ogc data
ogc_data_df = data_df[data_df["datapoint"] == "ogc"]
ogc_metrics = get_sub_metrics(ogc_data_df, "ogc")
logger.info(f"OGC metrics: {ogc_metrics}")
# performance_fee data
performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"]
performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee")
logger.info(f"Performance fee metrics: {performance_fee_metrics}")
metrics_df = pd.DataFrame([mapping_metrics, tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics])
metrics_df.reset_index(drop=True, inplace=True)
output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/"
output_metrics_file = os.path.join(output_folder,
r"mapping_data_info_30_documents_all_4_datapoints_roughly_metrics.xlsx")
with pd.ExcelWriter(output_metrics_file) as writer:
metrics_df.to_excel(writer, index=False, sheet_name="metrics")
def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
gt_list = [1] * len(data_df)
pre_list = data_df["check"].tolist()
# convert pre_list member to be integer
pre_list = [int(pre) for pre in pre_list]
zero_pre_list = [int(pre) for pre in pre_list if int(pre) == 0]
gt_list += zero_pre_list
pre_list += [1] * len(zero_pre_list)
# calculate metrics
accuracy = accuracy_score(gt_list, pre_list)
precision = precision_score(gt_list, pre_list)
recall = recall_score(gt_list, pre_list)
f1 = f1_score(gt_list, pre_list)
support = len(data_df)
metrics = {
"DataPoint": data_point,
"F1": f1,
"Precision": precision,
"Recall": recall,
"Accuracy": accuracy,
"Support": support
}
return metrics
if __name__ == "__main__": if __name__ == "__main__":
test_calculate_metrics()
# test_replace_abbrevation() # test_replace_abbrevation()
# test_translate_pdf() # test_translate_pdf()
pdf_folder = r"/data/emea_ar/pdf/" pdf_folder = r"/data/emea_ar/pdf/"
@ -1110,32 +1187,32 @@ if __name__ == "__main__":
"546046730", "546046730",
"546919329" "546919329"
] ]
special_doc_id_list = ["506326520"] # special_doc_id_list = ["507928179"]
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_extract_data = False re_run_extract_data = False
re_run_mapping_data = True re_run_mapping_data = False
force_save_total_data = False force_save_total_data = False
calculate_metrics = False calculate_metrics = False
extract_ways = ["text"] extract_ways = ["text"]
# pdf_folder = r"/data/emea_ar/small_pdf/" # pdf_folder = r"/data/emea_ar/small_pdf/"
pdf_folder = r"/data/emea_ar/pdf/" pdf_folder = r"/data/emea_ar/pdf/"
for extract_way in extract_ways: # for extract_way in extract_ways:
batch_start_job( # batch_start_job(
pdf_folder, # pdf_folder,
page_filter_ground_truth_file, # page_filter_ground_truth_file,
output_extract_data_child_folder, # output_extract_data_child_folder,
output_mapping_child_folder, # output_mapping_child_folder,
output_extract_data_total_folder, # output_extract_data_total_folder,
output_mapping_total_folder, # output_mapping_total_folder,
extract_way, # extract_way,
special_doc_id_list, # special_doc_id_list,
re_run_extract_data, # re_run_extract_data,
re_run_mapping_data, # re_run_mapping_data,
force_save_total_data=force_save_total_data, # force_save_total_data=force_save_total_data,
calculate_metrics=calculate_metrics, # calculate_metrics=calculate_metrics,
) # )
# test_data_extraction_metrics() # test_data_extraction_metrics()
# test_mapping_raw_name() # test_mapping_raw_name()