optimize ORB strategy

This commit is contained in:
blade 2025-09-01 18:01:21 +08:00
parent a8a310ecf0
commit e990db26a6
8 changed files with 177 additions and 65 deletions

View File

@ -73,7 +73,7 @@ OKX_MONITOR_CONFIG = {
US_STOCK_MONITOR_CONFIG = {
"volume_monitor":{
"symbols": ["QQQ", "TQQQ", "MSFT", "AAPL", "GOOG", "NVDA", "META", "AMZN", "TSLA", "AVGO"],
"bars": ["5m"],
"bars": ["5", "15m", "30m", "1H"],
"initial_date": "2015-08-31 00:00:00"
}
}

View File

@ -12,7 +12,7 @@ from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
from core.db.db_market_data import DBMarketData
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
@ -32,7 +32,7 @@ class MaBreakStatistics:
之间的涨跌幅
"""
def __init__(self):
def __init__(self, is_us_stock: bool = False):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
@ -44,10 +44,20 @@ class MaBreakStatistics:
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
if is_us_stock:
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"]
)
else:
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
if is_us_stock:
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
else:
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/"

View File

@ -1,4 +1,5 @@
import yfinance as yf
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
@ -22,15 +23,28 @@ class ORBStrategy:
max_leverage=4,
risk_per_trade=0.01,
commission_per_share=0.0005,
is_us_stock=False,
):
"""
初始化ORB策略参数
ORB策略说明
1. 每天仅1次交易机会多头或空头,排除十字星open1 == close1
2. 第一根5分钟K线确定开盘区间High1, Low1
3. 第二根5分钟K线根据第一根K线方向生成多空信号open1<close1为多头open1>close1为空头
entry_price=第二根K线开盘价stop_price=第一根K线最低价多头或第一根K线最高价空头
4. 多头跌破止损止损突破止盈止盈
5. 空头突破止损止损跌破止盈止盈
6. 止损/止盈根据$R计算$R=|entry_price-stop_price|
7. 盈利目标10R即10*$R
8. 账户净值曲线账户价值与市场价格
:param initial_capital: 初始账户资金美元
:param max_leverage: 最大杠杆倍数默认4倍符合FINRA规定
:param risk_per_trade: 单次交易风险比例默认1%
:param commission_per_share: 每股交易佣金美元默认0.0005
"""
logger.info(f"初始化ORB策略参数初始账户资金={initial_capital},最大杠杆倍数={max_leverage},单次交易风险比例={risk_per_trade},每股交易佣金={commission_per_share}")
logger.info(
f"初始化ORB策略参数初始账户资金={initial_capital},最大杠杆倍数={max_leverage},单次交易风险比例={risk_per_trade},每股交易佣金={commission_per_share}"
)
self.initial_capital = initial_capital
self.max_leverage = max_leverage
self.risk_per_trade = risk_per_trade
@ -48,6 +62,9 @@ class ORBStrategy:
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.is_us_stock = is_us_stock
self.output_chart_folder = r"./output/trade_sandbox/orb_strategy/chart/"
os.makedirs(self.output_chart_folder, exist_ok=True)
def fetch_intraday_data(self, symbol, start_date, end_date, interval="5m"):
"""
@ -72,14 +89,20 @@ class ORBStrategy:
data["Low"] = data["low"]
data["Close"] = data["close"]
data["Volume"] = data["volume"]
# 将data["date_time"]从字符串类型转换为日期
data["date_time"] = pd.to_datetime(data["date_time"])
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
data[date_time_field] = pd.to_datetime(data[date_time_field])
# data["Date"]为日期不包括时分秒即date_time如果是2025-01-01 10:00:00则Date为2025-01-01
data["Date"] = data["date_time"].dt.date
data["Date"] = data[date_time_field].dt.date
# 将Date转换为datetime64[ns]类型以确保类型一致
data["Date"] = pd.to_datetime(data["Date"])
self.data = data[["Date", "date_time", "Open", "High", "Low", "Close", "Volume"]].copy()
self.data = data[
["symbol", "bar", "Date", date_time_field, "Open", "High", "Low", "Close", "Volume"]
].copy()
self.data.rename(columns={date_time_field: "date_time"}, inplace=True)
logger.info(f"成功获取{symbol}数据:{len(self.data)}{interval}K线")
def calculate_shares(self, account_value, entry_price, stop_price):
@ -90,7 +113,9 @@ class ORBStrategy:
:param stop_price: 止损价格多头=第一根K线最低价空头=第一根K线最高价
:return: 整数股数Shares
"""
logger.info(f"开始计算交易股数:账户价值={account_value}entry价格={entry_price},止损价格={stop_price}")
logger.info(
f"开始计算交易股数:账户价值={account_value}entry价格={entry_price},止损价格={stop_price}"
)
# 计算单交易风险金额($R
risk_per_trade_dollar = abs(entry_price - stop_price) # 风险金额取绝对值
if risk_per_trade_dollar <= 0:
@ -149,7 +174,7 @@ class ORBStrategy:
stop_price = high1 # 空头止损=第一根K线最高价
else:
# 十字星→无信号
signal = "None"
signal = None
stop_price = None
signals.append(
@ -167,10 +192,12 @@ class ORBStrategy:
# 将信号合并到原始数据
signals_df = pd.DataFrame(signals)
# 确保Date列类型一致将Date转换为datetime64[ns]类型
signals_df['Date'] = pd.to_datetime(signals_df['Date'])
signals_df["Date"] = pd.to_datetime(signals_df["Date"])
# 使用merge而不是join来合并数据,根据signals_df的EntryTime与self.data的date_time进行匹配
# TODO: 这里需要优化
self.data = self.data.merge(signals_df, left_on="date_time", right_on="EntryTime", how="left")
self.data = self.data.merge(
signals_df, left_on="date_time", right_on="EntryTime", how="left"
)
# 将Date_x和Date_y合并为Date
self.data["Date"] = self.data["Date_x"].combine_first(self.data["Date_y"])
# 删除Date_x和Date_y
@ -214,6 +241,7 @@ class ORBStrategy:
signal = signal_row["Signal"]
if pd.isna(signal):
continue
entry_price = signal_row["EntryPrice"]
stop_price = signal_row["StopPrice"]
high1 = signal_row["High1"]
@ -254,24 +282,24 @@ class ORBStrategy:
if low <= stop_price:
exit_price = stop_price
exit_reason = "Stop Loss"
exit_time = time
exit_time = row["date_time"]
break
elif high >= profit_target:
exit_price = profit_target
exit_reason = "Profit Target (10R)"
exit_time = time
exit_time = row["date_time"]
break
elif signal == "Short":
# 空头:突破止损→止损;跌破止盈→止盈
if high >= stop_price:
exit_price = stop_price
exit_reason = "Stop Loss"
exit_time = time
exit_time = row["date_time"]
break
elif low <= profit_target:
exit_price = profit_target
exit_reason = "Profit Target (10R)"
exit_time = time
exit_time = row["date_time"]
break
# 若未触发止损/止盈,当日收盘平仓
@ -280,12 +308,16 @@ class ORBStrategy:
exit_reason = "End of Day (EoD)"
exit_time = daily_prices.iloc[-1].date_time
initial_account_value = account_value
# 计算盈亏
if signal == "Long":
profit_loss = (exit_price - entry_price) * shares - total_commission
else: # Short
profit_loss = (entry_price - exit_price) * shares - total_commission
# 计算盈亏百分比profit_loss除以当期初始资金
profit_loss_percentage = (profit_loss / initial_account_value) * 100
# 更新账户价值
account_value += profit_loss
account_value = max(account_value, 0) # 账户价值不能为负
@ -296,14 +328,16 @@ class ORBStrategy:
"TradeID": trade_id,
"Date": date,
"Signal": signal,
"EntryTime": signal_row.date_time,
"EntryTime": signal_row.date_time.strftime("%Y-%m-%d %H:%M:%S"),
"EntryPrice": entry_price,
"ExitTime": exit_time,
"ExitTime": exit_time.strftime("%Y-%m-%d %H:%M:%S"),
"ExitPrice": exit_price,
"Shares": shares,
"RiskAssumed": risk_assumed,
"ProfitLoss": profit_loss,
"ProfitLossPercentage": profit_loss_percentage,
"ExitReason": exit_reason,
"AccountValueInitial": initial_account_value,
"AccountValueAfter": account_value,
}
)
@ -313,12 +347,7 @@ class ORBStrategy:
trade_id += 1
# 生成净值曲线
self.equity_curve = pd.Series(
equity_history,
index=pd.date_range(
start=self.data.index[0].date(), periods=len(equity_history), freq="D"
),
)
self.create_equity_curve()
# 输出回测结果
trades_df = pd.DataFrame(self.trades)
@ -330,6 +359,13 @@ class ORBStrategy:
if len(trades_df) > 0
else 0
)
# 计算盈亏比
profit_sum = trades_df[trades_df["ProfitLoss"] > 0]["ProfitLoss"].sum()
loss_sum = abs(trades_df[trades_df["ProfitLoss"] < 0]["ProfitLoss"].sum())
if loss_sum == 0:
profit_loss_ratio = float('inf')
else:
profit_loss_ratio = (profit_sum / loss_sum) * 100
logger.info("\n" + "=" * 50)
logger.info("ORB策略回测结果")
@ -338,12 +374,40 @@ class ORBStrategy:
logger.info(f"最终资金:${account_value:,.2f}")
logger.info(f"总收益率:{total_return:.2f}%")
logger.info(f"总交易次数:{len(trades_df)}")
logger.info(f"盈亏比:{profit_loss_ratio:.2f}%")
logger.info(f"胜率:{win_rate:.2f}%")
if len(trades_df) > 0:
logger.info(f"平均每笔盈亏:${trades_df['ProfitLoss'].mean():.2f}")
logger.info(f"最大单笔盈利:${trades_df['ProfitLoss'].max():.2f}")
logger.info(f"最大单笔亏损:${trades_df['ProfitLoss'].min():.2f}")
def create_equity_curve(self):
"""
创建账户净值曲线
"""
equity_curve_list = []
# 将self.data.index[0].Date的值转换为字符串且格式为YYYY-MM-DD
first_date = self.data.iloc[0].date_time.strftime("%Y-%m-%d %H:%M:%S")
first_open = float(self.data.iloc[0].Open)
equity_curve_list.append(
{
"DateTime": first_date,
"AccountValue": self.initial_capital,
"MarketPrice": first_open,
}
)
for trade in self.trades:
equity_curve_list.append(
{
"DateTime": trade["ExitTime"],
"AccountValue": trade["AccountValueAfter"],
"MarketPrice": trade["ExitPrice"],
}
)
self.equity_curve = pd.DataFrame(equity_curve_list)
self.equity_curve.sort_values(by="DateTime", inplace=True)
self.equity_curve.reset_index(drop=True, inplace=True)
def plot_equity_curve(self):
"""
绘制账户净值曲线
@ -359,22 +423,51 @@ class ORBStrategy:
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
symbol = self.data.iloc[0].symbol
bar = self.data.iloc[0].bar
first_account_value = self.equity_curve.iloc[0]["AccountValue"]
first_market_price = self.equity_curve.iloc[0]["MarketPrice"]
account_value_to_1 = self.equity_curve["AccountValue"] / first_account_value
market_price_to_1 = self.equity_curve["MarketPrice"] / first_market_price
plt.figure(figsize=(12, 6))
plt.plot(
self.equity_curve.index,
self.equity_curve.values,
label="ORB策略净值",
color="blue",
)
plt.axhline(
y=self.initial_capital, color="red", linestyle="--", label="初始资金"
)
plt.title("ORB策略账户净值曲线", fontsize=14)
plt.xlabel("日期", fontsize=12)
plt.ylabel("账户价值(美元)", fontsize=12)
plt.legend()
plt.plot(self.equity_curve["DateTime"], account_value_to_1, label="账户价值", color='blue', linewidth=2, marker='o', markersize=4)
plt.plot(self.equity_curve["DateTime"], market_price_to_1, label="市场价格", color='green', linewidth=2, marker='s', markersize=4)
plt.title(f"ORB策略账户净值曲线 {symbol} {bar}", fontsize=14, fontweight='bold')
plt.xlabel("时间", fontsize=12)
plt.ylabel("涨跌变化", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()
# 设置x轴标签避免matplotlib警告
# 选择合适的时间间隔显示标签,避免过于密集
if len(self.equity_curve) > 30:
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
step = max(1, len(self.equity_curve) // 30)
# 创建标签索引列表,确保包含首尾数据
label_indices = [0] # 第一条
# 添加中间间隔的标签
for i in range(step, len(self.equity_curve) - 1, step):
label_indices.append(i)
# 添加最后一条(如果还没有包含的话)
if len(self.equity_curve) - 1 not in label_indices:
label_indices.append(len(self.equity_curve) - 1)
# 设置x轴标签
plt.xticks(self.equity_curve["DateTime"].iloc[label_indices],
self.equity_curve["DateTime"].iloc[label_indices],
rotation=45, ha='right', fontsize=10)
else:
# 如果数据点较少,全部显示
plt.xticks(self.equity_curve["DateTime"],
self.equity_curve["DateTime"],
rotation=45, ha='right', fontsize=10)
plt.tight_layout()
save_path = f"{self.output_chart_folder}/{symbol}_{bar}_orb_strategy_equity_curve.png"
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
# ------------------- 策略示例回测QQQ的ORB策略2016-2023 -------------------

View File

@ -109,14 +109,16 @@ class MarketDataMain:
# 如果bar为1D, 则end_time_ts与start_time_ts相差超过10天则按照10天为单位
# 获取数据直到end_time_ts
threshold = None
if bar in ["5m", "15m", "30m"]:
if bar in ["5m", "15m", "30m", "1H"]:
if self.is_us_stock:
if bar == "5m":
threshold = 86400000 * 4
elif bar == "15m":
threshold = 86400000 * 4 * 3
threshold = 86400000 * 6
elif bar == "30m":
threshold = 86400000 * 4 * 6
threshold = 86400000 * 12
elif bar == "1H":
threshold = 86400000 * 24
else:
threshold = 86400000
elif bar in ["1H", "4H"]:

View File

@ -1,27 +1,34 @@
from core.trade.orb_trade import ORBStrategy
from config import US_STOCK_MONITOR_CONFIG
import core.logger as logging
logger = logging.logger
def main():
# 初始化ORB策略
orb_strategy = ORBStrategy(
initial_capital=25000,
max_leverage=4,
risk_per_trade=0.01,
commission_per_share=0.0005,
)
symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", ["QQQ"])
for symbol in symbols:
logger.info(f"开始回测 {symbol}")
# 初始化ORB策略
orb_strategy = ORBStrategy(
initial_capital=25000,
max_leverage=4,
risk_per_trade=0.01,
commission_per_share=0.0005,
is_us_stock=True,
)
# 1. 获取QQQ的5分钟日内数据2024-2025注意yfinance免费版可能限制历史日内数据建议用专业数据源
orb_strategy.fetch_intraday_data(
symbol=symbol, start_date="2024-11-30", end_date="2025-08-30", interval="5m"
)
# 1. 获取QQQ的5分钟日内数据2024-2025注意yfinance免费版可能限制历史日内数据建议用专业数据源
orb_strategy.fetch_intraday_data(
symbol="ETH-USDT", start_date="2025-05-15", end_date="2025-08-20", interval="5m"
)
# 2. 生成ORB策略信号
orb_strategy.generate_orb_signals()
# 2. 生成ORB策略信号
orb_strategy.generate_orb_signals()
# 3. 回测策略盈利目标10R
orb_strategy.backtest(profit_target_multiple=10)
# 3. 回测策略盈利目标10R
orb_strategy.backtest(profit_target_multiple=10)
# 4. 绘制净值曲线
orb_strategy.plot_equity_curve()
# 4. 绘制净值曲线
orb_strategy.plot_equity_curve()
if __name__ == "__main__":

View File

@ -25,8 +25,8 @@ from config import (
logger = logging.logger
class TradeMaStrategyMain:
def __init__(self):
self.ma_break_statistics = MaBreakStatistics()
def __init__(self, is_us_stock: bool = False):
self.ma_break_statistics = MaBreakStatistics(is_us_stock=is_us_stock)
def batch_ma_break_statistics(self):
"""
@ -59,5 +59,5 @@ class TradeMaStrategyMain:
if __name__ == "__main__":
trade_ma_strategy_main = TradeMaStrategyMain()
trade_ma_strategy_main = TradeMaStrategyMain(is_us_stock=True)
trade_ma_strategy_main.batch_ma_break_statistics()