""" 均线多空判定方法测试脚本 本脚本用于测试和比较不同均线多空判定方法的有效性。 """ import pandas as pd import numpy as np import matplotlib.pyplot as plt from core.db.db_market_data import DBMarketData from core.biz.metrics_calculation import MetricsCalculation import logging from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE # plt支持中文 plt.rcParams['font.family'] = ['SimHei'] # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def get_real_data(symbol, bar, start, end): mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_password = MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_database = MYSQL_CONFIG.get("database", "okx") db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" db_market_data = DBMarketData(db_url) data = db_market_data.query_market_data_by_symbol_bar( symbol, bar, start, end ) if data is None: logging.warning( f"获取行情数据失败: {symbol} {bar} 从 {start} 到 {end}" ) return None else: if len(data) == 0: logging.warning( f"获取行情数据为空: {symbol} {bar} 从 {start} 到 {end}" ) return None else: if isinstance(data, list): data = pd.DataFrame(data) elif isinstance(data, dict): data = pd.DataFrame([data]) return data def generate_test_data(n=1000): """生成测试数据""" np.random.seed(42) # 生成价格数据 price = 100 prices = [] for i in range(n): # 添加趋势和随机波动 trend = np.sin(i * 0.1) * 2 # 周期性趋势 noise = np.random.normal(0, 1) # 随机噪声 price += trend + noise prices.append(price) # 创建DataFrame data = pd.DataFrame({ 'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'), 'close': prices, 'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices], 'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices], 'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices], 'volume': np.random.randint(1000, 10000, n) }) return data def test_ma_methods(): """测试不同的均线多空判定方法""" print("=== 均线多空判定方法测试 ===\n") # 生成测试数据 data = get_real_data("BTC-USDT", "15m", "2025-07-01 00:00:00", "2025-08-07 00:00:00") # data = generate_test_data(1000) print(f"生成测试数据: {len(data)} 条记录") # 初始化指标计算器 metrics = MetricsCalculation() # 计算均线 data = metrics.ma5102030(data) data = metrics.calculate_ma_price_percent(data) # 测试不同方法 methods = [ "weighted_voting", "trend_strength", "ma_alignment", "statistical", "hybrid" ] results = {} for method in methods: print(f"\n--- 测试方法: {method} ---") # 复制数据避免相互影响 test_data = data.copy() # 应用方法 test_data = metrics.set_ma_long_short_advanced(test_data, method=method) # 统计结果 long_count = (test_data['ma_long_short'] == '多').sum() short_count = (test_data['ma_long_short'] == '空').sum() neutral_count = (test_data['ma_long_short'] == '震荡').sum() results[method] = { 'long': long_count, 'short': short_count, 'neutral': neutral_count, 'data': test_data } print(f"多头信号: {long_count} ({long_count/len(test_data)*100:.1f}%)") print(f"空头信号: {short_count} ({short_count/len(test_data)*100:.1f}%)") print(f"震荡信号: {neutral_count} ({neutral_count/len(test_data)*100:.1f}%)") return results, data def compare_methods(results, original_data): """比较不同方法的结果""" print("\n=== 方法比较分析 ===\n") # 创建比较表格 comparison_data = [] for method, result in results.items(): total = result['long'] + result['short'] + result['neutral'] comparison_data.append({ '方法': method, '多头信号': f"{result['long']} ({result['long']/total*100:.1f}%)", '空头信号': f"{result['short']} ({result['short']/total*100:.1f}%)", '震荡信号': f"{result['neutral']} ({result['neutral']/total*100:.1f}%)", '信号总数': result['long'] + result['short'] }) comparison_df = pd.DataFrame(comparison_data) print("信号分布比较:") print(comparison_df.to_string(index=False)) # 分析信号一致性 print("\n信号一致性分析:") methods = list(results.keys()) for i in range(len(methods)): for j in range(i+1, len(methods)): method1, method2 = methods[i], methods[j] data1 = results[method1]['data']['ma_long_short'] data2 = results[method2]['data']['ma_long_short'] # 计算一致性 agreement = (data1 == data2).sum() agreement_rate = agreement / len(data1) * 100 print(f"{method1} vs {method2}: {agreement_rate:.1f}% 一致") def visualize_results(results, original_data): """可视化结果""" print("\n=== 生成可视化图表 ===") # 选择前2000个数据点进行可视化 plot_data = original_data.head(1000).copy() # 添加不同方法的结果 for method, result in results.items(): plot_data[f'{method}_signal'] = result['data'].head(1000)['ma_long_short'] # 创建图表 fig, axes = plt.subplots(4, 2, figsize=(15, 16)) fig.suptitle('均线多空判定方法比较', fontsize=16) # 价格和均线图 ax1 = axes[0, 0] ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7) ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8) ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8) ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8) ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8) ax1.set_title('价格和均线') ax1.legend() ax1.grid(True, alpha=0.3) # 均线偏离度图 ax2 = axes[0, 1] ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7) ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7) ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7) ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7) ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5) ax2.set_title('均线偏离度') ax2.legend() ax2.grid(True, alpha=0.3) # 不同方法的信号图 methods = list(results.keys()) colors = ['red', 'blue', 'green', 'orange', 'purple'] for i, method in enumerate(methods): print(f"绘制{method}方法信号") row = 1 + (i // 2) # 从第2行开始,每行2个图 col = i % 2 # 0或1 ax = axes[row, col] # 绘制信号 long_signals = plot_data[plot_data[f'{method}_signal'] == '多'].index short_signals = plot_data[plot_data[f'{method}_signal'] == '空'].index ax.plot(plot_data.index, plot_data['close'], color='gray', alpha=0.5, label='价格') ax.scatter(long_signals, plot_data.loc[long_signals, 'close'], color='red', marker='^', s=50, label='多头信号', alpha=0.8) ax.scatter(short_signals, plot_data.loc[short_signals, 'close'], color='blue', marker='v', s=50, label='空头信号', alpha=0.8) ax.set_title(f'{method}方法信号') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig('ma_methods_comparison.png', dpi=300, bbox_inches='tight') print("图表已保存为: ma_methods_comparison.png") plt.show() def main(): """主函数""" print("开始测试均线多空判定方法...") # 测试方法 results, original_data = test_ma_methods() # 比较结果 compare_methods(results, original_data) # 可视化结果 try: visualize_results(results, original_data) except Exception as e: print(f"可视化失败: {e}") print("\n=== 测试完成 ===") print("\n建议:") print("1. 加权投票方法适合大多数市场环境") print("2. 趋势强度方法适合趋势明显的市场") print("3. 均线排列方法适合技术分析") print("4. 统计方法适合波动较大的市场") print("5. 混合方法综合了多种优势,但计算复杂度较高") if __name__ == "__main__": main()