crypto_quant/core/media/truth_social_retriever.py

298 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import core.logger as logging
from core.db.db_truth_social_content import DBTruthSocialContent
from config import TRUTH_SOCIAL_API, COIN_MYSQL_CONFIG, WECHAT_CONFIG, ALI_API_KEY
from core.wechat import Wechat
import requests
import json
import os
from bs4 import BeautifulSoup
import time
from datetime import datetime
import pytz
import pandas as pd
import dashscope
logger = logging.logger
class TruthSocialRetriever:
def __init__(self) -> None:
self.api_key = TRUTH_SOCIAL_API.get("api_key", "")
self.user_info = TRUTH_SOCIAL_API.get("user_id", {})
mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_truth_social_content = DBTruthSocialContent(self.db_url)
trump_key = WECHAT_CONFIG.get("trump_key", "")
if trump_key:
self.wechat = Wechat(trump_key)
else:
self.wechat = None
self.save_path = r"./output/media/truth_social/"
os.makedirs(self.save_path, exist_ok=True)
self.ali_api_key = ALI_API_KEY
instruction_file = r"./instructions/media_article_instructions.json"
with open(instruction_file, "r", encoding="utf-8") as f:
self.instruction = json.load(f)
def get_user_id_from_page(self, handle="realDonaldTrump"):
url = f"https://truthsocial.com/@{handle}"
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
} # 模拟浏览器
response = requests.get(url, headers=headers)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
# 查找嵌入的 JSONTruth Social 使用 data 属性或 script 标签)
scripts = soup.find_all("script")
for script in scripts:
if script.string and "id" in script.string and handle in script.string:
# 简单提取(实际可能需正则匹配 JSON
import re
match = re.search(r'"id"\s*:\s*"(\d+)"', script.string)
if match:
return match.group(1)
return None
def get_user_posts(self, limit: int = None):
"""
获取用户在 Truth Social 的最新帖子。
免费版100次
付费版:
47美元25,000次如果5分钟跑一次则可以跑86.8天
497美元500,000次如果5分钟跑一次则可以跑1736天
参数:
- limit: 最大帖子数API 默认返回 20 条,可通过分页获取更多)。
返回:
- 帖子列表JSON 格式)。
"""
headers = {"x-api-key": self.api_key, "Content-Type": "application/json"}
for user_name, user_id in self.user_info.items():
params = {
"handle": user_name, # 用户名
"user_id": user_id, # 可选,用户 ID
"next_max_id": None, # 分页时设置为上一次响应的 max_id
"trim": "false", # 保留完整内容
}
url = "https://api.scrapecreators.com/v1/truthsocial/user/posts"
logger.info(f"Searching contents for user: {user_name}")
try:
response = requests.get(url, headers=headers, params=params)
response.raise_for_status() # 检查 HTTP 错误
data = response.json()
# 提取帖子列表(假设响应中 'posts' 是键,根据实际文档调整)
if limit is not None and isinstance(limit, int):
posts = data.get("posts", [])[:limit]
else:
posts = data.get("posts", [])
results = []
if posts:
logger.info(f"获取{user_name}帖子: {len(posts)}")
for post in posts:
result = {}
result["article_id"] = post.get("id")
result["user_id"] = user_id
result["user_name"] = user_name
datetime_text = post.get("created_at")
datetime_dict = self.transform_datetime(datetime_text)
timestamp_ms = datetime_dict["timestamp_ms"]
result["timestamp"] = timestamp_ms
beijing_time_str = datetime_dict["beijing_time_str"]
result["date_time"] = beijing_time_str
result["text"] = post.get("text", "无内容")
media_attachments = post.get("media_attachments", [])
result["media_url"] = ""
result["media_type"] = ""
result["media_thumbnail"] = ""
if media_attachments:
for media_attachment in media_attachments:
result["media_url"] = media_attachment.get("url")
result["media_type"] = media_attachment.get("type")
result["media_thumbnail"] = media_attachment.get(
"preview_url"
)
break
results.append(result)
else:
print("获取帖子失败,请检查 API 密钥或网络。")
if len(results) > 0:
# user_path = os.path.join(self.save_path, user_name)
# os.makedirs(user_path, exist_ok=True)
# now_date_time = datetime.now().strftime("%Y%m%d%H%M%S")
# json_file_name = os.path.join(user_path, f"{user_name}_{now_date_time}.json")
# # 将results内容写入json_file_name文件中
# with open(json_file_name, 'w', encoding='utf-8') as f:
# json.dump(results, f, ensure_ascii=False, indent=2)
# logger.info(f"已将{len(results)}条数据保存到: {json_file_name}")
result_df = pd.DataFrame(results)
result_df = self.remove_duplicate_posts(result_df)
result_df["analysis_result"] = ""
result_df["analysis_token"] = 0
if len(result_df) > 0:
result_df = self.send_wechat_message(result_df)
result_df = result_df[
[
"article_id",
"user_id",
"user_name",
"timestamp",
"date_time",
"text",
"analysis_result",
"analysis_token",
"media_url",
"media_type",
"media_thumbnail",
]
]
self.db_truth_social_content.insert_data_to_mysql(result_df)
logger.info(f"已将{len(result_df)}条数据插入到数据库")
else:
logger.info(f"没有数据需要插入到数据库和发送企业微信消息")
except requests.exceptions.RequestException as e:
print(f"请求错误: {e}")
except json.JSONDecodeError as e:
print(f"JSON 解析错误: {e}")
def send_message_by_json_file(self, json_file_name: str):
with open(json_file_name, "r", encoding="utf-8") as f:
results = json.load(f)
result_df = pd.DataFrame(results)
result_df = self.remove_duplicate_posts(result_df)
if len(result_df) > 0:
self.send_wechat_message(result_df)
else:
logger.info(f"没有数据需要发送企业微信消息")
def remove_duplicate_posts(self, result_df: pd.DataFrame):
try:
duplicate_index_list = []
for index, row in result_df.iterrows():
article_id = row["article_id"]
exist_data = self.db_truth_social_content.query_data_by_article_id(
article_id
)
if exist_data:
duplicate_index_list.append(index)
# 删除重复的行
result_df = result_df.drop(duplicate_index_list)
result_df.sort_values(by="timestamp", ascending=True, inplace=True)
result_df.reset_index(drop=True, inplace=True)
logger.info(f"删除重复的行后,剩余{len(result_df)}条数据")
except Exception as e:
result_df = pd.DataFrame([])
logger.error(f"删除重复的行失败: {e}")
return result_df
def send_wechat_message(self, result_df: pd.DataFrame):
if self.wechat is None:
logger.error("企业微信未初始化")
return
for index, row in result_df.iterrows():
try:
date_time = row["date_time"]
text = row["text"]
media_thumbnail = row["media_thumbnail"]
if media_thumbnail and len(media_thumbnail) > 0:
self.wechat.send_image(media_thumbnail)
else:
contents = []
contents.append(f"## 川普推文")
contents.append(text)
contents.append(f"## 推文时间")
contents.append(date_time)
mark_down_text = "\n\n".join(contents)
analysis_result, analysis_token = self.analyze_truth_social_content(
text
)
result_df.at[index, "analysis_result"] = analysis_result
result_df.at[index, "analysis_token"] = analysis_token
analysis_text = f"\n\n## 分析结果\n\n{analysis_result}"
analysis_text += f"\n\n## 分析token\n\n{analysis_token}"
if self.calculate_bytes(mark_down_text + analysis_text) > 4096:
self.wechat.send_markdown(mark_down_text)
if self.calculate_bytes(analysis_text) > 4096:
half_analysis_text_length = len(analysis_text) // 2
analysis_1st = analysis_text[:half_analysis_text_length].strip()
analysis_2nd = analysis_text[half_analysis_text_length:].strip()
self.wechat.send_markdown(
f"## 分析结果第一部分\n\n{analysis_1st}"
)
self.wechat.send_markdown(
f"## 分析结果第二部分\n\n{analysis_2nd}"
)
else:
self.wechat.send_markdown(f"## 分析结果\n\n{analysis_text}")
else:
self.wechat.send_markdown(mark_down_text + analysis_text)
except Exception as e:
logger.error(f"发送企业微信消息失败: {e}")
continue
return result_df
def calculate_bytes(self, text: str):
return len(text.encode("utf-8"))
def analyze_truth_social_content(self, text: str):
try:
context = text
instructions = self.instruction.get("Instructions", "")
output = self.instruction.get("Output", "")
prompt = f"# Context\n\n{context}\n\n# Instructions\n\n{instructions}\n\n# Output\n\n{output}"
response = dashscope.Generation.call(
api_key=self.ali_api_key,
model="qwen-plus",
messages=[{"role": "user", "content": prompt}],
enable_search=True,
search_options={"forced_search": True}, # 强制联网搜索
result_format="message",
)
response_contents = (
response.get("output", {})
.get("choices", [])[0]
.get("message", {})
.get("content", "")
)
# 获取response的token
token = response.get("usage", {}).get("total_tokens", 0)
return response_contents, token
except Exception as e:
logger.error(f"分析推文失败: {e}")
return None
def transform_datetime(self, datetime_text: str):
utc_time = datetime.strptime(datetime_text, "%Y-%m-%dT%H:%M:%S.%fZ").replace(
tzinfo=pytz.UTC
)
# 1. 转换为时间戳(毫秒)
timestamp_ms = int(utc_time.timestamp() * 1000)
# 2. 转换为北京时间ISO 8601 格式,带 +08:00
beijing_tz = pytz.timezone("Asia/Shanghai")
beijing_time = utc_time.astimezone(beijing_tz)
beijing_time_str = beijing_time.strftime("%Y-%m-%dT%H:%M:%S%z")
# 插入冒号到时区偏移(如 +0800 -> +08:00
beijing_time_str = beijing_time_str[:-2] + ":" + beijing_time_str[-2:]
result = {"timestamp_ms": timestamp_ms, "beijing_time_str": beijing_time_str}
return result