crypto_quant/test_ma_methods.py

259 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
均线多空判定方法测试脚本
本脚本用于测试和比较不同均线多空判定方法的有效性。
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation
import logging
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
# plt支持中文
plt.rcParams['font.family'] = ['SimHei']
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_real_data(symbol, bar, start, end):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
db_market_data = DBMarketData(db_url)
data = db_market_data.query_market_data_by_symbol_bar(
symbol, bar, start, end
)
if data is None:
logging.warning(
f"获取行情数据失败: {symbol} {bar}{start}{end}"
)
return None
else:
if len(data) == 0:
logging.warning(
f"获取行情数据为空: {symbol} {bar}{start}{end}"
)
return None
else:
if isinstance(data, list):
data = pd.DataFrame(data)
elif isinstance(data, dict):
data = pd.DataFrame([data])
return data
def generate_test_data(n=1000):
"""生成测试数据"""
np.random.seed(42)
# 生成价格数据
price = 100
prices = []
for i in range(n):
# 添加趋势和随机波动
trend = np.sin(i * 0.1) * 2 # 周期性趋势
noise = np.random.normal(0, 1) # 随机噪声
price += trend + noise
prices.append(price)
# 创建DataFrame
data = pd.DataFrame({
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
'close': prices,
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
'volume': np.random.randint(1000, 10000, n)
})
return data
def test_ma_methods():
"""测试不同的均线多空判定方法"""
print("=== 均线多空判定方法测试 ===\n")
# 生成测试数据
data = get_real_data("BTC-USDT", "15m", "2025-07-01 00:00:00", "2025-08-07 00:00:00")
# data = generate_test_data(1000)
print(f"生成测试数据: {len(data)} 条记录")
# 初始化指标计算器
metrics = MetricsCalculation()
# 计算均线
data = metrics.ma5102030(data)
data = metrics.calculate_ma_price_percent(data)
# 测试不同方法
methods = [
"weighted_voting",
"trend_strength",
"ma_alignment",
"statistical",
"hybrid"
]
results = {}
for method in methods:
print(f"\n--- 测试方法: {method} ---")
# 复制数据避免相互影响
test_data = data.copy()
# 应用方法
test_data = metrics.set_ma_long_short_advanced(test_data, method=method)
# 统计结果
long_count = (test_data['ma_long_short'] == '').sum()
short_count = (test_data['ma_long_short'] == '').sum()
neutral_count = (test_data['ma_long_short'] == '震荡').sum()
results[method] = {
'long': long_count,
'short': short_count,
'neutral': neutral_count,
'data': test_data
}
print(f"多头信号: {long_count} ({long_count/len(test_data)*100:.1f}%)")
print(f"空头信号: {short_count} ({short_count/len(test_data)*100:.1f}%)")
print(f"震荡信号: {neutral_count} ({neutral_count/len(test_data)*100:.1f}%)")
return results, data
def compare_methods(results, original_data):
"""比较不同方法的结果"""
print("\n=== 方法比较分析 ===\n")
# 创建比较表格
comparison_data = []
for method, result in results.items():
total = result['long'] + result['short'] + result['neutral']
comparison_data.append({
'方法': method,
'多头信号': f"{result['long']} ({result['long']/total*100:.1f}%)",
'空头信号': f"{result['short']} ({result['short']/total*100:.1f}%)",
'震荡信号': f"{result['neutral']} ({result['neutral']/total*100:.1f}%)",
'信号总数': result['long'] + result['short']
})
comparison_df = pd.DataFrame(comparison_data)
print("信号分布比较:")
print(comparison_df.to_string(index=False))
# 分析信号一致性
print("\n信号一致性分析:")
methods = list(results.keys())
for i in range(len(methods)):
for j in range(i+1, len(methods)):
method1, method2 = methods[i], methods[j]
data1 = results[method1]['data']['ma_long_short']
data2 = results[method2]['data']['ma_long_short']
# 计算一致性
agreement = (data1 == data2).sum()
agreement_rate = agreement / len(data1) * 100
print(f"{method1} vs {method2}: {agreement_rate:.1f}% 一致")
def visualize_results(results, original_data):
"""可视化结果"""
print("\n=== 生成可视化图表 ===")
# 选择前2000个数据点进行可视化
plot_data = original_data.head(1000).copy()
# 添加不同方法的结果
for method, result in results.items():
plot_data[f'{method}_signal'] = result['data'].head(1000)['ma_long_short']
# 创建图表
fig, axes = plt.subplots(4, 2, figsize=(15, 16))
fig.suptitle('均线多空判定方法比较', fontsize=16)
# 价格和均线图
ax1 = axes[0, 0]
ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7)
ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8)
ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8)
ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8)
ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8)
ax1.set_title('价格和均线')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 均线偏离度图
ax2 = axes[0, 1]
ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7)
ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7)
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
ax2.set_title('均线偏离度')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 不同方法的信号图
methods = list(results.keys())
colors = ['red', 'blue', 'green', 'orange', 'purple']
for i, method in enumerate(methods):
print(f"绘制{method}方法信号")
row = 1 + (i // 2) # 从第2行开始每行2个图
col = i % 2 # 0或1
ax = axes[row, col]
# 绘制信号
long_signals = plot_data[plot_data[f'{method}_signal'] == ''].index
short_signals = plot_data[plot_data[f'{method}_signal'] == ''].index
ax.plot(plot_data.index, plot_data['close'], color='gray', alpha=0.5, label='价格')
ax.scatter(long_signals, plot_data.loc[long_signals, 'close'],
color='red', marker='^', s=50, label='多头信号', alpha=0.8)
ax.scatter(short_signals, plot_data.loc[short_signals, 'close'],
color='blue', marker='v', s=50, label='空头信号', alpha=0.8)
ax.set_title(f'{method}方法信号')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('ma_methods_comparison.png', dpi=300, bbox_inches='tight')
print("图表已保存为: ma_methods_comparison.png")
plt.show()
def main():
"""主函数"""
print("开始测试均线多空判定方法...")
# 测试方法
results, original_data = test_ma_methods()
# 比较结果
compare_methods(results, original_data)
# 可视化结果
try:
visualize_results(results, original_data)
except Exception as e:
print(f"可视化失败: {e}")
print("\n=== 测试完成 ===")
print("\n建议:")
print("1. 加权投票方法适合大多数市场环境")
print("2. 趋势强度方法适合趋势明显的市场")
print("3. 均线排列方法适合技术分析")
print("4. 统计方法适合波动较大的市场")
print("5. 混合方法综合了多种优势,但计算复杂度较高")
if __name__ == "__main__":
main()