diff --git a/core/data_extraction.py b/core/data_extraction.py index fa230a8..6817d6b 100644 --- a/core/data_extraction.py +++ b/core/data_extraction.py @@ -6,9 +6,9 @@ import fitz import pandas as pd from utils.gpt_utils import chat from utils.pdf_util import PDFUtil -from utils.sql_query_util import query_document_fund_mapping +from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider from utils.logger import logger -from utils.biz_utils import add_slash_to_text_as_regex, clean_text +from utils.biz_utils import add_slash_to_text_as_regex, clean_text, get_most_similar_name class DataExtraction: @@ -44,6 +44,13 @@ class DataExtraction: self.document_mapping_info_df = query_document_fund_mapping(doc_id) else: self.document_mapping_info_df = document_mapping_info_df + self.provider_mapping_df = self.get_provider_mapping() + if len(self.provider_mapping_df) == 0: + self.provider_fund_name_list = [] + else: + self.provider_fund_name_list = ( + self.provider_mapping_df["FundName"].unique().tolist() + ) self.datapoint_page_info = datapoint_page_info self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info() self.datapoints = datapoints @@ -53,6 +60,20 @@ class DataExtraction: self.extract_way = extract_way self.output_image_folder = output_image_folder + def get_provider_mapping(self): + if len(self.document_mapping_info_df) == 0: + return pd.DataFrame() + provider_id_list = ( + self.document_mapping_info_df["ProviderId"].unique().tolist() + ) + provider_mapping_list = [] + for provider_id in provider_id_list: + provider_mapping_list.append(query_investment_by_provider(provider_id)) + provider_mapping_df = pd.concat(provider_mapping_list) + provider_mapping_df = provider_mapping_df.drop_duplicates() + provider_mapping_df.reset_index(drop=True, inplace=True) + return provider_mapping_df + def get_pdf_image_base64(self, page_index: int) -> dict: pdf_util = PDFUtil(self.pdf_file) return pdf_util.extract_image_from_page(page_index=page_index, @@ -403,6 +424,9 @@ class DataExtraction: split_context = re.split(r"\n", page_text) split_context = [text.strip() for text in split_context if len(text.strip()) > 0] + if len(split_context) < 10: + return {"data": []} + split_context_len = len(split_context) top_10_context = split_context[:10] rest_context = split_context[10:] @@ -411,12 +435,37 @@ class DataExtraction: # the member of half_len should not start with number # reverse iterate the list by half_len half_len_list = [i for i in range(half_len)] - for index in reversed(half_len_list): - first_letter = rest_context[index].strip()[0] - if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]: - half_len = index - break + fund_name_line = "" + half_line = rest_context[half_len].strip() + max_similarity_fund_name, max_similarity = get_most_similar_name( + half_line, self.provider_fund_name_list + ) + if max_similarity < 0.2: + # get the fund name line text from the first half + for index in reversed(half_len_list): + line_text = rest_context[index].strip() + if len(line_text) == 0: + continue + line_text_split = line_text.split() + if len(line_text_split) < 3: + continue + first_word = line_text_split[0] + if first_word.lower() == "class": + continue + + max_similarity_fund_name, max_similarity = get_most_similar_name( + line_text, self.provider_fund_name_list + ) + if max_similarity >= 0.2: + fund_name_line = line_text + break + else: + fund_name_line = half_line + half_len += 1 + if fund_name_line == "": + return {"data": []} + logger.info(f"Split first part from 0 to {half_len}") split_first_part = "\n".join(split_context[:half_len]) first_part = '\n'.join(split_first_part) @@ -435,7 +484,7 @@ class DataExtraction: logger.info(f"Split second part from {half_len} to {split_context_len}") split_second_part = "\n".join(split_context[half_len:]) - second_part = header + '\n' + split_second_part + second_part = header + "\n" + fund_name_line + "\n" + split_second_part second_instructions = self.get_instructions_by_datapoints( second_part, page_datapoints, need_exclude, exclude_data ) @@ -456,6 +505,30 @@ class DataExtraction: for first_data in first_part_data_list: if first_data in second_part_data_list: second_part_data_list.remove(first_data) + else: + # if the first part data is with same fund name and share name, + # remove the second part data + first_data_dp = [key for key in list(first_data.keys()) + if key not in ["fund name", "share name"]] + # order the data points + first_data_dp.sort() + first_fund_name = first_data.get("fund name", "") + first_share_name = first_data.get("share name", "") + if len(first_fund_name) > 0 and len(first_share_name) > 0: + remove_second_list = [] + for second_data in second_part_data_list: + second_fund_name = second_data.get("fund name", "") + second_share_name = second_data.get("share name", "") + if first_fund_name == second_fund_name and \ + first_share_name == second_share_name: + second_data_dp = [key for key in list(second_data.keys()) + if key not in ["fund name", "share name"]] + second_data_dp.sort() + if first_data_dp == second_data_dp: + remove_second_list.append(second_data) + for remove_second in remove_second_list: + if remove_second in second_part_data_list: + second_part_data_list.remove(remove_second) data_list = first_part_data_list + second_part_data_list extract_data = {"data": data_list} @@ -486,6 +559,22 @@ class DataExtraction: for remove_data in remove_list: if remove_data in data_list: data_list.remove(remove_data) + # check performance_fee + for data in data_list: + performance_fee = data.get("performance_fee", None) + if performance_fee is not None: + performance_fee = float(performance_fee) + if performance_fee > 3 and performance_fee % 2.5 == 0: + data.pop("performance_fee") + remove_list = [] + for data in data_list: + keys = [key for key in list(data.keys()) + if key not in ["fund name", "share name"]] + if len(keys) == 0: + remove_list.append(data) + for remove_data in remove_list: + if remove_data in data_list: + data_list.remove(remove_data) # update "fund name" to be "fund_name" # update "share name" to be "share_name" new_data_list = [] diff --git a/core/metrics.py b/core/metrics.py index 3fac16d..1481f74 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -3,7 +3,7 @@ import pandas as pd import time import json from sklearn.metrics import precision_score, recall_score, f1_score -from utils.biz_utils import get_unique_words_text +from utils.biz_utils import get_unique_words_text, get_beginning_common_words from utils.logger import logger @@ -293,24 +293,18 @@ class Metrics: prediction_data: pd.DataFrame, data_point: str, ): + dp_prediction = prediction_data[prediction_data["datapoint"] == data_point] + dp_prediction = self.modify_data(dp_prediction) + pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist() + pred_simple_name_unique_words_list = dp_prediction["simple_name_unique_words"].unique().tolist() + dp_ground_truth = ground_truth_data[ ground_truth_data["datapoint"] == data_point ] - dp_prediction = prediction_data[prediction_data["datapoint"] == data_point] - - # add new column to store unique words for dp_ground_truth - dp_ground_truth["unique_words"] = dp_ground_truth["raw_name"].apply( - get_unique_words_text - ) - ground_truth_unique_words_list = dp_ground_truth["unique_words"].unique().tolist() - ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist() - # add new column to store unique words for dp_prediction - dp_prediction["unique_words"] = dp_prediction["raw_name"].apply( - get_unique_words_text - ) - pred_unique_words_list = dp_prediction["unique_words"].unique().tolist() - pred_raw_names = dp_prediction["raw_name"].unique().tolist() - + dp_ground_truth = self.modify_data(dp_ground_truth) + gt_simple_raw_names = dp_ground_truth["simple_raw_name"].unique().tolist() + gt_simple_name_unique_words_list = dp_ground_truth["simple_name_unique_words"].unique().tolist() + true_data = [] pred_data = [] @@ -320,28 +314,53 @@ class Metrics: true_data.append(1) pred_data.append(1) return true_data, pred_data, missing_error_data - + for index, prediction in dp_prediction.iterrows(): pred_page_index = prediction["page_index"] pred_raw_name = prediction["raw_name"] - pred_unique_words = prediction["unique_words"] + pred_simple_raw_name = prediction["simple_raw_name"] + pred_simple_name_unique_words = prediction["simple_name_unique_words"] pred_data_point_value = prediction["value"] pred_investment_type = prediction["investment_type"] - find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names - if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name] - if pred_unique_words in ground_truth_unique_words_list or len(find_raw_name_in_gt) > 0: + find_raw_name_in_gt = [gt_raw_name for gt_raw_name in gt_simple_raw_names + if (gt_raw_name in pred_simple_raw_name or pred_simple_raw_name in gt_raw_name) + and gt_raw_name.endswith(pred_raw_name.split()[-1])] + if pred_simple_name_unique_words in gt_simple_name_unique_words_list or \ + len(find_raw_name_in_gt) > 0: # get the ground truth data with the same unique words - if pred_unique_words in ground_truth_unique_words_list: - gt_data = dp_ground_truth[ - dp_ground_truth["unique_words"] == pred_unique_words - ].iloc[0] + if pred_simple_name_unique_words in gt_simple_name_unique_words_list: + gt_data_df = dp_ground_truth[ + dp_ground_truth["simple_name_unique_words"] == pred_simple_name_unique_words + ] + if len(gt_data_df) > 1: + if len(gt_data_df[gt_data_df["page_index"] == pred_page_index]) == 0: + gt_data = gt_data_df.iloc[0] + else: + gt_data = gt_data_df[gt_data_df["page_index"] == pred_page_index].iloc[0] + elif len(gt_data_df) == 1: + gt_data = gt_data_df.iloc[0] + else: + gt_data = None else: - gt_data = dp_ground_truth[ - dp_ground_truth["raw_name"] == find_raw_name_in_gt[0] - ].iloc[0] - gt_data_point_value = gt_data["value"] - if pred_data_point_value == gt_data_point_value: + gt_data_df = dp_ground_truth[ + dp_ground_truth["simple_raw_name"] == find_raw_name_in_gt[0] + ] + if len(gt_data_df) > 1: + if len(gt_data_df[gt_data_df["page_index"] == pred_page_index]) == 0: + gt_data = gt_data_df.iloc[0] + else: + gt_data = gt_data_df[gt_data_df["page_index"] == pred_page_index].iloc[0] + elif len(gt_data_df) == 1: + gt_data = gt_data_df.iloc[0] + else: + gt_data = None + if gt_data is None: + gt_data_point_value = None + else: + gt_data_point_value = gt_data["value"] + if gt_data_point_value is not None and \ + pred_data_point_value == gt_data_point_value: true_data.append(1) pred_data.append(1) else: @@ -376,14 +395,16 @@ class Metrics: for index, ground_truth in dp_ground_truth.iterrows(): gt_page_index = ground_truth["page_index"] gt_raw_name = ground_truth["raw_name"] - gt_unique_words = ground_truth["unique_words"] + gt_simple_raw_name = ground_truth["simple_raw_name"] + gt_simple_name_unique_words = ground_truth["simple_name_unique_words"] gt_data_point_value = ground_truth["value"] gt_investment_type = ground_truth["investment_type"] - find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names - if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name] + find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_simple_raw_names + if (gt_simple_raw_name in pred_raw_name or pred_raw_name in gt_simple_raw_name) + and pred_raw_name.endswith(gt_raw_name.split()[-1])] - if gt_unique_words not in pred_unique_words_list and \ + if gt_simple_name_unique_words not in pred_simple_name_unique_words_list and \ len(find_raw_name_in_pred) == 0: true_data.append(1) pred_data.append(0) @@ -400,6 +421,26 @@ class Metrics: missing_error_data.append(error_data) return true_data, pred_data, missing_error_data + + def modify_data(self, data: pd.DataFrame): + data["simple_raw_name"] = "" + data["simple_name_unique_words"] = "" + page_index_list = data["page_index"].unique().tolist() + for pagex_index in page_index_list: + page_data = data[data["page_index"] == pagex_index] + raw_name_list = page_data["raw_name"].unique().tolist() + beginning_common_words = get_beginning_common_words(raw_name_list) + for raw_name in raw_name_list: + if beginning_common_words is not None and len(beginning_common_words) > 0: + simple_raw_name = raw_name.replace(beginning_common_words, "").strip() + else: + simple_raw_name = raw_name + # set simple_raw_name which with the same page and same raw_name + data.loc[(data["page_index"] == pagex_index) & (data["raw_name"] == raw_name), + "simple_raw_name"] = simple_raw_name + data.loc[(data["page_index"] == pagex_index) & (data["raw_name"] == raw_name), + "simple_name_unique_words"] = get_unique_words_text(simple_raw_name) + return data def get_specific_metrics(self, true_data: list, pred_data: list): precision = precision_score(true_data, pred_data) diff --git a/main.py b/main.py index a566652..2285f79 100644 --- a/main.py +++ b/main.py @@ -523,7 +523,7 @@ def test_auto_generate_instructions(): def test_data_extraction_metrics(): data_type = "data_extraction" prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240919120502.xlsx" - # prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/321733631.xlsx" + # prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/509350496.xlsx" prediction_sheet_name = "mapping_data" ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx" ground_truth_sheet_name = "mapping_data" @@ -577,26 +577,26 @@ if __name__ == "__main__": # extract_way, # re_run_extract_data) - special_doc_id_list = ["476492237"] + special_doc_id_list = [] output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/" output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/" re_run_mapping_data = True force_save_total_data = False - extract_ways = ["text"] - # for extract_way in extract_ways: - # batch_start_job( - # pdf_folder, - # page_filter_ground_truth_file, - # output_extract_data_child_folder, - # output_mapping_child_folder, - # output_extract_data_total_folder, - # output_mapping_total_folder, - # extract_way, - # special_doc_id_list, - # re_run_extract_data, - # re_run_mapping_data, - # force_save_total_data=force_save_total_data, - # ) + extract_ways = ["text", "image"] + for extract_way in extract_ways: + batch_start_job( + pdf_folder, + page_filter_ground_truth_file, + output_extract_data_child_folder, + output_mapping_child_folder, + output_extract_data_total_folder, + output_mapping_total_folder, + extract_way, + special_doc_id_list, + re_run_extract_data, + re_run_mapping_data, + force_save_total_data=force_save_total_data, + ) - test_data_extraction_metrics() + # test_data_extraction_metrics() diff --git a/utils/biz_utils.py b/utils/biz_utils.py index 073aefe..571f265 100644 --- a/utils/biz_utils.py +++ b/utils/biz_utils.py @@ -205,6 +205,31 @@ def get_jacard_similarity(text_left, else: return 0 +def get_beginning_common_words(text_list: list): + """ + Get the beginning common words in text_list + """ + if text_list is None or len(text_list) < 2: + return [] + + common_words_list = [] + first_text_split = text_list[0].split() + for w_i, word in enumerate(first_text_split): + all_same = True + for text in text_list[1:]: + text_split = text.split() + if w_i >= len(text_split): + all_same = False + break + if text_split[w_i] != word: + all_same = False + break + if all_same: + common_words_list.append(word) + else: + break + + return ' '.join(common_words_list).strip() def replace_abbrevation(text: str): if text is None or len(text.strip()) == 0: