diff --git a/.gitignore b/.gitignore index 34c5c08..972b63d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ /utils/__pycache__ /__pycache__/*.pyc /core/__pycache__/*.pyc +/test_calc_metrics.py +/test_metrics diff --git a/main.py b/main.py index 2468ee2..e6ca41d 100644 --- a/main.py +++ b/main.py @@ -714,7 +714,84 @@ def test_replace_abbrevation(): logger.info(f"Original text: {text}, replaced text: {result}") +def test_calculate_metrics(): + from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score + data_file = r"/data/emea_ar/ground_truth/data_extraction/verify/mapping_data_info_30_documents_all_4_datapoints_20241106_verify_mapping.xlsx" + mapping_file = r"/data/emea_ar/basic_information/English\sample_doc/emea_doc_with_all_4_dp/doc_ar_data_with_all_4_dp.xlsx" + + data_df = pd.read_excel(data_file, sheet_name="data_in_doc_mapping") + data_df = data_df[data_df["check"].isin([0, 1])] + data_df.fillna("", inplace=True) + data_df.reset_index(drop=True, inplace=True) + + mapping_df = pd.read_excel(mapping_file, sheet_name="doc_ar_data_in_db") + mapping_fund_id = mapping_df["FundId"].unique().tolist() + mapping_share_id = mapping_df["FundClassId"].unique().tolist() + mapping_id_list = mapping_fund_id + mapping_share_id + # filter data_df whether investment_id in mapping_id_list + filter_data_df = data_df[(data_df["investment_id"].isin(mapping_id_list)) | (data_df["investment_id"] == "")] + + # Investment mapping data + mapping_metrics = get_sub_metrics(filter_data_df, "investment_mapping") + logger.info(f"Investment mapping metrics: {mapping_metrics}") + + # tor data + tor_data_df = data_df[data_df["datapoint"] == "tor"] + tor_metrics = get_sub_metrics(tor_data_df, "tor") + logger.info(f"TOR metrics: {tor_metrics}") + + # ter data + ter_data_df = data_df[data_df["datapoint"] == "ter"] + ter_metrics = get_sub_metrics(ter_data_df, "ter") + logger.info(f"TER metrics: {ter_metrics}") + + # ogc data + ogc_data_df = data_df[data_df["datapoint"] == "ogc"] + ogc_metrics = get_sub_metrics(ogc_data_df, "ogc") + logger.info(f"OGC metrics: {ogc_metrics}") + + # performance_fee data + performance_fee_data_df = data_df[data_df["datapoint"] == "performance_fee"] + performance_fee_metrics = get_sub_metrics(performance_fee_data_df, "performance_fee") + logger.info(f"Performance fee metrics: {performance_fee_metrics}") + + metrics_df = pd.DataFrame([mapping_metrics, tor_metrics, ter_metrics, ogc_metrics, performance_fee_metrics]) + metrics_df.reset_index(drop=True, inplace=True) + output_folder = r"/data/emea_ar/ground_truth/data_extraction/verify/" + output_metrics_file = os.path.join(output_folder, + r"mapping_data_info_30_documents_all_4_datapoints_roughly_metrics.xlsx") + with pd.ExcelWriter(output_metrics_file) as writer: + metrics_df.to_excel(writer, index=False, sheet_name="metrics") + + +def get_sub_metrics(data_df: pd.DataFrame, data_point: str) -> dict: + from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score + gt_list = [1] * len(data_df) + pre_list = data_df["check"].tolist() + # convert pre_list member to be integer + pre_list = [int(pre) for pre in pre_list] + zero_pre_list = [int(pre) for pre in pre_list if int(pre) == 0] + gt_list += zero_pre_list + pre_list += [1] * len(zero_pre_list) + # calculate metrics + accuracy = accuracy_score(gt_list, pre_list) + precision = precision_score(gt_list, pre_list) + recall = recall_score(gt_list, pre_list) + f1 = f1_score(gt_list, pre_list) + support = len(data_df) + + metrics = { + "DataPoint": data_point, + "F1": f1, + "Precision": precision, + "Recall": recall, + "Accuracy": accuracy, + "Support": support + } + return metrics + if __name__ == "__main__": + test_calculate_metrics() # test_replace_abbrevation() # test_translate_pdf() pdf_folder = r"/data/emea_ar/pdf/" @@ -1110,32 +1187,32 @@ if __name__ == "__main__": "546046730", "546919329" ] - special_doc_id_list = ["506326520"] + # special_doc_id_list = ["507928179"] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_extract_data = False - re_run_mapping_data = True + re_run_mapping_data = False force_save_total_data = False calculate_metrics = False extract_ways = ["text"] # pdf_folder = r"/data/emea_ar/small_pdf/" pdf_folder = r"/data/emea_ar/pdf/" - for extract_way in extract_ways: - batch_start_job( - pdf_folder, - page_filter_ground_truth_file, - output_extract_data_child_folder, - output_mapping_child_folder, - output_extract_data_total_folder, - output_mapping_total_folder, - extract_way, - special_doc_id_list, - re_run_extract_data, - re_run_mapping_data, - force_save_total_data=force_save_total_data, - calculate_metrics=calculate_metrics, - ) + # for extract_way in extract_ways: + # batch_start_job( + # pdf_folder, + # page_filter_ground_truth_file, + # output_extract_data_child_folder, + # output_mapping_child_folder, + # output_extract_data_total_folder, + # output_mapping_total_folder, + # extract_way, + # special_doc_id_list, + # re_run_extract_data, + # re_run_mapping_data, + # force_save_total_data=force_save_total_data, + # calculate_metrics=calculate_metrics, + # ) # test_data_extraction_metrics() # test_mapping_raw_name()