crypto_quant/core/trade_data.py

121 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from datetime import datetime, timedelta
import logging
from typing import Optional
import pandas as pd
import okx.MarketData as Market
from core.utils import datetime_to_timestamp, timestamp_to_datetime
from core.db_trade_data import DBTradeData
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')
class TradeData:
def __init__(self,
api_key: str,
secret_key: str,
passphrase: str,
sandbox: bool = True,
db_url: str = None):
flag = "1" if sandbox else "0" # 0:实盘环境 1:沙盒环境
self.market_api = Market.MarketAPI(
api_key=api_key, api_secret_key=secret_key, passphrase=passphrase,
flag=flag
)
self.db_url = db_url
self.db_trade_data = DBTradeData(self.db_url)
def get_history_trades(self, symbol: str, start_ts: int, end_ts: int, limit: int = 100):
"""
获取历史交易数据
:param symbol: 交易对
:param start_ts: 起始时间毫秒级timestamp
:param end_ts: 结束时间毫秒级timestamp
:param limit: 每次请求数量
:return: pd.DataFrame
symbol-USDT
tradeId交易ID。
px交易价格如USDT
sz交易数量如BTC数量
side交易方向buy表示买入sell表示卖出
ts交易时间戳成交时间Unix时间戳的毫秒数格式 如1597026383085
"""
try:
all_trades = []
after = end_ts # 从较晚时间开始OKX的after表示更早的数据
while True:
count = 0
while True:
try:
result = self.market_api.get_history_trades(
instId=symbol, # 交易对如BTC-USDT
type=2,
limit=str(limit), # 每次最大返回100条
after=str(after) # 获取更早的数据
)
if result:
break
except Exception as e:
logging.error(f"请求出错: {e}")
count += 1
if count > 3:
break
time.sleep(3)
if result["code"] != "0":
print(f"Error: {result['msg']}")
break
trades = result["data"]
if not trades:
break
from_time = int(trades[-1]["ts"])
to_time = int(trades[0]["ts"])
from_date_time = timestamp_to_datetime(from_time)
to_date_time = timestamp_to_datetime(to_time)
logging.info(f"获得交易数据,最早时间: {from_date_time} 最近时间: {to_date_time}")
df = pd.DataFrame(trades)
# 过滤时间范围
df["ts"] = df["ts"].astype(int)
df = df[(df["ts"] >= start_ts) & (df["ts"] <= end_ts)]
# set sz, px 为float
df["sz"] = df["sz"].astype(float)
df["px"] = df["px"].astype(float)
# 将instId重命名为symbol
df.rename(columns={"instId": "symbol"}, inplace=True)
df["date_time"] = df["ts"].apply(lambda x: timestamp_to_datetime(x))
df["tradeId"] = df["tradeId"].astype(str)
df["create_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
df = df[["symbol", "ts", "date_time", "tradeId", "side", "sz", "px", "create_time"]]
self.db_trade_data.insert_data_to_mysql(df)
all_trades.append(df)
# 检查最近时间戳是否超出范围
if from_time <= start_ts:
break
after = from_time - 1 # 更新after继续获取更早的数据
time.sleep(0.5)
if len(all_trades) > 0:
# 转换为DataFrame
final_df = pd.concat(all_trades)
if final_df.empty:
print("No trades found in the specified time range.")
return None
else:
return final_df
else:
return None
except Exception as e:
logging.error(f"获取历史交易数据失败: {e}")
return None