support SAR

This commit is contained in:
blade 2025-09-02 12:44:34 +08:00
parent e990db26a6
commit 6ee64abaf5
9 changed files with 296 additions and 104 deletions

View File

@ -73,7 +73,7 @@ OKX_MONITOR_CONFIG = {
US_STOCK_MONITOR_CONFIG = { US_STOCK_MONITOR_CONFIG = {
"volume_monitor":{ "volume_monitor":{
"symbols": ["QQQ", "TQQQ", "MSFT", "AAPL", "GOOG", "NVDA", "META", "AMZN", "TSLA", "AVGO"], "symbols": ["QQQ", "TQQQ", "MSFT", "AAPL", "GOOG", "NVDA", "META", "AMZN", "TSLA", "AVGO"],
"bars": ["5", "15m", "30m", "1H"], "bars": ["5m", "15m", "30m", "1H"],
"initial_date": "2015-08-31 00:00:00" "initial_date": "2015-08-31 00:00:00"
} }
} }

View File

@ -137,6 +137,64 @@ class MetricsCalculation:
] = "死叉" ] = "死叉"
return df return df
def sar(self, df: pd.DataFrame, acceleration=0.02, maximum=0.2):
"""
计算SAR抛物线转向指标
Args:
df: 包含high, low, close列的DataFrame
acceleration: 加速因子默认0.02控制SAR值随价格变化的加速程度
maximum: 最大加速因子默认0.2设定加速因子的上限防止过度增加
参数说明
- acceleration=0.02: 标准设置适合大多数市场
- maximum=0.2: 标准设置防止SAR过度敏感
- 对于高波动性市场如加密货币可适当增加acceleration到0.03-0.04
- 对于低波动性市场可降低acceleration到0.015-0.02
"""
logger.info(f"计算SAR指标acceleration={acceleration}, maximum={maximum}")
# 初始化sar和sar_signal列
df["sar"] = np.nan
df["sar_signal"] = ""
df["sar"] = tb.SAR(
df["high"], df["low"], acceleration=acceleration, maximum=maximum
)
# sar_position = df["sar"] > df["close"]
# df.loc[
# sar_position[
# (sar_position == True) & (sar_position.shift() == False)
# ].index,
# "sar_signal",
# ] = "SAR多头"
# df.loc[
# sar_position[
# (sar_position == False) & (sar_position.shift() == True)
# ].index,
# "sar_signal",
# ] = "SAR空头"
# df.loc[sar_position[sar_position == False].index, "sar_signal"] = "SAR观望"
# 生成交易信号
# SAR多头: SAR < close
# SAR空头: SAR > close
# SAR观望: SAR == close 或 SAR为NaN
df["sar_signal"] = np.where(
df["sar"].isna(),
"SAR观望",
np.where(
df["sar"] < df["close"],
"SAR多头",
np.where(df["sar"] > df["close"], "SAR空头", "SAR观望"),
),
)
# 确保sar列为float类型
df["sar"] = df["sar"].astype(float)
# 确保sar_signal列为str类型
df["sar_signal"] = df["sar_signal"].astype(str)
return df
def set_kdj_pattern(self, df: pd.DataFrame): def set_kdj_pattern(self, df: pd.DataFrame):
""" """
设置每一根K线数据对应的KDJ形态超买超卖情况 设置每一根K线数据对应的KDJ形态超买超卖情况
@ -209,7 +267,7 @@ class MetricsCalculation:
# 震荡条件已经在初始化时设置,无需额外处理 # 震荡条件已经在初始化时设置,无需额外处理
data["ma_long_short"] = "震荡" data["ma_long_short"] = "震荡"
data = self._trend_strength_method(data) data = self._trend_strength_method(data)
# 计算各均线偏离度的标准差和均值 # 计算各均线偏离度的标准差和均值
data["ma_divergence"] = "未知" data["ma_divergence"] = "未知"
ma_diffs = data[ ma_diffs = data[
@ -421,7 +479,7 @@ class MetricsCalculation:
df["ma30"] = df["close"].rolling(window=30).mean().dropna() df["ma30"] = df["close"].rolling(window=30).mean().dropna()
df["ma_cross"] = "" df["ma_cross"] = ""
# 定义均线交叉检测函数 # 定义均线交叉检测函数
def detect_cross(short_ma, long_ma, short_name, long_name): def detect_cross(short_ma, long_ma, short_name, long_name):
"""检测均线交叉""" """检测均线交叉"""
@ -429,22 +487,22 @@ class MetricsCalculation:
cross_up = (position == True) & (position.shift() == False) cross_up = (position == True) & (position.shift() == False)
cross_down = (position == False) & (position.shift() == True) cross_down = (position == False) & (position.shift() == True)
return cross_up, cross_down return cross_up, cross_down
# 检测所有均线交叉 # 检测所有均线交叉
crosses = {} crosses = {}
# MA5与其他均线的交叉 # MA5与其他均线的交叉
ma5_ma10_up, ma5_ma10_down = detect_cross("ma5", "ma10", "5", "10") ma5_ma10_up, ma5_ma10_down = detect_cross("ma5", "ma10", "5", "10")
ma5_ma20_up, ma5_ma20_down = detect_cross("ma5", "ma20", "5", "20") ma5_ma20_up, ma5_ma20_down = detect_cross("ma5", "ma20", "5", "20")
ma5_ma30_up, ma5_ma30_down = detect_cross("ma5", "ma30", "5", "30") ma5_ma30_up, ma5_ma30_down = detect_cross("ma5", "ma30", "5", "30")
# MA10与其他均线的交叉 # MA10与其他均线的交叉
ma10_ma20_up, ma10_ma20_down = detect_cross("ma10", "ma20", "10", "20") ma10_ma20_up, ma10_ma20_down = detect_cross("ma10", "ma20", "10", "20")
ma10_ma30_up, ma10_ma30_down = detect_cross("ma10", "ma30", "10", "30") ma10_ma30_up, ma10_ma30_down = detect_cross("ma10", "ma30", "10", "30")
# MA20与MA30的交叉 # MA20与MA30的交叉
ma20_ma30_up, ma20_ma30_down = detect_cross("ma20", "ma30", "20", "30") ma20_ma30_up, ma20_ma30_down = detect_cross("ma20", "ma30", "20", "30")
# 存储上穿信号 # 存储上穿信号
crosses["5上穿10"] = ma5_ma10_up crosses["5上穿10"] = ma5_ma10_up
crosses["5上穿20"] = ma5_ma20_up crosses["5上穿20"] = ma5_ma20_up
@ -452,7 +510,7 @@ class MetricsCalculation:
crosses["10上穿20"] = ma10_ma20_up crosses["10上穿20"] = ma10_ma20_up
crosses["10上穿30"] = ma10_ma30_up crosses["10上穿30"] = ma10_ma30_up
crosses["20上穿30"] = ma20_ma30_up crosses["20上穿30"] = ma20_ma30_up
# 存储下穿信号 # 存储下穿信号
crosses["10下穿5"] = ma5_ma10_down crosses["10下穿5"] = ma5_ma10_down
crosses["20下穿10"] = ma10_ma20_down crosses["20下穿10"] = ma10_ma20_down
@ -460,22 +518,22 @@ class MetricsCalculation:
crosses["30下穿20"] = ma20_ma30_down crosses["30下穿20"] = ma20_ma30_down
crosses["30下穿10"] = ma10_ma30_down crosses["30下穿10"] = ma10_ma30_down
crosses["30下穿5"] = ma5_ma30_down crosses["30下穿5"] = ma5_ma30_down
# 分析每个时间点的交叉组合 # 分析每个时间点的交叉组合
for idx in df.index: for idx in df.index:
current_crosses = [] current_crosses = []
# 检查当前时间点的所有交叉信号 # 检查当前时间点的所有交叉信号
for cross_name, cross_signal in crosses.items(): for cross_name, cross_signal in crosses.items():
if cross_signal.loc[idx]: if cross_signal.loc[idx]:
current_crosses.append(cross_name) current_crosses.append(cross_name)
# 根据交叉类型组合信号 # 根据交叉类型组合信号
if len(current_crosses) > 0: if len(current_crosses) > 0:
# 分离上穿和下穿信号 # 分离上穿和下穿信号
up_crosses = [c for c in current_crosses if "上穿" in c] up_crosses = [c for c in current_crosses if "上穿" in c]
down_crosses = [c for c in current_crosses if "下穿" in c] down_crosses = [c for c in current_crosses if "下穿" in c]
# 组合信号 # 组合信号
if len(up_crosses) > 1: if len(up_crosses) > 1:
# 多个上穿信号 # 多个上穿信号
@ -486,7 +544,7 @@ class MetricsCalculation:
else: else:
# 单个交叉信号 # 单个交叉信号
df.loc[idx, "ma_cross"] = current_crosses[0] df.loc[idx, "ma_cross"] = current_crosses[0]
return df return df
def rsi(self, df: pd.DataFrame): def rsi(self, df: pd.DataFrame):
@ -726,13 +784,21 @@ class MetricsCalculation:
) # 下影线长度 ) # 下影线长度
# 计算实体占比 # 计算实体占比
df["open_close_fill"] = df["open_close_diff"] / df["high_low_diff"].replace(0, np.nan) df["open_close_fill"] = df["open_close_diff"] / df["high_low_diff"].replace(
0, np.nan
)
df["open_close_fill"] = df["open_close_fill"].fillna(1.0) # 处理除零情况 df["open_close_fill"] = df["open_close_fill"].fillna(1.0) # 处理除零情况
# 计算影线占比 # 计算影线占比
df["upper_shadow_ratio"] = df["high_close_diff"] / df["high_low_diff"].replace(0, np.nan) df["upper_shadow_ratio"] = df["high_close_diff"] / df["high_low_diff"].replace(
df["lower_shadow_ratio"] = df["low_close_diff"] / df["high_low_diff"].replace(0, np.nan) 0, np.nan
df["upper_shadow_ratio"] = df["upper_shadow_ratio"].fillna(0) # 无波动时影线占比为 0 )
df["lower_shadow_ratio"] = df["low_close_diff"] / df["high_low_diff"].replace(
0, np.nan
)
df["upper_shadow_ratio"] = df["upper_shadow_ratio"].fillna(
0
) # 无波动时影线占比为 0
df["lower_shadow_ratio"] = df["lower_shadow_ratio"].fillna(0) df["lower_shadow_ratio"] = df["lower_shadow_ratio"].fillna(0)
# 初始化k_shape列 # 初始化k_shape列
@ -760,15 +826,23 @@ class MetricsCalculation:
) )
# 计算滚动窗口内 price_range_ratio 和 price_range_zscore 的分位数 # 计算滚动窗口内 price_range_ratio 和 price_range_zscore 的分位数
df["price_range_ratio_p75"] = df["price_range_ratio"].rolling(window=window_size, min_periods=1).quantile(0.75) df["price_range_ratio_p75"] = (
df["price_range_zscore_p75"] = df["price_range_zscore"].rolling(window=window_size, min_periods=1).quantile(0.75) df["price_range_ratio"]
.rolling(window=window_size, min_periods=1)
.quantile(0.75)
)
df["price_range_zscore_p75"] = (
df["price_range_zscore"]
.rolling(window=window_size, min_periods=1)
.quantile(0.75)
)
# 识别“一字”形态波动极小Z 分数 < -1.0 或 price_range_ratio < 0.05%)且无影线 # 识别“一字”形态波动极小Z 分数 < -1.0 或 price_range_ratio < 0.05%)且无影线
one_line_condition = ( one_line_condition = (
((df["price_range_zscore"] < -1.0) | (df["price_range_ratio"] < 0.05)) & ((df["price_range_zscore"] < -1.0) | (df["price_range_ratio"] < 0.05))
(df["upper_shadow_ratio"] <= 0.01) & # 上影线极小或无 & (df["upper_shadow_ratio"] <= 0.01) # 上影线极小或无
(df["lower_shadow_ratio"] <= 0.01) & # 下影线极小或无 & (df["lower_shadow_ratio"] <= 0.01) # 下影线极小或无
(df["open_close_diff"] / df["close"] < 0.0005) # 开收盘价差小于0.05% & (df["open_close_diff"] / df["close"] < 0.0005) # 开收盘价差小于0.05%
) )
df.loc[one_line_condition, "k_shape"] = "一字" df.loc[one_line_condition, "k_shape"] = "一字"
@ -857,10 +931,18 @@ class MetricsCalculation:
& (df["open_close_fill"] <= 0.55) & (df["open_close_fill"] <= 0.55)
& (df["k_shape"] != "一字") & (df["k_shape"] != "一字")
) )
df.loc[small_body_condition_2 df.loc[
& (df["upper_shadow_ratio"] >= 0.25) & (df["k_shape"] == "未知"), "k_shape"] = "长上影线纺锤体" small_body_condition_2
df.loc[small_body_condition_2 & (df["upper_shadow_ratio"] >= 0.25)
& (df["lower_shadow_ratio"] >= 0.25) & (df["k_shape"] == "未知"), "k_shape"] = "长下影线纺锤体" & (df["k_shape"] == "未知"),
"k_shape",
] = "长上影线纺锤体"
df.loc[
small_body_condition_2
& (df["lower_shadow_ratio"] >= 0.25)
& (df["k_shape"] == "未知"),
"k_shape",
] = "长下影线纺锤体"
df.loc[small_body_condition_2 & (df["k_shape"] == "未知"), "k_shape"] = "小实体" df.loc[small_body_condition_2 & (df["k_shape"] == "未知"), "k_shape"] = "小实体"
# 大实体实体占比55%-90% # 大实体实体占比55%-90%
@ -873,16 +955,20 @@ class MetricsCalculation:
# 识别“超大实体”形态:实体占比 75%-90%,价格波动显著,且非“一字”或“大实体” # 识别“超大实体”形态:实体占比 75%-90%,价格波动显著,且非“一字”或“大实体”
super_large_body_condition = ( super_large_body_condition = (
(df["open_close_fill"] > 0.75) & (df["open_close_fill"] > 0.75)
(df["open_close_fill"] <= 1) & & (df["open_close_fill"] <= 1)
(df["price_range_ratio"] >= df["price_range_ratio_p75"]) & # 价格波动范围超过75th分位数 & (
(df["k_shape"] != "一字") df["price_range_ratio"] >= df["price_range_ratio_p75"]
) # 价格波动范围超过75th分位数
& (df["k_shape"] != "一字")
) )
df.loc[super_large_body_condition, "k_shape"] = "超大实体" df.loc[super_large_body_condition, "k_shape"] = "超大实体"
# 光头光脚:实体占比>90%(非一字情况) # 光头光脚:实体占比>90%(非一字情况)
bald_body_condition = (df["open_close_fill"] > 0.9) & (df["k_shape"] != "一字") bald_body_condition = (df["open_close_fill"] > 0.9) & (df["k_shape"] != "一字")
df.loc[bald_body_condition & (df["k_shape"] == "超大实体"), "k_shape"] = "超大实体+光头光脚" df.loc[bald_body_condition & (df["k_shape"] == "超大实体"), "k_shape"] = (
"超大实体+光头光脚"
)
df.loc[bald_body_condition & (df["k_shape"] == "未知"), "k_shape"] = "光头光脚" df.loc[bald_body_condition & (df["k_shape"] == "未知"), "k_shape"] = "光头光脚"
# 清理临时列 # 清理临时列
@ -911,7 +997,7 @@ class MetricsCalculation:
def set_ma_long_short_advanced(self, data: pd.DataFrame, method="weighted_voting"): def set_ma_long_short_advanced(self, data: pd.DataFrame, method="weighted_voting"):
""" """
高级均线多空判定方法提供多种科学的判定策略 高级均线多空判定方法提供多种科学的判定策略
Args: Args:
data: 包含均线数据的DataFrame data: 包含均线数据的DataFrame
method: 判定方法 method: 判定方法
@ -922,7 +1008,7 @@ class MetricsCalculation:
- "hybrid": 混合方法 - "hybrid": 混合方法
""" """
logger.info(f"使用{method}方法设置均线多空") logger.info(f"使用{method}方法设置均线多空")
if method == "weighted_voting": if method == "weighted_voting":
return self._weighted_voting_method(data) return self._weighted_voting_method(data)
elif method == "trend_strength": elif method == "trend_strength":
@ -936,20 +1022,20 @@ class MetricsCalculation:
else: else:
logger.warning(f"未知的方法: {method},使用默认加权投票方法") logger.warning(f"未知的方法: {method},使用默认加权投票方法")
return self._weighted_voting_method(data) return self._weighted_voting_method(data)
def _weighted_voting_method(self, data: pd.DataFrame): def _weighted_voting_method(self, data: pd.DataFrame):
"""加权投票机制:短期均线权重更高""" """加权投票机制:短期均线权重更高"""
# 权重设置:短期均线权重更高 # 权重设置:短期均线权重更高
weights = { weights = {
"ma5_close_diff": 0.4, # 40%权重 "ma5_close_diff": 0.4, # 40%权重
"ma10_close_diff": 0.3, # 30%权重 "ma10_close_diff": 0.3, # 30%权重
"ma20_close_diff": 0.2, # 20%权重 "ma20_close_diff": 0.2, # 20%权重
"ma30_close_diff": 0.1 # 10%权重 "ma30_close_diff": 0.1, # 10%权重
} }
# 计算加权得分 # 计算加权得分
weighted_score = sum(data[col] * weight for col, weight in weights.items()) weighted_score = sum(data[col] * weight for col, weight in weights.items())
# 动态阈值:基于历史分布 # 动态阈值:基于历史分布
window_size = min(50, len(data) // 4) window_size = min(50, len(data) // 4)
if window_size > 10: if window_size > 10:
@ -960,69 +1046,78 @@ class MetricsCalculation:
else: else:
long_threshold = 0.3 long_threshold = 0.3
short_threshold = -0.3 short_threshold = -0.3
# 判定逻辑 # 判定逻辑
data.loc[weighted_score > long_threshold, "ma_long_short"] = "" data.loc[weighted_score > long_threshold, "ma_long_short"] = ""
data.loc[weighted_score < short_threshold, "ma_long_short"] = "" data.loc[weighted_score < short_threshold, "ma_long_short"] = ""
return data return data
def _trend_strength_method(self, data: pd.DataFrame): def _trend_strength_method(self, data: pd.DataFrame):
"""趋势强度评估:考虑偏离幅度和趋势持续性""" """趋势强度评估:考虑偏离幅度和趋势持续性"""
# 计算趋势强度(考虑偏离幅度) # 计算趋势强度(考虑偏离幅度)
trend_strength = data["ma_close_avg"] trend_strength = data["ma_close_avg"]
# 计算趋势持续性(连续同向的周期数) # 计算趋势持续性(连续同向的周期数)
trend_persistence = self._calculate_trend_persistence(data) trend_persistence = self._calculate_trend_persistence(data)
# 综合评分 # 综合评分
strength_threshold = 0.5 strength_threshold = 0.5
persistence_threshold = 3 # 至少连续3个周期 persistence_threshold = 3 # 至少连续3个周期
long_condition = (trend_strength > strength_threshold) & (trend_persistence >= persistence_threshold) long_condition = (trend_strength > strength_threshold) & (
short_condition = (trend_strength < -strength_threshold) & (trend_persistence >= persistence_threshold) trend_persistence >= persistence_threshold
)
short_condition = (trend_strength < -strength_threshold) & (
trend_persistence >= persistence_threshold
)
data.loc[long_condition, "ma_long_short"] = "" data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = "" data.loc[short_condition, "ma_long_short"] = ""
return data return data
def _ma_alignment_method(self, data: pd.DataFrame): def _ma_alignment_method(self, data: pd.DataFrame):
"""均线排列分析:检查均线的排列顺序和间距""" """均线排列分析:检查均线的排列顺序和间距"""
# 检查均线排列顺序 # 检查均线排列顺序
ma_alignment_score = 0 ma_alignment_score = 0
# 多头排列MA5 > MA10 > MA20 > MA30 # 多头排列MA5 > MA10 > MA20 > MA30
bullish_alignment = ( bullish_alignment = (
(data["ma5_close_diff"] > data["ma10_close_diff"]) & (data["ma5_close_diff"] > data["ma10_close_diff"])
(data["ma10_close_diff"] > data["ma20_close_diff"]) & & (data["ma10_close_diff"] > data["ma20_close_diff"])
(data["ma20_close_diff"] > data["ma30_close_diff"]) & (data["ma20_close_diff"] > data["ma30_close_diff"])
) )
# 空头排列MA5 < MA10 < MA20 < MA30 # 空头排列MA5 < MA10 < MA20 < MA30
bearish_alignment = ( bearish_alignment = (
(data["ma5_close_diff"] < data["ma10_close_diff"]) & (data["ma5_close_diff"] < data["ma10_close_diff"])
(data["ma10_close_diff"] < data["ma20_close_diff"]) & & (data["ma10_close_diff"] < data["ma20_close_diff"])
(data["ma20_close_diff"] < data["ma30_close_diff"]) & (data["ma20_close_diff"] < data["ma30_close_diff"])
) )
# 计算均线间距的合理性 # 计算均线间距的合理性
ma_spacing = self._calculate_ma_spacing(data) ma_spacing = self._calculate_ma_spacing(data)
# 综合判定 # 综合判定
long_condition = bullish_alignment & (ma_spacing > 0.2) long_condition = bullish_alignment & (ma_spacing > 0.2)
short_condition = bearish_alignment & (ma_spacing > 0.2) short_condition = bearish_alignment & (ma_spacing > 0.2)
data.loc[long_condition, "ma_long_short"] = "" data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = "" data.loc[short_condition, "ma_long_short"] = ""
return data return data
def _statistical_method(self, data: pd.DataFrame): def _statistical_method(self, data: pd.DataFrame):
"""统计分布方法基于历史分位数和Z-score""" """统计分布方法基于历史分位数和Z-score"""
# 计算各均线偏离度的Z-score # 计算各均线偏离度的Z-score
ma_cols = ["ma5_close_diff", "ma10_close_diff", "ma20_close_diff", "ma30_close_diff"] ma_cols = [
"ma5_close_diff",
"ma10_close_diff",
"ma20_close_diff",
"ma30_close_diff",
]
# 使用滚动窗口计算Z-score # 使用滚动窗口计算Z-score
window_size = min(30, len(data) // 4) window_size = min(30, len(data) // 4)
if window_size > 10: if window_size > 10:
@ -1031,44 +1126,46 @@ class MetricsCalculation:
rolling_mean = data[col].rolling(window=window_size).mean() rolling_mean = data[col].rolling(window=window_size).mean()
rolling_std = data[col].rolling(window=window_size).std() rolling_std = data[col].rolling(window=window_size).std()
z_scores[col] = (data[col] - rolling_mean) / rolling_std z_scores[col] = (data[col] - rolling_mean) / rolling_std
# 计算综合Z-score # 计算综合Z-score
avg_z_score = z_scores.mean(axis=1) avg_z_score = z_scores.mean(axis=1)
# 基于Z-score判定 # 基于Z-score判定
long_condition = avg_z_score > 0.5 long_condition = avg_z_score > 0.5
short_condition = avg_z_score < -0.5 short_condition = avg_z_score < -0.5
data.loc[long_condition, "ma_long_short"] = "" data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = "" data.loc[short_condition, "ma_long_short"] = ""
return data return data
def _hybrid_method(self, data: pd.DataFrame): def _hybrid_method(self, data: pd.DataFrame):
"""混合方法:结合多种判定策略""" """混合方法:结合多种判定策略"""
# 1. 加权投票得分 # 1. 加权投票得分
weights = {"ma5_close_diff": 0.4, "ma10_close_diff": 0.3, weights = {
"ma20_close_diff": 0.2, "ma30_close_diff": 0.1} "ma5_close_diff": 0.4,
"ma10_close_diff": 0.3,
"ma20_close_diff": 0.2,
"ma30_close_diff": 0.1,
}
weighted_score = sum(data[col] * weight for col, weight in weights.items()) weighted_score = sum(data[col] * weight for col, weight in weights.items())
# 2. 均线排列得分 # 2. 均线排列得分
alignment_score = ( alignment_score = (
(data["ma5_close_diff"] >= data["ma10_close_diff"]) * 0.25 + (data["ma5_close_diff"] >= data["ma10_close_diff"]) * 0.25
(data["ma10_close_diff"] >= data["ma20_close_diff"]) * 0.25 + + (data["ma10_close_diff"] >= data["ma20_close_diff"]) * 0.25
(data["ma20_close_diff"] >= data["ma30_close_diff"]) * 0.25 + + (data["ma20_close_diff"] >= data["ma30_close_diff"]) * 0.25
(data["ma_close_avg"] > 0) * 0.25 + (data["ma_close_avg"] > 0) * 0.25
) )
# 3. 趋势强度得分 # 3. 趋势强度得分
strength_score = data["ma_close_avg"].abs() strength_score = data["ma_close_avg"].abs()
# 4. 综合评分 # 4. 综合评分
composite_score = ( composite_score = (
weighted_score * 0.4 + weighted_score * 0.4 + alignment_score * 0.3 + strength_score * 0.3
alignment_score * 0.3 +
strength_score * 0.3
) )
# 动态阈值 # 动态阈值
window_size = min(50, len(data) // 4) window_size = min(50, len(data) // 4)
if window_size > 10: if window_size > 10:
@ -1079,38 +1176,44 @@ class MetricsCalculation:
else: else:
long_threshold = 0.4 long_threshold = 0.4
short_threshold = -0.4 short_threshold = -0.4
# 判定 # 判定
long_condition = composite_score > long_threshold long_condition = composite_score > long_threshold
short_condition = composite_score < short_threshold short_condition = composite_score < short_threshold
data.loc[long_condition, "ma_long_short"] = "" data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = "" data.loc[short_condition, "ma_long_short"] = ""
return data return data
def _calculate_trend_persistence(self, data: pd.DataFrame): def _calculate_trend_persistence(self, data: pd.DataFrame):
"""计算趋势持续性""" """计算趋势持续性"""
trend_persistence = pd.Series(0, index=data.index) trend_persistence = pd.Series(0, index=data.index)
for i in range(1, len(data)): for i in range(1, len(data)):
if data["ma_close_avg"].iloc[i] > 0 and data["ma_close_avg"].iloc[i-1] > 0: if (
trend_persistence.iloc[i] = trend_persistence.iloc[i-1] + 1 data["ma_close_avg"].iloc[i] > 0
elif data["ma_close_avg"].iloc[i] < 0 and data["ma_close_avg"].iloc[i-1] < 0: and data["ma_close_avg"].iloc[i - 1] > 0
trend_persistence.iloc[i] = trend_persistence.iloc[i-1] + 1 ):
trend_persistence.iloc[i] = trend_persistence.iloc[i - 1] + 1
elif (
data["ma_close_avg"].iloc[i] < 0
and data["ma_close_avg"].iloc[i - 1] < 0
):
trend_persistence.iloc[i] = trend_persistence.iloc[i - 1] + 1
else: else:
trend_persistence.iloc[i] = 0 trend_persistence.iloc[i] = 0
return trend_persistence return trend_persistence
def _calculate_ma_spacing(self, data: pd.DataFrame): def _calculate_ma_spacing(self, data: pd.DataFrame):
"""计算均线间距的合理性""" """计算均线间距的合理性"""
# 计算相邻均线之间的间距 # 计算相邻均线之间的间距
spacing_5_10 = abs(data["ma5_close_diff"] - data["ma10_close_diff"]) spacing_5_10 = abs(data["ma5_close_diff"] - data["ma10_close_diff"])
spacing_10_20 = abs(data["ma10_close_diff"] - data["ma20_close_diff"]) spacing_10_20 = abs(data["ma10_close_diff"] - data["ma20_close_diff"])
spacing_20_30 = abs(data["ma20_close_diff"] - data["ma30_close_diff"]) spacing_20_30 = abs(data["ma20_close_diff"] - data["ma30_close_diff"])
# 平均间距 # 平均间距
avg_spacing = (spacing_5_10 + spacing_10_20 + spacing_20_30) / 3 avg_spacing = (spacing_5_10 + spacing_10_20 + spacing_20_30) / 3
return avg_spacing return avg_spacing

View File

@ -44,6 +44,8 @@ class DBMarketData:
"kdj_j", "kdj_j",
"kdj_signal", "kdj_signal",
"kdj_pattern", "kdj_pattern",
"sar",
"sar_signal",
"ma5", "ma5",
"ma10", "ma10",
"ma20", "ma20",
@ -471,7 +473,7 @@ class DBMarketData:
return self.db_manager.query_data(sql, condition_dict, return_multi=False) return self.db_manager.query_data(sql, condition_dict, return_multi=False)
def query_market_data_by_symbol_bar(self, symbol: str, bar: str, start: str, end: str): def query_market_data_by_symbol_bar(self, symbol: str, bar: str, start: str = None, end: str = None):
""" """
根据交易对和K线周期查询数据 根据交易对和K线周期查询数据
:param symbol: 交易对 :param symbol: 交易对

View File

@ -134,13 +134,20 @@ class ORBStrategy:
return int(max_shares) # 股数取整 return int(max_shares) # 股数取整
def generate_orb_signals(self): def generate_orb_signals(self, direction: str = None, by_sar: bool = False):
""" """
生成ORB策略信号每日仅1次交易机会 生成ORB策略信号每日仅1次交易机会
- 第一根5分钟K线确定开盘区间High1, Low1 - 第一根5分钟K线确定开盘区间High1, Low1
- 第二根5分钟K线根据第一根K线方向生成多空信号 - 第二根5分钟K线根据第一根K线方向生成多空信号
:param direction: 方向None=自动Long=多头Short=空头
:param by_sar: 是否根据SAR指标生成信号True=False=
""" """
logger.info("开始生成ORB策略信号") direction_desc = "既做多又做空"
if direction == "Long":
direction_desc = "做多"
elif direction == "Short":
direction_desc = "做空"
logger.info(f"开始生成ORB策略信号{direction_desc}根据SAR指标{by_sar}")
if self.data is None: if self.data is None:
raise ValueError("请先调用fetch_intraday_data获取数据") raise ValueError("请先调用fetch_intraday_data获取数据")
@ -164,16 +171,16 @@ class ORBStrategy:
entry_time = second_candle.date_time # entry时间 entry_time = second_candle.date_time # entry时间
# 生成信号第一根K线方向决定多空排除十字星open1 == close1 # 生成信号第一根K线方向决定多空排除十字星open1 == close1
if open1 < close1: if open1 < close1 and (direction == "Long" or direction is None):
# 第一根K线收涨→多头信号 # 第一根K线收涨→多头信号
signal = "Long" signal = "Long"
stop_price = low1 # 多头止损=第一根K线最低价 stop_price = low1 # 多头止损=第一根K线最低价
elif open1 > close1: elif open1 > close1 and (direction == "Short" or direction is None):
# 第一根K线收跌→空头信号 # 第一根K线收跌→空头信号
signal = "Short" signal = "Short"
stop_price = high1 # 空头止损=第一根K线最高价 stop_price = high1 # 空头止损=第一根K线最高价
else: else:
# 十字星→无信号 # 与direction不一致或十字星→无信号
signal = None signal = None
stop_price = None stop_price = None

View File

@ -512,7 +512,7 @@ def test_send_huge_volume_data_to_wechat():
if __name__ == "__main__": if __name__ == "__main__":
# test_send_huge_volume_data_to_wechat() # test_send_huge_volume_data_to_wechat()
# batch_initial_detect_volume_spike(threshold=2.0) # batch_initial_detect_volume_spike(threshold=2.0)
batch_update_volume_spike(threshold=2.0, is_us_stock=True) batch_update_volume_spike(threshold=2.0, is_us_stock=False)
# huge_volume_main = HugeVolumeMain(threshold=2.0) # huge_volume_main = HugeVolumeMain(threshold=2.0)
# huge_volume_main.batch_next_periods_rise_or_fall(output_excel=True) # huge_volume_main.batch_next_periods_rise_or_fall(output_excel=True)
# data_file_path = "./output/huge_volume_statistics/next_periods_rise_or_fall_stat_20250731200304.xlsx" # data_file_path = "./output/huge_volume_statistics/next_periods_rise_or_fall_stat_20250731200304.xlsx"

View File

@ -252,6 +252,8 @@ class MarketDataMain:
data["kdj_j"] = None data["kdj_j"] = None
data["kdj_signal"] = None data["kdj_signal"] = None
data["kdj_pattern"] = None data["kdj_pattern"] = None
data["sar"] = None
data["sar_signal"] = None
data["ma5"] = None data["ma5"] = None
data["ma10"] = None data["ma10"] = None
data["ma20"] = None data["ma20"] = None
@ -301,6 +303,8 @@ class MarketDataMain:
kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值', kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值',
kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号', kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号',
kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买超卖徘徊', kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买超卖徘徊',
sar DOUBLE DEFAULT NULL COMMENT 'SAR指标',
sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头SAR空头SAR观望',
ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线', ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线',
ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线', ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线',
ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线', ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线',
@ -331,6 +335,7 @@ class MarketDataMain:
data = metrics_calculation.pre_close(data) data = metrics_calculation.pre_close(data)
data = metrics_calculation.macd(data) data = metrics_calculation.macd(data)
data = metrics_calculation.kdj(data) data = metrics_calculation.kdj(data)
data = metrics_calculation.sar(data)
data = metrics_calculation.set_kdj_pattern(data) data = metrics_calculation.set_kdj_pattern(data)
data = metrics_calculation.update_macd_divergence_column_simple(data) data = metrics_calculation.update_macd_divergence_column_simple(data)
data = metrics_calculation.ma5102030(data) data = metrics_calculation.ma5102030(data)

View File

@ -33,6 +33,8 @@ CREATE TABLE IF NOT EXISTS crypto_market_data (
kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值', kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值',
kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号', kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号',
kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买超卖徘徊', kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买超卖徘徊',
sar DOUBLE DEFAULT NULL COMMENT 'SAR指标值',
sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头空头信号',
ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线', ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线',
ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线', ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线',
ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线', ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线',
@ -64,3 +66,7 @@ ALTER TABLE crypto_market_data MODIFY COLUMN ma_cross VARCHAR(150) DEFAULT NULL
--date_time_us字段 --date_time_us字段
ALTER TABLE crypto_market_data ADD COLUMN date_time_us VARCHAR(50) NULL COMMENT '美国时间格式的日期时间' AFTER date_time; ALTER TABLE crypto_market_data ADD COLUMN date_time_us VARCHAR(50) NULL COMMENT '美国时间格式的日期时间' AFTER date_time;
--SAR相关字段
ALTER TABLE crypto_market_data ADD COLUMN sar DOUBLE DEFAULT NULL COMMENT 'SAR指标值' AFTER kdj_pattern;
ALTER TABLE crypto_market_data ADD COLUMN sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头空头信号' AFTER sar;

69
update_data_main.py Normal file
View File

@ -0,0 +1,69 @@
import pandas as pd
from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation
from config import MYSQL_CONFIG, US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG
import core.logger as logging
logger = logging.logger
class UpdateDataMain:
def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.metrics_calculation = MetricsCalculation()
def batch_update_data(self, is_us_stock: bool = False):
"""
批量更新数据
"""
if is_us_stock:
symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", [])
bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("bars", ["5m", "15m", "1H", "1D"])
else:
symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", [])
bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get("bars", ["5m", "15m", "1H", "1D"])
for symbol in symbols:
for bar in bars:
self.update_data(symbol, bar)
def update_data(self, symbol: str, bar: str):
"""
更新数据
"""
logger.info(f"开始更新数据: {symbol} {bar}")
data = self.db_market_data.query_market_data_by_symbol_bar(symbol, bar)
logger.info(f"查询数据完成: {symbol} {bar},共有{len(data)}条数据")
data = pd.DataFrame(data)
data.sort_values(by="timestamp", inplace=True)
data = self.update_date_time_us(data)
logger.info("更新SAR指标")
data = self.metrics_calculation.sar(data)
logger.info("更新SAR指标完成")
logger.info(f"开始保存数据: {symbol} {bar}")
self.db_market_data.insert_data_to_mysql(data)
logger.info(f"保存数据完成: {symbol} {bar}")
def update_date_time_us(self, data: pd.DataFrame):
"""
更新日期时间
"""
logger.info(f"开始更新美东日期时间: {data.shape[0]}条数据")
dt_us_series = pd.to_datetime(data['timestamp'].astype(int), unit='ms', utc=True, errors='coerce').dt.tz_convert('America/New_York')
data['date_time_us'] = dt_us_series.dt.strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"更新美东日期时间完成: {data.shape[0]}条数据")
return data
if __name__ == "__main__":
update_data_main = UpdateDataMain()
update_data_main.batch_update_data(is_us_stock=True)
update_data_main.batch_update_data(is_us_stock=False)