optimize data extract, metrics calculation algorithm
This commit is contained in:
parent
48dc8690c3
commit
c4985ac75f
|
|
@ -6,9 +6,9 @@ import fitz
|
|||
import pandas as pd
|
||||
from utils.gpt_utils import chat
|
||||
from utils.pdf_util import PDFUtil
|
||||
from utils.sql_query_util import query_document_fund_mapping
|
||||
from utils.sql_query_util import query_document_fund_mapping, query_investment_by_provider
|
||||
from utils.logger import logger
|
||||
from utils.biz_utils import add_slash_to_text_as_regex, clean_text
|
||||
from utils.biz_utils import add_slash_to_text_as_regex, clean_text, get_most_similar_name
|
||||
|
||||
|
||||
class DataExtraction:
|
||||
|
|
@ -44,6 +44,13 @@ class DataExtraction:
|
|||
self.document_mapping_info_df = query_document_fund_mapping(doc_id)
|
||||
else:
|
||||
self.document_mapping_info_df = document_mapping_info_df
|
||||
self.provider_mapping_df = self.get_provider_mapping()
|
||||
if len(self.provider_mapping_df) == 0:
|
||||
self.provider_fund_name_list = []
|
||||
else:
|
||||
self.provider_fund_name_list = (
|
||||
self.provider_mapping_df["FundName"].unique().tolist()
|
||||
)
|
||||
self.datapoint_page_info = datapoint_page_info
|
||||
self.page_nums_with_datapoints = self.get_page_nums_from_datapoint_page_info()
|
||||
self.datapoints = datapoints
|
||||
|
|
@ -53,6 +60,20 @@ class DataExtraction:
|
|||
self.extract_way = extract_way
|
||||
self.output_image_folder = output_image_folder
|
||||
|
||||
def get_provider_mapping(self):
|
||||
if len(self.document_mapping_info_df) == 0:
|
||||
return pd.DataFrame()
|
||||
provider_id_list = (
|
||||
self.document_mapping_info_df["ProviderId"].unique().tolist()
|
||||
)
|
||||
provider_mapping_list = []
|
||||
for provider_id in provider_id_list:
|
||||
provider_mapping_list.append(query_investment_by_provider(provider_id))
|
||||
provider_mapping_df = pd.concat(provider_mapping_list)
|
||||
provider_mapping_df = provider_mapping_df.drop_duplicates()
|
||||
provider_mapping_df.reset_index(drop=True, inplace=True)
|
||||
return provider_mapping_df
|
||||
|
||||
def get_pdf_image_base64(self, page_index: int) -> dict:
|
||||
pdf_util = PDFUtil(self.pdf_file)
|
||||
return pdf_util.extract_image_from_page(page_index=page_index,
|
||||
|
|
@ -403,6 +424,9 @@ class DataExtraction:
|
|||
split_context = re.split(r"\n", page_text)
|
||||
split_context = [text.strip() for text in split_context
|
||||
if len(text.strip()) > 0]
|
||||
if len(split_context) < 10:
|
||||
return {"data": []}
|
||||
|
||||
split_context_len = len(split_context)
|
||||
top_10_context = split_context[:10]
|
||||
rest_context = split_context[10:]
|
||||
|
|
@ -411,11 +435,36 @@ class DataExtraction:
|
|||
# the member of half_len should not start with number
|
||||
# reverse iterate the list by half_len
|
||||
half_len_list = [i for i in range(half_len)]
|
||||
|
||||
fund_name_line = ""
|
||||
half_line = rest_context[half_len].strip()
|
||||
max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
half_line, self.provider_fund_name_list
|
||||
)
|
||||
if max_similarity < 0.2:
|
||||
# get the fund name line text from the first half
|
||||
for index in reversed(half_len_list):
|
||||
first_letter = rest_context[index].strip()[0]
|
||||
if not first_letter.isnumeric() and first_letter not in [".", "(", ")", "-"]:
|
||||
half_len = index
|
||||
line_text = rest_context[index].strip()
|
||||
if len(line_text) == 0:
|
||||
continue
|
||||
line_text_split = line_text.split()
|
||||
if len(line_text_split) < 3:
|
||||
continue
|
||||
first_word = line_text_split[0]
|
||||
if first_word.lower() == "class":
|
||||
continue
|
||||
|
||||
max_similarity_fund_name, max_similarity = get_most_similar_name(
|
||||
line_text, self.provider_fund_name_list
|
||||
)
|
||||
if max_similarity >= 0.2:
|
||||
fund_name_line = line_text
|
||||
break
|
||||
else:
|
||||
fund_name_line = half_line
|
||||
half_len += 1
|
||||
if fund_name_line == "":
|
||||
return {"data": []}
|
||||
|
||||
logger.info(f"Split first part from 0 to {half_len}")
|
||||
split_first_part = "\n".join(split_context[:half_len])
|
||||
|
|
@ -435,7 +484,7 @@ class DataExtraction:
|
|||
|
||||
logger.info(f"Split second part from {half_len} to {split_context_len}")
|
||||
split_second_part = "\n".join(split_context[half_len:])
|
||||
second_part = header + '\n' + split_second_part
|
||||
second_part = header + "\n" + fund_name_line + "\n" + split_second_part
|
||||
second_instructions = self.get_instructions_by_datapoints(
|
||||
second_part, page_datapoints, need_exclude, exclude_data
|
||||
)
|
||||
|
|
@ -456,6 +505,30 @@ class DataExtraction:
|
|||
for first_data in first_part_data_list:
|
||||
if first_data in second_part_data_list:
|
||||
second_part_data_list.remove(first_data)
|
||||
else:
|
||||
# if the first part data is with same fund name and share name,
|
||||
# remove the second part data
|
||||
first_data_dp = [key for key in list(first_data.keys())
|
||||
if key not in ["fund name", "share name"]]
|
||||
# order the data points
|
||||
first_data_dp.sort()
|
||||
first_fund_name = first_data.get("fund name", "")
|
||||
first_share_name = first_data.get("share name", "")
|
||||
if len(first_fund_name) > 0 and len(first_share_name) > 0:
|
||||
remove_second_list = []
|
||||
for second_data in second_part_data_list:
|
||||
second_fund_name = second_data.get("fund name", "")
|
||||
second_share_name = second_data.get("share name", "")
|
||||
if first_fund_name == second_fund_name and \
|
||||
first_share_name == second_share_name:
|
||||
second_data_dp = [key for key in list(second_data.keys())
|
||||
if key not in ["fund name", "share name"]]
|
||||
second_data_dp.sort()
|
||||
if first_data_dp == second_data_dp:
|
||||
remove_second_list.append(second_data)
|
||||
for remove_second in remove_second_list:
|
||||
if remove_second in second_part_data_list:
|
||||
second_part_data_list.remove(remove_second)
|
||||
|
||||
data_list = first_part_data_list + second_part_data_list
|
||||
extract_data = {"data": data_list}
|
||||
|
|
@ -486,6 +559,22 @@ class DataExtraction:
|
|||
for remove_data in remove_list:
|
||||
if remove_data in data_list:
|
||||
data_list.remove(remove_data)
|
||||
# check performance_fee
|
||||
for data in data_list:
|
||||
performance_fee = data.get("performance_fee", None)
|
||||
if performance_fee is not None:
|
||||
performance_fee = float(performance_fee)
|
||||
if performance_fee > 3 and performance_fee % 2.5 == 0:
|
||||
data.pop("performance_fee")
|
||||
remove_list = []
|
||||
for data in data_list:
|
||||
keys = [key for key in list(data.keys())
|
||||
if key not in ["fund name", "share name"]]
|
||||
if len(keys) == 0:
|
||||
remove_list.append(data)
|
||||
for remove_data in remove_list:
|
||||
if remove_data in data_list:
|
||||
data_list.remove(remove_data)
|
||||
# update "fund name" to be "fund_name"
|
||||
# update "share name" to be "share_name"
|
||||
new_data_list = []
|
||||
|
|
|
|||
103
core/metrics.py
103
core/metrics.py
|
|
@ -3,7 +3,7 @@ import pandas as pd
|
|||
import time
|
||||
import json
|
||||
from sklearn.metrics import precision_score, recall_score, f1_score
|
||||
from utils.biz_utils import get_unique_words_text
|
||||
from utils.biz_utils import get_unique_words_text, get_beginning_common_words
|
||||
from utils.logger import logger
|
||||
|
||||
|
||||
|
|
@ -293,23 +293,17 @@ class Metrics:
|
|||
prediction_data: pd.DataFrame,
|
||||
data_point: str,
|
||||
):
|
||||
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
|
||||
dp_prediction = self.modify_data(dp_prediction)
|
||||
pred_simple_raw_names = dp_prediction["simple_raw_name"].unique().tolist()
|
||||
pred_simple_name_unique_words_list = dp_prediction["simple_name_unique_words"].unique().tolist()
|
||||
|
||||
dp_ground_truth = ground_truth_data[
|
||||
ground_truth_data["datapoint"] == data_point
|
||||
]
|
||||
dp_prediction = prediction_data[prediction_data["datapoint"] == data_point]
|
||||
|
||||
# add new column to store unique words for dp_ground_truth
|
||||
dp_ground_truth["unique_words"] = dp_ground_truth["raw_name"].apply(
|
||||
get_unique_words_text
|
||||
)
|
||||
ground_truth_unique_words_list = dp_ground_truth["unique_words"].unique().tolist()
|
||||
ground_truth_raw_names = dp_ground_truth["raw_name"].unique().tolist()
|
||||
# add new column to store unique words for dp_prediction
|
||||
dp_prediction["unique_words"] = dp_prediction["raw_name"].apply(
|
||||
get_unique_words_text
|
||||
)
|
||||
pred_unique_words_list = dp_prediction["unique_words"].unique().tolist()
|
||||
pred_raw_names = dp_prediction["raw_name"].unique().tolist()
|
||||
dp_ground_truth = self.modify_data(dp_ground_truth)
|
||||
gt_simple_raw_names = dp_ground_truth["simple_raw_name"].unique().tolist()
|
||||
gt_simple_name_unique_words_list = dp_ground_truth["simple_name_unique_words"].unique().tolist()
|
||||
|
||||
true_data = []
|
||||
pred_data = []
|
||||
|
|
@ -324,24 +318,49 @@ class Metrics:
|
|||
for index, prediction in dp_prediction.iterrows():
|
||||
pred_page_index = prediction["page_index"]
|
||||
pred_raw_name = prediction["raw_name"]
|
||||
pred_unique_words = prediction["unique_words"]
|
||||
pred_simple_raw_name = prediction["simple_raw_name"]
|
||||
pred_simple_name_unique_words = prediction["simple_name_unique_words"]
|
||||
pred_data_point_value = prediction["value"]
|
||||
pred_investment_type = prediction["investment_type"]
|
||||
|
||||
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in ground_truth_raw_names
|
||||
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
|
||||
if pred_unique_words in ground_truth_unique_words_list or len(find_raw_name_in_gt) > 0:
|
||||
find_raw_name_in_gt = [gt_raw_name for gt_raw_name in gt_simple_raw_names
|
||||
if (gt_raw_name in pred_simple_raw_name or pred_simple_raw_name in gt_raw_name)
|
||||
and gt_raw_name.endswith(pred_raw_name.split()[-1])]
|
||||
if pred_simple_name_unique_words in gt_simple_name_unique_words_list or \
|
||||
len(find_raw_name_in_gt) > 0:
|
||||
# get the ground truth data with the same unique words
|
||||
if pred_unique_words in ground_truth_unique_words_list:
|
||||
gt_data = dp_ground_truth[
|
||||
dp_ground_truth["unique_words"] == pred_unique_words
|
||||
].iloc[0]
|
||||
if pred_simple_name_unique_words in gt_simple_name_unique_words_list:
|
||||
gt_data_df = dp_ground_truth[
|
||||
dp_ground_truth["simple_name_unique_words"] == pred_simple_name_unique_words
|
||||
]
|
||||
if len(gt_data_df) > 1:
|
||||
if len(gt_data_df[gt_data_df["page_index"] == pred_page_index]) == 0:
|
||||
gt_data = gt_data_df.iloc[0]
|
||||
else:
|
||||
gt_data = gt_data_df[gt_data_df["page_index"] == pred_page_index].iloc[0]
|
||||
elif len(gt_data_df) == 1:
|
||||
gt_data = gt_data_df.iloc[0]
|
||||
else:
|
||||
gt_data = None
|
||||
else:
|
||||
gt_data_df = dp_ground_truth[
|
||||
dp_ground_truth["simple_raw_name"] == find_raw_name_in_gt[0]
|
||||
]
|
||||
if len(gt_data_df) > 1:
|
||||
if len(gt_data_df[gt_data_df["page_index"] == pred_page_index]) == 0:
|
||||
gt_data = gt_data_df.iloc[0]
|
||||
else:
|
||||
gt_data = gt_data_df[gt_data_df["page_index"] == pred_page_index].iloc[0]
|
||||
elif len(gt_data_df) == 1:
|
||||
gt_data = gt_data_df.iloc[0]
|
||||
else:
|
||||
gt_data = None
|
||||
if gt_data is None:
|
||||
gt_data_point_value = None
|
||||
else:
|
||||
gt_data = dp_ground_truth[
|
||||
dp_ground_truth["raw_name"] == find_raw_name_in_gt[0]
|
||||
].iloc[0]
|
||||
gt_data_point_value = gt_data["value"]
|
||||
if pred_data_point_value == gt_data_point_value:
|
||||
if gt_data_point_value is not None and \
|
||||
pred_data_point_value == gt_data_point_value:
|
||||
true_data.append(1)
|
||||
pred_data.append(1)
|
||||
else:
|
||||
|
|
@ -376,14 +395,16 @@ class Metrics:
|
|||
for index, ground_truth in dp_ground_truth.iterrows():
|
||||
gt_page_index = ground_truth["page_index"]
|
||||
gt_raw_name = ground_truth["raw_name"]
|
||||
gt_unique_words = ground_truth["unique_words"]
|
||||
gt_simple_raw_name = ground_truth["simple_raw_name"]
|
||||
gt_simple_name_unique_words = ground_truth["simple_name_unique_words"]
|
||||
gt_data_point_value = ground_truth["value"]
|
||||
gt_investment_type = ground_truth["investment_type"]
|
||||
|
||||
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_raw_names
|
||||
if gt_raw_name in pred_raw_name or pred_raw_name in gt_raw_name]
|
||||
find_raw_name_in_pred = [pred_raw_name for pred_raw_name in pred_simple_raw_names
|
||||
if (gt_simple_raw_name in pred_raw_name or pred_raw_name in gt_simple_raw_name)
|
||||
and pred_raw_name.endswith(gt_raw_name.split()[-1])]
|
||||
|
||||
if gt_unique_words not in pred_unique_words_list and \
|
||||
if gt_simple_name_unique_words not in pred_simple_name_unique_words_list and \
|
||||
len(find_raw_name_in_pred) == 0:
|
||||
true_data.append(1)
|
||||
pred_data.append(0)
|
||||
|
|
@ -401,6 +422,26 @@ class Metrics:
|
|||
|
||||
return true_data, pred_data, missing_error_data
|
||||
|
||||
def modify_data(self, data: pd.DataFrame):
|
||||
data["simple_raw_name"] = ""
|
||||
data["simple_name_unique_words"] = ""
|
||||
page_index_list = data["page_index"].unique().tolist()
|
||||
for pagex_index in page_index_list:
|
||||
page_data = data[data["page_index"] == pagex_index]
|
||||
raw_name_list = page_data["raw_name"].unique().tolist()
|
||||
beginning_common_words = get_beginning_common_words(raw_name_list)
|
||||
for raw_name in raw_name_list:
|
||||
if beginning_common_words is not None and len(beginning_common_words) > 0:
|
||||
simple_raw_name = raw_name.replace(beginning_common_words, "").strip()
|
||||
else:
|
||||
simple_raw_name = raw_name
|
||||
# set simple_raw_name which with the same page and same raw_name
|
||||
data.loc[(data["page_index"] == pagex_index) & (data["raw_name"] == raw_name),
|
||||
"simple_raw_name"] = simple_raw_name
|
||||
data.loc[(data["page_index"] == pagex_index) & (data["raw_name"] == raw_name),
|
||||
"simple_name_unique_words"] = get_unique_words_text(simple_raw_name)
|
||||
return data
|
||||
|
||||
def get_specific_metrics(self, true_data: list, pred_data: list):
|
||||
precision = precision_score(true_data, pred_data)
|
||||
recall = recall_score(true_data, pred_data)
|
||||
|
|
|
|||
36
main.py
36
main.py
|
|
@ -523,7 +523,7 @@ def test_auto_generate_instructions():
|
|||
def test_data_extraction_metrics():
|
||||
data_type = "data_extraction"
|
||||
prediction_file = r"/data/emea_ar/output/mapping_data/total/mapping_data_info_88_documents_20240919120502.xlsx"
|
||||
# prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/321733631.xlsx"
|
||||
# prediction_file = r"/data/emea_ar/output/mapping_data/docs/by_text/excel/509350496.xlsx"
|
||||
prediction_sheet_name = "mapping_data"
|
||||
ground_truth_file = r"/data/emea_ar/ground_truth/data_extraction/mapping_data_info_73_documents.xlsx"
|
||||
ground_truth_sheet_name = "mapping_data"
|
||||
|
|
@ -577,26 +577,26 @@ if __name__ == "__main__":
|
|||
# extract_way,
|
||||
# re_run_extract_data)
|
||||
|
||||
special_doc_id_list = ["476492237"]
|
||||
special_doc_id_list = []
|
||||
output_mapping_child_folder = r"/data/emea_ar/output/mapping_data/docs/"
|
||||
output_mapping_total_folder = r"/data/emea_ar/output/mapping_data/total/"
|
||||
re_run_mapping_data = True
|
||||
force_save_total_data = False
|
||||
|
||||
extract_ways = ["text"]
|
||||
# for extract_way in extract_ways:
|
||||
# batch_start_job(
|
||||
# pdf_folder,
|
||||
# page_filter_ground_truth_file,
|
||||
# output_extract_data_child_folder,
|
||||
# output_mapping_child_folder,
|
||||
# output_extract_data_total_folder,
|
||||
# output_mapping_total_folder,
|
||||
# extract_way,
|
||||
# special_doc_id_list,
|
||||
# re_run_extract_data,
|
||||
# re_run_mapping_data,
|
||||
# force_save_total_data=force_save_total_data,
|
||||
# )
|
||||
extract_ways = ["text", "image"]
|
||||
for extract_way in extract_ways:
|
||||
batch_start_job(
|
||||
pdf_folder,
|
||||
page_filter_ground_truth_file,
|
||||
output_extract_data_child_folder,
|
||||
output_mapping_child_folder,
|
||||
output_extract_data_total_folder,
|
||||
output_mapping_total_folder,
|
||||
extract_way,
|
||||
special_doc_id_list,
|
||||
re_run_extract_data,
|
||||
re_run_mapping_data,
|
||||
force_save_total_data=force_save_total_data,
|
||||
)
|
||||
|
||||
test_data_extraction_metrics()
|
||||
# test_data_extraction_metrics()
|
||||
|
|
|
|||
|
|
@ -205,6 +205,31 @@ def get_jacard_similarity(text_left,
|
|||
else:
|
||||
return 0
|
||||
|
||||
def get_beginning_common_words(text_list: list):
|
||||
"""
|
||||
Get the beginning common words in text_list
|
||||
"""
|
||||
if text_list is None or len(text_list) < 2:
|
||||
return []
|
||||
|
||||
common_words_list = []
|
||||
first_text_split = text_list[0].split()
|
||||
for w_i, word in enumerate(first_text_split):
|
||||
all_same = True
|
||||
for text in text_list[1:]:
|
||||
text_split = text.split()
|
||||
if w_i >= len(text_split):
|
||||
all_same = False
|
||||
break
|
||||
if text_split[w_i] != word:
|
||||
all_same = False
|
||||
break
|
||||
if all_same:
|
||||
common_words_list.append(word)
|
||||
else:
|
||||
break
|
||||
|
||||
return ' '.join(common_words_list).strip()
|
||||
|
||||
def replace_abbrevation(text: str):
|
||||
if text is None or len(text.strip()) == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue