crypto_quant/play.py

176 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from core.biz.quant_trader import QuantTrader
from core.biz.strategy import QuantStrategy
from config import MYSQL_CONFIG
from sqlalchemy import create_engine, exc, text
import pandas as pd
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
def main() -> None:
"""主函数"""
logging.info("=== 比特币量化交易系统 ===")
# 导入配置
try:
from config import API_KEY, SECRET_KEY, PASSPHRASE, SANDBOX, TRADING_CONFIG, TIME_CONFIG
except ImportError:
logging.error("找不到config.py文件请确保配置文件存在")
return
# 检查是否配置了API密钥
if API_KEY == "your_api_key_here":
logging.error("请先在config.py中配置你的OKX API密钥\n1. 登录OKX官网\n2. 进入API管理页面\n3. 创建API Key、Secret Key和Passphrase\n4. 将密钥填入config.py文件中的相应位置")
return
# 创建交易器实例
# sandbox = TRADING_CONFIG.get("sandbox", True)
symbol = TRADING_CONFIG.get("symbol", "BTC-USDT")
position_size = TRADING_CONFIG.get("position_size", 0.001)
trader = QuantTrader(
API_KEY, SECRET_KEY, PASSPHRASE,
sandbox=SANDBOX,
symbol=symbol,
position_size=position_size
)
strategy = QuantStrategy(
API_KEY, SECRET_KEY, PASSPHRASE,
sandbox=SANDBOX,
symbol=symbol,
position_size=position_size
)
# 显示菜单
while True:
logging.info("\n请选择操作:\n1. 查看账户余额\n2. 查看当前价格\n3. 执行移动平均线策略\n4. 执行RSI策略\n5. 执行网格交易策略\n6. 运行策略循环\n7. 买入测试\n8. 卖出测试\n9. 获取最小交易量\n0. 退出")
choice = input("请输入选择 (0-9): ").strip()
if choice == '0':
logging.info("退出程序")
break
elif choice == '1':
trader.get_account_balance()
elif choice == '2':
trader.get_current_price()
elif choice == '3':
sma_short_period = TRADING_CONFIG.get("sma_short_period", 5)
sma_long_period = TRADING_CONFIG.get("sma_long_period", 20)
strategy.simple_moving_average_strategy(sma_short_period, sma_long_period)
elif choice == '4':
period = TRADING_CONFIG.get("rsi_period", 14)
rsi_oversold = TRADING_CONFIG.get("rsi_oversold", 30)
rsi_overbought = TRADING_CONFIG.get("rsi_overbought", 70)
strategy.rsi_strategy(period, rsi_oversold, rsi_overbought)
elif choice == '5':
grid_levels = TRADING_CONFIG.get("grid_levels", 5)
grid_range = TRADING_CONFIG.get("grid_range", 0.02)
strategy.grid_trading_strategy(grid_levels, grid_range)
elif choice == '6':
strategy_name = input("选择策略 (sma/rsi/grid): ").strip()
interval = TIME_CONFIG.get("strategy_interval", 30)
strategy.run_strategy_loop(strategy_name, interval, TRADING_CONFIG)
elif choice == '7':
defalt_position_size = 0.01
input_size = input("请输入买入数量: ")
if input_size:
try:
position_size = float(input_size)
logging.info(f"买入{position_size}BTC")
trader.place_market_order('buy', position_size)
except ValueError:
logging.warning(f"输入无效,默认买入{defalt_position_size}BTC")
trader.place_market_order('buy', defalt_position_size)
else:
logging.info(f"默认买入{defalt_position_size}BTC")
trader.place_market_order('buy', defalt_position_size)
elif choice == '8':
defalt_position_size = 0.01
input_size = input("请输入卖出数量: ")
if input_size:
try:
position_size = float(input_size)
logging.info(f"卖出{position_size}BTC")
trader.place_market_order('sell', position_size)
except ValueError:
logging.warning(f"输入无效,默认卖出{defalt_position_size}BTC")
trader.place_market_order('sell', defalt_position_size)
else:
logging.info(f"默认卖出{defalt_position_size}BTC")
trader.place_market_order('sell', defalt_position_size)
elif choice == '9':
trader.get_minimun_order_size()
else:
logging.warning("无效选择,请重新输入")
def test_query():
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
db_engine = create_engine(
db_url,
pool_size=25, # 连接池大小
max_overflow=10, # 允许的最大溢出连接
pool_timeout=30, # 连接超时时间(秒)
pool_recycle=60, # 连接回收时间(秒),避免长时间闲置
)
sql = "SELECT symbol, min(date_time), max(date_time) FROM okx.crypto_binance_data where bar='5m' group by symbol;"
condition_dict = {}
return_multi = True
try:
result = query_data(db_engine, sql, condition_dict, return_multi)
if result is not None and len(result) > 0:
data = pd.DataFrame(result)
data.columns = ["symbol", "min_date_time", "max_date_time"]
print(data)
# output to excel
data.to_excel("./data/binance/crypto_binance_data_5m.xlsx", index=False)
except Exception as e:
print(f"查询数据出错: {e}")
return None
def transform_data_type(data: dict):
"""
遍历字典将所有Decimal类型的值转换为float类型
"""
from decimal import Decimal
for key, value in data.items():
if isinstance(value, Decimal):
data[key] = float(value)
return data
def query_data(db_engine, sql: str, condition_dict: dict, return_multi: bool = True):
"""
查询数据
:param sql: 查询SQL
:param db_url: 数据库连接URL
"""
try:
with db_engine.connect() as conn:
result = conn.execute(text(sql), condition_dict)
if return_multi:
result = result.fetchall()
if result:
result_list = [
transform_data_type(dict(row._mapping)) for row in result
]
return result_list
else:
return None
else:
result = result.fetchone()
if result:
result_dict = transform_data_type(dict(result._mapping))
return result_dict
else:
return None
except Exception as e:
print(f"查询数据出错: {e}")
return None
if __name__ == "__main__":
# main()
test_query()