crypto_quant/core/trade/orb_trade.py

905 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import yfinance as yf
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from datetime import datetime, timedelta
import core.logger as logging
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
from core.db.db_market_data import DBMarketData
from core.db.db_binance_data import DBBinanceData
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class ORBStrategy:
def __init__(
self,
symbol: str,
bar: str,
start_date: str,
end_date: str,
initial_capital=25000,
max_leverage=4,
risk_per_trade=0.01,
commission_per_share=0.0005,
profit_target_multiple=10,
is_us_stock=False,
is_binance=False,
direction=None,
by_sar=False,
symbol_bar_data=None,
symbol_1h_data=None,
price_range_mean_as_R=False,
by_big_k=False,
by_1h_k=False,
):
"""
初始化ORB策略参数
ORB策略说明
1. 每天仅1次交易机会多头或空头,排除十字星open1 == close1
2. 第一根5分钟K线确定开盘区间High1, Low1
3. 第二根5分钟K线根据第一根K线方向生成多空信号open1<close1为多头open1>close1为空头
entry_price=第二根K线开盘价stop_price=第一根K线最低价多头或第一根K线最高价空头
4. 多头:跌破止损→止损;突破止盈→止盈
5. 空头:突破止损→止损;跌破止盈→止盈
6. 止损/止盈:根据$R计算$R=|entry_price-stop_price|
7. 盈利目标10R即10*$R
8. 账户净值曲线:账户价值与市场价格
:param symbol: 股票代码
:param bar: K线周期
:param start_date: 开始日期
:param end_date: 结束日期
:param initial_capital: 初始账户资金(美元)
:param max_leverage: 最大杠杆倍数默认4倍符合FINRA规定
:param risk_per_trade: 单次交易风险比例默认1%
:param commission_per_share: 每股交易佣金美元默认0.0005
:param profit_target_multiple: 盈利目标倍数默认10倍$R即10R
:param is_us_stock: 是否是美股
:param is_binance: 是否是Binance
:param direction: 方向None=自动Long=多头Short=空头
:param by_sar: 是否根据SAR指标生成信号True=是False=否
:param symbol_bar_data: 5分钟K线数据
:param symbol_1h_data: 1小时K线数据
:param price_range_mean_as_R: 是否将价格振幅均值作为$RTrue=是False=否
:param by_big_k: 是否根据K线实体部分亦即abs(open-close)超过high-low的50%True=是False=否
:param by_1h_k: 是否根据1小时K线True=是False=否
"""
logger.info(
f"初始化ORB策略参数股票代码={symbol}K线周期={bar},开始日期={start_date},结束日期={end_date},初始账户资金={initial_capital},最大杠杆倍数={max_leverage},单次交易风险比例={risk_per_trade},每股交易佣金={commission_per_share}"
)
self.symbol = symbol
self.bar = bar
self.start_date = start_date
self.end_date = end_date
self.initial_capital = initial_capital
self.max_leverage = max_leverage
self.risk_per_trade = risk_per_trade
self.commission_per_share = commission_per_share
self.profit_target_multiple = profit_target_multiple
self.data = None # 存储K线数据
self.trades = [] # 存储交易记录
self.equity_curve = None # 存储账户净值曲线
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.is_us_stock = is_us_stock
self.is_binance = is_binance
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
if self.is_binance:
self.db_market_data = DBBinanceData(self.db_url)
else:
self.db_market_data = DBMarketData(self.db_url)
self.output_chart_folder = r"./output/trade_sandbox/orb_strategy/chart/"
self.output_excel_folder = r"./output/trade_sandbox/orb_strategy/excel/"
os.makedirs(self.output_chart_folder, exist_ok=True)
os.makedirs(self.output_excel_folder, exist_ok=True)
self.direction = direction
self.by_sar = by_sar
self.direction_desc = "既做多又做空"
if self.direction == "Long":
self.direction_desc = "做多"
elif self.direction == "Short":
self.direction_desc = "做空"
self.sar_desc = "不考虑SAR"
if self.by_sar:
self.sar_desc = "考虑SAR"
self.symbol_bar_data = symbol_bar_data
self.symbol_1h_data = symbol_1h_data
self.price_range_mean_as_R = price_range_mean_as_R
if self.price_range_mean_as_R:
self.price_range_mean_as_R_desc = "R为振幅均值"
else:
self.price_range_mean_as_R_desc = "R为entry减stop"
self.by_big_k = by_big_k
if self.by_big_k:
self.by_big_k_desc = "K线实体过50%"
else:
self.by_big_k_desc = "无K线要求"
self.by_1h_k = by_1h_k
if self.by_1h_k:
self.by_1h_k_desc = "参照1小时K线"
else:
self.by_1h_k_desc = "不参照1小时K线"
def run(self):
"""
运行ORB策略
"""
self.fetch_intraday_data()
self.generate_orb_signals()
self.backtest()
if len(self.trades) > 0:
self.plot_equity_curve()
self.output_trade_summary()
return (
self.symbol_bar_data,
self.symbol_1h_data,
self.trades_df,
self.trades_summary_df,
)
def fetch_intraday_data(self):
"""
获取日内5分钟K线数据需yfinance支持部分数据可能有延迟
:param ticker: 股票代码如QQQ、TQQQ
:param start_date: 起始日期格式YYYY-MM-DD
:param end_date: 结束日期格式YYYY-MM-DD
:param interval: K线周期默认5分钟
"""
logger.info(
f"开始获取{self.symbol}数据:{self.start_date}{self.end_date},间隔{self.bar}"
)
if self.symbol_bar_data is None or len(self.symbol_bar_data) == 0:
self.data = self.get_full_data(bar=self.bar)
self.calculate_price_range_mean()
self.symbol_bar_data = self.data.copy()
self.symbol_1h_data = self.get_full_data(bar="1H")
else:
self.data = self.symbol_bar_data.copy()
# 获取Close的mean
self.close_mean = self.data["Close"].mean()
if self.close_mean > 10000:
self.initial_capital = self.initial_capital * 10000
elif self.close_mean > 5000:
self.initial_capital = self.initial_capital * 5000
elif self.close_mean > 1000:
self.initial_capital = self.initial_capital * 1000
elif self.close_mean > 500:
self.initial_capital = self.initial_capital * 500
elif self.close_mean > 100:
self.initial_capital = self.initial_capital * 100
else:
pass
logger.info(f"收盘价均值:{self.close_mean}")
logger.info(f"初始资金调整为:{self.initial_capital}")
logger.info(
f"成功获取{self.symbol}数据:{len(self.data)}{self.bar}K线,开始日期={self.start_date},结束日期={self.end_date}"
)
def get_full_data(self, bar: str = "5m"):
"""
分段获取数据,并将数据合并为完整数据
分段依据如果end_date与start_date相差超过一年则每次取一年数据
"""
data = pd.DataFrame()
start_date = datetime.strptime(self.start_date, "%Y-%m-%d")
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
fields = [
"symbol",
"bar",
"date_time",
"date_time_us",
"open",
"high",
"low",
"close",
"volume",
"sar_signal",
"ma5",
"ma10",
"ma20",
"ma30",
"dif",
"macd",
]
while start_date < end_date:
current_end_date = min(start_date + timedelta(days=180), end_date)
start_date_str = start_date.strftime("%Y-%m-%d")
current_end_date_str = current_end_date.strftime("%Y-%m-%d")
logger.info(
f"获取{self.symbol}数据:{start_date_str}{current_end_date_str}"
)
current_data = self.db_market_data.query_market_data_by_symbol_bar(
self.symbol, bar, fields, start=start_date_str, end=current_end_date_str
)
if current_data is not None and len(current_data) > 0:
current_data = pd.DataFrame(current_data)
data = pd.concat([data, current_data])
start_date = current_end_date
data.drop_duplicates(inplace=True)
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
data.sort_values(by=date_time_field, inplace=True)
data.reset_index(drop=True, inplace=True)
# 保留核心列:开盘价、最高价、最低价、收盘价、成交量
data["Open"] = data["open"]
data["High"] = data["high"]
data["Low"] = data["low"]
data["Close"] = data["close"]
data["Volume"] = data["volume"]
data[date_time_field] = pd.to_datetime(data[date_time_field])
# data["Date"]为日期不包括时分秒即date_time如果是2025-01-01 10:00:00则Date为2025-01-01
data["Date"] = data[date_time_field].dt.date
# 将Date转换为datetime64[ns]类型以确保类型一致
data["Date"] = pd.to_datetime(data["Date"])
# 最小data["Date"]
self.start_date = data["Date"].min().strftime("%Y-%m-%d")
# 最大data["Date"]
self.end_date = data["Date"].max().strftime("%Y-%m-%d")
data = data[
[
"symbol",
"bar",
"Date",
date_time_field,
"Open",
"High",
"Low",
"Close",
"Volume",
"sar_signal",
"ma5",
"ma10",
"ma20",
"ma30",
"dif",
"macd",
]
]
data.rename(columns={date_time_field: "date_time"}, inplace=True)
return data
def calculate_shares(self, account_value, entry_price, stop_price, risk_assumed):
"""
根据ORB公式计算交易股数
:param account_value: 当前账户价值(美元)
:param entry_price: 交易entry价格第二根5分钟K线开盘价
:param stop_price: 止损价格(多头=第一根K线最低价空头=第一根K线最高价
:param risk_assumed: 风险金额($R,根据price_range_mean_as_R决定
:return: 整数股数Shares
"""
logger.info(
f"开始计算交易股数:账户价值={account_value}entry价格={entry_price},止损价格={stop_price}"
)
# 计算单交易风险金额($R
risk_per_trade_dollar = risk_assumed # 风险金额取绝对值
if risk_per_trade_dollar <= 0:
return 0 # 无风险时不交易
# 公式1基于风险预算的最大股数风险控制优先
shares_risk = (account_value * self.risk_per_trade) / risk_per_trade_dollar
# 公式2基于杠杆限制的最大股数杠杆约束
shares_leverage = (self.max_leverage * account_value) / entry_price
# 取两者最小值(满足风险和杠杆双重约束)
max_shares = min(shares_risk, shares_leverage)
# 扣除佣金影响(简化计算:假设佣金从可用资金中扣除)
commission_cost = max_shares * self.commission_per_share
if (account_value - commission_cost) < 0:
return 0 # 扣除佣金后资金不足,不交易
return int(max_shares) # 股数取整
def generate_orb_signals(self):
"""
生成ORB策略信号每日仅1次交易机会
- 第一根5分钟K线确定开盘区间High1, Low1
- 第二根5分钟K线根据第一根K线方向生成多空信号
"""
logger.info(
f"开始生成ORB策略信号{self.direction_desc}根据SAR指标{self.by_sar}{self.by_1h_k_desc}"
)
if self.data is None:
raise ValueError("请先调用fetch_intraday_data获取数据")
signals = []
# 按日期分组处理每日数据
for date, daily_data in self.data.groupby("Date"):
daily_data = daily_data.sort_index() # 按时间排序
if len(daily_data) < 2:
continue # 当日K线不足2根跳过
# 第一根5分钟K线开盘区间
first_candle = daily_data.iloc[0]
current_date = first_candle["Date"]
high1 = first_candle["High"]
low1 = first_candle["Low"]
open1 = first_candle["Open"]
close1 = first_candle["Close"]
sar_signal = first_candle["sar_signal"]
if high1 == low1:
continue
if self.by_big_k:
if (abs(open1 - close1) / (high1 - low1)) < 0.5:
continue
ma5_1h = None
ma10_1h = None
dif_1h = None
macd_1h = None
if self.by_1h_k:
if self.symbol_1h_data is None or len(self.symbol_1h_data) == 0:
continue
if len(self.symbol_1h_data) < 2:
continue
symbol_1h_date_data = self.symbol_1h_data[
self.symbol_1h_data["Date"] == current_date
]
if len(symbol_1h_date_data) > 0:
first_candle_1h = symbol_1h_date_data.iloc[0]
ma5_1h = first_candle_1h["ma5"]
ma10_1h = first_candle_1h["ma10"]
dif_1h = first_candle_1h["dif"]
macd_1h = first_candle_1h["macd"]
# 第二根5分钟K线entry信号
second_candle = daily_data.iloc[1]
entry_price = second_candle["Open"] # entry价格=第二根K线开盘价
entry_time = second_candle.date_time # entry时间
# 生成信号第一根K线方向决定多空排除十字星open1 == close1
if (
open1 < close1
and (self.direction == "Long" or self.direction is None)
and ((self.by_sar and sar_signal == "SAR多头") or not self.by_sar)
and (
(
self.by_1h_k
and (
ma5_1h is not None
and ma10_1h is not None
and ma5_1h > ma10_1h
)
and (
dif_1h is not None
and macd_1h is not None
and (dif_1h > 0 or macd_1h > 0)
)
)
or not self.by_1h_k
)
):
# 第一根K线收涨→多头信号
signal = "Long"
stop_price = low1 # 多头止损=第一根K线最低价
elif (
open1 > close1
and (self.direction == "Short" or self.direction is None)
and ((self.by_sar and sar_signal == "SAR空头") or not self.by_sar)
and (
(
self.by_1h_k
and (
ma5_1h is not None
and ma10_1h is not None
and ma5_1h < ma10_1h
)
and (
dif_1h is not None
and macd_1h is not None
and (dif_1h < 0 or macd_1h < 0)
)
)
or not self.by_1h_k
)
):
# 第一根K线收跌→空头信号
signal = "Short"
stop_price = high1 # 空头止损=第一根K线最高价
else:
# 与direction不一致或十字星→无信号
signal = None
stop_price = None
signals.append(
{
"Date": date,
"EntryTime": entry_time,
"Signal": signal,
"EntryPrice": entry_price,
"StopPrice": stop_price,
"High1": high1,
"Low1": low1,
}
)
# 将信号合并到原始数据
signals_df = pd.DataFrame(signals)
# 确保Date列类型一致将Date转换为datetime64[ns]类型
signals_df["Date"] = pd.to_datetime(signals_df["Date"])
# 使用merge而不是join来合并数据,根据signals_df的EntryTime与self.data的date_time进行匹配
# TODO: 这里需要优化
self.data = self.data.merge(
signals_df, left_on="date_time", right_on="EntryTime", how="left"
)
# 将Date_x和Date_y合并为Date
self.data["Date"] = self.data["Date_x"].combine_first(self.data["Date_y"])
# 删除Date_x和Date_y
self.data.drop(columns=["Date_x", "Date_y"], inplace=True)
logger.info(
f"生成信号完成:共{len(signals_df)}个交易日,其中多头{sum(signals_df['Signal']=='Long')}次,空头{sum(signals_df['Signal']=='Short')}"
)
def calculate_price_range_mean(self):
"""
计算价格振幅均值,振幅为最高价与最低价之差
计算价格振幅标准差
要求用滑动窗口: window_size=100计算均值每次计算都包含当前行
返回一个新列,列名为"PriceRangeMean"
"""
self.data["PriceRange"] = self.data["High"] - self.data["Low"]
self.data["PriceRangeMean"] = self.data["PriceRange"].rolling(window=100).mean()
self.data["PriceRangeStd"] = self.data["PriceRange"].rolling(window=100).std()
def backtest(self):
"""
回测ORB策略
:param profit_target_multiple: 盈利目标倍数默认10倍$R即10R
"""
logger.info(f"开始回测ORB策略盈利目标倍数={self.profit_target_multiple}")
if "Signal" not in self.data.columns:
raise ValueError("请先调用generate_orb_signals生成策略信号")
account_value = self.initial_capital # 初始账户价值
current_position = None # 当前持仓None=空仓Long/Short=持仓)
equity_history = [account_value] # 净值历史
trade_id = 0 # 交易ID
# 按时间遍历数据每日仅处理第二根K线后的信号
for date, daily_data in self.data.groupby("Date"):
daily_data = daily_data.sort_index()
if len(daily_data) < 2:
continue
# 获取当日信号第二根K线的信号
signal_row = (
daily_data[~pd.isna(daily_data["Signal"])].iloc[0]
if sum(~pd.isna(daily_data["Signal"])) > 0
else None
)
if signal_row is None:
# 无信号→当日不交易,净值保持不变
equity_history.append(account_value)
continue
# 提取信号参数
signal = signal_row["Signal"]
if pd.isna(signal):
continue
entry_price = signal_row["EntryPrice"]
stop_price = signal_row["StopPrice"]
high1 = signal_row["High1"]
low1 = signal_row["Low1"]
price_range = signal_row["PriceRange"]
price_range_mean = signal_row["PriceRangeMean"]
price_range_std = signal_row["PriceRangeStd"]
# 计算$R
if (
self.price_range_mean_as_R
and price_range_mean is not None
and price_range_mean > 0
):
risk_assumed = price_range_mean
else:
risk_assumed = abs(entry_price - stop_price)
profit_target = (
entry_price + (risk_assumed * self.profit_target_multiple)
if signal == "Long"
else entry_price - (risk_assumed * self.profit_target_multiple)
)
# 计算交易股数
shares = self.calculate_shares(
account_value, entry_price, stop_price, risk_assumed
)
if shares == 0:
# 股数为0→不交易
equity_history.append(account_value)
continue
# 计算佣金(买入/卖出各收一次)
total_commission = shares * self.commission_per_share * 2 # 往返佣金
# 模拟日内持仓:寻找止损/止盈触发点,或当日收盘平仓
daily_prices = daily_data[
daily_data.date_time > signal_row.date_time
] # 从entry时间开始遍历
exit_price = None
exit_time = None
exit_reason = None
for idx, (time, row) in enumerate(daily_prices.iterrows()):
high = row["High"]
low = row["Low"]
close = row["Close"]
# 检查止损/止盈条件
if signal == "Long":
# 多头:跌破止损→止损;突破止盈→止盈
if low <= stop_price:
exit_price = stop_price
exit_reason = "Stop Loss"
exit_time = row["date_time"]
break
elif high >= profit_target:
exit_price = profit_target
exit_reason = f"Profit Target ({self.profit_target_multiple}R)"
exit_time = row["date_time"]
break
elif signal == "Short":
# 空头:突破止损→止损;跌破止盈→止盈
if high >= stop_price:
exit_price = stop_price
exit_reason = "Stop Loss"
exit_time = row["date_time"]
break
elif low <= profit_target:
exit_price = profit_target
exit_reason = f"Profit Target ({self.profit_target_multiple}R)"
exit_time = row["date_time"]
break
# 若未触发止损/止盈,当日收盘平仓
if exit_price is None:
exit_price = daily_prices.iloc[-1]["Close"]
exit_reason = "End of Day (EoD)"
exit_time = daily_prices.iloc[-1].date_time
initial_account_value = account_value
# 计算盈亏
if signal == "Long":
profit_loss = (exit_price - entry_price) * shares - total_commission
else: # Short
profit_loss = (entry_price - exit_price) * shares - total_commission
# 计算盈亏百分比profit_loss除以当期初始资金
profit_loss_percentage = (profit_loss / initial_account_value) * 100
# 更新账户价值
account_value += profit_loss
account_value = max(account_value, 0) # 账户价值不能为负
# 记录交易
self.trades.append(
{
"TradeID": trade_id,
"Direction": self.direction_desc,
"BySar": self.sar_desc,
"PriceRangeMeanAsR": self.price_range_mean_as_R_desc,
"ByBigK": self.by_big_k_desc,
"By1hK": self.by_1h_k_desc,
"Symbol": self.symbol,
"Bar": self.bar,
"Date": date,
"Signal": signal,
"EntryTime": signal_row.date_time.strftime("%Y-%m-%d %H:%M:%S"),
"EntryPrice": entry_price,
"ExitTime": exit_time.strftime("%Y-%m-%d %H:%M:%S"),
"ExitPrice": exit_price,
"Shares": shares,
"RiskAssumed": risk_assumed,
"ProfitLoss": profit_loss,
"ProfitLossPercentage": profit_loss_percentage,
"ExitReason": exit_reason,
"AccountValueInitial": initial_account_value,
"AccountValueAfter": account_value,
}
)
# 记录净值
equity_history.append(account_value)
trade_id += 1
if len(self.trades) == 0:
logger.info("没有交易")
self.trades_df = pd.DataFrame()
self.initial_trade_summary()
return
# 生成净值曲线
self.create_equity_curve()
# 输出回测结果
self.trades_df = pd.DataFrame(self.trades)
self.trades_df.sort_values(by="ExitTime", inplace=True)
total_return = (
(account_value - self.initial_capital) / self.initial_capital * 100
)
win_rate = (
(self.trades_df["ProfitLoss"] > 0).sum() / len(self.trades_df) * 100
if len(self.trades_df) > 0
else 0
)
# 计算盈亏比
profit_sum = self.trades_df[self.trades_df["ProfitLoss"] > 0][
"ProfitLoss"
].sum()
loss_sum = abs(
self.trades_df[self.trades_df["ProfitLoss"] < 0]["ProfitLoss"].sum()
)
if loss_sum == 0:
profit_loss_ratio = float("inf")
else:
profit_loss_ratio = (profit_sum / loss_sum) * 100
first_entry_price = self.trades_df.iloc[0]["EntryPrice"]
last_exit_price = self.trades_df.iloc[-1]["ExitPrice"]
natural_return = (last_exit_price - first_entry_price) / first_entry_price * 100
self.initial_trade_summary()
if len(self.trades_df) > 0:
logger.info("\n" + "=" * 50)
logger.info("ORB策略回测结果")
logger.info("=" * 50)
logger.info(f"股票代码:{self.symbol}")
logger.info(f"K线周期{self.bar}")
logger.info(f"开始日期:{self.start_date}")
logger.info(f"结束日期:{self.end_date}")
logger.info(f"盈利目标倍数:{self.profit_target_multiple}")
logger.info(f"初始资金:${self.initial_capital:,.2f}")
logger.info(f"最终资金:${account_value:,.2f}")
self.trades_summary["最终资金$"] = account_value
logger.info(f"总收益率:{total_return:.2f}%")
self.trades_summary["总收益率%"] = total_return
logger.info(f"自然收益率:{natural_return:.2f}%")
self.trades_summary["自然收益率%"] = natural_return
logger.info(f"总交易次数:{len(self.trades_df)}")
self.trades_summary["总交易次数"] = len(self.trades_df)
logger.info(f"盈亏比:{profit_loss_ratio:.2f}%")
self.trades_summary["盈亏比%"] = profit_loss_ratio
logger.info(f"胜率:{win_rate:.2f}%")
self.trades_summary["胜率%"] = win_rate
logger.info(f"平均每笔盈亏:${self.trades_df['ProfitLoss'].mean():.2f}")
self.trades_summary["平均每笔盈亏$"] = self.trades_df["ProfitLoss"].mean()
logger.info(f"最大单笔盈利:${self.trades_df['ProfitLoss'].max():.2f}")
self.trades_summary["最大单笔盈利$"] = self.trades_df["ProfitLoss"].max()
logger.info(f"最大单笔亏损:${abs(self.trades_df['ProfitLoss'].min()):.2f}")
self.trades_summary["最大单笔亏损$"] = abs(
self.trades_df["ProfitLoss"].min()
)
else:
logger.info("没有交易")
self.trades_summary_df = pd.DataFrame([self.trades_summary])
def initial_trade_summary(self):
"""
初始化交易总结
"""
self.trades_summary = {}
self.trades_summary["方向"] = self.direction_desc
self.trades_summary["根据SAR"] = self.sar_desc
self.trades_summary["R算法"] = self.price_range_mean_as_R_desc
self.trades_summary["K线条件"] = self.by_big_k_desc
self.trades_summary["1小时K线条件"] = self.by_1h_k_desc
self.trades_summary["股票代码"] = self.symbol
self.trades_summary["K线周期"] = self.bar
self.trades_summary["开始日期"] = self.start_date
self.trades_summary["结束日期"] = self.end_date
self.trades_summary["盈利目标倍数"] = self.profit_target_multiple
self.trades_summary["初始资金$"] = self.initial_capital
self.trades_summary["最终资金$"] = self.initial_capital
self.trades_summary["总收益率%"] = 0
self.trades_summary["自然收益率%"] = 0
self.trades_summary["总交易次数"] = 0
self.trades_summary["盈亏比%"] = 0
self.trades_summary["胜率%"] = 0
self.trades_summary["平均每笔盈亏$"] = 0
self.trades_summary["最大单笔盈利$"] = 0
self.trades_summary["最大单笔亏损$"] = 0
def create_equity_curve(self):
"""
创建账户净值曲线
"""
equity_curve_list = []
# 将self.data.index[0].Date的值转换为字符串且格式为YYYY-MM-DD
first_date = self.data.iloc[0].date_time.strftime("%Y-%m-%d %H:%M:%S")
first_open = float(self.data.iloc[0].Open)
equity_curve_list.append(
{
"DateTime": first_date,
"AccountValue": self.initial_capital,
"MarketPrice": first_open,
}
)
for trade in self.trades:
equity_curve_list.append(
{
"DateTime": trade["ExitTime"],
"AccountValue": trade["AccountValueAfter"],
"MarketPrice": trade["ExitPrice"],
}
)
self.equity_curve = pd.DataFrame(equity_curve_list)
self.equity_curve.sort_values(by="DateTime", inplace=True)
self.equity_curve.reset_index(drop=True, inplace=True)
def plot_equity_curve(self):
"""
绘制账户净值曲线
"""
logger.info("开始绘制账户净值曲线")
if self.equity_curve is None:
raise ValueError("请先调用backtest进行回测")
# seaborn风格设置
sns.set_theme(style="whitegrid")
# plt.rcParams['font.family'] = "SimHei"
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
symbol = self.data.iloc[0].symbol
bar = self.data.iloc[0].bar
first_account_value = self.equity_curve.iloc[0]["AccountValue"]
first_market_price = self.equity_curve.iloc[0]["MarketPrice"]
account_value_to_1 = self.equity_curve["AccountValue"] / first_account_value
market_price_to_1 = self.equity_curve["MarketPrice"] / first_market_price
plt.figure(figsize=(12, 6))
plt.plot(
self.equity_curve["DateTime"],
account_value_to_1,
label="账户价值",
color="blue",
linewidth=2,
marker="o",
markersize=4,
)
plt.plot(
self.equity_curve["DateTime"],
market_price_to_1,
label="市场价格",
color="green",
linewidth=2,
marker="s",
markersize=4,
)
plt.title(
f"{symbol} {bar} {self.direction_desc} {self.sar_desc} {self.price_range_mean_as_R_desc} {self.by_big_k_desc} {self.by_1h_k_desc}",
fontsize=14,
fontweight="bold",
)
plt.xlabel("时间", fontsize=12)
plt.ylabel("涨跌变化", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
# 设置x轴标签避免matplotlib警告
# 选择合适的时间间隔显示标签,避免过于密集
if len(self.equity_curve) > 30:
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
step = max(1, len(self.equity_curve) // 30)
# 创建标签索引列表,确保包含首尾数据
label_indices = [0] # 第一条
# 添加中间间隔的标签
for i in range(step, len(self.equity_curve) - 1, step):
label_indices.append(i)
# 添加最后一条(如果还没有包含的话)
if len(self.equity_curve) - 1 not in label_indices:
label_indices.append(len(self.equity_curve) - 1)
# 设置x轴标签
plt.xticks(
self.equity_curve["DateTime"].iloc[label_indices],
self.equity_curve["DateTime"].iloc[label_indices],
rotation=45,
ha="right",
fontsize=10,
)
else:
# 如果数据点较少,全部显示
plt.xticks(
self.equity_curve["DateTime"],
self.equity_curve["DateTime"],
rotation=45,
ha="right",
fontsize=10,
)
plt.tight_layout()
self.chart_save_path = f"{self.output_chart_folder}/{symbol}_{bar}_{self.direction_desc}_{self.sar_desc}_{self.price_range_mean_as_R_desc}_{self.by_big_k_desc}_{self.by_1h_k_desc}_orb.png"
plt.savefig(self.chart_save_path, dpi=150, bbox_inches="tight")
plt.close()
def output_trade_summary(self):
"""
输出交易明细交易总结与Chart图片到Excel
"""
start_date = self.start_date.replace("-", "")
end_date = self.end_date.replace("-", "")
output_file_name = f"orb_{self.symbol}_{self.bar}_{start_date}_{end_date}_{self.direction_desc}_{self.sar_desc}_{self.price_range_mean_as_R_desc}_{self.by_big_k_desc}_{self.by_1h_k_desc}.xlsx"
output_file_path = os.path.join(self.output_excel_folder, output_file_name)
logger.info(f"导出{output_file_path}")
with pd.ExcelWriter(output_file_path) as writer:
self.trades_df.to_excel(writer, sheet_name="交易明细", index=False)
self.trades_summary_df.to_excel(writer, sheet_name="交易总结", index=False)
if os.path.exists(self.chart_save_path):
charts_dict = {"账户净值曲线": self.chart_save_path}
self.output_chart_to_excel(output_file_path, charts_dict)
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典,格式为:
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_path in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
logger.info(f"图表已输出到{excel_file_path}")
# ------------------- 策略示例回测QQQ的ORB策略2016-2023 -------------------
if __name__ == "__main__":
# 初始化ORB策略
orb_strategy = ORBStrategy(
symbol="ETH-USDT",
bar="5m",
start_date="2025-05-15",
end_date="2025-08-20",
initial_capital=25000,
max_leverage=4,
risk_per_trade=0.01,
commission_per_share=0.0005,
profit_target_multiple=10,
is_us_stock=False,
direction=None,
by_sar=False,
)
orb_strategy.run()