121 lines
4.6 KiB
Python
121 lines
4.6 KiB
Python
import time
|
||
from datetime import datetime, timedelta, timezone
|
||
from core.utils import get_current_date_time
|
||
from typing import Optional
|
||
import pandas as pd
|
||
import okx.MarketData as Market
|
||
from core.utils import timestamp_to_datetime
|
||
from core.db.db_trade_data import DBTradeData
|
||
import core.logger as logging
|
||
|
||
logger = logging.logger
|
||
|
||
class TradeData:
|
||
def __init__(self,
|
||
api_key: str,
|
||
secret_key: str,
|
||
passphrase: str,
|
||
sandbox: bool = True,
|
||
db_url: str = None):
|
||
flag = "1" if sandbox else "0" # 0:实盘环境 1:沙盒环境
|
||
self.market_api = Market.MarketAPI(
|
||
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
|
||
flag=flag
|
||
)
|
||
self.db_url = db_url
|
||
self.db_trade_data = DBTradeData(self.db_url)
|
||
|
||
def get_history_trades(self, symbol: str, start_ts: int, end_ts: int, limit: int = 100):
|
||
"""
|
||
获取历史交易数据
|
||
:param symbol: 交易对
|
||
:param start_ts: 起始时间(毫秒级timestamp)
|
||
:param end_ts: 结束时间(毫秒级timestamp)
|
||
:param limit: 每次请求数量
|
||
:return: pd.DataFrame
|
||
symbol-USDT)。
|
||
tradeId:交易ID。
|
||
px:交易价格(如USDT)。
|
||
sz:交易数量(如BTC数量)。
|
||
side:交易方向(buy表示买入,sell表示卖出)
|
||
ts:交易时间戳(成交时间,Unix时间戳的毫秒数格式, 如1597026383085)。
|
||
"""
|
||
try:
|
||
all_trades = []
|
||
after = end_ts # 从较晚时间开始(OKX的after表示更早的数据)
|
||
|
||
while True:
|
||
count = 0
|
||
while True:
|
||
try:
|
||
result = self.market_api.get_history_trades(
|
||
instId=symbol, # 交易对,如BTC-USDT
|
||
type=2,
|
||
limit=str(limit), # 每次最大返回100条
|
||
after=str(after) # 获取更早的数据
|
||
)
|
||
if result:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"请求出错: {e}")
|
||
count += 1
|
||
if count > 3:
|
||
break
|
||
time.sleep(3)
|
||
if result["code"] != "0":
|
||
print(f"Error: {result['msg']}")
|
||
break
|
||
|
||
trades = result["data"]
|
||
if not trades:
|
||
break
|
||
|
||
from_time = int(trades[-1]["ts"])
|
||
to_time = int(trades[0]["ts"])
|
||
from_date_time = timestamp_to_datetime(from_time)
|
||
to_date_time = timestamp_to_datetime(to_time)
|
||
|
||
logger.info(f"获得交易数据,最早时间: {from_date_time}, 最近时间: {to_date_time}")
|
||
|
||
df = pd.DataFrame(trades)
|
||
# 过滤时间范围
|
||
df["ts"] = df["ts"].astype(int)
|
||
df = df[(df["ts"] >= start_ts) & (df["ts"] <= end_ts)]
|
||
# set sz, px 为float
|
||
df["sz"] = df["sz"].astype(float)
|
||
df["px"] = df["px"].astype(float)
|
||
# 将instId重命名为symbol
|
||
df.rename(columns={"instId": "symbol"}, inplace=True)
|
||
df["date_time"] = df["ts"].apply(lambda x: timestamp_to_datetime(x))
|
||
df["tradeId"] = df["tradeId"].astype(str)
|
||
df["create_time"] = get_current_date_time()
|
||
df = df[["symbol", "ts", "date_time", "tradeId", "side", "sz", "px", "create_time"]]
|
||
|
||
self.db_trade_data.insert_data_to_mysql(df)
|
||
all_trades.append(df)
|
||
# 检查最近时间戳是否超出范围
|
||
|
||
if from_time <= start_ts:
|
||
break
|
||
|
||
after = from_time - 1 # 更新after,继续获取更早的数据
|
||
time.sleep(0.5)
|
||
|
||
if len(all_trades) > 0:
|
||
# 转换为DataFrame
|
||
final_df = pd.concat(all_trades)
|
||
if final_df.empty:
|
||
print("No trades found in the specified time range.")
|
||
return None
|
||
else:
|
||
return final_df
|
||
else:
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"获取历史交易数据失败: {e}")
|
||
return None
|
||
|
||
|
||
|
||
|