314 lines
13 KiB
Python
314 lines
13 KiB
Python
import okx.Account as Account
|
||
import okx.Trade as Trade
|
||
import okx.MarketData as Market
|
||
import okx.PublicData as Public
|
||
import pandas as pd
|
||
import logging
|
||
from typing import Optional
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
|
||
|
||
class QuantTrader:
|
||
def __init__(self,
|
||
api_key: str,
|
||
secret_key: str,
|
||
passphrase: str,
|
||
sandbox: bool = True,
|
||
symbol: str = "BTC-USDT",
|
||
position_size: float = 0.001):
|
||
"""
|
||
初始化货币量化交易器
|
||
Args:
|
||
api_key: OKX API Key
|
||
secret_key: OKX Secret Key
|
||
passphrase: OKX API Passphrase
|
||
sandbox: 是否使用沙盒环境(建议先用沙盒测试)
|
||
"""
|
||
self.api_key = api_key
|
||
self.secret_key = secret_key
|
||
self.passphrase = passphrase
|
||
flag = "1" if sandbox else "0" # 0:实盘环境 1:沙盒环境
|
||
self.account_api = Account.AccountAPI(
|
||
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
|
||
flag=flag
|
||
)
|
||
self.trade_api = Trade.TradeAPI(
|
||
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
|
||
flag=flag
|
||
)
|
||
self.market_api = Market.MarketAPI(
|
||
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
|
||
flag=flag
|
||
)
|
||
self.public_api = Public.PublicAPI(
|
||
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
|
||
flag=flag
|
||
)
|
||
self.symbol = symbol
|
||
self.symbol_swap = f'{symbol}-SWAP'
|
||
self.symbol_prefix = symbol.split('-')[0]
|
||
self.position_size = position_size
|
||
|
||
def get_account_balance(self) -> float:
|
||
"""获取账户余额"""
|
||
try:
|
||
result = {}
|
||
search_result = self.account_api.get_account_balance()
|
||
if search_result.get('code') == '0':
|
||
balances = search_result.get('data', [])
|
||
for balance in balances:
|
||
details = balance.get('details', [])
|
||
for detail in details:
|
||
if detail.get('ccy') == 'USDT':
|
||
logging.info(f"USDT余额: {detail.get('availBal')}")
|
||
result['USDT'] = float(detail.get('availBal', 0))
|
||
if detail.get('ccy') == self.symbol_prefix:
|
||
logging.info(f"{self.symbol_prefix}余额: {detail.get('availBal')}")
|
||
result[self.symbol_prefix] = float(detail.get('availBal', 0))
|
||
if detail.get('ccy') == self.symbol_swap:
|
||
logging.info(f"{self.symbol_swap}余额: {detail.get('availBal')}")
|
||
result[self.symbol_swap] = float(detail.get('availBal', 0))
|
||
return result
|
||
else:
|
||
logging.error(f"获取余额失败: {search_result}")
|
||
return {}
|
||
except Exception as e:
|
||
logging.error(f"获取余额异常: {e}")
|
||
return {}
|
||
|
||
def get_current_price(self, symbol: str = None) -> Optional[float]:
|
||
"""获取当前货币价格"""
|
||
try:
|
||
if symbol is None:
|
||
symbol = self.symbol
|
||
symbol_prefix = self.symbol_prefix
|
||
else:
|
||
symbol_prefix = symbol.split('-')[0]
|
||
result = self.market_api.get_ticker(instId=symbol)
|
||
if result.get('code') == '0':
|
||
data = result.get('data', [])
|
||
if data and 'last' in data[0]:
|
||
price = float(data[0]['last'])
|
||
logging.info(f"当前{symbol_prefix}价格: ${price:,.2f}")
|
||
return price
|
||
else:
|
||
logging.error(f"ticker数据格式异常: {data}")
|
||
return None
|
||
else:
|
||
logging.error(f"获取价格失败: {result}")
|
||
return None
|
||
except Exception as e:
|
||
logging.error(f"获取价格异常: {e}")
|
||
return None
|
||
|
||
def get_kline_data(self, symbol: str = None, after: str = None, before: str = None, bar: str = '1m', limit: int = 100) -> Optional[pd.DataFrame]:
|
||
"""获取K线数据"""
|
||
if symbol is None:
|
||
symbol = self.symbol
|
||
try:
|
||
if after is None and before is None:
|
||
result = self.market_api.get_candlesticks(
|
||
instId=symbol,
|
||
bar=bar,
|
||
limit=str(limit)
|
||
)
|
||
else:
|
||
if after is None:
|
||
after = ''
|
||
if before is None:
|
||
before = ''
|
||
if limit == 0:
|
||
limit = ''
|
||
result = self.market_api.get_candlesticks(
|
||
instId=symbol,
|
||
after=after,
|
||
before=before,
|
||
bar=bar,
|
||
limit=str(limit),
|
||
)
|
||
if result.get('code') == '0':
|
||
data = result.get('data', [])
|
||
if not data:
|
||
logging.warning("K线数据为空")
|
||
return None
|
||
df = pd.DataFrame(data, columns=[
|
||
'timestamp', 'open', 'high', 'low', 'close',
|
||
'volume', 'volCcy', "volCCyQuote", "confirm"
|
||
])
|
||
for col in ['open', 'high', 'low', 'close', 'volume']:
|
||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||
df['timestamp'] = pd.to_datetime(df['timestamp'].astype(int), unit='ms', errors='coerce')
|
||
return df
|
||
else:
|
||
logging.error(f"获取K线数据失败: {result}")
|
||
return None
|
||
except Exception as e:
|
||
logging.error(f"获取K线数据异常: {e}")
|
||
return None
|
||
|
||
def place_market_order(self, side: str, size: float) -> Optional[str]:
|
||
"""下市价单"""
|
||
balance = self.get_account_balance()
|
||
usdt_balance = balance.get('USDT')
|
||
symbol_balance = balance.get(self.symbol_prefix)
|
||
if side == 'sell':
|
||
try:
|
||
if symbol_balance < size:
|
||
logging.error(f"{self.symbol_prefix}余额不足,目前余额: {symbol_balance}")
|
||
return None
|
||
result = self.trade_api.place_order(
|
||
instId=self.symbol,
|
||
tdMode='cash',
|
||
side=side,
|
||
ordType='market',
|
||
sz=str(size)
|
||
)
|
||
if result.get('code') == '0':
|
||
logging.info(f"下单成功: {side} {size} {self.symbol_prefix}")
|
||
return result['data'][0]['ordId']
|
||
else:
|
||
logging.error(f"下单失败: {result}")
|
||
return None
|
||
except Exception as e:
|
||
logging.error(f"下单异常: {e}")
|
||
return None
|
||
elif side == 'buy':
|
||
try:
|
||
instrument_result = self.public_api.get_instruments(instType="SPOT", instId=self.symbol)
|
||
instrument_data = instrument_result.get("data", [])
|
||
if not instrument_data:
|
||
logging.error(f"未获取到合约信息: {instrument_result}")
|
||
return None
|
||
min_sz = float(instrument_data[0].get("minSz", 0))
|
||
if size < min_sz:
|
||
size = min_sz
|
||
ticker = self.market_api.get_ticker(instId=self.symbol)
|
||
last_price = float(ticker["data"][0]["last"])
|
||
usdt_amount = float(last_price * size)
|
||
if usdt_balance < usdt_amount:
|
||
logging.error(f"USDT余额不足,目前余额: {usdt_balance}")
|
||
return None
|
||
result = self.trade_api.place_order(
|
||
instId=self.symbol,
|
||
tdMode="cash",
|
||
side=side,
|
||
ordType="market",
|
||
sz=str(usdt_amount)
|
||
)
|
||
if result.get('code') == '0':
|
||
logging.info(f"下单成功: {side} {usdt_amount} USDT")
|
||
return result['data'][0]['ordId']
|
||
else:
|
||
logging.error(f"下单失败: {result}")
|
||
return None
|
||
except Exception as e:
|
||
logging.error(f"下单异常: {e}")
|
||
return None
|
||
else:
|
||
logging.error(f"不支持的下单方向: {side}")
|
||
return None
|
||
|
||
# 设置杠杆倍数
|
||
def set_leverage(self, leverage: int = 1, mgn_mode: str = "cross", ccy: str = "USDT", posSide: str = "short"):
|
||
result = self.account_api.set_leverage(
|
||
lever=str(leverage),
|
||
mgnMode=mgn_mode,
|
||
instId=self.symbol_swap,
|
||
ccy=ccy,
|
||
posSide=posSide
|
||
)
|
||
if result["code"] == "0":
|
||
logging.info(f"设置杠杆倍数 {leverage}x 成功")
|
||
else:
|
||
logging.error(f"设置杠杆失败: {result['msg']}")
|
||
return result["code"] == "0"
|
||
|
||
# 计算保证金需求
|
||
def calculate_margin(self, quantity: int = 10, leverage: int = 1, slot: float = 0.01, buffer_ratio: float = 0.3):
|
||
price = self.get_current_price(self.symbol_swap)
|
||
if not price:
|
||
return None
|
||
contract_value = quantity * slot * price # 每张 0.01 BTC
|
||
initial_margin = contract_value / leverage
|
||
recommended_margin = initial_margin * (1 + buffer_ratio)
|
||
logging.info(f"开仓{self.symbol_swap}价格: {price:.2f} USDT")
|
||
logging.info(f"合约总价值: {contract_value:.2f} USDT")
|
||
logging.info(f"初始保证金: {initial_margin:.2f} USDT")
|
||
logging.info(f"推荐保证金 (含 {buffer_ratio*100}% 缓冲): {recommended_margin:.2f} USDT")
|
||
return recommended_margin, price
|
||
|
||
# 开空头仓位(卖出空单)
|
||
def place_short_order(self, td_mode: str = "cross", quantity: int = 10, leverage: int = 1, slot: float = 0.01, buffer_ratio: float = 0.3):
|
||
"""开空头仓位(卖出空单)"""
|
||
# 计算所需保证金和开仓价格
|
||
margin_data = self.calculate_margin(quantity, leverage, slot, buffer_ratio)
|
||
if not margin_data:
|
||
logging.error("无法计算保证金,终止下单")
|
||
return None, None
|
||
required_margin, entry_price = margin_data
|
||
|
||
# 检查余额
|
||
balance = self.get_account_balance()
|
||
avail_bal = balance.get('USDT')
|
||
if avail_bal is None or avail_bal < required_margin:
|
||
logging.error(f"保证金不足,需至少 {required_margin:.2f} USDT,当前余额: {avail_bal}")
|
||
return None, None
|
||
|
||
# 设置杠杆
|
||
if not self.set_leverage(leverage, mgn_mode=td_mode, ccy="USDT", posSide="short"):
|
||
return None, None
|
||
|
||
# 下单
|
||
order_data = {
|
||
"instId": self.symbol_swap,
|
||
"tdMode": td_mode,
|
||
"ccy": "USDT",
|
||
"side": "sell",
|
||
"posSide": "short",
|
||
"ordType": "market",
|
||
"sz": str(quantity),
|
||
}
|
||
result = self.trade_api.place_order(**order_data)
|
||
if result.get("code") == "0":
|
||
logging.info(f"开空单成功,订单ID: {result['data'][0]['ordId']}")
|
||
return result["data"][0]["ordId"], entry_price
|
||
else:
|
||
logging.error(f"开空单失败: {result.get('msg', result)}")
|
||
return None, None
|
||
|
||
# 平空单(买入平仓)
|
||
def close_short_order(self, td_mode: str = "cross", quantity: float = 10) -> bool:
|
||
"""平空单(买入平仓)"""
|
||
order_data = {
|
||
"instId": self.symbol_swap,
|
||
"tdMode": td_mode,
|
||
"ccy": "USDT",
|
||
"side": "buy",
|
||
"posSide": "short",
|
||
"ordType": "market",
|
||
"sz": str(quantity),
|
||
}
|
||
result = self.trade_api.place_order(**order_data)
|
||
if result.get("code") == "0":
|
||
logging.info(f"平空单成功,订单ID: {result['data'][0]['ordId']}")
|
||
return True
|
||
else:
|
||
logging.error(f"平空单失败: {result.get('msg', result)}")
|
||
return False
|
||
|
||
def get_minimun_order_size(self) -> None:
|
||
"""获取最小订单数量"""
|
||
try:
|
||
result = self.public_api.get_instruments(instType="SPOT", instId=self.symbol)
|
||
if result.get("code") == "0":
|
||
instrument = result.get("data", [{}])[0]
|
||
min_sz = float(instrument.get("minSz", 0))
|
||
lot_sz = float(instrument.get("lotSz", 0))
|
||
logging.info(f"最小交易量 (minSz): {min_sz} {self.symbol_prefix}")
|
||
logging.info(f"交易量精度 (lotSz): {lot_sz} {self.symbol_prefix}")
|
||
else:
|
||
logging.error(f"错误: {result.get('msg', result)}")
|
||
except Exception as e:
|
||
logging.error(f"异常: {str(e)}")
|
||
|