201 lines
7.7 KiB
Python
201 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试DBTradeData类的功能
|
||
验证根据crypto_trade_data.sql表结构创建的代码是否正常工作
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from core.db_trade_data import DBTradeData
|
||
import logging
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def test_db_trade_data():
|
||
"""测试DBTradeData类的功能"""
|
||
|
||
# 数据库连接URL(请根据实际情况修改)
|
||
db_url = "mysql+pymysql://username:password@localhost:3306/database_name"
|
||
|
||
try:
|
||
# 创建DBTradeData实例
|
||
db_trade_data = DBTradeData(db_url)
|
||
|
||
logger.info("✅ DBTradeData实例创建成功")
|
||
logger.info(f"📊 表名: {db_trade_data.table_name}")
|
||
logger.info(f"📋 字段数量: {len(db_trade_data.columns)}")
|
||
logger.info(f"📋 字段列表: {db_trade_data.columns}")
|
||
|
||
# 验证字段是否与SQL表结构匹配
|
||
expected_columns = [
|
||
"symbol", "ts", "date_time", "tradeId", "side", "sz", "px", "create_time"
|
||
]
|
||
|
||
if db_trade_data.columns == expected_columns:
|
||
logger.info("✅ 字段列表与SQL表结构完全匹配")
|
||
else:
|
||
logger.error("❌ 字段列表与SQL表结构不匹配")
|
||
logger.error(f"期望字段: {expected_columns}")
|
||
logger.error(f"实际字段: {db_trade_data.columns}")
|
||
return False
|
||
|
||
# 测试私有方法
|
||
logger.info("🔍 测试私有方法...")
|
||
if hasattr(db_trade_data, '_process_time_parameter'):
|
||
logger.info("✅ 私有方法 _process_time_parameter 存在")
|
||
else:
|
||
logger.error("❌ 私有方法 _process_time_parameter 不存在")
|
||
return False
|
||
|
||
if hasattr(db_trade_data, '_build_query_conditions'):
|
||
logger.info("✅ 私有方法 _build_query_conditions 存在")
|
||
else:
|
||
logger.error("❌ 私有方法 _build_query_conditions 不存在")
|
||
return False
|
||
|
||
# 测试查询方法(不实际连接数据库,只验证方法存在)
|
||
methods_to_test = [
|
||
"insert_data_to_mysql",
|
||
"insert_data_to_mysql_fast",
|
||
"insert_data_to_mysql_chunk",
|
||
"insert_data_to_mysql_simple",
|
||
"query_latest_data",
|
||
"query_data_by_tradeId",
|
||
"query_trade_data_by_symbol",
|
||
"query_trade_data_by_side",
|
||
"query_buy_trades",
|
||
"query_sell_trades",
|
||
"get_trade_statistics",
|
||
"get_volume_price_analysis",
|
||
"get_recent_trades",
|
||
"get_trades_by_price_range",
|
||
"get_trades_by_volume_range"
|
||
]
|
||
|
||
logger.info("🔍 验证所有查询方法是否存在...")
|
||
for method_name in methods_to_test:
|
||
if hasattr(db_trade_data, method_name):
|
||
logger.info(f"✅ 方法 {method_name} 存在")
|
||
else:
|
||
logger.error(f"❌ 方法 {method_name} 不存在")
|
||
return False
|
||
|
||
# 测试类型提示
|
||
logger.info("🔍 验证类型提示...")
|
||
import inspect
|
||
for method_name in methods_to_test:
|
||
method = getattr(db_trade_data, method_name)
|
||
if method_name.startswith('query_') or method_name.startswith('get_'):
|
||
sig = inspect.signature(method)
|
||
if sig.return_annotation != inspect.Signature.empty:
|
||
logger.info(f"✅ 方法 {method_name} 有返回类型提示")
|
||
else:
|
||
logger.warning(f"⚠️ 方法 {method_name} 缺少返回类型提示")
|
||
|
||
logger.info("🎉 所有测试通过!DBTradeData类创建成功")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 测试失败: {str(e)}")
|
||
return False
|
||
|
||
def show_class_methods():
|
||
"""显示DBTradeData类的所有方法"""
|
||
logger.info("📚 DBTradeData类的方法列表:")
|
||
|
||
methods = [
|
||
("_process_time_parameter", "私有方法:处理时间参数"),
|
||
("_build_query_conditions", "私有方法:构建查询条件"),
|
||
("insert_data_to_mysql", "标准插入数据到MySQL"),
|
||
("insert_data_to_mysql_fast", "快速插入数据(使用executemany)"),
|
||
("insert_data_to_mysql_chunk", "分块插入数据(适合大数据量)"),
|
||
("insert_data_to_mysql_simple", "简单插入数据(使用to_sql)"),
|
||
("query_latest_data", "查询最新交易数据"),
|
||
("query_data_by_tradeId", "根据交易ID查询数据"),
|
||
("query_trade_data_by_symbol", "根据交易对查询数据"),
|
||
("query_trade_data_by_side", "根据交易方向查询数据"),
|
||
("query_buy_trades", "查询买入交易记录"),
|
||
("query_sell_trades", "查询卖出交易记录"),
|
||
("get_trade_statistics", "获取交易统计信息"),
|
||
("get_volume_price_analysis", "获取成交量价格分析"),
|
||
("get_recent_trades", "获取最近的交易记录"),
|
||
("get_trades_by_price_range", "根据价格范围查询交易"),
|
||
("get_trades_by_volume_range", "根据成交量范围查询交易")
|
||
]
|
||
|
||
for method_name, description in methods:
|
||
logger.info(f" • {method_name}: {description}")
|
||
|
||
def show_optimization_benefits():
|
||
"""显示代码优化的好处"""
|
||
logger.info("🚀 DBTradeData类设计亮点:")
|
||
benefits = [
|
||
"✅ 添加了完整的类型提示,提高代码可读性和IDE支持",
|
||
"✅ 提取了重复的时间处理逻辑为私有方法 _process_time_parameter",
|
||
"✅ 提取了重复的查询条件构建逻辑为私有方法 _build_query_conditions",
|
||
"✅ 提供了丰富的查询方法,支持多种查询场景",
|
||
"✅ 支持按交易对、交易方向、时间范围等维度查询",
|
||
"✅ 提供了统计分析和价格成交量分析功能",
|
||
"✅ 支持价格范围和成交量范围查询",
|
||
"✅ 符合PEP 8代码风格指南"
|
||
]
|
||
|
||
for benefit in benefits:
|
||
logger.info(f" {benefit}")
|
||
|
||
def show_usage_examples():
|
||
"""显示使用示例"""
|
||
logger.info("📝 使用示例:")
|
||
examples = [
|
||
"# 创建实例",
|
||
"db_trade_data = DBTradeData('mysql+pymysql://user:pass@localhost/db')",
|
||
"",
|
||
"# 查询最新交易",
|
||
"latest = db_trade_data.query_latest_data('BTC-USDT')",
|
||
"",
|
||
"# 查询买入交易",
|
||
"buy_trades = db_trade_data.query_buy_trades('BTC-USDT', start='2024-01-01')",
|
||
"",
|
||
"# 获取交易统计",
|
||
"stats = db_trade_data.get_trade_statistics('BTC-USDT')",
|
||
"",
|
||
"# 价格范围查询",
|
||
"trades = db_trade_data.get_trades_by_price_range(50000, 60000, 'BTC-USDT')",
|
||
"",
|
||
"# 成交量分析",
|
||
"analysis = db_trade_data.get_volume_price_analysis('BTC-USDT')"
|
||
]
|
||
|
||
for example in examples:
|
||
logger.info(f" {example}")
|
||
|
||
if __name__ == "__main__":
|
||
logger.info("🚀 开始测试DBTradeData类...")
|
||
|
||
# 显示类的方法列表
|
||
show_class_methods()
|
||
print()
|
||
|
||
# 显示优化亮点
|
||
show_optimization_benefits()
|
||
print()
|
||
|
||
# 显示使用示例
|
||
show_usage_examples()
|
||
print()
|
||
|
||
# 运行测试
|
||
success = test_db_trade_data()
|
||
|
||
if success:
|
||
logger.info("🎯 测试完成,所有功能正常!")
|
||
logger.info("💡 提示:请根据实际数据库配置修改db_url参数")
|
||
logger.info("📈 代码已成功创建,提供了完整的交易数据管理功能!")
|
||
else:
|
||
logger.error("💥 测试失败,请检查代码!")
|
||
sys.exit(1) |