crypto_quant/test_ma_methods.py

259 lines
9.1 KiB
Python
Raw Normal View History

"""
均线多空判定方法测试脚本
本脚本用于测试和比较不同均线多空判定方法的有效性
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation
import logging
2025-08-31 03:20:59 +00:00
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
# plt支持中文
plt.rcParams['font.family'] = ['SimHei']
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_real_data(symbol, bar, start, end):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
db_market_data = DBMarketData(db_url)
data = db_market_data.query_market_data_by_symbol_bar(
symbol, bar, start, end
)
if data is None:
logging.warning(
f"获取行情数据失败: {symbol} {bar}{start}{end}"
)
return None
else:
if len(data) == 0:
logging.warning(
f"获取行情数据为空: {symbol} {bar}{start}{end}"
)
return None
else:
if isinstance(data, list):
data = pd.DataFrame(data)
elif isinstance(data, dict):
data = pd.DataFrame([data])
return data
def generate_test_data(n=1000):
"""生成测试数据"""
np.random.seed(42)
# 生成价格数据
price = 100
prices = []
for i in range(n):
# 添加趋势和随机波动
trend = np.sin(i * 0.1) * 2 # 周期性趋势
noise = np.random.normal(0, 1) # 随机噪声
price += trend + noise
prices.append(price)
# 创建DataFrame
data = pd.DataFrame({
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
'close': prices,
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
'volume': np.random.randint(1000, 10000, n)
})
return data
def test_ma_methods():
"""测试不同的均线多空判定方法"""
print("=== 均线多空判定方法测试 ===\n")
# 生成测试数据
data = get_real_data("BTC-USDT", "15m", "2025-07-01 00:00:00", "2025-08-07 00:00:00")
# data = generate_test_data(1000)
print(f"生成测试数据: {len(data)} 条记录")
# 初始化指标计算器
metrics = MetricsCalculation()
# 计算均线
data = metrics.ma5102030(data)
data = metrics.calculate_ma_price_percent(data)
# 测试不同方法
methods = [
"weighted_voting",
"trend_strength",
"ma_alignment",
"statistical",
"hybrid"
]
results = {}
for method in methods:
print(f"\n--- 测试方法: {method} ---")
# 复制数据避免相互影响
test_data = data.copy()
# 应用方法
test_data = metrics.set_ma_long_short_advanced(test_data, method=method)
# 统计结果
long_count = (test_data['ma_long_short'] == '').sum()
short_count = (test_data['ma_long_short'] == '').sum()
neutral_count = (test_data['ma_long_short'] == '震荡').sum()
results[method] = {
'long': long_count,
'short': short_count,
'neutral': neutral_count,
'data': test_data
}
print(f"多头信号: {long_count} ({long_count/len(test_data)*100:.1f}%)")
print(f"空头信号: {short_count} ({short_count/len(test_data)*100:.1f}%)")
print(f"震荡信号: {neutral_count} ({neutral_count/len(test_data)*100:.1f}%)")
return results, data
def compare_methods(results, original_data):
"""比较不同方法的结果"""
print("\n=== 方法比较分析 ===\n")
# 创建比较表格
comparison_data = []
for method, result in results.items():
total = result['long'] + result['short'] + result['neutral']
comparison_data.append({
'方法': method,
'多头信号': f"{result['long']} ({result['long']/total*100:.1f}%)",
'空头信号': f"{result['short']} ({result['short']/total*100:.1f}%)",
'震荡信号': f"{result['neutral']} ({result['neutral']/total*100:.1f}%)",
'信号总数': result['long'] + result['short']
})
comparison_df = pd.DataFrame(comparison_data)
print("信号分布比较:")
print(comparison_df.to_string(index=False))
# 分析信号一致性
print("\n信号一致性分析:")
methods = list(results.keys())
for i in range(len(methods)):
for j in range(i+1, len(methods)):
method1, method2 = methods[i], methods[j]
data1 = results[method1]['data']['ma_long_short']
data2 = results[method2]['data']['ma_long_short']
# 计算一致性
agreement = (data1 == data2).sum()
agreement_rate = agreement / len(data1) * 100
print(f"{method1} vs {method2}: {agreement_rate:.1f}% 一致")
def visualize_results(results, original_data):
"""可视化结果"""
print("\n=== 生成可视化图表 ===")
# 选择前2000个数据点进行可视化
plot_data = original_data.head(1000).copy()
# 添加不同方法的结果
for method, result in results.items():
plot_data[f'{method}_signal'] = result['data'].head(1000)['ma_long_short']
# 创建图表
fig, axes = plt.subplots(4, 2, figsize=(15, 16))
fig.suptitle('均线多空判定方法比较', fontsize=16)
# 价格和均线图
ax1 = axes[0, 0]
ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7)
ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8)
ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8)
ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8)
ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8)
ax1.set_title('价格和均线')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 均线偏离度图
ax2 = axes[0, 1]
ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7)
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax2.set_title('均线偏离度')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 不同方法的信号图
methods = list(results.keys())
colors = ['red', 'blue', 'green', 'orange', 'purple']
for i, method in enumerate(methods):
print(f"绘制{method}方法信号")
row = 1 + (i // 2) # 从第2行开始每行2个图
col = i % 2 # 0或1
ax = axes[row, col]
# 绘制信号
long_signals = plot_data[plot_data[f'{method}_signal'] == ''].index
short_signals = plot_data[plot_data[f'{method}_signal'] == ''].index
ax.plot(plot_data.index, plot_data['close'], color='gray', alpha=0.5, label='价格')
ax.scatter(long_signals, plot_data.loc[long_signals, 'close'],
color='red', marker='^', s=50, label='多头信号', alpha=0.8)
ax.scatter(short_signals, plot_data.loc[short_signals, 'close'],
color='blue', marker='v', s=50, label='空头信号', alpha=0.8)
ax.set_title(f'{method}方法信号')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('ma_methods_comparison.png', dpi=300, bbox_inches='tight')
print("图表已保存为: ma_methods_comparison.png")
plt.show()
def main():
"""主函数"""
print("开始测试均线多空判定方法...")
# 测试方法
results, original_data = test_ma_methods()
# 比较结果
compare_methods(results, original_data)
# 可视化结果
try:
visualize_results(results, original_data)
except Exception as e:
print(f"可视化失败: {e}")
print("\n=== 测试完成 ===")
print("\n建议:")
print("1. 加权投票方法适合大多数市场环境")
print("2. 趋势强度方法适合趋势明显的市场")
print("3. 均线排列方法适合技术分析")
print("4. 统计方法适合波动较大的市场")
print("5. 混合方法综合了多种优势,但计算复杂度较高")
if __name__ == "__main__":
main()