crypto_quant/core/biz/quant_trader.py

314 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import okx.Account as Account
import okx.Trade as Trade
import okx.MarketData as Market
import okx.PublicData as Public
import pandas as pd
import core.logger as logging
from typing import Optional
logger = logging.logger
class QuantTrader:
def __init__(self,
api_key: str,
secret_key: str,
passphrase: str,
sandbox: bool = True,
symbol: str = "BTC-USDT",
position_size: float = 0.001):
"""
初始化货币量化交易器
Args:
api_key: OKX API Key
secret_key: OKX Secret Key
passphrase: OKX API Passphrase
sandbox: 是否使用沙盒环境(建议先用沙盒测试)
"""
self.api_key = api_key
self.secret_key = secret_key
self.passphrase = passphrase
flag = "1" if sandbox else "0" # 0:实盘环境 1:沙盒环境
self.account_api = Account.AccountAPI(
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
flag=flag
)
self.trade_api = Trade.TradeAPI(
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
flag=flag
)
self.market_api = Market.MarketAPI(
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
flag=flag
)
self.public_api = Public.PublicAPI(
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
flag=flag
)
self.symbol = symbol
self.symbol_swap = f'{symbol}-SWAP'
self.symbol_prefix = symbol.split('-')[0]
self.position_size = position_size
def get_account_balance(self) -> float:
"""获取账户余额"""
try:
result = {}
search_result = self.account_api.get_account_balance()
if search_result.get('code') == '0':
balances = search_result.get('data', [])
for balance in balances:
details = balance.get('details', [])
for detail in details:
if detail.get('ccy') == 'USDT':
logger.info(f"USDT余额: {detail.get('availBal')}")
result['USDT'] = float(detail.get('availBal', 0))
if detail.get('ccy') == self.symbol_prefix:
logger.info(f"{self.symbol_prefix}余额: {detail.get('availBal')}")
result[self.symbol_prefix] = float(detail.get('availBal', 0))
if detail.get('ccy') == self.symbol_swap:
logger.info(f"{self.symbol_swap}余额: {detail.get('availBal')}")
result[self.symbol_swap] = float(detail.get('availBal', 0))
return result
else:
logger.error(f"获取余额失败: {search_result}")
return {}
except Exception as e:
logger.error(f"获取余额异常: {e}")
return {}
def get_current_price(self, symbol: str = None) -> Optional[float]:
"""获取当前货币价格"""
try:
if symbol is None:
symbol = self.symbol
symbol_prefix = self.symbol_prefix
else:
symbol_prefix = symbol.split('-')[0]
result = self.market_api.get_ticker(instId=symbol)
if result.get('code') == '0':
data = result.get('data', [])
if data and 'last' in data[0]:
price = float(data[0]['last'])
logger.info(f"当前{symbol_prefix}价格: ${price:,.2f}")
return price
else:
logger.error(f"ticker数据格式异常: {data}")
return None
else:
logger.error(f"获取价格失败: {result}")
return None
except Exception as e:
logger.error(f"获取价格异常: {e}")
return None
def get_kline_data(self, symbol: str = None, after: str = None, before: str = None, bar: str = '1m', limit: int = 100) -> Optional[pd.DataFrame]:
"""获取K线数据"""
if symbol is None:
symbol = self.symbol
try:
if after is None and before is None:
result = self.market_api.get_candlesticks(
instId=symbol,
bar=bar,
limit=str(limit)
)
else:
if after is None:
after = ''
if before is None:
before = ''
if limit == 0:
limit = ''
result = self.market_api.get_candlesticks(
instId=symbol,
after=after,
before=before,
bar=bar,
limit=str(limit),
)
if result.get('code') == '0':
data = result.get('data', [])
if not data:
logger.warning("K线数据为空")
return None
df = pd.DataFrame(data, columns=[
'timestamp', 'open', 'high', 'low', 'close',
'volume', 'volCcy', "volCCyQuote", "confirm"
])
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = pd.to_numeric(df[col], errors='coerce')
df['timestamp'] = pd.to_datetime(df['timestamp'].astype(int), unit='ms', errors='coerce')
return df
else:
logger.error(f"获取K线数据失败: {result}")
return None
except Exception as e:
logger.error(f"获取K线数据异常: {e}")
return None
def place_market_order(self, side: str, size: float) -> Optional[str]:
"""下市价单"""
balance = self.get_account_balance()
usdt_balance = balance.get('USDT')
symbol_balance = balance.get(self.symbol_prefix)
if side == 'sell':
try:
if symbol_balance < size:
logger.error(f"{self.symbol_prefix}余额不足,目前余额: {symbol_balance}")
return None
result = self.trade_api.place_order(
instId=self.symbol,
tdMode='cash',
side=side,
ordType='market',
sz=str(size)
)
if result.get('code') == '0':
logger.info(f"下单成功: {side} {size} {self.symbol_prefix}")
return result['data'][0]['ordId']
else:
logger.error(f"下单失败: {result}")
return None
except Exception as e:
logger.error(f"下单异常: {e}")
return None
elif side == 'buy':
try:
instrument_result = self.public_api.get_instruments(instType="SPOT", instId=self.symbol)
instrument_data = instrument_result.get("data", [])
if not instrument_data:
logger.error(f"未获取到合约信息: {instrument_result}")
return None
min_sz = float(instrument_data[0].get("minSz", 0))
if size < min_sz:
size = min_sz
ticker = self.market_api.get_ticker(instId=self.symbol)
last_price = float(ticker["data"][0]["last"])
usdt_amount = float(last_price * size)
if usdt_balance < usdt_amount:
logger.error(f"USDT余额不足目前余额: {usdt_balance}")
return None
result = self.trade_api.place_order(
instId=self.symbol,
tdMode="cash",
side=side,
ordType="market",
sz=str(usdt_amount)
)
if result.get('code') == '0':
logger.info(f"下单成功: {side} {usdt_amount} USDT")
return result['data'][0]['ordId']
else:
logger.error(f"下单失败: {result}")
return None
except Exception as e:
logger.error(f"下单异常: {e}")
return None
else:
logger.error(f"不支持的下单方向: {side}")
return None
# 设置杠杆倍数
def set_leverage(self, leverage: int = 1, mgn_mode: str = "cross", ccy: str = "USDT", posSide: str = "short"):
result = self.account_api.set_leverage(
lever=str(leverage),
mgnMode=mgn_mode,
instId=self.symbol_swap,
ccy=ccy,
posSide=posSide
)
if result["code"] == "0":
logger.info(f"设置杠杆倍数 {leverage}x 成功")
else:
logger.error(f"设置杠杆失败: {result['msg']}")
return result["code"] == "0"
# 计算保证金需求
def calculate_margin(self, quantity: int = 10, leverage: int = 1, slot: float = 0.01, buffer_ratio: float = 0.3):
price = self.get_current_price(self.symbol_swap)
if not price:
return None
contract_value = quantity * slot * price # 每张 0.01 BTC
initial_margin = contract_value / leverage
recommended_margin = initial_margin * (1 + buffer_ratio)
logger.info(f"开仓{self.symbol_swap}价格: {price:.2f} USDT")
logger.info(f"合约总价值: {contract_value:.2f} USDT")
logger.info(f"初始保证金: {initial_margin:.2f} USDT")
logger.info(f"推荐保证金 (含 {buffer_ratio*100}% 缓冲): {recommended_margin:.2f} USDT")
return recommended_margin, price
# 开空头仓位(卖出空单)
def place_short_order(self, td_mode: str = "cross", quantity: int = 10, leverage: int = 1, slot: float = 0.01, buffer_ratio: float = 0.3):
"""开空头仓位(卖出空单)"""
# 计算所需保证金和开仓价格
margin_data = self.calculate_margin(quantity, leverage, slot, buffer_ratio)
if not margin_data:
logger.error("无法计算保证金,终止下单")
return None, None
required_margin, entry_price = margin_data
# 检查余额
balance = self.get_account_balance()
avail_bal = balance.get('USDT')
if avail_bal is None or avail_bal < required_margin:
logger.error(f"保证金不足,需至少 {required_margin:.2f} USDT当前余额: {avail_bal}")
return None, None
# 设置杠杆
if not self.set_leverage(leverage, mgn_mode=td_mode, ccy="USDT", posSide="short"):
return None, None
# 下单
order_data = {
"instId": self.symbol_swap,
"tdMode": td_mode,
"ccy": "USDT",
"side": "sell",
"posSide": "short",
"ordType": "market",
"sz": str(quantity),
}
result = self.trade_api.place_order(**order_data)
if result.get("code") == "0":
logger.info(f"开空单成功订单ID: {result['data'][0]['ordId']}")
return result["data"][0]["ordId"], entry_price
else:
logger.error(f"开空单失败: {result.get('msg', result)}")
return None, None
# 平空单(买入平仓)
def close_short_order(self, td_mode: str = "cross", quantity: float = 10) -> bool:
"""平空单(买入平仓)"""
order_data = {
"instId": self.symbol_swap,
"tdMode": td_mode,
"ccy": "USDT",
"side": "buy",
"posSide": "short",
"ordType": "market",
"sz": str(quantity),
}
result = self.trade_api.place_order(**order_data)
if result.get("code") == "0":
logger.info(f"平空单成功订单ID: {result['data'][0]['ordId']}")
return True
else:
logger.error(f"平空单失败: {result.get('msg', result)}")
return False
def get_minimun_order_size(self) -> None:
"""获取最小订单数量"""
try:
result = self.public_api.get_instruments(instType="SPOT", instId=self.symbol)
if result.get("code") == "0":
instrument = result.get("data", [{}])[0]
min_sz = float(instrument.get("minSz", 0))
lot_sz = float(instrument.get("lotSz", 0))
logger.info(f"最小交易量 (minSz): {min_sz} {self.symbol_prefix}")
logger.info(f"交易量精度 (lotSz): {lot_sz} {self.symbol_prefix}")
else:
logger.error(f"错误: {result.get('msg', result)}")
except Exception as e:
logger.error(f"异常: {str(e)}")