crypto_quant/test_huge_volume.py

293 lines
11 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试HugeVolume类的功能
验证根据SQL表结构和db_huge_volume_data.py更新后的代码是否正常工作
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from core.biz.huge_volume import HugeVolume
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
def create_test_data() -> pd.DataFrame:
"""创建测试数据"""
# 生成测试数据
np.random.seed(42)
n_records = 100
# 生成时间戳
base_time = datetime(2024, 1, 1)
timestamps = [int((base_time + timedelta(minutes=i)).timestamp()) for i in range(n_records)]
# 生成价格数据(模拟价格波动)
base_price = 50000
price_changes = np.random.normal(0, 0.02, n_records) # 2%的标准差
prices = [base_price * (1 + sum(price_changes[:i+1])) for i in range(n_records)]
# 生成成交量数据(模拟巨量交易)
base_volume = 1000
volumes = []
for i in range(n_records):
if i % 10 == 0: # 每10个周期出现一次巨量
volume = base_volume * np.random.uniform(3, 5) # 3-5倍正常成交量
else:
volume = base_volume * np.random.uniform(0.5, 1.5) # 正常成交量
volumes.append(volume)
# 创建DataFrame
data = pd.DataFrame({
'symbol': ['BTC-USDT'] * n_records,
'bar': ['1m'] * n_records,
'timestamp': timestamps,
'date_time': [datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') for ts in timestamps],
'open': prices,
'high': [p * (1 + np.random.uniform(0, 0.01)) for p in prices],
'low': [p * (1 - np.random.uniform(0, 0.01)) for p in prices],
'close': prices,
'volume': volumes,
'volCcy': volumes,
'volCCyQuote': [v * p for v, p in zip(volumes, prices)]
})
return data
def test_huge_volume_detection():
"""测试巨量交易检测功能"""
logger.info("🧪 开始测试巨量交易检测功能...")
# 创建测试数据
test_data = create_test_data()
logger.info(f"📊 创建测试数据: {len(test_data)} 条记录")
# 创建HugeVolume实例
huge_volume = HugeVolume()
# 测试基本巨量检测(不检查价格)
result_basic = huge_volume.detect_huge_volume(
data=test_data.copy(),
window_size=20,
threshold=2.0,
check_price=False,
only_output_huge_volume=False
)
if result_basic is not None:
logger.info("✅ 基本巨量检测功能正常")
logger.info(f"📈 检测到 {result_basic['huge_volume'].sum()} 条巨量交易记录")
# 检查必要字段是否存在
required_fields = ['volume_ma', 'volume_std', 'volume_threshold', 'huge_volume', 'volume_ratio', 'spike_intensity']
missing_fields = [field for field in required_fields if field not in result_basic.columns]
if not missing_fields:
logger.info("✅ 所有必要字段都已生成")
else:
logger.error(f"❌ 缺少字段: {missing_fields}")
return False
else:
logger.error("❌ 基本巨量检测失败")
return False
# 测试包含价格检查的巨量检测
result_with_price = huge_volume.detect_huge_volume(
data=test_data.copy(),
window_size=20,
threshold=2.0,
check_price=True,
only_output_huge_volume=False
)
if result_with_price is not None:
logger.info("✅ 包含价格检查的巨量检测功能正常")
# 检查新增的分位数字段
percentile_fields = [
'close_80_percentile', 'close_20_percentile', 'close_90_percentile', 'close_10_percentile',
'price_80_high', 'price_20_low', 'price_90_high', 'price_10_low',
'volume_80_20_price_spike', 'volume_90_10_price_spike'
]
missing_percentile_fields = [field for field in percentile_fields if field not in result_with_price.columns]
if not missing_percentile_fields:
logger.info("✅ 所有分位数字段都已生成")
# 统计各种指标的数量
huge_volume_count = result_with_price['huge_volume'].sum()
price_80_high_count = result_with_price['price_80_high'].sum()
price_20_low_count = result_with_price['price_20_low'].sum()
price_90_high_count = result_with_price['price_90_high'].sum()
price_10_low_count = result_with_price['price_10_low'].sum()
volume_80_20_spike_count = result_with_price['volume_80_20_price_spike'].sum()
volume_90_10_spike_count = result_with_price['volume_90_10_price_spike'].sum()
logger.info(f"📊 统计结果:")
logger.info(f" - 巨量交易: {huge_volume_count}")
logger.info(f" - 价格80%高点: {price_80_high_count}")
logger.info(f" - 价格20%低点: {price_20_low_count}")
logger.info(f" - 价格90%高点: {price_90_high_count}")
logger.info(f" - 价格10%低点: {price_10_low_count}")
logger.info(f" - 80/20量价尖峰: {volume_80_20_spike_count}")
logger.info(f" - 90/10量价尖峰: {volume_90_10_spike_count}")
else:
logger.error(f"❌ 缺少分位数字段: {missing_percentile_fields}")
return False
else:
logger.error("❌ 包含价格检查的巨量检测失败")
return False
return True
def test_next_periods_analysis():
"""测试未来周期分析功能"""
logger.info("🧪 开始测试未来周期分析功能...")
# 创建测试数据
test_data = create_test_data()
# 先进行巨量检测
huge_volume = HugeVolume()
result = huge_volume.detect_huge_volume(
data=test_data.copy(),
window_size=20,
threshold=2.0,
check_price=True,
only_output_huge_volume=False
)
if result is None:
logger.error("❌ 巨量检测失败,无法进行未来周期分析")
return False
# 测试未来周期分析
processed_data, stats_data = huge_volume.next_periods_rise_or_fall(
data=result.copy(),
periods=[3, 5],
output_excel=False
)
logger.info("✅ 未来周期分析功能正常")
logger.info(f"📊 处理后的数据: {len(processed_data)} 条记录")
logger.info(f"📈 统计结果: {len(stats_data)} 条统计记录")
if len(stats_data) > 0:
logger.info("📋 统计结果示例:")
for _, row in stats_data.head().iterrows():
logger.info(f" - {row['price_type']}, 周期{row['next_period']}: "
f"上涨率{row['rise_ratio']:.2%}, 下跌率{row['fall_ratio']:.2%}")
return True
def test_private_methods():
"""测试私有方法"""
logger.info("🧪 开始测试私有方法...")
huge_volume = HugeVolume()
test_data = create_test_data()
# 测试分位数指标计算
if hasattr(huge_volume, '_calculate_percentile_indicators'):
logger.info("✅ 私有方法 _calculate_percentile_indicators 存在")
# 测试方法调用
result = huge_volume._calculate_percentile_indicators(test_data.copy(), 20)
expected_fields = ['close_80_percentile', 'close_20_percentile', 'close_90_percentile', 'close_10_percentile',
'price_80_high', 'price_20_low', 'price_90_high', 'price_10_low']
missing_fields = [field for field in expected_fields if field not in result.columns]
if not missing_fields:
logger.info("✅ _calculate_percentile_indicators 方法工作正常")
else:
logger.error(f"❌ _calculate_percentile_indicators 缺少字段: {missing_fields}")
return False
else:
logger.error("❌ 私有方法 _calculate_percentile_indicators 不存在")
return False
# 测试量价尖峰计算
if hasattr(huge_volume, '_calculate_volume_price_spikes'):
logger.info("✅ 私有方法 _calculate_volume_price_spikes 存在")
# 先计算分位数指标
data_with_percentiles = huge_volume._calculate_percentile_indicators(test_data.copy(), 20)
data_with_percentiles['huge_volume'] = 1 # 模拟巨量交易
# 测试方法调用
result = huge_volume._calculate_volume_price_spikes(data_with_percentiles)
expected_spike_fields = ['volume_80_20_price_spike', 'volume_90_10_price_spike']
missing_spike_fields = [field for field in expected_spike_fields if field not in result.columns]
if not missing_spike_fields:
logger.info("✅ _calculate_volume_price_spikes 方法工作正常")
else:
logger.error(f"❌ _calculate_volume_price_spikes 缺少字段: {missing_spike_fields}")
return False
else:
logger.error("❌ 私有方法 _calculate_volume_price_spikes 不存在")
return False
return True
def show_optimization_benefits():
"""显示代码优化的好处"""
logger.info("🚀 HugeVolume类优化亮点:")
benefits = [
"✅ 添加了完整的类型提示提高代码可读性和IDE支持",
"✅ 提取了重复的分位数计算逻辑为私有方法 _calculate_percentile_indicators",
"✅ 提取了重复的量价尖峰计算逻辑为私有方法 _calculate_volume_price_spikes",
"✅ 支持80/20和90/10两种分位数分析",
"✅ 字段名称与SQL表结构完全匹配",
"✅ 增强了未来周期分析功能,支持多种分位数类型",
"✅ 改进了错误处理和边界条件检查",
"✅ 符合PEP 8代码风格指南"
]
for benefit in benefits:
logger.info(f" {benefit}")
def main():
"""主测试函数"""
logger.info("🚀 开始测试HugeVolume类...")
# 显示优化亮点
show_optimization_benefits()
print()
# 运行测试
tests = [
("私有方法测试", test_private_methods),
("巨量交易检测测试", test_huge_volume_detection),
("未来周期分析测试", test_next_periods_analysis)
]
all_passed = True
for test_name, test_func in tests:
logger.info(f"🔍 开始 {test_name}...")
try:
if test_func():
logger.info(f"{test_name} 通过")
else:
logger.error(f"{test_name} 失败")
all_passed = False
except Exception as e:
logger.error(f"{test_name} 异常: {str(e)}")
all_passed = False
print()
if all_passed:
logger.info("🎯 所有测试通过HugeVolume类更新成功")
logger.info("📈 代码已成功优化,提高了可维护性和可读性!")
else:
logger.error("💥 部分测试失败,请检查代码!")
sys.exit(1)
if __name__ == "__main__":
main()