2025-08-06 06:36:22 +00:00
|
|
|
|
"""
|
|
|
|
|
|
均线多空判定方法测试脚本
|
|
|
|
|
|
|
|
|
|
|
|
本脚本用于测试和比较不同均线多空判定方法的有效性。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
from core.db.db_market_data import DBMarketData
|
|
|
|
|
|
from core.biz.metrics_calculation import MetricsCalculation
|
|
|
|
|
|
import logging
|
2025-08-31 03:20:59 +00:00
|
|
|
|
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
|
2025-08-06 06:36:22 +00:00
|
|
|
|
# plt支持中文
|
|
|
|
|
|
plt.rcParams['font.family'] = ['SimHei']
|
|
|
|
|
|
|
|
|
|
|
|
# 设置日志
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
|
|
def get_real_data(symbol, bar, start, end):
|
|
|
|
|
|
mysql_user = MYSQL_CONFIG.get("user", "xch")
|
|
|
|
|
|
mysql_password = MYSQL_CONFIG.get("password", "")
|
|
|
|
|
|
if not mysql_password:
|
|
|
|
|
|
raise ValueError("MySQL password is not set")
|
|
|
|
|
|
mysql_host = MYSQL_CONFIG.get("host", "localhost")
|
|
|
|
|
|
mysql_port = MYSQL_CONFIG.get("port", 3306)
|
|
|
|
|
|
mysql_database = MYSQL_CONFIG.get("database", "okx")
|
|
|
|
|
|
|
|
|
|
|
|
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
|
|
|
|
|
db_market_data = DBMarketData(db_url)
|
|
|
|
|
|
data = db_market_data.query_market_data_by_symbol_bar(
|
|
|
|
|
|
symbol, bar, start, end
|
|
|
|
|
|
)
|
|
|
|
|
|
if data is None:
|
|
|
|
|
|
logging.warning(
|
|
|
|
|
|
f"获取行情数据失败: {symbol} {bar} 从 {start} 到 {end}"
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
else:
|
|
|
|
|
|
if len(data) == 0:
|
|
|
|
|
|
logging.warning(
|
|
|
|
|
|
f"获取行情数据为空: {symbol} {bar} 从 {start} 到 {end}"
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
else:
|
|
|
|
|
|
if isinstance(data, list):
|
|
|
|
|
|
data = pd.DataFrame(data)
|
|
|
|
|
|
elif isinstance(data, dict):
|
|
|
|
|
|
data = pd.DataFrame([data])
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
def generate_test_data(n=1000):
|
|
|
|
|
|
"""生成测试数据"""
|
|
|
|
|
|
np.random.seed(42)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成价格数据
|
|
|
|
|
|
price = 100
|
|
|
|
|
|
prices = []
|
|
|
|
|
|
for i in range(n):
|
|
|
|
|
|
# 添加趋势和随机波动
|
|
|
|
|
|
trend = np.sin(i * 0.1) * 2 # 周期性趋势
|
|
|
|
|
|
noise = np.random.normal(0, 1) # 随机噪声
|
|
|
|
|
|
price += trend + noise
|
|
|
|
|
|
prices.append(price)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建DataFrame
|
|
|
|
|
|
data = pd.DataFrame({
|
|
|
|
|
|
'timestamp': pd.date_range('2023-01-01', periods=n, freq='H'),
|
|
|
|
|
|
'close': prices,
|
|
|
|
|
|
'open': [p * (1 + np.random.normal(0, 0.01)) for p in prices],
|
|
|
|
|
|
'high': [p * (1 + abs(np.random.normal(0, 0.02))) for p in prices],
|
|
|
|
|
|
'low': [p * (1 - abs(np.random.normal(0, 0.02))) for p in prices],
|
|
|
|
|
|
'volume': np.random.randint(1000, 10000, n)
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
def test_ma_methods():
|
|
|
|
|
|
"""测试不同的均线多空判定方法"""
|
|
|
|
|
|
print("=== 均线多空判定方法测试 ===\n")
|
|
|
|
|
|
|
|
|
|
|
|
# 生成测试数据
|
|
|
|
|
|
data = get_real_data("BTC-USDT", "15m", "2025-07-01 00:00:00", "2025-08-07 00:00:00")
|
|
|
|
|
|
# data = generate_test_data(1000)
|
|
|
|
|
|
print(f"生成测试数据: {len(data)} 条记录")
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化指标计算器
|
|
|
|
|
|
metrics = MetricsCalculation()
|
|
|
|
|
|
|
|
|
|
|
|
# 计算均线
|
|
|
|
|
|
data = metrics.ma5102030(data)
|
|
|
|
|
|
data = metrics.calculate_ma_price_percent(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 测试不同方法
|
|
|
|
|
|
methods = [
|
|
|
|
|
|
"weighted_voting",
|
|
|
|
|
|
"trend_strength",
|
|
|
|
|
|
"ma_alignment",
|
|
|
|
|
|
"statistical",
|
|
|
|
|
|
"hybrid"
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
results = {}
|
|
|
|
|
|
|
|
|
|
|
|
for method in methods:
|
|
|
|
|
|
print(f"\n--- 测试方法: {method} ---")
|
|
|
|
|
|
|
|
|
|
|
|
# 复制数据避免相互影响
|
|
|
|
|
|
test_data = data.copy()
|
|
|
|
|
|
|
|
|
|
|
|
# 应用方法
|
|
|
|
|
|
test_data = metrics.set_ma_long_short_advanced(test_data, method=method)
|
|
|
|
|
|
|
|
|
|
|
|
# 统计结果
|
|
|
|
|
|
long_count = (test_data['ma_long_short'] == '多').sum()
|
|
|
|
|
|
short_count = (test_data['ma_long_short'] == '空').sum()
|
|
|
|
|
|
neutral_count = (test_data['ma_long_short'] == '震荡').sum()
|
|
|
|
|
|
|
|
|
|
|
|
results[method] = {
|
|
|
|
|
|
'long': long_count,
|
|
|
|
|
|
'short': short_count,
|
|
|
|
|
|
'neutral': neutral_count,
|
|
|
|
|
|
'data': test_data
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
print(f"多头信号: {long_count} ({long_count/len(test_data)*100:.1f}%)")
|
|
|
|
|
|
print(f"空头信号: {short_count} ({short_count/len(test_data)*100:.1f}%)")
|
|
|
|
|
|
print(f"震荡信号: {neutral_count} ({neutral_count/len(test_data)*100:.1f}%)")
|
|
|
|
|
|
|
|
|
|
|
|
return results, data
|
|
|
|
|
|
|
|
|
|
|
|
def compare_methods(results, original_data):
|
|
|
|
|
|
"""比较不同方法的结果"""
|
|
|
|
|
|
print("\n=== 方法比较分析 ===\n")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建比较表格
|
|
|
|
|
|
comparison_data = []
|
|
|
|
|
|
for method, result in results.items():
|
|
|
|
|
|
total = result['long'] + result['short'] + result['neutral']
|
|
|
|
|
|
comparison_data.append({
|
|
|
|
|
|
'方法': method,
|
|
|
|
|
|
'多头信号': f"{result['long']} ({result['long']/total*100:.1f}%)",
|
|
|
|
|
|
'空头信号': f"{result['short']} ({result['short']/total*100:.1f}%)",
|
|
|
|
|
|
'震荡信号': f"{result['neutral']} ({result['neutral']/total*100:.1f}%)",
|
|
|
|
|
|
'信号总数': result['long'] + result['short']
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
comparison_df = pd.DataFrame(comparison_data)
|
|
|
|
|
|
print("信号分布比较:")
|
|
|
|
|
|
print(comparison_df.to_string(index=False))
|
|
|
|
|
|
|
|
|
|
|
|
# 分析信号一致性
|
|
|
|
|
|
print("\n信号一致性分析:")
|
|
|
|
|
|
methods = list(results.keys())
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(methods)):
|
|
|
|
|
|
for j in range(i+1, len(methods)):
|
|
|
|
|
|
method1, method2 = methods[i], methods[j]
|
|
|
|
|
|
data1 = results[method1]['data']['ma_long_short']
|
|
|
|
|
|
data2 = results[method2]['data']['ma_long_short']
|
|
|
|
|
|
|
|
|
|
|
|
# 计算一致性
|
|
|
|
|
|
agreement = (data1 == data2).sum()
|
|
|
|
|
|
agreement_rate = agreement / len(data1) * 100
|
|
|
|
|
|
|
|
|
|
|
|
print(f"{method1} vs {method2}: {agreement_rate:.1f}% 一致")
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_results(results, original_data):
|
|
|
|
|
|
"""可视化结果"""
|
|
|
|
|
|
print("\n=== 生成可视化图表 ===")
|
|
|
|
|
|
|
|
|
|
|
|
# 选择前2000个数据点进行可视化
|
|
|
|
|
|
plot_data = original_data.head(1000).copy()
|
|
|
|
|
|
|
|
|
|
|
|
# 添加不同方法的结果
|
|
|
|
|
|
for method, result in results.items():
|
|
|
|
|
|
plot_data[f'{method}_signal'] = result['data'].head(1000)['ma_long_short']
|
|
|
|
|
|
|
|
|
|
|
|
# 创建图表
|
|
|
|
|
|
fig, axes = plt.subplots(4, 2, figsize=(15, 16))
|
|
|
|
|
|
fig.suptitle('均线多空判定方法比较', fontsize=16)
|
|
|
|
|
|
|
|
|
|
|
|
# 价格和均线图
|
|
|
|
|
|
ax1 = axes[0, 0]
|
|
|
|
|
|
ax1.plot(plot_data.index, plot_data['close'], label='价格', alpha=0.7)
|
|
|
|
|
|
ax1.plot(plot_data.index, plot_data['ma5'], label='MA5', alpha=0.8)
|
|
|
|
|
|
ax1.plot(plot_data.index, plot_data['ma10'], label='MA10', alpha=0.8)
|
|
|
|
|
|
ax1.plot(plot_data.index, plot_data['ma20'], label='MA20', alpha=0.8)
|
|
|
|
|
|
ax1.plot(plot_data.index, plot_data['ma30'], label='MA30', alpha=0.8)
|
|
|
|
|
|
ax1.set_title('价格和均线')
|
|
|
|
|
|
ax1.legend()
|
|
|
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# 均线偏离度图
|
|
|
|
|
|
ax2 = axes[0, 1]
|
|
|
|
|
|
ax2.plot(plot_data.index, plot_data['ma5_close_diff'], label='MA5偏离度', alpha=0.7)
|
|
|
|
|
|
ax2.plot(plot_data.index, plot_data['ma10_close_diff'], label='MA10偏离度', alpha=0.7)
|
|
|
|
|
|
ax2.plot(plot_data.index, plot_data['ma20_close_diff'], label='MA20偏离度', alpha=0.7)
|
|
|
|
|
|
ax2.plot(plot_data.index, plot_data['ma30_close_diff'], label='MA30偏离度', alpha=0.7)
|
|
|
|
|
|
ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
|
|
|
|
|
|
ax2.set_title('均线偏离度')
|
|
|
|
|
|
ax2.legend()
|
|
|
|
|
|
ax2.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# 不同方法的信号图
|
|
|
|
|
|
methods = list(results.keys())
|
|
|
|
|
|
colors = ['red', 'blue', 'green', 'orange', 'purple']
|
|
|
|
|
|
|
|
|
|
|
|
for i, method in enumerate(methods):
|
|
|
|
|
|
print(f"绘制{method}方法信号")
|
|
|
|
|
|
row = 1 + (i // 2) # 从第2行开始,每行2个图
|
|
|
|
|
|
col = i % 2 # 0或1
|
|
|
|
|
|
ax = axes[row, col]
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制信号
|
|
|
|
|
|
long_signals = plot_data[plot_data[f'{method}_signal'] == '多'].index
|
|
|
|
|
|
short_signals = plot_data[plot_data[f'{method}_signal'] == '空'].index
|
|
|
|
|
|
|
|
|
|
|
|
ax.plot(plot_data.index, plot_data['close'], color='gray', alpha=0.5, label='价格')
|
|
|
|
|
|
ax.scatter(long_signals, plot_data.loc[long_signals, 'close'],
|
|
|
|
|
|
color='red', marker='^', s=50, label='多头信号', alpha=0.8)
|
|
|
|
|
|
ax.scatter(short_signals, plot_data.loc[short_signals, 'close'],
|
|
|
|
|
|
color='blue', marker='v', s=50, label='空头信号', alpha=0.8)
|
|
|
|
|
|
|
|
|
|
|
|
ax.set_title(f'{method}方法信号')
|
|
|
|
|
|
ax.legend()
|
|
|
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
plt.savefig('ma_methods_comparison.png', dpi=300, bbox_inches='tight')
|
|
|
|
|
|
print("图表已保存为: ma_methods_comparison.png")
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
|
"""主函数"""
|
|
|
|
|
|
print("开始测试均线多空判定方法...")
|
|
|
|
|
|
|
|
|
|
|
|
# 测试方法
|
|
|
|
|
|
results, original_data = test_ma_methods()
|
|
|
|
|
|
|
|
|
|
|
|
# 比较结果
|
|
|
|
|
|
compare_methods(results, original_data)
|
|
|
|
|
|
|
|
|
|
|
|
# 可视化结果
|
|
|
|
|
|
try:
|
|
|
|
|
|
visualize_results(results, original_data)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"可视化失败: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n=== 测试完成 ===")
|
|
|
|
|
|
print("\n建议:")
|
|
|
|
|
|
print("1. 加权投票方法适合大多数市场环境")
|
|
|
|
|
|
print("2. 趋势强度方法适合趋势明显的市场")
|
|
|
|
|
|
print("3. 均线排列方法适合技术分析")
|
|
|
|
|
|
print("4. 统计方法适合波动较大的市场")
|
|
|
|
|
|
print("5. 混合方法综合了多种优势,但计算复杂度较高")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|