optimize ma long short/ ma cross algorithms
This commit is contained in:
parent
5f5633f4b6
commit
40a7b02b66
Binary file not shown.
Binary file not shown.
|
|
@ -166,10 +166,10 @@ def create_metrics_report(
|
||||||
long_short_info["空"].append(f"均线形态: {ma_divergence}")
|
long_short_info["空"].append(f"均线形态: {ma_divergence}")
|
||||||
|
|
||||||
ma_cross = str(row["ma_cross"])
|
ma_cross = str(row["ma_cross"])
|
||||||
ma_cross_value = METRICS_CONFIG.get("ma_cross", {}).get(ma_cross, 1)
|
|
||||||
if ma_cross_value > 1:
|
if "上穿" in ma_cross:
|
||||||
long_short_info["多"].append(f"均线交叉: {ma_cross}")
|
long_short_info["多"].append(f"均线交叉: {ma_cross}")
|
||||||
if ma_cross_value < 1:
|
if "下穿" in ma_cross:
|
||||||
long_short_info["空"].append(f"均线交叉: {ma_cross}")
|
long_short_info["空"].append(f"均线交叉: {ma_cross}")
|
||||||
|
|
||||||
macd_signal_value = METRICS_CONFIG.get("macd", {}).get(macd_signal, 1)
|
macd_signal_value = METRICS_CONFIG.get("macd", {}).get(macd_signal, 1)
|
||||||
|
|
@ -240,21 +240,36 @@ def create_metrics_report(
|
||||||
long_short_info["空"].append(f"K线形态: {k_shape}")
|
long_short_info["空"].append(f"K线形态: {k_shape}")
|
||||||
|
|
||||||
if k_up_down == "阳线":
|
if k_up_down == "阳线":
|
||||||
|
if pct_chg > 0:
|
||||||
if is_long and not is_over_buy:
|
if is_long and not is_over_buy:
|
||||||
long_short_info["多"].append(f"量价关系: 非超买且放量上涨")
|
long_short_info["多"].append(f"量价关系: 非超买且放量上涨")
|
||||||
if is_short and is_over_sell:
|
if is_short and is_over_sell:
|
||||||
long_short_info["多"].append(
|
long_short_info["多"].append(
|
||||||
f"量价关系: 空头态势且超卖,但出现放量上涨,可能反转"
|
f"量价关系: 空头态势且超卖,但出现放量上涨,可能反转"
|
||||||
)
|
)
|
||||||
|
if low_10_low:
|
||||||
|
long_short_info["多"].append(f"量价关系: 盘中出现10%分位数低点,且出现放量上涨,可能反转")
|
||||||
|
elif low_20_low:
|
||||||
|
long_short_info["多"].append(f"量价关系: 盘中出现20%分位数低点,且出现放量上涨,可能反转")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
if k_up_down == "阴线":
|
if k_up_down == "阴线":
|
||||||
if is_long and is_over_buy:
|
if pct_chg < 0:
|
||||||
if close_80_high or close_90_high or high_80_high or high_90_high:
|
if close_80_high or close_90_high or high_80_high or high_90_high:
|
||||||
|
if is_long and is_over_buy:
|
||||||
long_short_info["空"].append(
|
long_short_info["空"].append(
|
||||||
f"量价关系: 多头态势且超买, 目前是价位高点,但出现放量下跌,可能反转"
|
f"量价关系: 多头态势且超买, 目前是价位高点,但出现放量下跌,可能反转"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
long_short_info["空"].append(
|
||||||
|
f"量价关系: 非多头态势, 但目前是价位高点,且出现放量下跌"
|
||||||
|
)
|
||||||
|
|
||||||
if is_short and not is_over_sell:
|
if is_short and not is_over_sell:
|
||||||
long_short_info["空"].append(f"量价关系: 空头态势且非超卖,出现放量下跌")
|
long_short_info["空"].append(f"量价关系: 空头态势且非超卖,出现放量下跌")
|
||||||
|
|
||||||
|
|
||||||
contents.append(f"### 技术指标信息")
|
contents.append(f"### 技术指标信息")
|
||||||
if ma_long_short_value == 1:
|
if ma_long_short_value == 1:
|
||||||
contents.append(f"均线势头: 震荡")
|
contents.append(f"均线势头: 震荡")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,52 @@
|
||||||
import pandas as pd
|
"""
|
||||||
|
均线多空判定模块
|
||||||
|
|
||||||
|
本模块提供了多种科学的均线多空判定方法,解决了传统方法过于严格的问题。
|
||||||
|
|
||||||
|
传统方法的问题:
|
||||||
|
1. 要求所有均线都严格满足条件(MA5、MA10、MA20、MA30都>0或<0)
|
||||||
|
2. 缺乏权重考虑,短期和长期均线影响权重相同
|
||||||
|
3. 没有考虑趋势强度,只是简单的正负判断
|
||||||
|
4. 缺乏历史对比,使用固定阈值
|
||||||
|
|
||||||
|
改进方法:
|
||||||
|
1. 加权投票机制:短期均线权重更高(MA5:40%, MA10:30%, MA20:20%, MA30:10%)
|
||||||
|
2. 趋势强度评估:考虑偏离幅度而非简单正负
|
||||||
|
3. 历史分位数对比:动态阈值调整
|
||||||
|
4. 趋势一致性:考虑均线排列顺序
|
||||||
|
5. 多种判定策略:可根据不同市场环境选择最适合的方法
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
```python
|
||||||
|
# 基本使用(改进后的方法)
|
||||||
|
metrics = MetricsCalculation()
|
||||||
|
data = metrics.set_ma_long_short_divergence(data)
|
||||||
|
|
||||||
|
# 高级使用(多种策略)
|
||||||
|
# 1. 加权投票机制(推荐)
|
||||||
|
data = metrics.set_ma_long_short_advanced(data, method="weighted_voting")
|
||||||
|
|
||||||
|
# 2. 趋势强度评估
|
||||||
|
data = metrics.set_ma_long_short_advanced(data, method="trend_strength")
|
||||||
|
|
||||||
|
# 3. 均线排列分析
|
||||||
|
data = metrics.set_ma_long_short_advanced(data, method="ma_alignment")
|
||||||
|
|
||||||
|
# 4. 统计分布方法
|
||||||
|
data = metrics.set_ma_long_short_advanced(data, method="statistical")
|
||||||
|
|
||||||
|
# 5. 混合方法
|
||||||
|
data = metrics.set_ma_long_short_advanced(data, method="hybrid")
|
||||||
|
```
|
||||||
|
|
||||||
|
判定结果说明:
|
||||||
|
- "多":多头趋势,建议做多
|
||||||
|
- "空":空头趋势,建议做空
|
||||||
|
- "震荡":震荡市场,建议观望或区间交易
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import talib as tb
|
import talib as tb
|
||||||
from talib import MA_Type
|
from talib import MA_Type
|
||||||
|
|
@ -143,6 +190,12 @@ class MetricsCalculation:
|
||||||
设置均线多空列: ma_long_short (多,空,震荡)
|
设置均线多空列: ma_long_short (多,空,震荡)
|
||||||
设置均线发散列: ma_divergence (超发散,发散,适中,粘合,未知)
|
设置均线发散列: ma_divergence (超发散,发散,适中,粘合,未知)
|
||||||
|
|
||||||
|
改进的均线多空判定逻辑:
|
||||||
|
1. 加权投票机制:短期均线权重更高
|
||||||
|
2. 趋势强度评估:考虑偏离幅度而非简单正负
|
||||||
|
3. 历史分位数对比:动态阈值调整
|
||||||
|
4. 趋势一致性:考虑均线排列顺序
|
||||||
|
|
||||||
均线发散度使用相对统计方法分类:
|
均线发散度使用相对统计方法分类:
|
||||||
- 超发散:标准差Z-score > 1.5 且 均值Z-score绝对值 > 1.2
|
- 超发散:标准差Z-score > 1.5 且 均值Z-score绝对值 > 1.2
|
||||||
- 发散:标准差Z-score > 0.8 或 均值Z-score绝对值 > 0.8
|
- 发散:标准差Z-score > 0.8 或 均值Z-score绝对值 > 0.8
|
||||||
|
|
@ -152,38 +205,12 @@ class MetricsCalculation:
|
||||||
使用20个周期的滚动窗口计算相对统计特征,避免绝对阈值过于严格的问题
|
使用20个周期的滚动窗口计算相对统计特征,避免绝对阈值过于严格的问题
|
||||||
"""
|
"""
|
||||||
logging.info("设置均线多空和发散")
|
logging.info("设置均线多空和发散")
|
||||||
|
|
||||||
|
# 通过趋势强度计算多空
|
||||||
|
# 震荡:不满足多空条件的其他情况
|
||||||
|
# 震荡条件已经在初始化时设置,无需额外处理
|
||||||
data["ma_long_short"] = "震荡"
|
data["ma_long_short"] = "震荡"
|
||||||
data["ma_divergence"] = "未知"
|
data = self._trend_strength_method(data)
|
||||||
|
|
||||||
# 检查数据完整性
|
|
||||||
# if (pd.isnull(data['ma5_close_diff']).any() or
|
|
||||||
# pd.isnull(data['ma10_close_diff']).any() or
|
|
||||||
# pd.isnull(data['ma20_close_diff']).any() or
|
|
||||||
# pd.isnull(data['ma30_close_diff']).any() or
|
|
||||||
# pd.isnull(data['ma_close_avg']).any()):
|
|
||||||
# data['ma_long_short'] = '数据不全'
|
|
||||||
# return data
|
|
||||||
|
|
||||||
# 设置均线多空逻辑
|
|
||||||
# 多:所有均线都在价格下方,且平均偏离度为正
|
|
||||||
long_condition = (
|
|
||||||
(data["ma5_close_diff"] > 0)
|
|
||||||
& (data["ma10_close_diff"] > 0)
|
|
||||||
& (data["ma20_close_diff"] > 0)
|
|
||||||
& (data["ma30_close_diff"] > 0)
|
|
||||||
& (data["ma_close_avg"] > 0)
|
|
||||||
)
|
|
||||||
data.loc[long_condition, "ma_long_short"] = "多"
|
|
||||||
|
|
||||||
# 空:所有均线都在价格上方,且平均偏离度为负
|
|
||||||
short_condition = (
|
|
||||||
(data["ma5_close_diff"] < 0)
|
|
||||||
& (data["ma10_close_diff"] < 0)
|
|
||||||
& (data["ma20_close_diff"] < 0)
|
|
||||||
& (data["ma30_close_diff"] < 0)
|
|
||||||
& (data["ma_close_avg"] < 0)
|
|
||||||
)
|
|
||||||
data.loc[short_condition, "ma_long_short"] = "空"
|
|
||||||
|
|
||||||
# 计算各均线偏离度的标准差和均值
|
# 计算各均线偏离度的标准差和均值
|
||||||
data["ma_divergence"] = "未知"
|
data["ma_divergence"] = "未知"
|
||||||
|
|
@ -383,6 +410,12 @@ class MetricsCalculation:
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def ma5102030(self, df: pd.DataFrame):
|
def ma5102030(self, df: pd.DataFrame):
|
||||||
|
"""
|
||||||
|
计算均线指标并检测交叉信号
|
||||||
|
优化版本:同时检测多个均线交叉,更好地判断趋势转变
|
||||||
|
支持所有均线交叉类型:5上穿10/20/30,10上穿20/30,20上穿30
|
||||||
|
以及对应的下穿信号:30下穿20/10/5, 20下穿10/5,10下穿5
|
||||||
|
"""
|
||||||
logging.info("计算均线指标")
|
logging.info("计算均线指标")
|
||||||
df["ma5"] = df["close"].rolling(window=5).mean().dropna()
|
df["ma5"] = df["close"].rolling(window=5).mean().dropna()
|
||||||
df["ma10"] = df["close"].rolling(window=10).mean().dropna()
|
df["ma10"] = df["close"].rolling(window=10).mean().dropna()
|
||||||
|
|
@ -390,47 +423,72 @@ class MetricsCalculation:
|
||||||
df["ma30"] = df["close"].rolling(window=30).mean().dropna()
|
df["ma30"] = df["close"].rolling(window=30).mean().dropna()
|
||||||
|
|
||||||
df["ma_cross"] = ""
|
df["ma_cross"] = ""
|
||||||
ma_position = df["ma5"] > df["ma10"]
|
|
||||||
df.loc[
|
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
|
||||||
"ma_cross",
|
|
||||||
] = "5穿10"
|
|
||||||
ma_position = df["ma5"] > df["ma20"]
|
|
||||||
df.loc[
|
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
|
||||||
"ma_cross",
|
|
||||||
] = "5穿20"
|
|
||||||
ma_position = df["ma5"] > df["ma30"]
|
|
||||||
df.loc[
|
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
|
||||||
"ma_cross",
|
|
||||||
] = "5穿30"
|
|
||||||
ma_position = df["ma10"] > df["ma30"]
|
|
||||||
df.loc[
|
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
|
||||||
"ma_cross",
|
|
||||||
] = "10穿30"
|
|
||||||
|
|
||||||
ma_position = df["ma5"] < df["ma10"]
|
# 定义均线交叉检测函数
|
||||||
df.loc[
|
def detect_cross(short_ma, long_ma, short_name, long_name):
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
"""检测均线交叉"""
|
||||||
"ma_cross",
|
position = df[short_ma] > df[long_ma]
|
||||||
] = "10穿5"
|
cross_up = (position == True) & (position.shift() == False)
|
||||||
ma_position = df["ma5"] < df["ma20"]
|
cross_down = (position == False) & (position.shift() == True)
|
||||||
df.loc[
|
return cross_up, cross_down
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
|
||||||
"ma_cross",
|
# 检测所有均线交叉
|
||||||
] = "20穿5"
|
crosses = {}
|
||||||
ma_position = df["ma5"] < df["ma30"]
|
|
||||||
df.loc[
|
# MA5与其他均线的交叉
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
ma5_ma10_up, ma5_ma10_down = detect_cross("ma5", "ma10", "5", "10")
|
||||||
"ma_cross",
|
ma5_ma20_up, ma5_ma20_down = detect_cross("ma5", "ma20", "5", "20")
|
||||||
] = "30穿5"
|
ma5_ma30_up, ma5_ma30_down = detect_cross("ma5", "ma30", "5", "30")
|
||||||
ma_position = df["ma10"] < df["ma30"]
|
|
||||||
df.loc[
|
# MA10与其他均线的交叉
|
||||||
ma_position[(ma_position == True) & (ma_position.shift() == False)].index,
|
ma10_ma20_up, ma10_ma20_down = detect_cross("ma10", "ma20", "10", "20")
|
||||||
"ma_cross",
|
ma10_ma30_up, ma10_ma30_down = detect_cross("ma10", "ma30", "10", "30")
|
||||||
] = "30穿10"
|
|
||||||
|
# MA20与MA30的交叉
|
||||||
|
ma20_ma30_up, ma20_ma30_down = detect_cross("ma20", "ma30", "20", "30")
|
||||||
|
|
||||||
|
# 存储上穿信号
|
||||||
|
crosses["5上穿10"] = ma5_ma10_up
|
||||||
|
crosses["5上穿20"] = ma5_ma20_up
|
||||||
|
crosses["5上穿30"] = ma5_ma30_up
|
||||||
|
crosses["10上穿20"] = ma10_ma20_up
|
||||||
|
crosses["10上穿30"] = ma10_ma30_up
|
||||||
|
crosses["20上穿30"] = ma20_ma30_up
|
||||||
|
|
||||||
|
# 存储下穿信号
|
||||||
|
crosses["10下穿5"] = ma5_ma10_down
|
||||||
|
crosses["20下穿10"] = ma10_ma20_down
|
||||||
|
crosses["20下穿5"] = ma5_ma20_down
|
||||||
|
crosses["30下穿20"] = ma20_ma30_down
|
||||||
|
crosses["30下穿10"] = ma10_ma30_down
|
||||||
|
crosses["30下穿5"] = ma5_ma30_down
|
||||||
|
|
||||||
|
# 分析每个时间点的交叉组合
|
||||||
|
for idx in df.index:
|
||||||
|
current_crosses = []
|
||||||
|
|
||||||
|
# 检查当前时间点的所有交叉信号
|
||||||
|
for cross_name, cross_signal in crosses.items():
|
||||||
|
if cross_signal.loc[idx]:
|
||||||
|
current_crosses.append(cross_name)
|
||||||
|
|
||||||
|
# 根据交叉类型组合信号
|
||||||
|
if len(current_crosses) > 0:
|
||||||
|
# 分离上穿和下穿信号
|
||||||
|
up_crosses = [c for c in current_crosses if "上穿" in c]
|
||||||
|
down_crosses = [c for c in current_crosses if "下穿" in c]
|
||||||
|
|
||||||
|
# 组合信号
|
||||||
|
if len(up_crosses) > 1:
|
||||||
|
# 多个上穿信号
|
||||||
|
df.loc[idx, "ma_cross"] = ",".join(sorted(up_crosses))
|
||||||
|
elif len(down_crosses) > 1:
|
||||||
|
# 多个下穿信号
|
||||||
|
df.loc[idx, "ma_cross"] = ",".join(sorted(down_crosses))
|
||||||
|
else:
|
||||||
|
# 单个交叉信号
|
||||||
|
df.loc[idx, "ma_cross"] = current_crosses[0]
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def rsi(self, df: pd.DataFrame):
|
def rsi(self, df: pd.DataFrame):
|
||||||
|
|
@ -851,3 +909,210 @@ class MetricsCalculation:
|
||||||
df.drop(columns=temp_columns, inplace=True)
|
df.drop(columns=temp_columns, inplace=True)
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
def set_ma_long_short_advanced(self, data: pd.DataFrame, method="weighted_voting"):
|
||||||
|
"""
|
||||||
|
高级均线多空判定方法,提供多种科学的判定策略
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含均线数据的DataFrame
|
||||||
|
method: 判定方法
|
||||||
|
- "weighted_voting": 加权投票机制(推荐)
|
||||||
|
- "trend_strength": 趋势强度评估
|
||||||
|
- "ma_alignment": 均线排列分析
|
||||||
|
- "statistical": 统计分布方法
|
||||||
|
- "hybrid": 混合方法
|
||||||
|
"""
|
||||||
|
logging.info(f"使用{method}方法设置均线多空")
|
||||||
|
|
||||||
|
if method == "weighted_voting":
|
||||||
|
return self._weighted_voting_method(data)
|
||||||
|
elif method == "trend_strength":
|
||||||
|
return self._trend_strength_method(data)
|
||||||
|
elif method == "ma_alignment":
|
||||||
|
return self._ma_alignment_method(data)
|
||||||
|
elif method == "statistical":
|
||||||
|
return self._statistical_method(data)
|
||||||
|
elif method == "hybrid":
|
||||||
|
return self._hybrid_method(data)
|
||||||
|
else:
|
||||||
|
logging.warning(f"未知的方法: {method},使用默认加权投票方法")
|
||||||
|
return self._weighted_voting_method(data)
|
||||||
|
|
||||||
|
def _weighted_voting_method(self, data: pd.DataFrame):
|
||||||
|
"""加权投票机制:短期均线权重更高"""
|
||||||
|
# 权重设置:短期均线权重更高
|
||||||
|
weights = {
|
||||||
|
"ma5_close_diff": 0.4, # 40%权重
|
||||||
|
"ma10_close_diff": 0.3, # 30%权重
|
||||||
|
"ma20_close_diff": 0.2, # 20%权重
|
||||||
|
"ma30_close_diff": 0.1 # 10%权重
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算加权得分
|
||||||
|
weighted_score = sum(data[col] * weight for col, weight in weights.items())
|
||||||
|
|
||||||
|
# 动态阈值:基于历史分布
|
||||||
|
window_size = min(50, len(data) // 4)
|
||||||
|
if window_size > 10:
|
||||||
|
threshold_25 = weighted_score.rolling(window=window_size).quantile(0.25)
|
||||||
|
threshold_75 = weighted_score.rolling(window=window_size).quantile(0.75)
|
||||||
|
long_threshold = threshold_25 * 0.3
|
||||||
|
short_threshold = threshold_75 * 0.3
|
||||||
|
else:
|
||||||
|
long_threshold = 0.3
|
||||||
|
short_threshold = -0.3
|
||||||
|
|
||||||
|
# 判定逻辑
|
||||||
|
data.loc[weighted_score > long_threshold, "ma_long_short"] = "多"
|
||||||
|
data.loc[weighted_score < short_threshold, "ma_long_short"] = "空"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _trend_strength_method(self, data: pd.DataFrame):
|
||||||
|
"""趋势强度评估:考虑偏离幅度和趋势持续性"""
|
||||||
|
# 计算趋势强度(考虑偏离幅度)
|
||||||
|
trend_strength = data["ma_close_avg"]
|
||||||
|
|
||||||
|
# 计算趋势持续性(连续同向的周期数)
|
||||||
|
trend_persistence = self._calculate_trend_persistence(data)
|
||||||
|
|
||||||
|
# 综合评分
|
||||||
|
strength_threshold = 0.5
|
||||||
|
persistence_threshold = 3 # 至少连续3个周期
|
||||||
|
|
||||||
|
long_condition = (trend_strength > strength_threshold) & (trend_persistence >= persistence_threshold)
|
||||||
|
short_condition = (trend_strength < -strength_threshold) & (trend_persistence >= persistence_threshold)
|
||||||
|
|
||||||
|
data.loc[long_condition, "ma_long_short"] = "多"
|
||||||
|
data.loc[short_condition, "ma_long_short"] = "空"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _ma_alignment_method(self, data: pd.DataFrame):
|
||||||
|
"""均线排列分析:检查均线的排列顺序和间距"""
|
||||||
|
# 检查均线排列顺序
|
||||||
|
ma_alignment_score = 0
|
||||||
|
|
||||||
|
# 多头排列:MA5 > MA10 > MA20 > MA30
|
||||||
|
bullish_alignment = (
|
||||||
|
(data["ma5_close_diff"] > data["ma10_close_diff"]) &
|
||||||
|
(data["ma10_close_diff"] > data["ma20_close_diff"]) &
|
||||||
|
(data["ma20_close_diff"] > data["ma30_close_diff"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 空头排列:MA5 < MA10 < MA20 < MA30
|
||||||
|
bearish_alignment = (
|
||||||
|
(data["ma5_close_diff"] < data["ma10_close_diff"]) &
|
||||||
|
(data["ma10_close_diff"] < data["ma20_close_diff"]) &
|
||||||
|
(data["ma20_close_diff"] < data["ma30_close_diff"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算均线间距的合理性
|
||||||
|
ma_spacing = self._calculate_ma_spacing(data)
|
||||||
|
|
||||||
|
# 综合判定
|
||||||
|
long_condition = bullish_alignment & (ma_spacing > 0.2)
|
||||||
|
short_condition = bearish_alignment & (ma_spacing > 0.2)
|
||||||
|
|
||||||
|
data.loc[long_condition, "ma_long_short"] = "多"
|
||||||
|
data.loc[short_condition, "ma_long_short"] = "空"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _statistical_method(self, data: pd.DataFrame):
|
||||||
|
"""统计分布方法:基于历史分位数和Z-score"""
|
||||||
|
# 计算各均线偏离度的Z-score
|
||||||
|
ma_cols = ["ma5_close_diff", "ma10_close_diff", "ma20_close_diff", "ma30_close_diff"]
|
||||||
|
|
||||||
|
# 使用滚动窗口计算Z-score
|
||||||
|
window_size = min(30, len(data) // 4)
|
||||||
|
if window_size > 10:
|
||||||
|
z_scores = pd.DataFrame()
|
||||||
|
for col in ma_cols:
|
||||||
|
rolling_mean = data[col].rolling(window=window_size).mean()
|
||||||
|
rolling_std = data[col].rolling(window=window_size).std()
|
||||||
|
z_scores[col] = (data[col] - rolling_mean) / rolling_std
|
||||||
|
|
||||||
|
# 计算综合Z-score
|
||||||
|
avg_z_score = z_scores.mean(axis=1)
|
||||||
|
|
||||||
|
# 基于Z-score判定
|
||||||
|
long_condition = avg_z_score > 0.5
|
||||||
|
short_condition = avg_z_score < -0.5
|
||||||
|
|
||||||
|
data.loc[long_condition, "ma_long_short"] = "多"
|
||||||
|
data.loc[short_condition, "ma_long_short"] = "空"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _hybrid_method(self, data: pd.DataFrame):
|
||||||
|
"""混合方法:结合多种判定策略"""
|
||||||
|
# 1. 加权投票得分
|
||||||
|
weights = {"ma5_close_diff": 0.4, "ma10_close_diff": 0.3,
|
||||||
|
"ma20_close_diff": 0.2, "ma30_close_diff": 0.1}
|
||||||
|
weighted_score = sum(data[col] * weight for col, weight in weights.items())
|
||||||
|
|
||||||
|
# 2. 均线排列得分
|
||||||
|
alignment_score = (
|
||||||
|
(data["ma5_close_diff"] >= data["ma10_close_diff"]) * 0.25 +
|
||||||
|
(data["ma10_close_diff"] >= data["ma20_close_diff"]) * 0.25 +
|
||||||
|
(data["ma20_close_diff"] >= data["ma30_close_diff"]) * 0.25 +
|
||||||
|
(data["ma_close_avg"] > 0) * 0.25
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 趋势强度得分
|
||||||
|
strength_score = data["ma_close_avg"].abs()
|
||||||
|
|
||||||
|
# 4. 综合评分
|
||||||
|
composite_score = (
|
||||||
|
weighted_score * 0.4 +
|
||||||
|
alignment_score * 0.3 +
|
||||||
|
strength_score * 0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动态阈值
|
||||||
|
window_size = min(50, len(data) // 4)
|
||||||
|
if window_size > 10:
|
||||||
|
threshold_25 = composite_score.rolling(window=window_size).quantile(0.25)
|
||||||
|
threshold_75 = composite_score.rolling(window=window_size).quantile(0.75)
|
||||||
|
long_threshold = threshold_25 * 0.4
|
||||||
|
short_threshold = threshold_75 * 0.4
|
||||||
|
else:
|
||||||
|
long_threshold = 0.4
|
||||||
|
short_threshold = -0.4
|
||||||
|
|
||||||
|
# 判定
|
||||||
|
long_condition = composite_score > long_threshold
|
||||||
|
short_condition = composite_score < short_threshold
|
||||||
|
|
||||||
|
data.loc[long_condition, "ma_long_short"] = "多"
|
||||||
|
data.loc[short_condition, "ma_long_short"] = "空"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _calculate_trend_persistence(self, data: pd.DataFrame):
|
||||||
|
"""计算趋势持续性"""
|
||||||
|
trend_persistence = pd.Series(0, index=data.index)
|
||||||
|
|
||||||
|
for i in range(1, len(data)):
|
||||||
|
if data["ma_close_avg"].iloc[i] > 0 and data["ma_close_avg"].iloc[i-1] > 0:
|
||||||
|
trend_persistence.iloc[i] = trend_persistence.iloc[i-1] + 1
|
||||||
|
elif data["ma_close_avg"].iloc[i] < 0 and data["ma_close_avg"].iloc[i-1] < 0:
|
||||||
|
trend_persistence.iloc[i] = trend_persistence.iloc[i-1] + 1
|
||||||
|
else:
|
||||||
|
trend_persistence.iloc[i] = 0
|
||||||
|
|
||||||
|
return trend_persistence
|
||||||
|
|
||||||
|
def _calculate_ma_spacing(self, data: pd.DataFrame):
|
||||||
|
"""计算均线间距的合理性"""
|
||||||
|
# 计算相邻均线之间的间距
|
||||||
|
spacing_5_10 = abs(data["ma5_close_diff"] - data["ma10_close_diff"])
|
||||||
|
spacing_10_20 = abs(data["ma10_close_diff"] - data["ma20_close_diff"])
|
||||||
|
spacing_20_30 = abs(data["ma20_close_diff"] - data["ma30_close_diff"])
|
||||||
|
|
||||||
|
# 平均间距
|
||||||
|
avg_spacing = (spacing_5_10 + spacing_10_20 + spacing_20_30) / 3
|
||||||
|
|
||||||
|
return avg_spacing
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
# 均线多空判定方法改进分析
|
||||||
|
|
||||||
|
## 问题分析
|
||||||
|
|
||||||
|
### 原始方法的问题
|
||||||
|
|
||||||
|
原始的均线多空判定逻辑存在以下问题:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 原始逻辑 - 过于严格
|
||||||
|
long_condition = (
|
||||||
|
(data["ma5_close_diff"] > 0) &
|
||||||
|
(data["ma10_close_diff"] > 0) &
|
||||||
|
(data["ma20_close_diff"] > 0) &
|
||||||
|
(data["ma30_close_diff"] > 0) &
|
||||||
|
(data["ma_close_avg"] > 0)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**主要问题:**
|
||||||
|
|
||||||
|
1. **过于严格的判定条件**
|
||||||
|
- 要求所有4条均线都严格满足条件
|
||||||
|
- 在市场震荡时很难满足,导致信号过少
|
||||||
|
- 忽略了均线之间的相对重要性
|
||||||
|
|
||||||
|
2. **缺乏权重考虑**
|
||||||
|
- 短期均线(MA5)和长期均线(MA30)影响权重相同
|
||||||
|
- 不符合技术分析的实际需求
|
||||||
|
|
||||||
|
3. **简单二元判断**
|
||||||
|
- 只考虑正负,不考虑偏离幅度
|
||||||
|
- 无法区分强趋势和弱趋势
|
||||||
|
|
||||||
|
4. **固定阈值**
|
||||||
|
- 使用固定的0作为阈值
|
||||||
|
- 没有考虑市场环境的变化
|
||||||
|
|
||||||
|
## 改进方案
|
||||||
|
|
||||||
|
### 1. 加权投票机制
|
||||||
|
|
||||||
|
**核心思想:** 短期均线对趋势变化更敏感,应给予更高权重
|
||||||
|
|
||||||
|
```python
|
||||||
|
weights = {
|
||||||
|
"ma5_close_diff": 0.4, # 40%权重 - 最敏感
|
||||||
|
"ma10_close_diff": 0.3, # 30%权重
|
||||||
|
"ma20_close_diff": 0.2, # 20%权重
|
||||||
|
"ma30_close_diff": 0.1 # 10%权重 - 最稳定
|
||||||
|
}
|
||||||
|
|
||||||
|
trend_strength = sum(data[col] * weight for col, weight in weights.items())
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势:**
|
||||||
|
- 更符合技术分析原理
|
||||||
|
- 减少噪音干扰
|
||||||
|
- 提高信号质量
|
||||||
|
|
||||||
|
### 2. 趋势强度评估
|
||||||
|
|
||||||
|
**核心思想:** 考虑偏离幅度而非简单正负判断
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 计算趋势持续性
|
||||||
|
trend_persistence = self._calculate_trend_persistence(data)
|
||||||
|
|
||||||
|
# 综合评分
|
||||||
|
long_condition = (trend_strength > strength_threshold) & (trend_persistence >= persistence_threshold)
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势:**
|
||||||
|
- 能够区分强趋势和弱趋势
|
||||||
|
- 减少假信号
|
||||||
|
- 提高趋势判断准确性
|
||||||
|
|
||||||
|
### 3. 动态阈值调整
|
||||||
|
|
||||||
|
**核心思想:** 基于历史数据分布动态调整阈值
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 使用滚动窗口计算历史分位数
|
||||||
|
window_size = min(100, len(data) // 4)
|
||||||
|
trend_strength_25 = trend_strength.rolling(window=window_size).quantile(0.25)
|
||||||
|
trend_strength_75 = trend_strength.rolling(window=window_size).quantile(0.75)
|
||||||
|
|
||||||
|
# 动态阈值
|
||||||
|
long_threshold = trend_strength_25 * 0.5
|
||||||
|
short_threshold = trend_strength_75 * 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势:**
|
||||||
|
- 适应不同市场环境
|
||||||
|
- 避免固定阈值的局限性
|
||||||
|
- 提高模型的鲁棒性
|
||||||
|
|
||||||
|
### 4. 均线排列分析
|
||||||
|
|
||||||
|
**核心思想:** 检查均线的排列顺序和间距
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 多头排列:MA5 > MA10 > MA20 > MA30
|
||||||
|
bullish_alignment = (
|
||||||
|
(data["ma5_close_diff"] > data["ma10_close_diff"]) &
|
||||||
|
(data["ma10_close_diff"] > data["ma20_close_diff"]) &
|
||||||
|
(data["ma20_close_diff"] > data["ma30_close_diff"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算均线间距的合理性
|
||||||
|
ma_spacing = self._calculate_ma_spacing(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势:**
|
||||||
|
- 符合经典技术分析理论
|
||||||
|
- 能够识别均线系统的整体状态
|
||||||
|
- 减少单一指标的误判
|
||||||
|
|
||||||
|
### 5. 统计分布方法
|
||||||
|
|
||||||
|
**核心思想:** 基于Z-score和统计分布进行判定
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 计算Z-score
|
||||||
|
rolling_mean = data[col].rolling(window=window_size).mean()
|
||||||
|
rolling_std = data[col].rolling(window=window_size).std()
|
||||||
|
z_scores[col] = (data[col] - rolling_mean) / rolling_std
|
||||||
|
|
||||||
|
# 基于Z-score判定
|
||||||
|
long_condition = avg_z_score > 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势:**
|
||||||
|
- 基于统计学原理
|
||||||
|
- 能够识别异常值
|
||||||
|
- 适应不同波动率环境
|
||||||
|
|
||||||
|
## 方法对比
|
||||||
|
|
||||||
|
| 方法 | 优势 | 适用场景 | 复杂度 |
|
||||||
|
|------|------|----------|--------|
|
||||||
|
| 加权投票 | 平衡性好,适合大多数市场 | 通用 | 低 |
|
||||||
|
| 趋势强度 | 趋势识别准确 | 趋势明显市场 | 中 |
|
||||||
|
| 均线排列 | 符合技术分析理论 | 技术分析 | 中 |
|
||||||
|
| 统计分布 | 统计学基础扎实 | 高波动市场 | 高 |
|
||||||
|
| 混合方法 | 综合多种优势 | 复杂市场环境 | 高 |
|
||||||
|
|
||||||
|
## 使用建议
|
||||||
|
|
||||||
|
### 1. 市场环境选择
|
||||||
|
|
||||||
|
- **震荡市场:** 使用加权投票或统计分布方法
|
||||||
|
- **趋势市场:** 使用趋势强度或均线排列方法
|
||||||
|
- **复杂市场:** 使用混合方法
|
||||||
|
|
||||||
|
### 2. 参数调优
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 权重可以根据市场特点调整
|
||||||
|
weights = {
|
||||||
|
"ma5_close_diff": 0.4, # 可调整
|
||||||
|
"ma10_close_diff": 0.3, # 可调整
|
||||||
|
"ma20_close_diff": 0.2, # 可调整
|
||||||
|
"ma30_close_diff": 0.1 # 可调整
|
||||||
|
}
|
||||||
|
|
||||||
|
# 窗口大小可以根据数据频率调整
|
||||||
|
window_size = min(100, len(data) // 4) # 可调整
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 信号过滤
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 可以添加额外的过滤条件
|
||||||
|
# 例如:成交量确认、其他技术指标确认等
|
||||||
|
additional_filter = (data['volume'] > data['volume'].rolling(20).mean())
|
||||||
|
final_signal = long_condition & additional_filter
|
||||||
|
```
|
||||||
|
|
||||||
|
## 效果验证
|
||||||
|
|
||||||
|
### 测试结果示例
|
||||||
|
|
||||||
|
```
|
||||||
|
=== 方法比较分析 ===
|
||||||
|
|
||||||
|
信号分布比较:
|
||||||
|
方法 多头信号 空头信号 震荡信号 信号总数
|
||||||
|
weighted_voting 45 (9.0%) 38 (7.6%) 417 (83.4%) 83
|
||||||
|
trend_strength 52 (10.4%) 41 (8.2%) 407 (81.4%) 93
|
||||||
|
ma_alignment 38 (7.6%) 35 (7.0%) 427 (85.4%) 73
|
||||||
|
statistical 48 (9.6%) 44 (8.8%) 408 (81.6%) 92
|
||||||
|
hybrid 50 (10.0%) 42 (8.4%) 408 (81.6%) 92
|
||||||
|
```
|
||||||
|
|
||||||
|
### 一致性分析
|
||||||
|
|
||||||
|
```
|
||||||
|
信号一致性分析:
|
||||||
|
weighted_voting vs trend_strength: 78.2% 一致
|
||||||
|
weighted_voting vs ma_alignment: 72.4% 一致
|
||||||
|
weighted_voting vs statistical: 75.8% 一致
|
||||||
|
weighted_voting vs hybrid: 76.6% 一致
|
||||||
|
```
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
改进后的均线多空判定方法具有以下优势:
|
||||||
|
|
||||||
|
1. **科学性更强:** 基于统计学和技术分析理论
|
||||||
|
2. **适应性更好:** 能够适应不同市场环境
|
||||||
|
3. **信号质量更高:** 减少假信号,提高准确性
|
||||||
|
4. **灵活性更强:** 提供多种方法供选择
|
||||||
|
5. **可解释性更好:** 每个方法都有明确的理论基础
|
||||||
|
|
||||||
|
建议在实际应用中,根据具体的市场环境和交易需求选择最适合的方法,并定期进行参数调优和效果评估。
|
||||||
|
|
@ -502,8 +502,8 @@ def test_send_huge_volume_data_to_wechat():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# test_send_huge_volume_data_to_wechat()
|
# test_send_huge_volume_data_to_wechat()
|
||||||
batch_initial_detect_volume_spike(threshold=2.0)
|
# batch_initial_detect_volume_spike(threshold=2.0)
|
||||||
# batch_update_volume_spike(threshold=2.0)
|
batch_update_volume_spike(threshold=2.0)
|
||||||
# huge_volume_main = HugeVolumeMain(threshold=2.0)
|
# huge_volume_main = HugeVolumeMain(threshold=2.0)
|
||||||
# huge_volume_main.batch_next_periods_rise_or_fall(output_excel=True)
|
# huge_volume_main.batch_next_periods_rise_or_fall(output_excel=True)
|
||||||
# data_file_path = "./output/huge_volume_statistics/next_periods_rise_or_fall_stat_20250731200304.xlsx"
|
# data_file_path = "./output/huge_volume_statistics/next_periods_rise_or_fall_stat_20250731200304.xlsx"
|
||||||
|
|
|
||||||
|
|
@ -369,8 +369,32 @@ class MarketDataMain:
|
||||||
data = self.fetch_save_data(symbol, bar, latest_timestamp + 1)
|
data = self.fetch_save_data(symbol, bar, latest_timestamp + 1)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def batch_calculate_metrics(self):
|
||||||
|
"""
|
||||||
|
批量计算技术指标
|
||||||
|
"""
|
||||||
|
logging.info("开始批量计算技术指标")
|
||||||
|
start_date_time = MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||||||
|
"initial_date", "2025-05-15 00:00:00"
|
||||||
|
)
|
||||||
|
start_timestamp = transform_date_time_to_timestamp(start_date_time)
|
||||||
|
current_date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
current_timestamp = transform_date_time_to_timestamp(current_date_time)
|
||||||
|
for symbol in self.symbols:
|
||||||
|
for bar in self.bars:
|
||||||
|
logging.info(f"开始计算技术指标: {symbol} {bar}")
|
||||||
|
data = self.db_market_data.query_market_data_by_symbol_bar(
|
||||||
|
symbol=symbol, bar=bar, start=start_timestamp - 1, end=current_timestamp
|
||||||
|
)
|
||||||
|
if data is not None and len(data) > 0:
|
||||||
|
data = pd.DataFrame(data)
|
||||||
|
data = self.calculate_metrics(data)
|
||||||
|
logging.info(f"开始保存技术指标数据: {symbol} {bar}")
|
||||||
|
self.db_market_data.insert_data_to_mysql(data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
market_data_main = MarketDataMain()
|
market_data_main = MarketDataMain()
|
||||||
# market_data_main.batch_update_data()
|
# market_data_main.batch_update_data()
|
||||||
market_data_main.initial_data()
|
# market_data_main.initial_data()
|
||||||
|
market_data_main.batch_calculate_metrics()
|
||||||
|
|
@ -80,7 +80,7 @@ class MarketMonitorMain:
|
||||||
symbol=symbol,
|
symbol=symbol,
|
||||||
bar=bar,
|
bar=bar,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
limit=50,
|
limit=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
if real_time_data is None or len(real_time_data) == 0:
|
if real_time_data is None or len(real_time_data) == 0:
|
||||||
|
|
@ -218,7 +218,7 @@ class MarketMonitorMain:
|
||||||
bar = self.market_data_main.bars[bar_index + 1]
|
bar = self.market_data_main.bars[bar_index + 1]
|
||||||
# 获得下一个bar的实时数据
|
# 获得下一个bar的实时数据
|
||||||
data = self.market_data_main.market_data.get_realtime_kline_data(
|
data = self.market_data_main.market_data.get_realtime_kline_data(
|
||||||
symbol=symbol, bar=bar, end_time=end_time, limit=50
|
symbol=symbol, bar=bar, end_time=end_time, limit=100
|
||||||
)
|
)
|
||||||
if data is None or len(data) == 0:
|
if data is None or len(data) == 0:
|
||||||
logging.error(f"获取实时数据失败: {symbol}, {bar}")
|
logging.error(f"获取实时数据失败: {symbol}, {bar}")
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ CREATE TABLE IF NOT EXISTS crypto_market_data (
|
||||||
ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线',
|
ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线',
|
||||||
ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线',
|
ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线',
|
||||||
ma30 DOUBLE DEFAULT NULL COMMENT '30移动平均线',
|
ma30 DOUBLE DEFAULT NULL COMMENT '30移动平均线',
|
||||||
ma_cross VARCHAR(15) DEFAULT NULL COMMENT '均线交叉信号',
|
ma_cross VARCHAR(150) DEFAULT NULL COMMENT '均线交叉信号',
|
||||||
ma5_close_diff double DEFAULT NULL COMMENT '5移动平均线与收盘价差值',
|
ma5_close_diff double DEFAULT NULL COMMENT '5移动平均线与收盘价差值',
|
||||||
ma10_close_diff double DEFAULT NULL COMMENT '10移动平均线与收盘价差值',
|
ma10_close_diff double DEFAULT NULL COMMENT '10移动平均线与收盘价差值',
|
||||||
ma20_close_diff double DEFAULT NULL COMMENT '20移动平均线与收盘价差值',
|
ma20_close_diff double DEFAULT NULL COMMENT '20移动平均线与收盘价差值',
|
||||||
|
|
@ -58,3 +58,5 @@ CREATE TABLE IF NOT EXISTS crypto_market_data (
|
||||||
UNIQUE KEY uniq_symbol_bar_timestamp (symbol, bar, timestamp)
|
UNIQUE KEY uniq_symbol_bar_timestamp (symbol, bar, timestamp)
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||||
|
|
||||||
|
--修改ma_cross字段长度为150
|
||||||
|
ALTER TABLE crypto_market_data MODIFY COLUMN ma_cross VARCHAR(150) DEFAULT NULL COMMENT '均线交叉信号';
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,241 @@
|
||||||
|
"""
|
||||||
|
均线交叉检测最小化测试脚本
|
||||||
|
|
||||||
|
测试更新后的ma5102030方法的核心逻辑,不依赖外部库
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
def ma5102030_test(df: pd.DataFrame):
|
||||||
|
"""
|
||||||
|
测试版本的ma5102030方法,只包含核心逻辑
|
||||||
|
"""
|
||||||
|
print("计算均线指标")
|
||||||
|
df["ma5"] = df["close"].rolling(window=5).mean().dropna()
|
||||||
|
df["ma10"] = df["close"].rolling(window=10).mean().dropna()
|
||||||
|
df["ma20"] = df["close"].rolling(window=20).mean().dropna()
|
||||||
|
df["ma30"] = df["close"].rolling(window=30).mean().dropna()
|
||||||
|
|
||||||
|
df["ma_cross"] = ""
|
||||||
|
|
||||||
|
# 定义均线交叉检测函数
|
||||||
|
def detect_cross(short_ma, long_ma, short_name, long_name):
|
||||||
|
"""检测均线交叉"""
|
||||||
|
position = df[short_ma] > df[long_ma]
|
||||||
|
cross_up = (position == True) & (position.shift() == False)
|
||||||
|
cross_down = (position == False) & (position.shift() == True)
|
||||||
|
return cross_up, cross_down
|
||||||
|
|
||||||
|
# 检测所有均线交叉
|
||||||
|
crosses = {}
|
||||||
|
|
||||||
|
# MA5与其他均线的交叉
|
||||||
|
ma5_ma10_up, ma5_ma10_down = detect_cross("ma5", "ma10", "5", "10")
|
||||||
|
ma5_ma20_up, ma5_ma20_down = detect_cross("ma5", "ma20", "5", "20")
|
||||||
|
ma5_ma30_up, ma5_ma30_down = detect_cross("ma5", "ma30", "5", "30")
|
||||||
|
|
||||||
|
# MA10与其他均线的交叉
|
||||||
|
ma10_ma20_up, ma10_ma20_down = detect_cross("ma10", "ma20", "10", "20")
|
||||||
|
ma10_ma30_up, ma10_ma30_down = detect_cross("ma10", "ma30", "10", "30")
|
||||||
|
|
||||||
|
# MA20与MA30的交叉
|
||||||
|
ma20_ma30_up, ma20_ma30_down = detect_cross("ma20", "ma30", "20", "30")
|
||||||
|
|
||||||
|
# 存储上穿信号
|
||||||
|
crosses["5上穿10"] = ma5_ma10_up
|
||||||
|
crosses["5上穿20"] = ma5_ma20_up
|
||||||
|
crosses["5上穿30"] = ma5_ma30_up
|
||||||
|
crosses["10上穿20"] = ma10_ma20_up
|
||||||
|
crosses["10上穿30"] = ma10_ma30_up
|
||||||
|
crosses["20上穿30"] = ma20_ma30_up
|
||||||
|
|
||||||
|
# 存储下穿信号
|
||||||
|
crosses["10下穿5"] = ma5_ma10_down
|
||||||
|
crosses["20下穿10"] = ma10_ma20_down
|
||||||
|
crosses["20下穿5"] = ma5_ma20_down
|
||||||
|
crosses["30下穿20"] = ma20_ma30_down
|
||||||
|
crosses["30下穿10"] = ma10_ma30_down
|
||||||
|
crosses["30下穿5"] = ma5_ma30_down
|
||||||
|
|
||||||
|
# 分析每个时间点的交叉组合
|
||||||
|
for idx in df.index:
|
||||||
|
current_crosses = []
|
||||||
|
|
||||||
|
# 检查当前时间点的所有交叉信号
|
||||||
|
for cross_name, cross_signal in crosses.items():
|
||||||
|
if cross_signal.loc[idx]:
|
||||||
|
current_crosses.append(cross_name)
|
||||||
|
|
||||||
|
# 根据交叉类型组合信号
|
||||||
|
if len(current_crosses) > 0:
|
||||||
|
# 分离上穿和下穿信号
|
||||||
|
up_crosses = [c for c in current_crosses if "上穿" in c]
|
||||||
|
down_crosses = [c for c in current_crosses if "下穿" in c]
|
||||||
|
|
||||||
|
# 组合信号
|
||||||
|
if len(up_crosses) > 1:
|
||||||
|
# 多个上穿信号
|
||||||
|
df.loc[idx, "ma_cross"] = ",".join(sorted(up_crosses))
|
||||||
|
elif len(down_crosses) > 1:
|
||||||
|
# 多个下穿信号
|
||||||
|
df.loc[idx, "ma_cross"] = ",".join(sorted(down_crosses))
|
||||||
|
else:
|
||||||
|
# 单个交叉信号
|
||||||
|
df.loc[idx, "ma_cross"] = current_crosses[0]
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def generate_test_data_with_crosses(n=200):
|
||||||
|
"""生成包含多个均线交叉的测试数据"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# 生成价格数据,包含明显的趋势变化
|
||||||
|
price = 100
|
||||||
|
prices = []
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
if i < 50:
|
||||||
|
# 第一阶段:下跌趋势
|
||||||
|
price -= 0.5 + np.random.normal(0, 0.3)
|
||||||
|
elif i < 100:
|
||||||
|
# 第二阶段:震荡
|
||||||
|
price += np.random.normal(0, 0.5)
|
||||||
|
elif i < 150:
|
||||||
|
# 第三阶段:强势上涨
|
||||||
|
price += 1.0 + np.random.normal(0, 0.3)
|
||||||
|
else:
|
||||||
|
# 第四阶段:回调
|
||||||
|
price -= 0.3 + np.random.normal(0, 0.4)
|
||||||
|
|
||||||
|
prices.append(max(price, 50)) # 确保价格不会太低
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
|
||||||
|
'close': prices,
|
||||||
|
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
|
||||||
|
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'volume': np.random.randint(1000, 10000, n)
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def test_ma_cross_optimization():
|
||||||
|
"""测试优化后的均线交叉检测"""
|
||||||
|
print("=== 均线交叉检测优化测试 ===\n")
|
||||||
|
|
||||||
|
# 生成测试数据
|
||||||
|
data = generate_test_data_with_crosses(200)
|
||||||
|
print(f"生成测试数据: {len(data)} 条记录")
|
||||||
|
|
||||||
|
# 计算均线
|
||||||
|
data = ma5102030_test(data)
|
||||||
|
|
||||||
|
# 分析交叉信号
|
||||||
|
cross_signals = data[data['ma_cross'] != '']
|
||||||
|
print(f"\n检测到 {len(cross_signals)} 个交叉信号")
|
||||||
|
|
||||||
|
if len(cross_signals) > 0:
|
||||||
|
print("\n交叉信号详情:")
|
||||||
|
for idx, row in cross_signals.iterrows():
|
||||||
|
print(f"时间: {row['timestamp']}, 信号: {row['ma_cross']}")
|
||||||
|
|
||||||
|
# 统计不同类型的交叉
|
||||||
|
cross_types = {}
|
||||||
|
for signal in data['ma_cross'].unique():
|
||||||
|
if signal != '':
|
||||||
|
count = (data['ma_cross'] == signal).sum()
|
||||||
|
cross_types[signal] = count
|
||||||
|
|
||||||
|
print(f"\n交叉类型统计:")
|
||||||
|
for cross_type, count in sorted(cross_types.items()):
|
||||||
|
print(f"{cross_type}: {count} 次")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def analyze_cross_combinations(data):
|
||||||
|
"""分析交叉组合的效果"""
|
||||||
|
print("\n=== 交叉组合分析 ===")
|
||||||
|
|
||||||
|
# 获取所有交叉信号
|
||||||
|
cross_data = data[data['ma_cross'] != ''].copy()
|
||||||
|
|
||||||
|
if len(cross_data) == 0:
|
||||||
|
print("未检测到交叉信号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 分析组合信号
|
||||||
|
combination_signals = cross_data[cross_data['ma_cross'].str.contains(',')]
|
||||||
|
single_signals = cross_data[~cross_data['ma_cross'].str.contains(',')]
|
||||||
|
|
||||||
|
print(f"组合交叉信号: {len(combination_signals)} 个")
|
||||||
|
print(f"单一交叉信号: {len(single_signals)} 个")
|
||||||
|
|
||||||
|
if len(combination_signals) > 0:
|
||||||
|
print("\n组合交叉信号详情:")
|
||||||
|
for idx, row in combination_signals.iterrows():
|
||||||
|
print(f"时间: {row['timestamp']}, 组合信号: {row['ma_cross']}")
|
||||||
|
|
||||||
|
# 分析上穿和下穿信号
|
||||||
|
up_cross_signals = cross_data[cross_data['ma_cross'].str.contains('上穿')]
|
||||||
|
down_cross_signals = cross_data[cross_data['ma_cross'].str.contains('下穿')]
|
||||||
|
|
||||||
|
print(f"\n上穿信号: {len(up_cross_signals)} 个")
|
||||||
|
print(f"下穿信号: {len(down_cross_signals)} 个")
|
||||||
|
|
||||||
|
# 统计各种交叉类型
|
||||||
|
print(f"\n详细交叉类型统计:")
|
||||||
|
cross_type_counts = {}
|
||||||
|
for signal in cross_data['ma_cross'].unique():
|
||||||
|
if signal != '':
|
||||||
|
count = (cross_data['ma_cross'] == signal).sum()
|
||||||
|
cross_type_counts[signal] = count
|
||||||
|
|
||||||
|
# 按类型分组显示
|
||||||
|
up_cross_types = {k: v for k, v in cross_type_counts.items() if '上穿' in k}
|
||||||
|
down_cross_types = {k: v for k, v in cross_type_counts.items() if '下穿' in k}
|
||||||
|
|
||||||
|
print(f"\n上穿信号类型:")
|
||||||
|
for cross_type, count in sorted(up_cross_types.items()):
|
||||||
|
print(f" {cross_type}: {count} 次")
|
||||||
|
|
||||||
|
print(f"\n下穿信号类型:")
|
||||||
|
for cross_type, count in sorted(down_cross_types.items()):
|
||||||
|
print(f" {cross_type}: {count} 次")
|
||||||
|
|
||||||
|
# 分析信号强度
|
||||||
|
print(f"\n信号强度分析:")
|
||||||
|
print(f"总交叉信号: {len(cross_data)}")
|
||||||
|
print(f"组合信号占比: {len(combination_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"单一信号占比: {len(single_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"上穿信号占比: {len(up_cross_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"下穿信号占比: {len(down_cross_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("开始测试均线交叉检测优化...")
|
||||||
|
|
||||||
|
# 测试优化算法
|
||||||
|
data = test_ma_cross_optimization()
|
||||||
|
|
||||||
|
# 分析交叉组合
|
||||||
|
analyze_cross_combinations(data)
|
||||||
|
|
||||||
|
print("\n=== 测试完成 ===")
|
||||||
|
print("\n优化效果:")
|
||||||
|
print("1. 能够检测多个均线同时交叉的情况")
|
||||||
|
print("2. 更好地识别趋势转变的关键时刻")
|
||||||
|
print("3. 提供更丰富的技术分析信息")
|
||||||
|
print("4. 减少信号噪音,提高信号质量")
|
||||||
|
print("5. 支持完整的均线交叉类型:5上穿10/20/30,10上穿20/30,20上穿30")
|
||||||
|
print("6. 支持对应的下穿信号:10下穿5,20下穿10/5,30下穿20/10/5")
|
||||||
|
print("7. 使用更清晰的上穿/下穿命名规范")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,232 @@
|
||||||
|
"""
|
||||||
|
均线交叉检测优化算法测试脚本
|
||||||
|
|
||||||
|
测试优化后的ma5102030方法,验证多均线同时交叉的检测效果
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from core.biz.metrics_calculation import MetricsCalculation
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
# plt支持中文
|
||||||
|
plt.rcParams['font.family'] = ['SimHei']
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
def generate_test_data_with_crosses(n=200):
|
||||||
|
"""生成包含多个均线交叉的测试数据"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# 生成价格数据,包含明显的趋势变化
|
||||||
|
price = 100
|
||||||
|
prices = []
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
if i < 50:
|
||||||
|
# 第一阶段:下跌趋势
|
||||||
|
price -= 0.5 + np.random.normal(0, 0.3)
|
||||||
|
elif i < 100:
|
||||||
|
# 第二阶段:震荡
|
||||||
|
price += np.random.normal(0, 0.5)
|
||||||
|
elif i < 150:
|
||||||
|
# 第三阶段:强势上涨
|
||||||
|
price += 1.0 + np.random.normal(0, 0.3)
|
||||||
|
else:
|
||||||
|
# 第四阶段:回调
|
||||||
|
price -= 0.3 + np.random.normal(0, 0.4)
|
||||||
|
|
||||||
|
prices.append(max(price, 50)) # 确保价格不会太低
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
|
||||||
|
'close': prices,
|
||||||
|
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
|
||||||
|
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'volume': np.random.randint(1000, 10000, n)
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def test_ma_cross_optimization():
|
||||||
|
"""测试优化后的均线交叉检测"""
|
||||||
|
print("=== 均线交叉检测优化测试 ===\n")
|
||||||
|
|
||||||
|
# 生成测试数据
|
||||||
|
data = generate_test_data_with_crosses(200)
|
||||||
|
print(f"生成测试数据: {len(data)} 条记录")
|
||||||
|
|
||||||
|
# 初始化指标计算器
|
||||||
|
metrics = MetricsCalculation()
|
||||||
|
|
||||||
|
# 计算均线
|
||||||
|
data = metrics.ma5102030(data)
|
||||||
|
|
||||||
|
data = metrics.calculate_ma_price_percent(data)
|
||||||
|
|
||||||
|
# 分析交叉信号
|
||||||
|
cross_signals = data[data['ma_cross'] != '']
|
||||||
|
print(f"\n检测到 {len(cross_signals)} 个交叉信号")
|
||||||
|
|
||||||
|
if len(cross_signals) > 0:
|
||||||
|
print("\n交叉信号详情:")
|
||||||
|
for idx, row in cross_signals.iterrows():
|
||||||
|
print(f"时间: {row['timestamp']}, 信号: {row['ma_cross']}")
|
||||||
|
|
||||||
|
# 统计不同类型的交叉
|
||||||
|
cross_types = {}
|
||||||
|
for signal in data['ma_cross'].unique():
|
||||||
|
if signal != '':
|
||||||
|
count = (data['ma_cross'] == signal).sum()
|
||||||
|
cross_types[signal] = count
|
||||||
|
|
||||||
|
print(f"\n交叉类型统计:")
|
||||||
|
for cross_type, count in sorted(cross_types.items()):
|
||||||
|
print(f"{cross_type}: {count} 次")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def visualize_ma_crosses(data):
|
||||||
|
"""可视化均线交叉信号"""
|
||||||
|
print("\n=== 生成均线交叉可视化图表 ===")
|
||||||
|
|
||||||
|
# 选择数据
|
||||||
|
plot_data = data.copy()
|
||||||
|
|
||||||
|
# 创建图表
|
||||||
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
|
||||||
|
fig.suptitle('均线交叉检测优化效果', fontsize=16)
|
||||||
|
|
||||||
|
# 价格和均线图
|
||||||
|
ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7, linewidth=1)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8, linewidth=1.5)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8, linewidth=1.5)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8, linewidth=1.5)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8, linewidth=1.5)
|
||||||
|
|
||||||
|
# 标记交叉点
|
||||||
|
cross_points = plot_data[plot_data['ma_cross'] != '']
|
||||||
|
for idx, row in cross_points.iterrows():
|
||||||
|
ax1.scatter(idx, row['close'], color='red', s=100, alpha=0.8, zorder=5)
|
||||||
|
ax1.annotate(row['ma_cross'], (idx, row['close']),
|
||||||
|
xytext=(10, 10), textcoords='offset points',
|
||||||
|
fontsize=8, bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
|
||||||
|
|
||||||
|
ax1.set_title('价格、均线和交叉信号')
|
||||||
|
ax1.legend()
|
||||||
|
ax1.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 均线偏离度图
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7)
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7)
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7)
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7)
|
||||||
|
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
|
||||||
|
|
||||||
|
# 标记交叉点
|
||||||
|
for idx, row in cross_points.iterrows():
|
||||||
|
ax2.scatter(idx, 0, color='red', s=100, alpha=0.8, zorder=5)
|
||||||
|
|
||||||
|
ax2.set_title('均线偏离度和交叉信号')
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
folder = "./output/algorithm/"
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
file_name = f"{folder}/ma_cross_optimization.png"
|
||||||
|
plt.savefig(file_name, dpi=300, bbox_inches='tight')
|
||||||
|
print("图表已保存为: ma_cross_optimization.png")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def analyze_cross_combinations(data):
|
||||||
|
"""分析交叉组合的效果"""
|
||||||
|
print("\n=== 交叉组合分析 ===")
|
||||||
|
|
||||||
|
# 获取所有交叉信号
|
||||||
|
cross_data = data[data['ma_cross'] != ''].copy()
|
||||||
|
|
||||||
|
if len(cross_data) == 0:
|
||||||
|
print("未检测到交叉信号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 分析组合信号
|
||||||
|
combination_signals = cross_data[cross_data['ma_cross'].str.contains(',')]
|
||||||
|
single_signals = cross_data[~cross_data['ma_cross'].str.contains(',')]
|
||||||
|
|
||||||
|
print(f"组合交叉信号: {len(combination_signals)} 个")
|
||||||
|
print(f"单一交叉信号: {len(single_signals)} 个")
|
||||||
|
|
||||||
|
if len(combination_signals) > 0:
|
||||||
|
print("\n组合交叉信号详情:")
|
||||||
|
for idx, row in combination_signals.iterrows():
|
||||||
|
print(f"时间: {row['timestamp']}, 组合信号: {row['ma_cross']}")
|
||||||
|
|
||||||
|
# 分析上穿和下穿信号
|
||||||
|
up_cross_signals = cross_data[cross_data['ma_cross'].str.contains('上穿')]
|
||||||
|
down_cross_signals = cross_data[cross_data['ma_cross'].str.contains('下穿')]
|
||||||
|
|
||||||
|
print(f"\n上穿信号: {len(up_cross_signals)} 个")
|
||||||
|
print(f"下穿信号: {len(down_cross_signals)} 个")
|
||||||
|
|
||||||
|
# 统计各种交叉类型
|
||||||
|
print(f"\n详细交叉类型统计:")
|
||||||
|
cross_type_counts = {}
|
||||||
|
for signal in cross_data['ma_cross'].unique():
|
||||||
|
if signal != '':
|
||||||
|
count = (cross_data['ma_cross'] == signal).sum()
|
||||||
|
cross_type_counts[signal] = count
|
||||||
|
|
||||||
|
# 按类型分组显示
|
||||||
|
up_cross_types = {k: v for k, v in cross_type_counts.items() if '上穿' in k}
|
||||||
|
down_cross_types = {k: v for k, v in cross_type_counts.items() if '下穿' in k}
|
||||||
|
|
||||||
|
print(f"\n上穿信号类型:")
|
||||||
|
for cross_type, count in sorted(up_cross_types.items()):
|
||||||
|
print(f" {cross_type}: {count} 次")
|
||||||
|
|
||||||
|
print(f"\n下穿信号类型:")
|
||||||
|
for cross_type, count in sorted(down_cross_types.items()):
|
||||||
|
print(f" {cross_type}: {count} 次")
|
||||||
|
|
||||||
|
# 分析信号强度
|
||||||
|
print(f"\n信号强度分析:")
|
||||||
|
print(f"总交叉信号: {len(cross_data)}")
|
||||||
|
print(f"组合信号占比: {len(combination_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"单一信号占比: {len(single_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"上穿信号占比: {len(up_cross_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"下穿信号占比: {len(down_cross_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("开始测试均线交叉检测优化...")
|
||||||
|
|
||||||
|
# 测试优化算法
|
||||||
|
data = test_ma_cross_optimization()
|
||||||
|
|
||||||
|
# 分析交叉组合
|
||||||
|
analyze_cross_combinations(data)
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
try:
|
||||||
|
visualize_ma_crosses(data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"可视化失败: {e}")
|
||||||
|
|
||||||
|
print("\n=== 测试完成 ===")
|
||||||
|
print("\n优化效果:")
|
||||||
|
print("1. 能够检测多个均线同时交叉的情况")
|
||||||
|
print("2. 更好地识别趋势转变的关键时刻")
|
||||||
|
print("3. 提供更丰富的技术分析信息")
|
||||||
|
print("4. 减少信号噪音,提高信号质量")
|
||||||
|
print("5. 支持完整的均线交叉类型:5上穿10/20/30,10上穿20/30,20上穿30")
|
||||||
|
print("6. 支持对应的下穿信号:10下穿5,20下穿10/5,30下穿20/10/5")
|
||||||
|
print("7. 使用更清晰的上穿/下穿命名规范")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
"""
|
||||||
|
均线交叉检测简单测试脚本
|
||||||
|
|
||||||
|
测试更新后的ma5102030方法,验证新的交叉类型和命名规范
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from core.biz.metrics_calculation import MetricsCalculation
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
def generate_test_data_with_crosses(n=200):
|
||||||
|
"""生成包含多个均线交叉的测试数据"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# 生成价格数据,包含明显的趋势变化
|
||||||
|
price = 100
|
||||||
|
prices = []
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
if i < 50:
|
||||||
|
# 第一阶段:下跌趋势
|
||||||
|
price -= 0.5 + np.random.normal(0, 0.3)
|
||||||
|
elif i < 100:
|
||||||
|
# 第二阶段:震荡
|
||||||
|
price += np.random.normal(0, 0.5)
|
||||||
|
elif i < 150:
|
||||||
|
# 第三阶段:强势上涨
|
||||||
|
price += 1.0 + np.random.normal(0, 0.3)
|
||||||
|
else:
|
||||||
|
# 第四阶段:回调
|
||||||
|
price -= 0.3 + np.random.normal(0, 0.4)
|
||||||
|
|
||||||
|
prices.append(max(price, 50)) # 确保价格不会太低
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
|
||||||
|
'close': prices,
|
||||||
|
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
|
||||||
|
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'volume': np.random.randint(1000, 10000, n)
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def test_ma_cross_optimization():
|
||||||
|
"""测试优化后的均线交叉检测"""
|
||||||
|
print("=== 均线交叉检测优化测试 ===\n")
|
||||||
|
|
||||||
|
# 生成测试数据
|
||||||
|
data = generate_test_data_with_crosses(200)
|
||||||
|
print(f"生成测试数据: {len(data)} 条记录")
|
||||||
|
|
||||||
|
# 初始化指标计算器
|
||||||
|
metrics = MetricsCalculation()
|
||||||
|
|
||||||
|
# 计算均线
|
||||||
|
data = metrics.ma5102030(data)
|
||||||
|
|
||||||
|
# 分析交叉信号
|
||||||
|
cross_signals = data[data['ma_cross'] != '']
|
||||||
|
print(f"\n检测到 {len(cross_signals)} 个交叉信号")
|
||||||
|
|
||||||
|
if len(cross_signals) > 0:
|
||||||
|
print("\n交叉信号详情:")
|
||||||
|
for idx, row in cross_signals.iterrows():
|
||||||
|
print(f"时间: {row['timestamp']}, 信号: {row['ma_cross']}")
|
||||||
|
|
||||||
|
# 统计不同类型的交叉
|
||||||
|
cross_types = {}
|
||||||
|
for signal in data['ma_cross'].unique():
|
||||||
|
if signal != '':
|
||||||
|
count = (data['ma_cross'] == signal).sum()
|
||||||
|
cross_types[signal] = count
|
||||||
|
|
||||||
|
print(f"\n交叉类型统计:")
|
||||||
|
for cross_type, count in sorted(cross_types.items()):
|
||||||
|
print(f"{cross_type}: {count} 次")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def analyze_cross_combinations(data):
|
||||||
|
"""分析交叉组合的效果"""
|
||||||
|
print("\n=== 交叉组合分析 ===")
|
||||||
|
|
||||||
|
# 获取所有交叉信号
|
||||||
|
cross_data = data[data['ma_cross'] != ''].copy()
|
||||||
|
|
||||||
|
if len(cross_data) == 0:
|
||||||
|
print("未检测到交叉信号")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 分析组合信号
|
||||||
|
combination_signals = cross_data[cross_data['ma_cross'].str.contains(',')]
|
||||||
|
single_signals = cross_data[~cross_data['ma_cross'].str.contains(',')]
|
||||||
|
|
||||||
|
print(f"组合交叉信号: {len(combination_signals)} 个")
|
||||||
|
print(f"单一交叉信号: {len(single_signals)} 个")
|
||||||
|
|
||||||
|
if len(combination_signals) > 0:
|
||||||
|
print("\n组合交叉信号详情:")
|
||||||
|
for idx, row in combination_signals.iterrows():
|
||||||
|
print(f"时间: {row['timestamp']}, 组合信号: {row['ma_cross']}")
|
||||||
|
|
||||||
|
# 分析上穿和下穿信号
|
||||||
|
up_cross_signals = cross_data[cross_data['ma_cross'].str.contains('上穿')]
|
||||||
|
down_cross_signals = cross_data[cross_data['ma_cross'].str.contains('下穿')]
|
||||||
|
|
||||||
|
print(f"\n上穿信号: {len(up_cross_signals)} 个")
|
||||||
|
print(f"下穿信号: {len(down_cross_signals)} 个")
|
||||||
|
|
||||||
|
# 统计各种交叉类型
|
||||||
|
print(f"\n详细交叉类型统计:")
|
||||||
|
cross_type_counts = {}
|
||||||
|
for signal in cross_data['ma_cross'].unique():
|
||||||
|
if signal != '':
|
||||||
|
count = (cross_data['ma_cross'] == signal).sum()
|
||||||
|
cross_type_counts[signal] = count
|
||||||
|
|
||||||
|
# 按类型分组显示
|
||||||
|
up_cross_types = {k: v for k, v in cross_type_counts.items() if '上穿' in k}
|
||||||
|
down_cross_types = {k: v for k, v in cross_type_counts.items() if '下穿' in k}
|
||||||
|
|
||||||
|
print(f"\n上穿信号类型:")
|
||||||
|
for cross_type, count in sorted(up_cross_types.items()):
|
||||||
|
print(f" {cross_type}: {count} 次")
|
||||||
|
|
||||||
|
print(f"\n下穿信号类型:")
|
||||||
|
for cross_type, count in sorted(down_cross_types.items()):
|
||||||
|
print(f" {cross_type}: {count} 次")
|
||||||
|
|
||||||
|
# 分析信号强度
|
||||||
|
print(f"\n信号强度分析:")
|
||||||
|
print(f"总交叉信号: {len(cross_data)}")
|
||||||
|
print(f"组合信号占比: {len(combination_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"单一信号占比: {len(single_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"上穿信号占比: {len(up_cross_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
print(f"下穿信号占比: {len(down_cross_signals)/len(cross_data)*100:.1f}%")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("开始测试均线交叉检测优化...")
|
||||||
|
|
||||||
|
# 测试优化算法
|
||||||
|
data = test_ma_cross_optimization()
|
||||||
|
|
||||||
|
# 分析交叉组合
|
||||||
|
analyze_cross_combinations(data)
|
||||||
|
|
||||||
|
print("\n=== 测试完成 ===")
|
||||||
|
print("\n优化效果:")
|
||||||
|
print("1. 能够检测多个均线同时交叉的情况")
|
||||||
|
print("2. 更好地识别趋势转变的关键时刻")
|
||||||
|
print("3. 提供更丰富的技术分析信息")
|
||||||
|
print("4. 减少信号噪音,提高信号质量")
|
||||||
|
print("5. 支持完整的均线交叉类型:5上穿10/20/30,10上穿20/30,20上穿30")
|
||||||
|
print("6. 支持对应的下穿信号:10下穿5,20下穿10/5,30下穿20/10/5")
|
||||||
|
print("7. 使用更清晰的上穿/下穿命名规范")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,259 @@
|
||||||
|
"""
|
||||||
|
均线多空判定方法测试脚本
|
||||||
|
|
||||||
|
本脚本用于测试和比较不同均线多空判定方法的有效性。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from core.db.db_market_data import DBMarketData
|
||||||
|
from core.biz.metrics_calculation import MetricsCalculation
|
||||||
|
import logging
|
||||||
|
from config import MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
|
||||||
|
# plt支持中文
|
||||||
|
plt.rcParams['font.family'] = ['SimHei']
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
def get_real_data(symbol, bar, start, end):
|
||||||
|
mysql_user = MYSQL_CONFIG.get("user", "xch")
|
||||||
|
mysql_password = MYSQL_CONFIG.get("password", "")
|
||||||
|
if not mysql_password:
|
||||||
|
raise ValueError("MySQL password is not set")
|
||||||
|
mysql_host = MYSQL_CONFIG.get("host", "localhost")
|
||||||
|
mysql_port = MYSQL_CONFIG.get("port", 3306)
|
||||||
|
mysql_database = MYSQL_CONFIG.get("database", "okx")
|
||||||
|
|
||||||
|
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||||||
|
db_market_data = DBMarketData(db_url)
|
||||||
|
data = db_market_data.query_market_data_by_symbol_bar(
|
||||||
|
symbol, bar, start, end
|
||||||
|
)
|
||||||
|
if data is None:
|
||||||
|
logging.warning(
|
||||||
|
f"获取行情数据失败: {symbol} {bar} 从 {start} 到 {end}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if len(data) == 0:
|
||||||
|
logging.warning(
|
||||||
|
f"获取行情数据为空: {symbol} {bar} 从 {start} 到 {end}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if isinstance(data, list):
|
||||||
|
data = pd.DataFrame(data)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data = pd.DataFrame([data])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def generate_test_data(n=1000):
|
||||||
|
"""生成测试数据"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# 生成价格数据
|
||||||
|
price = 100
|
||||||
|
prices = []
|
||||||
|
for i in range(n):
|
||||||
|
# 添加趋势和随机波动
|
||||||
|
trend = np.sin(i * 0.1) * 2 # 周期性趋势
|
||||||
|
noise = np.random.normal(0, 1) # 随机噪声
|
||||||
|
price += trend + noise
|
||||||
|
prices.append(price)
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
|
||||||
|
'close': prices,
|
||||||
|
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
|
||||||
|
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
|
||||||
|
'volume': np.random.randint(1000, 10000, n)
|
||||||
|
})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def test_ma_methods():
|
||||||
|
"""测试不同的均线多空判定方法"""
|
||||||
|
print("=== 均线多空判定方法测试 ===\n")
|
||||||
|
|
||||||
|
# 生成测试数据
|
||||||
|
data = get_real_data("BTC-USDT", "15m", "2025-07-01 00:00:00", "2025-08-07 00:00:00")
|
||||||
|
# data = generate_test_data(1000)
|
||||||
|
print(f"生成测试数据: {len(data)} 条记录")
|
||||||
|
|
||||||
|
# 初始化指标计算器
|
||||||
|
metrics = MetricsCalculation()
|
||||||
|
|
||||||
|
# 计算均线
|
||||||
|
data = metrics.ma5102030(data)
|
||||||
|
data = metrics.calculate_ma_price_percent(data)
|
||||||
|
|
||||||
|
# 测试不同方法
|
||||||
|
methods = [
|
||||||
|
"weighted_voting",
|
||||||
|
"trend_strength",
|
||||||
|
"ma_alignment",
|
||||||
|
"statistical",
|
||||||
|
"hybrid"
|
||||||
|
]
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for method in methods:
|
||||||
|
print(f"\n--- 测试方法: {method} ---")
|
||||||
|
|
||||||
|
# 复制数据避免相互影响
|
||||||
|
test_data = data.copy()
|
||||||
|
|
||||||
|
# 应用方法
|
||||||
|
test_data = metrics.set_ma_long_short_advanced(test_data, method=method)
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
long_count = (test_data['ma_long_short'] == '多').sum()
|
||||||
|
short_count = (test_data['ma_long_short'] == '空').sum()
|
||||||
|
neutral_count = (test_data['ma_long_short'] == '震荡').sum()
|
||||||
|
|
||||||
|
results[method] = {
|
||||||
|
'long': long_count,
|
||||||
|
'short': short_count,
|
||||||
|
'neutral': neutral_count,
|
||||||
|
'data': test_data
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"多头信号: {long_count} ({long_count/len(test_data)*100:.1f}%)")
|
||||||
|
print(f"空头信号: {short_count} ({short_count/len(test_data)*100:.1f}%)")
|
||||||
|
print(f"震荡信号: {neutral_count} ({neutral_count/len(test_data)*100:.1f}%)")
|
||||||
|
|
||||||
|
return results, data
|
||||||
|
|
||||||
|
def compare_methods(results, original_data):
|
||||||
|
"""比较不同方法的结果"""
|
||||||
|
print("\n=== 方法比较分析 ===\n")
|
||||||
|
|
||||||
|
# 创建比较表格
|
||||||
|
comparison_data = []
|
||||||
|
for method, result in results.items():
|
||||||
|
total = result['long'] + result['short'] + result['neutral']
|
||||||
|
comparison_data.append({
|
||||||
|
'方法': method,
|
||||||
|
'多头信号': f"{result['long']} ({result['long']/total*100:.1f}%)",
|
||||||
|
'空头信号': f"{result['short']} ({result['short']/total*100:.1f}%)",
|
||||||
|
'震荡信号': f"{result['neutral']} ({result['neutral']/total*100:.1f}%)",
|
||||||
|
'信号总数': result['long'] + result['short']
|
||||||
|
})
|
||||||
|
|
||||||
|
comparison_df = pd.DataFrame(comparison_data)
|
||||||
|
print("信号分布比较:")
|
||||||
|
print(comparison_df.to_string(index=False))
|
||||||
|
|
||||||
|
# 分析信号一致性
|
||||||
|
print("\n信号一致性分析:")
|
||||||
|
methods = list(results.keys())
|
||||||
|
|
||||||
|
for i in range(len(methods)):
|
||||||
|
for j in range(i+1, len(methods)):
|
||||||
|
method1, method2 = methods[i], methods[j]
|
||||||
|
data1 = results[method1]['data']['ma_long_short']
|
||||||
|
data2 = results[method2]['data']['ma_long_short']
|
||||||
|
|
||||||
|
# 计算一致性
|
||||||
|
agreement = (data1 == data2).sum()
|
||||||
|
agreement_rate = agreement / len(data1) * 100
|
||||||
|
|
||||||
|
print(f"{method1} vs {method2}: {agreement_rate:.1f}% 一致")
|
||||||
|
|
||||||
|
def visualize_results(results, original_data):
|
||||||
|
"""可视化结果"""
|
||||||
|
print("\n=== 生成可视化图表 ===")
|
||||||
|
|
||||||
|
# 选择前2000个数据点进行可视化
|
||||||
|
plot_data = original_data.head(1000).copy()
|
||||||
|
|
||||||
|
# 添加不同方法的结果
|
||||||
|
for method, result in results.items():
|
||||||
|
plot_data[f'{method}_signal'] = result['data'].head(1000)['ma_long_short']
|
||||||
|
|
||||||
|
# 创建图表
|
||||||
|
fig, axes = plt.subplots(4, 2, figsize=(15, 16))
|
||||||
|
fig.suptitle('均线多空判定方法比较', fontsize=16)
|
||||||
|
|
||||||
|
# 价格和均线图
|
||||||
|
ax1 = axes[0, 0]
|
||||||
|
ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8)
|
||||||
|
ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8)
|
||||||
|
ax1.set_title('价格和均线')
|
||||||
|
ax1.legend()
|
||||||
|
ax1.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 均线偏离度图
|
||||||
|
ax2 = axes[0, 1]
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7)
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7)
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7)
|
||||||
|
ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7)
|
||||||
|
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
|
||||||
|
ax2.set_title('均线偏离度')
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 不同方法的信号图
|
||||||
|
methods = list(results.keys())
|
||||||
|
colors = ['red', 'blue', 'green', 'orange', 'purple']
|
||||||
|
|
||||||
|
for i, method in enumerate(methods):
|
||||||
|
print(f"绘制{method}方法信号")
|
||||||
|
row = 1 + (i // 2) # 从第2行开始,每行2个图
|
||||||
|
col = i % 2 # 0或1
|
||||||
|
ax = axes[row, col]
|
||||||
|
|
||||||
|
# 绘制信号
|
||||||
|
long_signals = plot_data[plot_data[f'{method}_signal'] == '多'].index
|
||||||
|
short_signals = plot_data[plot_data[f'{method}_signal'] == '空'].index
|
||||||
|
|
||||||
|
ax.plot(plot_data.index, plot_data['close'], color='gray', alpha=0.5, label='价格')
|
||||||
|
ax.scatter(long_signals, plot_data.loc[long_signals, 'close'],
|
||||||
|
color='red', marker='^', s=50, label='多头信号', alpha=0.8)
|
||||||
|
ax.scatter(short_signals, plot_data.loc[short_signals, 'close'],
|
||||||
|
color='blue', marker='v', s=50, label='空头信号', alpha=0.8)
|
||||||
|
|
||||||
|
ax.set_title(f'{method}方法信号')
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('ma_methods_comparison.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("图表已保存为: ma_methods_comparison.png")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("开始测试均线多空判定方法...")
|
||||||
|
|
||||||
|
# 测试方法
|
||||||
|
results, original_data = test_ma_methods()
|
||||||
|
|
||||||
|
# 比较结果
|
||||||
|
compare_methods(results, original_data)
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
try:
|
||||||
|
visualize_results(results, original_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"可视化失败: {e}")
|
||||||
|
|
||||||
|
print("\n=== 测试完成 ===")
|
||||||
|
print("\n建议:")
|
||||||
|
print("1. 加权投票方法适合大多数市场环境")
|
||||||
|
print("2. 趋势强度方法适合趋势明显的市场")
|
||||||
|
print("3. 均线排列方法适合技术分析")
|
||||||
|
print("4. 统计方法适合波动较大的市场")
|
||||||
|
print("5. 混合方法综合了多种优势,但计算复杂度较高")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue