crypto_quant/test_ma_cross_optimization.py

232 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
均线交叉检测优化算法测试脚本
测试优化后的ma5102030方法验证多均线同时交叉的检测效果
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from core.biz.metrics_calculation import MetricsCalculation
import logging
import os
# plt支持中文
plt.rcParams['font.family'] = ['SimHei']
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def generate_test_data_with_crosses(n=200):
"""生成包含多个均线交叉的测试数据"""
np.random.seed(42)
# 生成价格数据,包含明显的趋势变化
price = 100
prices = []
for i in range(n):
if i < 50:
# 第一阶段:下跌趋势
price -= 0.5 + np.random.normal(0, 0.3)
elif i < 100:
# 第二阶段:震荡
price += np.random.normal(0, 0.5)
elif i < 150:
# 第三阶段:强势上涨
price += 1.0 + np.random.normal(0, 0.3)
else:
# 第四阶段:回调
price -= 0.3 + np.random.normal(0, 0.4)
prices.append(max(price, 50)) # 确保价格不会太低
# 创建DataFrame
data = pd.DataFrame({
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
'close': prices,
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
'volume': np.random.randint(1000, 10000, n)
})
return data
def test_ma_cross_optimization():
"""测试优化后的均线交叉检测"""
print("=== 均线交叉检测优化测试 ===\n")
# 生成测试数据
data = generate_test_data_with_crosses(200)
print(f"生成测试数据: {len(data)} 条记录")
# 初始化指标计算器
metrics = MetricsCalculation()
# 计算均线
data = metrics.ma5102030(data)
data = metrics.calculate_ma_price_percent(data)
# 分析交叉信号
cross_signals = data[data['ma_cross'] != '']
print(f"\n检测到 {len(cross_signals)} 个交叉信号")
if len(cross_signals) > 0:
print("\n交叉信号详情:")
for idx, row in cross_signals.iterrows():
print(f"时间: {row['timestamp']}, 信号: {row['ma_cross']}")
# 统计不同类型的交叉
cross_types = {}
for signal in data['ma_cross'].unique():
if signal != '':
count = (data['ma_cross'] == signal).sum()
cross_types[signal] = count
print(f"\n交叉类型统计:")
for cross_type, count in sorted(cross_types.items()):
print(f"{cross_type}: {count}")
return data
def visualize_ma_crosses(data):
"""可视化均线交叉信号"""
print("\n=== 生成均线交叉可视化图表 ===")
# 选择数据
plot_data = data.copy()
# 创建图表
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
fig.suptitle('均线交叉检测优化效果', fontsize=16)
# 价格和均线图
ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7, linewidth=1)
ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8, linewidth=1.5)
ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8, linewidth=1.5)
ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8, linewidth=1.5)
ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8, linewidth=1.5)
# 标记交叉点
cross_points = plot_data[plot_data['ma_cross'] != '']
for idx, row in cross_points.iterrows():
ax1.scatter(idx, row['close'], color='red', s=100, alpha=0.8, zorder=5)
ax1.annotate(row['ma_cross'], (idx, row['close']),
xytext=(10, 10), textcoords='offset points',
fontsize=8, bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))
ax1.set_title('价格、均线和交叉信号')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 均线偏离度图
ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7)
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
# 标记交叉点
for idx, row in cross_points.iterrows():
ax2.scatter(idx, 0, color='red', s=100, alpha=0.8, zorder=5)
ax2.set_title('均线偏离度和交叉信号')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
folder = "./output/algorithm/"
os.makedirs(folder, exist_ok=True)
file_name = f"{folder}/ma_cross_optimization.png"
plt.savefig(file_name, dpi=300, bbox_inches='tight')
print("图表已保存为: ma_cross_optimization.png")
plt.show()
def analyze_cross_combinations(data):
"""分析交叉组合的效果"""
print("\n=== 交叉组合分析 ===")
# 获取所有交叉信号
cross_data = data[data['ma_cross'] != ''].copy()
if len(cross_data) == 0:
print("未检测到交叉信号")
return
# 分析组合信号
combination_signals = cross_data[cross_data['ma_cross'].str.contains('')]
single_signals = cross_data[~cross_data['ma_cross'].str.contains('')]
print(f"组合交叉信号: {len(combination_signals)}")
print(f"单一交叉信号: {len(single_signals)}")
if len(combination_signals) > 0:
print("\n组合交叉信号详情:")
for idx, row in combination_signals.iterrows():
print(f"时间: {row['timestamp']}, 组合信号: {row['ma_cross']}")
# 分析上穿和下穿信号
up_cross_signals = cross_data[cross_data['ma_cross'].str.contains('上穿')]
down_cross_signals = cross_data[cross_data['ma_cross'].str.contains('下穿')]
print(f"\n上穿信号: {len(up_cross_signals)}")
print(f"下穿信号: {len(down_cross_signals)}")
# 统计各种交叉类型
print(f"\n详细交叉类型统计:")
cross_type_counts = {}
for signal in cross_data['ma_cross'].unique():
if signal != '':
count = (cross_data['ma_cross'] == signal).sum()
cross_type_counts[signal] = count
# 按类型分组显示
up_cross_types = {k: v for k, v in cross_type_counts.items() if '上穿' in k}
down_cross_types = {k: v for k, v in cross_type_counts.items() if '下穿' in k}
print(f"\n上穿信号类型:")
for cross_type, count in sorted(up_cross_types.items()):
print(f" {cross_type}: {count}")
print(f"\n下穿信号类型:")
for cross_type, count in sorted(down_cross_types.items()):
print(f" {cross_type}: {count}")
# 分析信号强度
print(f"\n信号强度分析:")
print(f"总交叉信号: {len(cross_data)}")
print(f"组合信号占比: {len(combination_signals)/len(cross_data)*100:.1f}%")
print(f"单一信号占比: {len(single_signals)/len(cross_data)*100:.1f}%")
print(f"上穿信号占比: {len(up_cross_signals)/len(cross_data)*100:.1f}%")
print(f"下穿信号占比: {len(down_cross_signals)/len(cross_data)*100:.1f}%")
def main():
"""主函数"""
print("开始测试均线交叉检测优化...")
# 测试优化算法
data = test_ma_cross_optimization()
# 分析交叉组合
analyze_cross_combinations(data)
# 可视化结果
try:
visualize_ma_crosses(data)
except Exception as e:
print(f"可视化失败: {e}")
print("\n=== 测试完成 ===")
print("\n优化效果:")
print("1. 能够检测多个均线同时交叉的情况")
print("2. 更好地识别趋势转变的关键时刻")
print("3. 提供更丰富的技术分析信息")
print("4. 减少信号噪音,提高信号质量")
print("5. 支持完整的均线交叉类型5上穿10/20/3010上穿20/3020上穿30")
print("6. 支持对应的下穿信号10下穿520下穿10/530下穿20/10/5")
print("7. 使用更清晰的上穿/下穿命名规范")
if __name__ == "__main__":
main()