""" 均线交叉检测优化算法测试脚本 测试优化后的ma5102030方法,验证多均线同时交叉的检测效果 """ import pandas as pd import numpy as np import matplotlib.pyplot as plt from core.biz.metrics_calculation import MetricsCalculation import logging import os # plt支持中文 plt.rcParams['font.family'] = ['SimHei'] # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def generate_test_data_with_crosses(n=200): """生成包含多个均线交叉的测试数据""" np.random.seed(42) # 生成价格数据,包含明显的趋势变化 price = 100 prices = [] for i in range(n): if i < 50: # 第一阶段:下跌趋势 price -= 0.5 + np.random.normal(0, 0.3) elif i < 100: # 第二阶段:震荡 price += np.random.normal(0, 0.5) elif i < 150: # 第三阶段:强势上涨 price += 1.0 + np.random.normal(0, 0.3) else: # 第四阶段:回调 price -= 0.3 + np.random.normal(0, 0.4) prices.append(max(price, 50)) # 确保价格不会太低 # 创建DataFrame data = pd.DataFrame({ 'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'), 'close': prices, 'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices], 'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices], 'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices], 'volume': np.random.randint(1000, 10000, n) }) return data def test_ma_cross_optimization(): """测试优化后的均线交叉检测""" print("=== 均线交叉检测优化测试 ===\n") # 生成测试数据 data = generate_test_data_with_crosses(200) print(f"生成测试数据: {len(data)} 条记录") # 初始化指标计算器 metrics = MetricsCalculation() # 计算均线 data = metrics.ma5102030(data) data = metrics.calculate_ma_price_percent(data) # 分析交叉信号 cross_signals = data[data['ma_cross'] != ''] print(f"\n检测到 {len(cross_signals)} 个交叉信号") if len(cross_signals) > 0: print("\n交叉信号详情:") for idx, row in cross_signals.iterrows(): print(f"时间: {row['timestamp']}, 信号: {row['ma_cross']}") # 统计不同类型的交叉 cross_types = {} for signal in data['ma_cross'].unique(): if signal != '': count = (data['ma_cross'] == signal).sum() cross_types[signal] = count print(f"\n交叉类型统计:") for cross_type, count in sorted(cross_types.items()): print(f"{cross_type}: {count} 次") return data def visualize_ma_crosses(data): """可视化均线交叉信号""" print("\n=== 生成均线交叉可视化图表 ===") # 选择数据 plot_data = data.copy() # 创建图表 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10)) fig.suptitle('均线交叉检测优化效果', fontsize=16) # 价格和均线图 ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7, linewidth=1) ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8, linewidth=1.5) ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8, linewidth=1.5) ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8, linewidth=1.5) ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8, linewidth=1.5) # 标记交叉点 cross_points = plot_data[plot_data['ma_cross'] != ''] for idx, row in cross_points.iterrows(): ax1.scatter(idx, row['close'], color='red', s=100, alpha=0.8, zorder=5) ax1.annotate(row['ma_cross'], (idx, row['close']), xytext=(10, 10), textcoords='offset points', fontsize=8, bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7)) ax1.set_title('价格、均线和交叉信号') ax1.legend() ax1.grid(True, alpha=0.3) # 均线偏离度图 ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7) ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7) ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7) ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7) ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5) # 标记交叉点 for idx, row in cross_points.iterrows(): ax2.scatter(idx, 0, color='red', s=100, alpha=0.8, zorder=5) ax2.set_title('均线偏离度和交叉信号') ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() folder = "./output/algorithm/" os.makedirs(folder, exist_ok=True) file_name = f"{folder}/ma_cross_optimization.png" plt.savefig(file_name, dpi=300, bbox_inches='tight') print("图表已保存为: ma_cross_optimization.png") plt.show() def analyze_cross_combinations(data): """分析交叉组合的效果""" print("\n=== 交叉组合分析 ===") # 获取所有交叉信号 cross_data = data[data['ma_cross'] != ''].copy() if len(cross_data) == 0: print("未检测到交叉信号") return # 分析组合信号 combination_signals = cross_data[cross_data['ma_cross'].str.contains(',')] single_signals = cross_data[~cross_data['ma_cross'].str.contains(',')] print(f"组合交叉信号: {len(combination_signals)} 个") print(f"单一交叉信号: {len(single_signals)} 个") if len(combination_signals) > 0: print("\n组合交叉信号详情:") for idx, row in combination_signals.iterrows(): print(f"时间: {row['timestamp']}, 组合信号: {row['ma_cross']}") # 分析上穿和下穿信号 up_cross_signals = cross_data[cross_data['ma_cross'].str.contains('上穿')] down_cross_signals = cross_data[cross_data['ma_cross'].str.contains('下穿')] print(f"\n上穿信号: {len(up_cross_signals)} 个") print(f"下穿信号: {len(down_cross_signals)} 个") # 统计各种交叉类型 print(f"\n详细交叉类型统计:") cross_type_counts = {} for signal in cross_data['ma_cross'].unique(): if signal != '': count = (cross_data['ma_cross'] == signal).sum() cross_type_counts[signal] = count # 按类型分组显示 up_cross_types = {k: v for k, v in cross_type_counts.items() if '上穿' in k} down_cross_types = {k: v for k, v in cross_type_counts.items() if '下穿' in k} print(f"\n上穿信号类型:") for cross_type, count in sorted(up_cross_types.items()): print(f" {cross_type}: {count} 次") print(f"\n下穿信号类型:") for cross_type, count in sorted(down_cross_types.items()): print(f" {cross_type}: {count} 次") # 分析信号强度 print(f"\n信号强度分析:") print(f"总交叉信号: {len(cross_data)}") print(f"组合信号占比: {len(combination_signals)/len(cross_data)*100:.1f}%") print(f"单一信号占比: {len(single_signals)/len(cross_data)*100:.1f}%") print(f"上穿信号占比: {len(up_cross_signals)/len(cross_data)*100:.1f}%") print(f"下穿信号占比: {len(down_cross_signals)/len(cross_data)*100:.1f}%") def main(): """主函数""" print("开始测试均线交叉检测优化...") # 测试优化算法 data = test_ma_cross_optimization() # 分析交叉组合 analyze_cross_combinations(data) # 可视化结果 try: visualize_ma_crosses(data) except Exception as e: print(f"可视化失败: {e}") print("\n=== 测试完成 ===") print("\n优化效果:") print("1. 能够检测多个均线同时交叉的情况") print("2. 更好地识别趋势转变的关键时刻") print("3. 提供更丰富的技术分析信息") print("4. 减少信号噪音,提高信号质量") print("5. 支持完整的均线交叉类型:5上穿10/20/30,10上穿20/30,20上穿30") print("6. 支持对应的下穿信号:10下穿5,20下穿10/5,30下穿20/10/5") print("7. 使用更清晰的上穿/下穿命名规范") if __name__ == "__main__": main()