diff --git a/config.py b/config.py index 835ce7b..fbf098f 100644 --- a/config.py +++ b/config.py @@ -73,7 +73,7 @@ OKX_MONITOR_CONFIG = { US_STOCK_MONITOR_CONFIG = { "volume_monitor":{ "symbols": ["QQQ", "TQQQ", "MSFT", "AAPL", "GOOG", "NVDA", "META", "AMZN", "TSLA", "AVGO"], - "bars": ["5m"], + "bars": ["5", "15m", "30m", "1H"], "initial_date": "2015-08-31 00:00:00" } } diff --git a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc index db4662d..ad1c914 100644 Binary files a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc and b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc differ diff --git a/core/trade/__pycache__/orb_trade.cpython-312.pyc b/core/trade/__pycache__/orb_trade.cpython-312.pyc index 741d5d7..59f0225 100644 Binary files a/core/trade/__pycache__/orb_trade.cpython-312.pyc and b/core/trade/__pycache__/orb_trade.cpython-312.pyc differ diff --git a/core/trade/ma_break_statistics.py b/core/trade/ma_break_statistics.py index 27ffe2f..8721002 100644 --- a/core/trade/ma_break_statistics.py +++ b/core/trade/ma_break_statistics.py @@ -12,7 +12,7 @@ from openpyxl.drawing.image import Image import openpyxl from openpyxl.styles import Font from PIL import Image as PILImage -from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE +from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE from core.db.db_market_data import DBMarketData from core.db.db_huge_volume_data import DBHugeVolumeData from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp @@ -32,7 +32,7 @@ class MaBreakStatistics: 之间的涨跌幅 """ - def __init__(self): + def __init__(self, is_us_stock: bool = False): mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_password = MYSQL_CONFIG.get("password", "") if not mysql_password: @@ -44,10 +44,20 @@ class MaBreakStatistics: self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_market_data = DBMarketData(self.db_url) self.db_huge_volume_data = DBHugeVolumeData(self.db_url) - self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( + if is_us_stock: + self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( + "symbols", ["QQQ"] + ) + else: + self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "symbols", ["XCH-USDT"] ) - self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( + if is_us_stock: + self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( + "bars", ["5m"] + ) + else: + self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "bars", ["5m", "15m", "30m", "1H"] ) self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/" diff --git a/core/trade/orb_trade.py b/core/trade/orb_trade.py index d4759d0..4d247dc 100644 --- a/core/trade/orb_trade.py +++ b/core/trade/orb_trade.py @@ -1,4 +1,5 @@ import yfinance as yf +import os import pandas as pd import numpy as np import matplotlib.pyplot as plt @@ -22,15 +23,28 @@ class ORBStrategy: max_leverage=4, risk_per_trade=0.01, commission_per_share=0.0005, + is_us_stock=False, ): """ 初始化ORB策略参数 + ORB策略说明: + 1. 每天仅1次交易机会,多头或空头,排除十字星:open1 == close1 + 2. 第一根5分钟K线:确定开盘区间(High1, Low1) + 3. 第二根5分钟K线:根据第一根K线方向生成多空信号,open1close1为空头 + entry_price=第二根K线开盘价,stop_price=第一根K线最低价(多头)或第一根K线最高价(空头) + 4. 多头:跌破止损→止损;突破止盈→止盈 + 5. 空头:突破止损→止损;跌破止盈→止盈 + 6. 止损/止盈:根据$R计算,$R=|entry_price-stop_price| + 7. 盈利目标:10R,即10*$R + 8. 账户净值曲线:账户价值与市场价格 :param initial_capital: 初始账户资金(美元) :param max_leverage: 最大杠杆倍数(默认4倍,符合FINRA规定) :param risk_per_trade: 单次交易风险比例(默认1%) :param commission_per_share: 每股交易佣金(美元,默认0.0005) """ - logger.info(f"初始化ORB策略参数:初始账户资金={initial_capital},最大杠杆倍数={max_leverage},单次交易风险比例={risk_per_trade},每股交易佣金={commission_per_share}") + logger.info( + f"初始化ORB策略参数:初始账户资金={initial_capital},最大杠杆倍数={max_leverage},单次交易风险比例={risk_per_trade},每股交易佣金={commission_per_share}" + ) self.initial_capital = initial_capital self.max_leverage = max_leverage self.risk_per_trade = risk_per_trade @@ -48,6 +62,9 @@ class ORBStrategy: self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_market_data = DBMarketData(self.db_url) + self.is_us_stock = is_us_stock + self.output_chart_folder = r"./output/trade_sandbox/orb_strategy/chart/" + os.makedirs(self.output_chart_folder, exist_ok=True) def fetch_intraday_data(self, symbol, start_date, end_date, interval="5m"): """ @@ -72,14 +89,20 @@ class ORBStrategy: data["Low"] = data["low"] data["Close"] = data["close"] data["Volume"] = data["volume"] - # 将data["date_time"]从字符串类型转换为日期 - data["date_time"] = pd.to_datetime(data["date_time"]) + if self.is_us_stock: + date_time_field = "date_time_us" + else: + date_time_field = "date_time" + data[date_time_field] = pd.to_datetime(data[date_time_field]) # data["Date"]为日期,不包括时分秒,即date_time如果是2025-01-01 10:00:00,则Date为2025-01-01 - data["Date"] = data["date_time"].dt.date + data["Date"] = data[date_time_field].dt.date # 将Date转换为datetime64[ns]类型以确保类型一致 data["Date"] = pd.to_datetime(data["Date"]) - self.data = data[["Date", "date_time", "Open", "High", "Low", "Close", "Volume"]].copy() + self.data = data[ + ["symbol", "bar", "Date", date_time_field, "Open", "High", "Low", "Close", "Volume"] + ].copy() + self.data.rename(columns={date_time_field: "date_time"}, inplace=True) logger.info(f"成功获取{symbol}数据:{len(self.data)}根{interval}K线") def calculate_shares(self, account_value, entry_price, stop_price): @@ -90,7 +113,9 @@ class ORBStrategy: :param stop_price: 止损价格(多头=第一根K线最低价,空头=第一根K线最高价) :return: 整数股数(Shares) """ - logger.info(f"开始计算交易股数:账户价值={account_value},entry价格={entry_price},止损价格={stop_price}") + logger.info( + f"开始计算交易股数:账户价值={account_value},entry价格={entry_price},止损价格={stop_price}" + ) # 计算单交易风险金额($R) risk_per_trade_dollar = abs(entry_price - stop_price) # 风险金额取绝对值 if risk_per_trade_dollar <= 0: @@ -149,7 +174,7 @@ class ORBStrategy: stop_price = high1 # 空头止损=第一根K线最高价 else: # 十字星→无信号 - signal = "None" + signal = None stop_price = None signals.append( @@ -167,10 +192,12 @@ class ORBStrategy: # 将信号合并到原始数据 signals_df = pd.DataFrame(signals) # 确保Date列类型一致,将Date转换为datetime64[ns]类型 - signals_df['Date'] = pd.to_datetime(signals_df['Date']) + signals_df["Date"] = pd.to_datetime(signals_df["Date"]) # 使用merge而不是join来合并数据,根据signals_df的EntryTime与self.data的date_time进行匹配 # TODO: 这里需要优化 - self.data = self.data.merge(signals_df, left_on="date_time", right_on="EntryTime", how="left") + self.data = self.data.merge( + signals_df, left_on="date_time", right_on="EntryTime", how="left" + ) # 将Date_x和Date_y合并为Date self.data["Date"] = self.data["Date_x"].combine_first(self.data["Date_y"]) # 删除Date_x和Date_y @@ -214,6 +241,7 @@ class ORBStrategy: signal = signal_row["Signal"] if pd.isna(signal): continue + entry_price = signal_row["EntryPrice"] stop_price = signal_row["StopPrice"] high1 = signal_row["High1"] @@ -254,24 +282,24 @@ class ORBStrategy: if low <= stop_price: exit_price = stop_price exit_reason = "Stop Loss" - exit_time = time + exit_time = row["date_time"] break elif high >= profit_target: exit_price = profit_target exit_reason = "Profit Target (10R)" - exit_time = time + exit_time = row["date_time"] break elif signal == "Short": # 空头:突破止损→止损;跌破止盈→止盈 if high >= stop_price: exit_price = stop_price exit_reason = "Stop Loss" - exit_time = time + exit_time = row["date_time"] break elif low <= profit_target: exit_price = profit_target exit_reason = "Profit Target (10R)" - exit_time = time + exit_time = row["date_time"] break # 若未触发止损/止盈,当日收盘平仓 @@ -280,12 +308,16 @@ class ORBStrategy: exit_reason = "End of Day (EoD)" exit_time = daily_prices.iloc[-1].date_time + initial_account_value = account_value # 计算盈亏 if signal == "Long": profit_loss = (exit_price - entry_price) * shares - total_commission else: # Short profit_loss = (entry_price - exit_price) * shares - total_commission + # 计算盈亏百分比,profit_loss除以当期初始资金 + profit_loss_percentage = (profit_loss / initial_account_value) * 100 + # 更新账户价值 account_value += profit_loss account_value = max(account_value, 0) # 账户价值不能为负 @@ -296,14 +328,16 @@ class ORBStrategy: "TradeID": trade_id, "Date": date, "Signal": signal, - "EntryTime": signal_row.date_time, + "EntryTime": signal_row.date_time.strftime("%Y-%m-%d %H:%M:%S"), "EntryPrice": entry_price, - "ExitTime": exit_time, + "ExitTime": exit_time.strftime("%Y-%m-%d %H:%M:%S"), "ExitPrice": exit_price, "Shares": shares, "RiskAssumed": risk_assumed, "ProfitLoss": profit_loss, + "ProfitLossPercentage": profit_loss_percentage, "ExitReason": exit_reason, + "AccountValueInitial": initial_account_value, "AccountValueAfter": account_value, } ) @@ -313,12 +347,7 @@ class ORBStrategy: trade_id += 1 # 生成净值曲线 - self.equity_curve = pd.Series( - equity_history, - index=pd.date_range( - start=self.data.index[0].date(), periods=len(equity_history), freq="D" - ), - ) + self.create_equity_curve() # 输出回测结果 trades_df = pd.DataFrame(self.trades) @@ -330,6 +359,13 @@ class ORBStrategy: if len(trades_df) > 0 else 0 ) + # 计算盈亏比 + profit_sum = trades_df[trades_df["ProfitLoss"] > 0]["ProfitLoss"].sum() + loss_sum = abs(trades_df[trades_df["ProfitLoss"] < 0]["ProfitLoss"].sum()) + if loss_sum == 0: + profit_loss_ratio = float('inf') + else: + profit_loss_ratio = (profit_sum / loss_sum) * 100 logger.info("\n" + "=" * 50) logger.info("ORB策略回测结果") @@ -338,12 +374,40 @@ class ORBStrategy: logger.info(f"最终资金:${account_value:,.2f}") logger.info(f"总收益率:{total_return:.2f}%") logger.info(f"总交易次数:{len(trades_df)}") + logger.info(f"盈亏比:{profit_loss_ratio:.2f}%") logger.info(f"胜率:{win_rate:.2f}%") if len(trades_df) > 0: logger.info(f"平均每笔盈亏:${trades_df['ProfitLoss'].mean():.2f}") logger.info(f"最大单笔盈利:${trades_df['ProfitLoss'].max():.2f}") logger.info(f"最大单笔亏损:${trades_df['ProfitLoss'].min():.2f}") + def create_equity_curve(self): + """ + 创建账户净值曲线 + """ + equity_curve_list = [] + # 将self.data.index[0].Date的值转换为字符串,且格式为YYYY-MM-DD + first_date = self.data.iloc[0].date_time.strftime("%Y-%m-%d %H:%M:%S") + first_open = float(self.data.iloc[0].Open) + equity_curve_list.append( + { + "DateTime": first_date, + "AccountValue": self.initial_capital, + "MarketPrice": first_open, + } + ) + for trade in self.trades: + equity_curve_list.append( + { + "DateTime": trade["ExitTime"], + "AccountValue": trade["AccountValueAfter"], + "MarketPrice": trade["ExitPrice"], + } + ) + self.equity_curve = pd.DataFrame(equity_curve_list) + self.equity_curve.sort_values(by="DateTime", inplace=True) + self.equity_curve.reset_index(drop=True, inplace=True) + def plot_equity_curve(self): """ 绘制账户净值曲线 @@ -351,7 +415,7 @@ class ORBStrategy: logger.info("开始绘制账户净值曲线") if self.equity_curve is None: raise ValueError("请先调用backtest进行回测") - + # seaborn风格设置 sns.set_theme(style="whitegrid") # plt.rcParams['font.family'] = "SimHei" @@ -359,22 +423,51 @@ class ORBStrategy: plt.rcParams["font.size"] = 11 # 设置字体大小 plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题 + symbol = self.data.iloc[0].symbol + bar = self.data.iloc[0].bar + first_account_value = self.equity_curve.iloc[0]["AccountValue"] + first_market_price = self.equity_curve.iloc[0]["MarketPrice"] + account_value_to_1 = self.equity_curve["AccountValue"] / first_account_value + market_price_to_1 = self.equity_curve["MarketPrice"] / first_market_price plt.figure(figsize=(12, 6)) - plt.plot( - self.equity_curve.index, - self.equity_curve.values, - label="ORB策略净值", - color="blue", - ) - plt.axhline( - y=self.initial_capital, color="red", linestyle="--", label="初始资金" - ) - plt.title("ORB策略账户净值曲线", fontsize=14) - plt.xlabel("日期", fontsize=12) - plt.ylabel("账户价值(美元)", fontsize=12) - plt.legend() + plt.plot(self.equity_curve["DateTime"], account_value_to_1, label="账户价值", color='blue', linewidth=2, marker='o', markersize=4) + plt.plot(self.equity_curve["DateTime"], market_price_to_1, label="市场价格", color='green', linewidth=2, marker='s', markersize=4) + plt.title(f"ORB策略账户净值曲线 {symbol} {bar}", fontsize=14, fontweight='bold') + plt.xlabel("时间", fontsize=12) + plt.ylabel("涨跌变化", fontsize=12) + plt.legend(fontsize=11) plt.grid(True, alpha=0.3) - plt.show() + + # 设置x轴标签,避免matplotlib警告 + # 选择合适的时间间隔显示标签,避免过于密集 + if len(self.equity_curve) > 30: + # 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示 + step = max(1, len(self.equity_curve) // 30) + + # 创建标签索引列表,确保包含首尾数据 + label_indices = [0] # 第一条 + + # 添加中间间隔的标签 + for i in range(step, len(self.equity_curve) - 1, step): + label_indices.append(i) + + # 添加最后一条(如果还没有包含的话) + if len(self.equity_curve) - 1 not in label_indices: + label_indices.append(len(self.equity_curve) - 1) + + # 设置x轴标签 + plt.xticks(self.equity_curve["DateTime"].iloc[label_indices], + self.equity_curve["DateTime"].iloc[label_indices], + rotation=45, ha='right', fontsize=10) + else: + # 如果数据点较少,全部显示 + plt.xticks(self.equity_curve["DateTime"], + self.equity_curve["DateTime"], + rotation=45, ha='right', fontsize=10) + plt.tight_layout() + save_path = f"{self.output_chart_folder}/{symbol}_{bar}_orb_strategy_equity_curve.png" + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() # ------------------- 策略示例:回测QQQ的ORB策略(2016-2023) ------------------- diff --git a/market_data_main.py b/market_data_main.py index 03c6963..f8c2e3a 100644 --- a/market_data_main.py +++ b/market_data_main.py @@ -109,14 +109,16 @@ class MarketDataMain: # 如果bar为1D, 则end_time_ts与start_time_ts相差超过10天,则按照10天为单位 # 获取数据,直到end_time_ts threshold = None - if bar in ["5m", "15m", "30m"]: + if bar in ["5m", "15m", "30m", "1H"]: if self.is_us_stock: if bar == "5m": threshold = 86400000 * 4 elif bar == "15m": - threshold = 86400000 * 4 * 3 + threshold = 86400000 * 6 elif bar == "30m": - threshold = 86400000 * 4 * 6 + threshold = 86400000 * 12 + elif bar == "1H": + threshold = 86400000 * 24 else: threshold = 86400000 elif bar in ["1H", "4H"]: diff --git a/orb_trade_main.py b/orb_trade_main.py index d14a0d7..5421938 100644 --- a/orb_trade_main.py +++ b/orb_trade_main.py @@ -1,27 +1,34 @@ from core.trade.orb_trade import ORBStrategy +from config import US_STOCK_MONITOR_CONFIG +import core.logger as logging + +logger = logging.logger def main(): - # 初始化ORB策略 - orb_strategy = ORBStrategy( - initial_capital=25000, - max_leverage=4, - risk_per_trade=0.01, - commission_per_share=0.0005, - ) + symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", ["QQQ"]) + for symbol in symbols: + logger.info(f"开始回测 {symbol}") + # 初始化ORB策略 + orb_strategy = ORBStrategy( + initial_capital=25000, + max_leverage=4, + risk_per_trade=0.01, + commission_per_share=0.0005, + is_us_stock=True, + ) + # 1. 获取QQQ的5分钟日内数据(2024-2025,注意:yfinance免费版可能限制历史日内数据,建议用专业数据源) + orb_strategy.fetch_intraday_data( + symbol=symbol, start_date="2024-11-30", end_date="2025-08-30", interval="5m" + ) - # 1. 获取QQQ的5分钟日内数据(2024-2025,注意:yfinance免费版可能限制历史日内数据,建议用专业数据源) - orb_strategy.fetch_intraday_data( - symbol="ETH-USDT", start_date="2025-05-15", end_date="2025-08-20", interval="5m" - ) + # 2. 生成ORB策略信号 + orb_strategy.generate_orb_signals() - # 2. 生成ORB策略信号 - orb_strategy.generate_orb_signals() + # 3. 回测策略(盈利目标10R) + orb_strategy.backtest(profit_target_multiple=10) - # 3. 回测策略(盈利目标10R) - orb_strategy.backtest(profit_target_multiple=10) - - # 4. 绘制净值曲线 - orb_strategy.plot_equity_curve() + # 4. 绘制净值曲线 + orb_strategy.plot_equity_curve() if __name__ == "__main__": diff --git a/trade_ma_strategy_main.py b/trade_ma_strategy_main.py index 128e88e..927c85a 100644 --- a/trade_ma_strategy_main.py +++ b/trade_ma_strategy_main.py @@ -25,8 +25,8 @@ from config import ( logger = logging.logger class TradeMaStrategyMain: - def __init__(self): - self.ma_break_statistics = MaBreakStatistics() + def __init__(self, is_us_stock: bool = False): + self.ma_break_statistics = MaBreakStatistics(is_us_stock=is_us_stock) def batch_ma_break_statistics(self): """ @@ -59,5 +59,5 @@ class TradeMaStrategyMain: if __name__ == "__main__": - trade_ma_strategy_main = TradeMaStrategyMain() + trade_ma_strategy_main = TradeMaStrategyMain(is_us_stock=True) trade_ma_strategy_main.batch_ma_break_statistics() \ No newline at end of file