442 lines
22 KiB
Python
442 lines
22 KiB
Python
import core.logger as logging
|
||
from core.db.db_truth_social_content import DBTruthSocialContent
|
||
from config import TRUTH_SOCIAL_API, COIN_MYSQL_CONFIG, WECHAT_CONFIG, ALI_API_KEY
|
||
from core.wechat import Wechat
|
||
|
||
import requests
|
||
import json
|
||
import os
|
||
from bs4 import BeautifulSoup
|
||
import time
|
||
from datetime import datetime
|
||
import pytz
|
||
import pandas as pd
|
||
import dashscope
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
class TruthSocialRetriever:
|
||
def __init__(self) -> None:
|
||
self.api_key = TRUTH_SOCIAL_API.get("api_key", "")
|
||
self.media_config_list = TRUTH_SOCIAL_API.get("media_config", [])
|
||
# self.user_info = TRUTH_SOCIAL_API.get("user_id", {})
|
||
mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
|
||
mysql_password = COIN_MYSQL_CONFIG.get("password", "")
|
||
if not mysql_password:
|
||
raise ValueError("MySQL password is not set")
|
||
mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
|
||
mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
|
||
mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
|
||
|
||
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||
self.db_truth_social_content = DBTruthSocialContent(self.db_url)
|
||
|
||
trump_key = WECHAT_CONFIG.get("trump_key", "")
|
||
if trump_key:
|
||
self.wechat = Wechat(trump_key)
|
||
else:
|
||
self.wechat = None
|
||
|
||
self.save_path = r"./output/media/truth_social/"
|
||
os.makedirs(self.save_path, exist_ok=True)
|
||
|
||
self.ali_api_key = ALI_API_KEY
|
||
text_instruction_file = r"./instructions/media_article_instructions.json"
|
||
with open(text_instruction_file, "r", encoding="utf-8") as f:
|
||
self.text_instruction = json.load(f)
|
||
|
||
image_instruction_file = r"./instructions/media_image_instructions.json"
|
||
with open(image_instruction_file, "r", encoding="utf-8") as f:
|
||
self.image_instruction = json.load(f)
|
||
|
||
image_post_instruction_file = r"./instructions/media_image_post_instructions.json"
|
||
with open(image_post_instruction_file, "r", encoding="utf-8") as f:
|
||
self.image_post_instruction = json.load(f)
|
||
|
||
text_image_post_instruction_file = r"./instructions/media_article_image_post_instructions.json"
|
||
with open(text_image_post_instruction_file, "r", encoding="utf-8") as f:
|
||
self.text_image_post_instruction = json.load(f)
|
||
|
||
def get_user_id_from_page(self, handle="realDonaldTrump"):
|
||
url = f"https://truthsocial.com/@{handle}"
|
||
headers = {
|
||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||
} # 模拟浏览器
|
||
|
||
response = requests.get(url, headers=headers)
|
||
response.raise_for_status()
|
||
|
||
soup = BeautifulSoup(response.text, "html.parser")
|
||
# 查找嵌入的 JSON(Truth Social 使用 data 属性或 script 标签)
|
||
scripts = soup.find_all("script")
|
||
for script in scripts:
|
||
if script.string and "id" in script.string and handle in script.string:
|
||
# 简单提取(实际可能需正则匹配 JSON)
|
||
import re
|
||
|
||
match = re.search(r'"id"\s*:\s*"(\d+)"', script.string)
|
||
if match:
|
||
return match.group(1)
|
||
return None
|
||
|
||
def get_user_posts(self, limit: int = None):
|
||
"""
|
||
获取用户在 Truth Social 的最新帖子。
|
||
免费版:100次
|
||
付费版:
|
||
47美元:25,000次,如果5分钟跑一次,则可以跑86.8天
|
||
497美元:500,000次,如果5分钟跑一次,则可以跑1736天
|
||
参数:
|
||
- limit: 最大帖子数(API 默认返回 20 条,可通过分页获取更多)。
|
||
|
||
返回:
|
||
- 帖子列表(JSON 格式)。
|
||
"""
|
||
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
|
||
|
||
for media_config in self.media_config_list:
|
||
media_name = media_config.get("media_name", "")
|
||
logger.info(f"开始获取{media_name}的帖子")
|
||
base_url = media_config.get("base_url", "")
|
||
user_info = media_config.get("user_info", {})
|
||
for user_name, user_details in user_info.items():
|
||
user_id = user_details.get("id", "")
|
||
user_full_name = user_details.get("full_name", "")
|
||
|
||
params = {
|
||
"handle": user_name, # 用户名
|
||
"user_id": user_id, # 可选,用户 ID
|
||
"next_max_id": None, # 分页时设置为上一次响应的 max_id
|
||
"trim": "false", # 保留完整内容
|
||
}
|
||
|
||
logger.info(f"Searching contents for user: {user_name}")
|
||
try:
|
||
response = requests.get(base_url, headers=headers, params=params)
|
||
response.raise_for_status() # 检查 HTTP 错误
|
||
data = response.json()
|
||
|
||
# 提取帖子列表(假设响应中 'posts' 是键,根据实际文档调整)
|
||
if limit is not None and isinstance(limit, int):
|
||
posts = data.get("posts", [])[:limit]
|
||
else:
|
||
posts = data.get("posts", [])
|
||
|
||
if len(posts) == 1:
|
||
try:
|
||
max_id = posts[0].get("id")
|
||
params["next_max_id"] = max_id
|
||
response = requests.get(base_url, headers=headers, params=params)
|
||
response.raise_for_status() # 检查 HTTP 错误
|
||
data = response.json()
|
||
posts.extend(data.get("posts", []))
|
||
except Exception as e:
|
||
logger.error(f"获取下一页帖子失败: {e}")
|
||
pass
|
||
|
||
results = []
|
||
if posts:
|
||
logger.info(f"获取{user_name}帖子: {len(posts)}条")
|
||
for post in posts:
|
||
result = {}
|
||
result["article_id"] = post.get("id")
|
||
result["user_id"] = user_id
|
||
result["user_name"] = user_name
|
||
datetime_text = post.get("created_at")
|
||
datetime_dict = self.transform_datetime(datetime_text)
|
||
timestamp_ms = datetime_dict["timestamp_ms"]
|
||
result["timestamp"] = timestamp_ms
|
||
beijing_time_str = datetime_dict["beijing_time_str"]
|
||
result["date_time"] = beijing_time_str
|
||
result["text"] = post.get("text", "")
|
||
media_attachments = post.get("media_attachments", [])
|
||
result["media_url"] = ""
|
||
result["media_type"] = ""
|
||
result["media_thumbnail"] = ""
|
||
if media_attachments:
|
||
for media_attachment in media_attachments:
|
||
result["media_url"] = media_attachment.get("url")
|
||
result["media_type"] = media_attachment.get("type")
|
||
result["media_thumbnail"] = media_attachment.get(
|
||
"preview_url"
|
||
)
|
||
break
|
||
results.append(result)
|
||
else:
|
||
print("获取帖子失败,请检查 API 密钥或网络。")
|
||
|
||
if len(results) > 0:
|
||
result_df = pd.DataFrame(results)
|
||
result_df = self.remove_duplicate_posts(result_df)
|
||
|
||
if len(result_df) > 0:
|
||
result_df["analysis_result"] = ""
|
||
result_df["analysis_token"] = 0
|
||
result_df = self.send_wechat_message(result_df, user_full_name)
|
||
result_df = result_df[
|
||
[
|
||
"article_id",
|
||
"user_id",
|
||
"user_name",
|
||
"timestamp",
|
||
"date_time",
|
||
"text",
|
||
"analysis_result",
|
||
"analysis_token",
|
||
"media_url",
|
||
"media_type",
|
||
"media_thumbnail",
|
||
]
|
||
]
|
||
self.db_truth_social_content.insert_data_to_mysql(result_df)
|
||
logger.info(f"已将{len(result_df)}条数据插入到数据库")
|
||
else:
|
||
logger.info(f"没有数据需要插入到数据库和发送企业微信消息")
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"请求错误: {e}")
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON 解析错误: {e}")
|
||
|
||
def send_message_by_json_file(self, json_file_name: str):
|
||
with open(json_file_name, "r", encoding="utf-8") as f:
|
||
results = json.load(f)
|
||
result_df = pd.DataFrame(results)
|
||
result_df = self.remove_duplicate_posts(result_df)
|
||
if len(result_df) > 0:
|
||
self.send_wechat_message(result_df)
|
||
else:
|
||
logger.info(f"没有数据需要发送企业微信消息")
|
||
|
||
def remove_duplicate_posts(self, result_df: pd.DataFrame):
|
||
try:
|
||
duplicate_index_list = []
|
||
for index, row in result_df.iterrows():
|
||
article_id = row["article_id"]
|
||
exist_data = self.db_truth_social_content.query_data_by_article_id(
|
||
article_id
|
||
)
|
||
if exist_data:
|
||
duplicate_index_list.append(index)
|
||
# 删除重复的行
|
||
result_df = result_df.drop(duplicate_index_list)
|
||
result_df.sort_values(by="timestamp", ascending=True, inplace=True)
|
||
result_df.reset_index(drop=True, inplace=True)
|
||
logger.info(f"删除重复的行后,剩余{len(result_df)}条数据")
|
||
except Exception as e:
|
||
result_df = pd.DataFrame([])
|
||
logger.error(f"删除重复的行失败: {e}")
|
||
return result_df
|
||
|
||
def send_wechat_message(self, result_df: pd.DataFrame, user_full_name: str):
|
||
if self.wechat is None:
|
||
logger.error("企业微信未初始化")
|
||
return
|
||
for index, row in result_df.iterrows():
|
||
try:
|
||
date_time = row["date_time"]
|
||
text = row["text"]
|
||
media_thumbnail = row["media_thumbnail"]
|
||
if len(text) > 0:
|
||
if media_thumbnail and len(media_thumbnail) > 0:
|
||
contents = []
|
||
contents.append(f"## {user_full_name}推文")
|
||
contents.append(text)
|
||
contents.append(f"## 推文时间")
|
||
contents.append(date_time)
|
||
mark_down_text = "\n\n".join(contents)
|
||
self.wechat.send_markdown(mark_down_text)
|
||
response, image_path, base64_str, md5_str = self.wechat.send_image(media_thumbnail)
|
||
image_format = "jpg"
|
||
if image_path is not None and len(image_path) > 0:
|
||
image_format = image_path.split(".")[-1]
|
||
if image_format == "jpeg":
|
||
image_format = "jpg"
|
||
analysis_result, analysis_token = self.analyze_truth_social_content(
|
||
text=mark_down_text,
|
||
image_stream=base64_str,
|
||
image_format=image_format,
|
||
media_type="hybrid",
|
||
user_full_name=user_full_name
|
||
)
|
||
if analysis_result is not None and len(analysis_result) > 0:
|
||
result_df.at[index, "analysis_result"] = analysis_result
|
||
result_df.at[index, "analysis_token"] = analysis_token
|
||
else:
|
||
result_df.at[index, "analysis_result"] = ""
|
||
result_df.at[index, "analysis_token"] = 0
|
||
analysis_text = f"\n\n## 上述图文分析结果\n\n{analysis_result}"
|
||
analysis_text += f"\n\n## 上述图文分析token\n\n{analysis_token}"
|
||
self.wechat.send_markdown(analysis_text)
|
||
else:
|
||
contents = []
|
||
contents.append(f"## {user_full_name}推文")
|
||
contents.append(text)
|
||
contents.append(f"## 推文时间")
|
||
contents.append(date_time)
|
||
mark_down_text = "\n\n".join(contents)
|
||
analysis_result, analysis_token = self.analyze_truth_social_content(
|
||
text=text,
|
||
image_stream=None,
|
||
image_format=None,
|
||
media_type="text",
|
||
user_full_name=user_full_name
|
||
)
|
||
result_df.at[index, "analysis_result"] = analysis_result
|
||
result_df.at[index, "analysis_token"] = analysis_token
|
||
analysis_text = f"\n\n## 分析结果\n\n{analysis_result}"
|
||
analysis_text += f"\n\n## 分析token\n\n{analysis_token}"
|
||
if self.calculate_bytes(mark_down_text + analysis_text) > 4096:
|
||
self.wechat.send_markdown(mark_down_text)
|
||
if self.calculate_bytes(analysis_text) > 4096:
|
||
half_analysis_text_length = len(analysis_text) // 2
|
||
analysis_1st = analysis_text[:half_analysis_text_length].strip()
|
||
analysis_2nd = analysis_text[half_analysis_text_length:].strip()
|
||
self.wechat.send_markdown(
|
||
f"## 分析结果第一部分\n\n{analysis_1st}"
|
||
)
|
||
self.wechat.send_markdown(
|
||
f"## 分析结果第二部分\n\n{analysis_2nd}"
|
||
)
|
||
else:
|
||
self.wechat.send_markdown(f"## 分析结果\n\n{analysis_text}")
|
||
else:
|
||
self.wechat.send_markdown(mark_down_text + analysis_text)
|
||
elif media_thumbnail and len(media_thumbnail) > 0:
|
||
response, image_path, base64_str, md5_str = self.wechat.send_image(media_thumbnail)
|
||
image_format = "jpg"
|
||
if image_path is not None and len(image_path) > 0:
|
||
image_format = image_path.split(".")[-1]
|
||
if image_format == "jpeg":
|
||
image_format = "jpg"
|
||
analysis_result, analysis_token = self.analyze_truth_social_content(
|
||
text="",
|
||
image_stream=base64_str,
|
||
image_format=image_format,
|
||
media_type="image",
|
||
user_full_name=user_full_name
|
||
)
|
||
if analysis_result is not None and len(analysis_result) > 0:
|
||
result_df.at[index, "analysis_result"] = analysis_result
|
||
result_df.at[index, "analysis_token"] = analysis_token
|
||
else:
|
||
result_df.at[index, "analysis_result"] = ""
|
||
result_df.at[index, "analysis_token"] = 0
|
||
analysis_text = f"\n\n## 上述图片分析结果\n\n{analysis_result}"
|
||
analysis_text += f"\n\n## 上述图片分析token\n\n{analysis_token}"
|
||
self.wechat.send_markdown(analysis_text)
|
||
else:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"发送企业微信消息失败: {e}")
|
||
continue
|
||
return result_df
|
||
|
||
def calculate_bytes(self, text: str):
|
||
return len(text.encode("utf-8"))
|
||
|
||
def analyze_truth_social_content(self, text: str, image_stream: str, image_format: str, media_type: str, user_full_name: str):
|
||
try:
|
||
token = 0
|
||
if text is None:
|
||
text = ""
|
||
image_text = ""
|
||
if media_type in ["image", "hybrid"]:
|
||
if image_stream is None or len(image_stream) == 0:
|
||
return "", 0
|
||
instructions = self.image_instruction.get("Instructions", "")
|
||
output = self.image_instruction.get("Output", "")
|
||
prompt = f"# Instructions\n\n{instructions}\n\n# Output\n\n{output}"
|
||
messages_local = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": f"data:image/{image_format};base64,{image_stream}"}, # base64 字符串
|
||
{"text": prompt} # 你的 prompt
|
||
]
|
||
}
|
||
]
|
||
response = dashscope.MultiModalConversation.call(
|
||
api_key=self.ali_api_key,
|
||
model='qwen-vl-plus',
|
||
messages=messages_local,
|
||
)
|
||
if response.status_code == 200:
|
||
image_text = (
|
||
response.get("output", {})
|
||
.get("choices", [])[0]
|
||
.get("message", {})
|
||
.get("content", "")
|
||
)
|
||
temp_image_text = ""
|
||
if isinstance(image_text, list):
|
||
for item in image_text:
|
||
if isinstance(item, dict):
|
||
temp_image_text += item.get("text", "") + "\n\n"
|
||
elif isinstance(item, str):
|
||
temp_image_text += item + "\n\n"
|
||
else:
|
||
pass
|
||
image_text = temp_image_text.strip()
|
||
token = response.get("usage", {}).get("total_tokens", 0)
|
||
else:
|
||
text = f"{response.code} {response.message} 无法分析图片"
|
||
token = 0
|
||
|
||
text += image_text
|
||
|
||
context = text
|
||
if media_type == "text":
|
||
instructions = self.text_instruction.get("Instructions", "").format(user_full_name)
|
||
output = self.text_instruction.get("Output", "")
|
||
prompt = f"# Context\n\n{context}\n\n# Instructions\n\n{instructions}\n\n# Output\n\n{output}"
|
||
elif media_type == "image":
|
||
instructions = self.image_post_instruction.get("Instructions", "").format(user_full_name)
|
||
output = self.image_post_instruction.get("Output", "")
|
||
prompt = f"# Context\n\n{context}\n\n# Instructions\n\n{instructions}\n\n# Output\n\n{output}"
|
||
elif media_type == "hybrid":
|
||
instructions = self.text_image_post_instruction.get("Instructions", "").format(user_full_name)
|
||
output = self.text_image_post_instruction.get("Output", "").format(user_full_name)
|
||
prompt = f"# Context\n\n{context}\n\n# Instructions\n\n{instructions}\n\n# Output\n\n{output}"
|
||
response = dashscope.Generation.call(
|
||
api_key=self.ali_api_key,
|
||
model="qwen-plus",
|
||
messages=[{"role": "user", "content": prompt}],
|
||
enable_search=True,
|
||
search_options={"forced_search": True}, # 强制联网搜索
|
||
result_format="message",
|
||
)
|
||
|
||
# 获取response的token
|
||
if response.status_code == 200:
|
||
response_contents = (
|
||
response.get("output", {})
|
||
.get("choices", [])[0]
|
||
.get("message", {})
|
||
.get("content", "")
|
||
)
|
||
token += response.get("usage", {}).get("total_tokens", 0)
|
||
else:
|
||
response_contents = f"{response.code} {response.message}"
|
||
token = 0
|
||
return response_contents, token
|
||
except Exception as e:
|
||
logger.error(f"分析推文失败: {e}")
|
||
return None
|
||
|
||
def transform_datetime(self, datetime_text: str):
|
||
utc_time = datetime.strptime(datetime_text, "%Y-%m-%dT%H:%M:%S.%fZ").replace(
|
||
tzinfo=pytz.UTC
|
||
)
|
||
|
||
# 1. 转换为时间戳(毫秒)
|
||
timestamp_ms = int(utc_time.timestamp() * 1000)
|
||
# 2. 转换为北京时间(ISO 8601 格式,带 +08:00)
|
||
beijing_tz = pytz.timezone("Asia/Shanghai")
|
||
beijing_time = utc_time.astimezone(beijing_tz)
|
||
beijing_time_str = beijing_time.strftime("%Y-%m-%dT%H:%M:%S%z")
|
||
# 插入冒号到时区偏移(如 +0800 -> +08:00)
|
||
beijing_time_str = beijing_time_str[:-2] + ":" + beijing_time_str[-2:]
|
||
result = {"timestamp_ms": timestamp_ms, "beijing_time_str": beijing_time_str}
|
||
return result
|