import okx.Account as Account import okx.Trade as Trade import okx.MarketData as Market import okx.PublicData as Public import pandas as pd import logging from typing import Optional logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s') class QuantTrader: def __init__(self, api_key: str, secret_key: str, passphrase: str, sandbox: bool = True, symbol: str = "BTC-USDT", position_size: float = 0.001): """ 初始化货币量化交易器 Args: api_key: OKX API Key secret_key: OKX Secret Key passphrase: OKX API Passphrase sandbox: 是否使用沙盒环境(建议先用沙盒测试) """ self.api_key = api_key self.secret_key = secret_key self.passphrase = passphrase flag = "1" if sandbox else "0" # 0:实盘环境 1:沙盒环境 self.account_api = Account.AccountAPI( api_key=api_key, api_secret_key=secret_key, passphrase=passphrase, flag=flag ) self.trade_api = Trade.TradeAPI( api_key=api_key, api_secret_key=secret_key, passphrase=passphrase, flag=flag ) self.market_api = Market.MarketAPI( api_key=api_key, api_secret_key=secret_key, passphrase=passphrase, flag=flag ) self.public_api = Public.PublicAPI( api_key=api_key, api_secret_key=secret_key, passphrase=passphrase, flag=flag ) self.symbol = symbol self.symbol_swap = f'{symbol}-SWAP' self.symbol_prefix = symbol.split('-')[0] self.position_size = position_size def get_account_balance(self) -> float: """获取账户余额""" try: result = {} search_result = self.account_api.get_account_balance() if search_result.get('code') == '0': balances = search_result.get('data', []) for balance in balances: details = balance.get('details', []) for detail in details: if detail.get('ccy') == 'USDT': logging.info(f"USDT余额: {detail.get('availBal')}") result['USDT'] = float(detail.get('availBal', 0)) if detail.get('ccy') == self.symbol_prefix: logging.info(f"{self.symbol_prefix}余额: {detail.get('availBal')}") result[self.symbol_prefix] = float(detail.get('availBal', 0)) if detail.get('ccy') == self.symbol_swap: logging.info(f"{self.symbol_swap}余额: {detail.get('availBal')}") result[self.symbol_swap] = float(detail.get('availBal', 0)) return result else: logging.error(f"获取余额失败: {search_result}") return {} except Exception as e: logging.error(f"获取余额异常: {e}") return {} def get_current_price(self, symbol: str = None) -> Optional[float]: """获取当前货币价格""" try: if symbol is None: symbol = self.symbol symbol_prefix = self.symbol_prefix else: symbol_prefix = symbol.split('-')[0] result = self.market_api.get_ticker(instId=symbol) if result.get('code') == '0': data = result.get('data', []) if data and 'last' in data[0]: price = float(data[0]['last']) logging.info(f"当前{symbol_prefix}价格: ${price:,.2f}") return price else: logging.error(f"ticker数据格式异常: {data}") return None else: logging.error(f"获取价格失败: {result}") return None except Exception as e: logging.error(f"获取价格异常: {e}") return None def get_kline_data(self, bar: str = '1m', limit: int = 100) -> Optional[pd.DataFrame]: """获取K线数据""" try: result = self.market_api.get_candlesticks( instId=self.symbol, bar=bar, limit=str(limit) ) if result.get('code') == '0': data = result.get('data', []) if not data: logging.warning("K线数据为空") return None df = pd.DataFrame(data, columns=[ 'timestamp', 'open', 'high', 'low', 'close', 'volume', 'volCcy', "volCCyQuote", "confirm" ]) for col in ['open', 'high', 'low', 'close', 'volume']: df[col] = pd.to_numeric(df[col], errors='coerce') df['timestamp'] = pd.to_datetime(df['timestamp'].astype(int), unit='ms', errors='coerce') return df else: logging.error(f"获取K线数据失败: {result}") return None except Exception as e: logging.error(f"获取K线数据异常: {e}") return None def place_market_order(self, side: str, size: float) -> Optional[str]: """下市价单""" balance = self.get_account_balance() usdt_balance = balance.get('USDT') symbol_balance = balance.get(self.symbol_prefix) if side == 'sell': try: if symbol_balance < size: logging.error(f"{self.symbol_prefix}余额不足,目前余额: {symbol_balance}") return None result = self.trade_api.place_order( instId=self.symbol, tdMode='cash', side=side, ordType='market', sz=str(size) ) if result.get('code') == '0': logging.info(f"下单成功: {side} {size} {self.symbol_prefix}") return result['data'][0]['ordId'] else: logging.error(f"下单失败: {result}") return None except Exception as e: logging.error(f"下单异常: {e}") return None elif side == 'buy': try: instrument_result = self.public_api.get_instruments(instType="SPOT", instId=self.symbol) instrument_data = instrument_result.get("data", []) if not instrument_data: logging.error(f"未获取到合约信息: {instrument_result}") return None min_sz = float(instrument_data[0].get("minSz", 0)) if size < min_sz: size = min_sz ticker = self.market_api.get_ticker(instId=self.symbol) last_price = float(ticker["data"][0]["last"]) usdt_amount = float(last_price * size) if usdt_balance < usdt_amount: logging.error(f"USDT余额不足,目前余额: {usdt_balance}") return None result = self.trade_api.place_order( instId=self.symbol, tdMode="cash", side=side, ordType="market", sz=str(usdt_amount) ) if result.get('code') == '0': logging.info(f"下单成功: {side} {usdt_amount} USDT") return result['data'][0]['ordId'] else: logging.error(f"下单失败: {result}") return None except Exception as e: logging.error(f"下单异常: {e}") return None else: logging.error(f"不支持的下单方向: {side}") return None # 设置杠杆倍数 def set_leverage(self, leverage: int = 1, mgn_mode: str = "cross", ccy: str = "USDT", posSide: str = "short"): result = self.account_api.set_leverage( lever=str(leverage), mgnMode=mgn_mode, instId=self.symbol_swap, ccy=ccy, posSide=posSide ) if result["code"] == "0": logging.info(f"设置杠杆倍数 {leverage}x 成功") else: logging.error(f"设置杠杆失败: {result['msg']}") return result["code"] == "0" # 计算保证金需求 def calculate_margin(self, quantity: int = 10, leverage: int = 1, slot: float = 0.01, buffer_ratio: float = 0.3): price = self.get_current_price(self.symbol_swap) if not price: return None contract_value = quantity * slot * price # 每张 0.01 BTC initial_margin = contract_value / leverage recommended_margin = initial_margin * (1 + buffer_ratio) logging.info(f"开仓{self.symbol_swap}价格: {price:.2f} USDT") logging.info(f"合约总价值: {contract_value:.2f} USDT") logging.info(f"初始保证金: {initial_margin:.2f} USDT") logging.info(f"推荐保证金 (含 {buffer_ratio*100}% 缓冲): {recommended_margin:.2f} USDT") return recommended_margin, price # 开空头仓位(卖出空单) def place_short_order(self, td_mode: str = "cross", quantity: int = 10, leverage: int = 1, slot: float = 0.01, buffer_ratio: float = 0.3): """开空头仓位(卖出空单)""" # 计算所需保证金和开仓价格 margin_data = self.calculate_margin(quantity, leverage, slot, buffer_ratio) if not margin_data: logging.error("无法计算保证金,终止下单") return None, None required_margin, entry_price = margin_data # 检查余额 balance = self.get_account_balance() avail_bal = balance.get('USDT') if avail_bal is None or avail_bal < required_margin: logging.error(f"保证金不足,需至少 {required_margin:.2f} USDT,当前余额: {avail_bal}") return None, None # 设置杠杆 if not self.set_leverage(leverage, mgn_mode=td_mode, ccy="USDT", posSide="short"): return None, None # 下单 order_data = { "instId": self.symbol_swap, "tdMode": td_mode, "ccy": "USDT", "side": "sell", "posSide": "short", "ordType": "market", "sz": str(quantity), } result = self.trade_api.place_order(**order_data) if result.get("code") == "0": logging.info(f"开空单成功,订单ID: {result['data'][0]['ordId']}") return result["data"][0]["ordId"], entry_price else: logging.error(f"开空单失败: {result.get('msg', result)}") return None, None # 平空单(买入平仓) def close_short_order(self, td_mode: str = "cross", quantity: float = 10) -> bool: """平空单(买入平仓)""" order_data = { "instId": self.symbol_swap, "tdMode": td_mode, "ccy": "USDT", "side": "buy", "posSide": "short", "ordType": "market", "sz": str(quantity), } result = self.trade_api.place_order(**order_data) if result.get("code") == "0": logging.info(f"平空单成功,订单ID: {result['data'][0]['ordId']}") return True else: logging.error(f"平空单失败: {result.get('msg', result)}") return False def get_minimun_order_size(self) -> None: """获取最小订单数量""" try: result = self.public_api.get_instruments(instType="SPOT", instId=self.symbol) if result.get("code") == "0": instrument = result.get("data", [{}])[0] min_sz = float(instrument.get("minSz", 0)) lot_sz = float(instrument.get("lotSz", 0)) logging.info(f"最小交易量 (minSz): {min_sz} {self.symbol_prefix}") logging.info(f"交易量精度 (lotSz): {lot_sz} {self.symbol_prefix}") else: logging.error(f"错误: {result.get('msg', result)}") except Exception as e: logging.error(f"异常: {str(e)}")