#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试HugeVolume类的功能 验证根据SQL表结构和db_huge_volume_data.py更新后的代码是否正常工作 """ import sys import os sys.path.append(os.path.dirname(os.path.abspath(__file__))) import pandas as pd import numpy as np from datetime import datetime, timedelta from core.huge_volume import HugeVolume import logging # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") logger = logging.getLogger(__name__) def create_test_data() -> pd.DataFrame: """创建测试数据""" # 生成测试数据 np.random.seed(42) n_records = 100 # 生成时间戳 base_time = datetime(2024, 1, 1) timestamps = [int((base_time + timedelta(minutes=i)).timestamp()) for i in range(n_records)] # 生成价格数据(模拟价格波动) base_price = 50000 price_changes = np.random.normal(0, 0.02, n_records) # 2%的标准差 prices = [base_price * (1 + sum(price_changes[:i+1])) for i in range(n_records)] # 生成成交量数据(模拟巨量交易) base_volume = 1000 volumes = [] for i in range(n_records): if i % 10 == 0: # 每10个周期出现一次巨量 volume = base_volume * np.random.uniform(3, 5) # 3-5倍正常成交量 else: volume = base_volume * np.random.uniform(0.5, 1.5) # 正常成交量 volumes.append(volume) # 创建DataFrame data = pd.DataFrame({ 'symbol': ['BTC-USDT'] * n_records, 'bar': ['1m'] * n_records, 'timestamp': timestamps, 'date_time': [datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S') for ts in timestamps], 'open': prices, 'high': [p * (1 + np.random.uniform(0, 0.01)) for p in prices], 'low': [p * (1 - np.random.uniform(0, 0.01)) for p in prices], 'close': prices, 'volume': volumes, 'volCcy': volumes, 'volCCyQuote': [v * p for v, p in zip(volumes, prices)] }) return data def test_huge_volume_detection(): """测试巨量交易检测功能""" logger.info("🧪 开始测试巨量交易检测功能...") # 创建测试数据 test_data = create_test_data() logger.info(f"📊 创建测试数据: {len(test_data)} 条记录") # 创建HugeVolume实例 huge_volume = HugeVolume() # 测试基本巨量检测(不检查价格) result_basic = huge_volume.detect_huge_volume( data=test_data.copy(), window_size=20, threshold=2.0, check_price=False, only_output_huge_volume=False ) if result_basic is not None: logger.info("✅ 基本巨量检测功能正常") logger.info(f"📈 检测到 {result_basic['huge_volume'].sum()} 条巨量交易记录") # 检查必要字段是否存在 required_fields = ['volume_ma', 'volume_std', 'volume_threshold', 'huge_volume', 'volume_ratio', 'spike_intensity'] missing_fields = [field for field in required_fields if field not in result_basic.columns] if not missing_fields: logger.info("✅ 所有必要字段都已生成") else: logger.error(f"❌ 缺少字段: {missing_fields}") return False else: logger.error("❌ 基本巨量检测失败") return False # 测试包含价格检查的巨量检测 result_with_price = huge_volume.detect_huge_volume( data=test_data.copy(), window_size=20, threshold=2.0, check_price=True, only_output_huge_volume=False ) if result_with_price is not None: logger.info("✅ 包含价格检查的巨量检测功能正常") # 检查新增的分位数字段 percentile_fields = [ 'close_80_percentile', 'close_20_percentile', 'close_90_percentile', 'close_10_percentile', 'price_80_high', 'price_20_low', 'price_90_high', 'price_10_low', 'volume_80_20_price_spike', 'volume_90_10_price_spike' ] missing_percentile_fields = [field for field in percentile_fields if field not in result_with_price.columns] if not missing_percentile_fields: logger.info("✅ 所有分位数字段都已生成") # 统计各种指标的数量 huge_volume_count = result_with_price['huge_volume'].sum() price_80_high_count = result_with_price['price_80_high'].sum() price_20_low_count = result_with_price['price_20_low'].sum() price_90_high_count = result_with_price['price_90_high'].sum() price_10_low_count = result_with_price['price_10_low'].sum() volume_80_20_spike_count = result_with_price['volume_80_20_price_spike'].sum() volume_90_10_spike_count = result_with_price['volume_90_10_price_spike'].sum() logger.info(f"📊 统计结果:") logger.info(f" - 巨量交易: {huge_volume_count}") logger.info(f" - 价格80%高点: {price_80_high_count}") logger.info(f" - 价格20%低点: {price_20_low_count}") logger.info(f" - 价格90%高点: {price_90_high_count}") logger.info(f" - 价格10%低点: {price_10_low_count}") logger.info(f" - 80/20量价尖峰: {volume_80_20_spike_count}") logger.info(f" - 90/10量价尖峰: {volume_90_10_spike_count}") else: logger.error(f"❌ 缺少分位数字段: {missing_percentile_fields}") return False else: logger.error("❌ 包含价格检查的巨量检测失败") return False return True def test_next_periods_analysis(): """测试未来周期分析功能""" logger.info("🧪 开始测试未来周期分析功能...") # 创建测试数据 test_data = create_test_data() # 先进行巨量检测 huge_volume = HugeVolume() result = huge_volume.detect_huge_volume( data=test_data.copy(), window_size=20, threshold=2.0, check_price=True, only_output_huge_volume=False ) if result is None: logger.error("❌ 巨量检测失败,无法进行未来周期分析") return False # 测试未来周期分析 processed_data, stats_data = huge_volume.next_periods_rise_or_fall( data=result.copy(), periods=[3, 5], output_excel=False ) logger.info("✅ 未来周期分析功能正常") logger.info(f"📊 处理后的数据: {len(processed_data)} 条记录") logger.info(f"📈 统计结果: {len(stats_data)} 条统计记录") if len(stats_data) > 0: logger.info("📋 统计结果示例:") for _, row in stats_data.head().iterrows(): logger.info(f" - {row['price_type']}, 周期{row['next_period']}: " f"上涨率{row['rise_ratio']:.2%}, 下跌率{row['fall_ratio']:.2%}") return True def test_private_methods(): """测试私有方法""" logger.info("🧪 开始测试私有方法...") huge_volume = HugeVolume() test_data = create_test_data() # 测试分位数指标计算 if hasattr(huge_volume, '_calculate_percentile_indicators'): logger.info("✅ 私有方法 _calculate_percentile_indicators 存在") # 测试方法调用 result = huge_volume._calculate_percentile_indicators(test_data.copy(), 20) expected_fields = ['close_80_percentile', 'close_20_percentile', 'close_90_percentile', 'close_10_percentile', 'price_80_high', 'price_20_low', 'price_90_high', 'price_10_low'] missing_fields = [field for field in expected_fields if field not in result.columns] if not missing_fields: logger.info("✅ _calculate_percentile_indicators 方法工作正常") else: logger.error(f"❌ _calculate_percentile_indicators 缺少字段: {missing_fields}") return False else: logger.error("❌ 私有方法 _calculate_percentile_indicators 不存在") return False # 测试量价尖峰计算 if hasattr(huge_volume, '_calculate_volume_price_spikes'): logger.info("✅ 私有方法 _calculate_volume_price_spikes 存在") # 先计算分位数指标 data_with_percentiles = huge_volume._calculate_percentile_indicators(test_data.copy(), 20) data_with_percentiles['huge_volume'] = 1 # 模拟巨量交易 # 测试方法调用 result = huge_volume._calculate_volume_price_spikes(data_with_percentiles) expected_spike_fields = ['volume_80_20_price_spike', 'volume_90_10_price_spike'] missing_spike_fields = [field for field in expected_spike_fields if field not in result.columns] if not missing_spike_fields: logger.info("✅ _calculate_volume_price_spikes 方法工作正常") else: logger.error(f"❌ _calculate_volume_price_spikes 缺少字段: {missing_spike_fields}") return False else: logger.error("❌ 私有方法 _calculate_volume_price_spikes 不存在") return False return True def show_optimization_benefits(): """显示代码优化的好处""" logger.info("🚀 HugeVolume类优化亮点:") benefits = [ "✅ 添加了完整的类型提示,提高代码可读性和IDE支持", "✅ 提取了重复的分位数计算逻辑为私有方法 _calculate_percentile_indicators", "✅ 提取了重复的量价尖峰计算逻辑为私有方法 _calculate_volume_price_spikes", "✅ 支持80/20和90/10两种分位数分析", "✅ 字段名称与SQL表结构完全匹配", "✅ 增强了未来周期分析功能,支持多种分位数类型", "✅ 改进了错误处理和边界条件检查", "✅ 符合PEP 8代码风格指南" ] for benefit in benefits: logger.info(f" {benefit}") def main(): """主测试函数""" logger.info("🚀 开始测试HugeVolume类...") # 显示优化亮点 show_optimization_benefits() print() # 运行测试 tests = [ ("私有方法测试", test_private_methods), ("巨量交易检测测试", test_huge_volume_detection), ("未来周期分析测试", test_next_periods_analysis) ] all_passed = True for test_name, test_func in tests: logger.info(f"🔍 开始 {test_name}...") try: if test_func(): logger.info(f"✅ {test_name} 通过") else: logger.error(f"❌ {test_name} 失败") all_passed = False except Exception as e: logger.error(f"❌ {test_name} 异常: {str(e)}") all_passed = False print() if all_passed: logger.info("🎯 所有测试通过!HugeVolume类更新成功") logger.info("📈 代码已成功优化,提高了可维护性和可读性!") else: logger.error("💥 部分测试失败,请检查代码!") sys.exit(1) if __name__ == "__main__": main()