optimize data extract, metrics calculation algorithm

This commit is contained in:
Blade He 2024-09-19 22:45:08 -05:00
parent 48dc8690c3
commit c4985ac75f
4 changed files with 215 additions and 60 deletions

View File

@ -6,9 +6,9 @@ import fitz
import pandas as pd
from utils.gpt_utils import chat
from utils.pdf_util import PDFUtil
from utils.sql_query_util import query_document_fund_mapping
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
from utils.logger import logger
from utils.biz_utils import add_slash_to_text_as_regex, clean_text
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, get_most_similar_name
class DataExtraction:
@ -44,6 +44,13 @@ class DataExtraction:
self.document_mapping_info_df = query_document_fund_mapping(doc_id)
else:
self.document_mapping_info_df = document_mapping_info_df
self.provider_mapping_df = self.get_provider_mapping()
if len(self.provider_mapping_df) == 0:
self.provider_fund_name_list = []
else:
self.provider_fund_name_list = (
self.provider_mapping_df["FundName"].unique().tolist()
)
self.datapoint_page_info = datapoint_page_info
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
self.datapoints = datapoints
@ -53,6 +60,20 @@ class DataExtraction:
self.extract_way = extract_way
self.output_image_folder = output_image_folder
def get_provider_mapping(self):
if len(self.document_mapping_info_df) == 0:
return pd.DataFrame()
provider_id_list = (
self.document_mapping_info_df["ProviderId"].unique().tolist()
)
provider_mapping_list = []
for provider_id in provider_id_list:
provider_mapping_list.append(query_investment_by_provider(provider_id))
provider_mapping_df = pd.concat(provider_mapping_list)
provider_mapping_df = provider_mapping_df.drop_duplicates()
provider_mapping_df.reset_index(drop=True, inplace=True)
return provider_mapping_df
def get_pdf_image_base64(self, page_index: int) -> dict:
pdf_util = PDFUtil(self.pdf_file)
return pdf_util.extract_image_from_page(page_index=page_index,
@ -403,6 +424,9 @@ class DataExtraction:
split_context = re.split(r"\n", page_text)
split_context = [text.strip() for text in split_context
if len(text.strip()) > 0]
if len(split_context) < 10:
return {"data": []}
split_context_len = len(split_context)
top_10_context = split_context[:10]
rest_context = split_context[10:]
@ -411,12 +435,37 @@ class DataExtraction:
# the member of half_len should not start with number
# reverse iterate the list by half_len
half_len_list = [i for i in range(half_len)]
for index in reversed(half_len_list):
first_letter = rest_context[index].strip()[0]
if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]:
half_len = index
break
fund_name_line = ""
half_line = rest_context[half_len].strip()
max_similarity_fund_name, max_similarity = get_most_similar_name(
half_line, self.provider_fund_name_list
)
if max_similarity < 0.2:
# get the fund name line text from the first half
for index in reversed(half_len_list):
line_text = rest_context[index].strip()
if len(line_text) == 0:
continue
line_text_split = line_text.split()
if len(line_text_split) < 3:
continue
first_word = line_text_split[0]
if first_word.lower() == "class":
continue
max_similarity_fund_name, max_similarity = get_most_similar_name(
line_text, self.provider_fund_name_list
)
if max_similarity >= 0.2:
fund_name_line = line_text
break
else:
fund_name_line = half_line
half_len += 1
if fund_name_line == "":
return {"data": []}
logger.info(f"Split first part from 0 to {half_len}")
split_first_part = "\n".join(split_context[:half_len])
first_part = '\n'.join(split_first_part)
@ -435,7 +484,7 @@ class DataExtraction:
logger.info(f"Split second part from {half_len} to {split_context_len}")
split_second_part = "\n".join(split_context[half_len:])
second_part = header + '\n' + split_second_part
second_part = header + "\n" + fund_name_line + "\n" + split_second_part
second_instructions = self.get_instructions_by_datapoints(
second_part, page_datapoints, need_exclude, exclude_data
)
@ -456,6 +505,30 @@ class DataExtraction:
for first_data in first_part_data_list:
if first_data in second_part_data_list:
second_part_data_list.remove(first_data)
else:
# if the first part data is with same fund name and share name,
# remove the second part data
first_data_dp = [key for key in list(first_data.keys())
if key not in ["fund name", "share name"]]
# order the data points
first_data_dp.sort()
first_fund_name = first_data.get("fund name", "")
first_share_name = first_data.get("share name", "")
if len(first_fund_name) > 0 and len(first_share_name) > 0:
remove_second_list = []
for second_data in second_part_data_list:
second_fund_name = second_data.get("fund name", "")
second_share_name = second_data.get("share name", "")
if first_fund_name == second_fund_name and \
first_share_name == second_share_name:
second_data_dp = [key for key in list(second_data.keys())
if key not in ["fund name", "share name"]]
second_data_dp.sort()
if first_data_dp == second_data_dp:
remove_second_list.append(second_data)
for remove_second in remove_second_list:
if remove_second in second_part_data_list:
second_part_data_list.remove(remove_second)
data_list = first_part_data_list + second_part_data_list
extract_data = {"data": data_list}
@ -486,6 +559,22 @@ class DataExtraction:
for remove_data in remove_list:
if remove_data in data_list:
data_list.remove(remove_data)
# check performance_fee
for data in data_list:
performance_fee = data.get("performance_fee", None)
if performance_fee is not None:
performance_fee = float(performance_fee)
if performance_fee > 3 and performance_fee % 2.5 == 0:
data.pop("performance_fee")
remove_list = []
for data in data_list:
keys = [key for key in list(data.keys())
if key not in ["fund name", "share name"]]
if len(keys) == 0:
remove_list.append(data)
for remove_data in remove_list:
if remove_data in data_list:
data_list.remove(remove_data)
# update "fund name" to be "fund_name"
# update "share name" to be "share_name"
new_data_list = []

View File

@ -3,7 +3,7 @@ import pandas as pd
import time
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from utils.biz_utils import get_unique_words_text
from utils.biz_utils import get_unique_words_text, get_beginning_common_words
from utils.logger import logger
@ -293,24 +293,18 @@ class Metrics:
prediction_data: pd.DataFrame,
data_point: str,
):
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
dp_prediction = self.modify_data(dp_prediction)
pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
pred_simple_name_unique_words_list = dp_prediction["simple_name_unique_words"].unique().tolist()
dp_ground_truth = ground_truth_data[
ground_truth_data["datapoint"] == data_point
]
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
# add new column to store unique words for dp_ground_truth
dp_ground_truth["unique_words"] = dp_ground_truth["raw_name"].apply(
get_unique_words_text
)
ground_truth_unique_words_list = dp_ground_truth["unique_words"].unique().tolist()
ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist()
# add new column to store unique words for dp_prediction
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
get_unique_words_text
)
pred_unique_words_list = dp_prediction["unique_words"].unique().tolist()
pred_raw_names = dp_prediction["raw_name"].unique().tolist()
dp_ground_truth = self.modify_data(dp_ground_truth)
gt_simple_raw_names = dp_ground_truth["simple_raw_name"].unique().tolist()
gt_simple_name_unique_words_list = dp_ground_truth["simple_name_unique_words"].unique().tolist()
true_data = []
pred_data = []
@ -320,28 +314,53 @@ class Metrics:
true_data.append(1)
pred_data.append(1)
return true_data, pred_data, missing_error_data
for index, prediction in dp_prediction.iterrows():
pred_page_index = prediction["page_index"]
pred_raw_name = prediction["raw_name"]
pred_unique_words = prediction["unique_words"]
pred_simple_raw_name = prediction["simple_raw_name"]
pred_simple_name_unique_words = prediction["simple_name_unique_words"]
pred_data_point_value = prediction["value"]
pred_investment_type = prediction["investment_type"]
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
if pred_unique_words in ground_truth_unique_words_list or len(find_raw_name_in_gt) > 0:
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in gt_simple_raw_names
if (gt_raw_name in pred_simple_raw_name or pred_simple_raw_name in gt_raw_name)
and gt_raw_name.endswith(pred_raw_name.split()[-1])]
if pred_simple_name_unique_words in gt_simple_name_unique_words_list or \
len(find_raw_name_in_gt) > 0:
# get the ground truth data with the same unique words
if pred_unique_words in ground_truth_unique_words_list:
gt_data = dp_ground_truth[
dp_ground_truth["unique_words"] == pred_unique_words
].iloc[0]
if pred_simple_name_unique_words in gt_simple_name_unique_words_list:
gt_data_df = dp_ground_truth[
dp_ground_truth["simple_name_unique_words"] == pred_simple_name_unique_words
]
if len(gt_data_df) > 1:
if len(gt_data_df[gt_data_df["page_index"] == pred_page_index]) == 0:
gt_data = gt_data_df.iloc[0]
else:
gt_data = gt_data_df[gt_data_df["page_index"] == pred_page_index].iloc[0]
elif len(gt_data_df) == 1:
gt_data = gt_data_df.iloc[0]
else:
gt_data = None
else:
gt_data = dp_ground_truth[
dp_ground_truth["raw_name"] == find_raw_name_in_gt[0]
].iloc[0]
gt_data_point_value = gt_data["value"]
if pred_data_point_value == gt_data_point_value:
gt_data_df = dp_ground_truth[
dp_ground_truth["simple_raw_name"] == find_raw_name_in_gt[0]
]
if len(gt_data_df) > 1:
if len(gt_data_df[gt_data_df["page_index"] == pred_page_index]) == 0:
gt_data = gt_data_df.iloc[0]
else:
gt_data = gt_data_df[gt_data_df["page_index"] == pred_page_index].iloc[0]
elif len(gt_data_df) == 1:
gt_data = gt_data_df.iloc[0]
else:
gt_data = None
if gt_data is None:
gt_data_point_value = None
else:
gt_data_point_value = gt_data["value"]
if gt_data_point_value is not None and \
pred_data_point_value == gt_data_point_value:
true_data.append(1)
pred_data.append(1)
else:
@ -376,14 +395,16 @@ class Metrics:
for index, ground_truth in dp_ground_truth.iterrows():
gt_page_index = ground_truth["page_index"]
gt_raw_name = ground_truth["raw_name"]
gt_unique_words = ground_truth["unique_words"]
gt_simple_raw_name = ground_truth["simple_raw_name"]
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
gt_data_point_value = ground_truth["value"]
gt_investment_type = ground_truth["investment_type"]
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_simple_raw_names
if (gt_simple_raw_name in pred_raw_name or pred_raw_name in gt_simple_raw_name)
and pred_raw_name.endswith(gt_raw_name.split()[-1])]
if gt_unique_words not in pred_unique_words_list and \
if gt_simple_name_unique_words not in pred_simple_name_unique_words_list and \
len(find_raw_name_in_pred) == 0:
true_data.append(1)
pred_data.append(0)
@ -400,6 +421,26 @@ class Metrics:
missing_error_data.append(error_data)
return true_data, pred_data, missing_error_data
def modify_data(self, data: pd.DataFrame):
data["simple_raw_name"] = ""
data["simple_name_unique_words"] = ""
page_index_list = data["page_index"].unique().tolist()
for pagex_index in page_index_list:
page_data = data[data["page_index"] == pagex_index]
raw_name_list = page_data["raw_name"].unique().tolist()
beginning_common_words = get_beginning_common_words(raw_name_list)
for raw_name in raw_name_list:
if beginning_common_words is not None and len(beginning_common_words) > 0:
simple_raw_name = raw_name.replace(beginning_common_words, "").strip()
else:
simple_raw_name = raw_name
# set simple_raw_name which with the same page and same raw_name
data.loc[(data["page_index"] == pagex_index) & (data["raw_name"] == raw_name),
"simple_raw_name"] = simple_raw_name
data.loc[(data["page_index"] == pagex_index) & (data["raw_name"] == raw_name),
"simple_name_unique_words"] = get_unique_words_text(simple_raw_name)
return data
def get_specific_metrics(self, true_data: list, pred_data: list):
precision = precision_score(true_data, pred_data)

36
main.py
View File

@ -523,7 +523,7 @@ def test_auto_generate_instructions():
def test_data_extraction_metrics():
data_type = "data_extraction"
prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240919120502.xlsx"
# prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/321733631.xlsx"
# prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/509350496.xlsx"
prediction_sheet_name = "mapping_data"
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
ground_truth_sheet_name = "mapping_data"
@ -577,26 +577,26 @@ if __name__ == "__main__":
# extract_way,
# re_run_extract_data)
special_doc_id_list = ["476492237"]
special_doc_id_list = []
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
re_run_mapping_data = True
force_save_total_data = False
extract_ways = ["text"]
# for extract_way in extract_ways:
# batch_start_job(
# pdf_folder,
# page_filter_ground_truth_file,
# output_extract_data_child_folder,
# output_mapping_child_folder,
# output_extract_data_total_folder,
# output_mapping_total_folder,
# extract_way,
# special_doc_id_list,
# re_run_extract_data,
# re_run_mapping_data,
# force_save_total_data=force_save_total_data,
# )
extract_ways = ["text", "image"]
for extract_way in extract_ways:
batch_start_job(
pdf_folder,
page_filter_ground_truth_file,
output_extract_data_child_folder,
output_mapping_child_folder,
output_extract_data_total_folder,
output_mapping_total_folder,
extract_way,
special_doc_id_list,
re_run_extract_data,
re_run_mapping_data,
force_save_total_data=force_save_total_data,
)
test_data_extraction_metrics()
# test_data_extraction_metrics()

View File

@ -205,6 +205,31 @@ def get_jacard_similarity(text_left,
else:
return 0
def get_beginning_common_words(text_list: list):
"""
Get the beginning common words in text_list
"""
if text_list is None or len(text_list) < 2:
return []
common_words_list = []
first_text_split = text_list[0].split()
for w_i, word in enumerate(first_text_split):
all_same = True
for text in text_list[1:]:
text_split = text.split()
if w_i >= len(text_split):
all_same = False
break
if text_split[w_i] != word:
all_same = False
break
if all_same:
common_words_list.append(word)
else:
break
return ' '.join(common_words_list).strip()
def replace_abbrevation(text: str):
if text is None or len(text.strip()) == 0: