344 lines
14 KiB
Python
344 lines
14 KiB
Python
from core.trade.orb_trade import ORBStrategy
|
||
from config import US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG
|
||
import core.logger as logging
|
||
from datetime import datetime
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
import openpyxl
|
||
import pandas as pd
|
||
import os
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
def main():
|
||
is_us_stock_list = [True, False]
|
||
bar = "5m"
|
||
direction_list = [None, "Long", "Short"]
|
||
by_sar_list = [False, True]
|
||
start_date = "2024-01-01"
|
||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||
profit_target_multiple = 10
|
||
initial_capital = 25000
|
||
max_leverage = 4
|
||
risk_per_trade = 0.01
|
||
commission_per_share = 0.0005
|
||
|
||
trades_df_list = []
|
||
trades_summary_df_list = []
|
||
symbol_data_cache = []
|
||
for is_us_stock in is_us_stock_list:
|
||
for direction in direction_list:
|
||
for by_sar in by_sar_list:
|
||
if is_us_stock:
|
||
symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["QQQ"]
|
||
)
|
||
else:
|
||
symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["BTC-USDT"]
|
||
)
|
||
for symbol in symbols:
|
||
logger.info(
|
||
f"开始回测 {symbol}, 交易周期:{bar}, 开始日期:{start_date}, 结束日期:{end_date}, 是否是美股:{is_us_stock}, 交易方向:{direction}, 是否使用SAR:{by_sar}"
|
||
)
|
||
symbol_bar_data = None
|
||
found_symbol_bar_data = False
|
||
for symbol_data_dict in symbol_data_cache:
|
||
if (
|
||
symbol_data_dict["symbol"] == symbol
|
||
and symbol_data_dict["bar"] == bar
|
||
):
|
||
symbol_bar_data = symbol_data_dict["data"]
|
||
found_symbol_bar_data = True
|
||
break
|
||
|
||
orb_strategy = ORBStrategy(
|
||
symbol=symbol,
|
||
bar=bar,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
is_us_stock=is_us_stock,
|
||
direction=direction,
|
||
by_sar=by_sar,
|
||
profit_target_multiple=profit_target_multiple,
|
||
initial_capital=initial_capital,
|
||
max_leverage=max_leverage,
|
||
risk_per_trade=risk_per_trade,
|
||
commission_per_share=commission_per_share,
|
||
symbol_bar_data=symbol_bar_data,
|
||
)
|
||
symbol_bar_data, trades_df, trades_summary_df = orb_strategy.run()
|
||
if symbol_bar_data is None or len(symbol_bar_data) == 0:
|
||
continue
|
||
if not found_symbol_bar_data:
|
||
symbol_data_cache.append(
|
||
{"symbol": symbol, "bar": bar, "data": symbol_bar_data}
|
||
)
|
||
if trades_summary_df is None or len(trades_summary_df) == 0:
|
||
continue
|
||
trades_summary_df_list.append(trades_summary_df)
|
||
trades_df_list.append(trades_df)
|
||
total_trades_df = pd.concat(trades_df_list)
|
||
total_trades_summary_df = pd.concat(trades_summary_df_list)
|
||
statitics_dict = statistics_summary(total_trades_summary_df)
|
||
output_excel_folder = r"./output/trade_sandbox/orb_strategy/excel/summary/"
|
||
os.makedirs(output_excel_folder, exist_ok=True)
|
||
now_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
excel_file_name = f"orb_strategy_summary_{now_str}.xlsx"
|
||
output_file_path = os.path.join(output_excel_folder, excel_file_name)
|
||
with pd.ExcelWriter(output_file_path) as writer:
|
||
total_trades_df.to_excel(writer, sheet_name="交易详情", index=False)
|
||
total_trades_summary_df.to_excel(writer, sheet_name="交易总结", index=False)
|
||
statitics_dict["statistics_summary_df"].to_excel(
|
||
writer, sheet_name="统计总结", index=False
|
||
)
|
||
statitics_dict["max_total_return_record_df"].to_excel(
|
||
writer, sheet_name="最大总收益率记录", index=False
|
||
)
|
||
statitics_dict["max_total_return_record_df_grouped_count"].to_excel(
|
||
writer, sheet_name="最大总收益率记录_方向和根据SAR的组合", index=False
|
||
)
|
||
statitics_dict["max_total_return_record_df_direction_count"].to_excel(
|
||
writer, sheet_name="最大总收益率记录_方向", index=False
|
||
)
|
||
statitics_dict["max_total_return_record_df_sar_count"].to_excel(
|
||
writer, sheet_name="最大总收益率记录_根据SAR", index=False
|
||
)
|
||
chart_path = r"./output/trade_sandbox/orb_strategy/chart/"
|
||
os.makedirs(chart_path, exist_ok=True)
|
||
copy_chart_to_excel(chart_path, output_file_path)
|
||
logger.info(f"交易总结已输出到{output_file_path}")
|
||
|
||
|
||
def statistics_summary(trades_summary_df: pd.DataFrame):
|
||
statistics_summary_list = []
|
||
summary = {}
|
||
# 1. 统计总收益率% > 0 的占比
|
||
total_return_gt_0 = trades_summary_df[trades_summary_df["总收益率%"] > 0].shape[0]
|
||
total_return_gt_0_ratio = round((total_return_gt_0 / trades_summary_df.shape[0]) * 100, 2)
|
||
summary["总收益率%>0占比"] = total_return_gt_0_ratio
|
||
logger.info(f"总收益率% > 0 的占比:{total_return_gt_0_ratio:.2f}%")
|
||
# 2. 统计总收益率% > 自然收益率% 的占比
|
||
total_return_gt_natural_return = trades_summary_df[
|
||
trades_summary_df["总收益率%"] > trades_summary_df["自然收益率%"]
|
||
].shape[0]
|
||
total_return_gt_natural_return_ratio = (
|
||
round((total_return_gt_natural_return / trades_summary_df.shape[0]) * 100, 2)
|
||
)
|
||
summary["总收益率%>自然收益率%占比"] = total_return_gt_natural_return_ratio
|
||
logger.info(
|
||
f"总收益率% > 自然收益率% 的占比:{total_return_gt_natural_return_ratio:.2f}%"
|
||
)
|
||
statistics_summary_list.append(summary)
|
||
statistics_summary_df = pd.DataFrame(statistics_summary_list)
|
||
|
||
symbol_list = trades_summary_df["股票代码"].unique()
|
||
max_total_return_record_list = []
|
||
for symbol in symbol_list:
|
||
trades_summary_df_copy = trades_summary_df.copy()
|
||
symbol_trades_summary_df = trades_summary_df_copy[
|
||
trades_summary_df_copy["股票代码"] == symbol
|
||
]
|
||
symbol_trades_summary_df.reset_index(drop=True, inplace=True)
|
||
if symbol_trades_summary_df.empty:
|
||
continue
|
||
# 过滤掉NaN,避免idxmax报错
|
||
valid_df = symbol_trades_summary_df[
|
||
symbol_trades_summary_df["总收益率%"].notna()
|
||
]
|
||
if valid_df.empty:
|
||
continue
|
||
# 获得总收益率%最大的记录
|
||
max_idx = valid_df["总收益率%"].idxmax()
|
||
max_total_return_record = symbol_trades_summary_df.loc[max_idx]
|
||
summary = {}
|
||
summary["股票代码"] = symbol
|
||
summary["方向"] = max_total_return_record["方向"]
|
||
summary["根据SAR"] = max_total_return_record["根据SAR"]
|
||
summary["总收益率%"] = max_total_return_record["总收益率%"]
|
||
summary["自然收益率%"] = max_total_return_record["自然收益率%"]
|
||
max_total_return_record_list.append(summary)
|
||
max_total_return_record_df = pd.DataFrame(max_total_return_record_list)
|
||
# 统计max_total_return_record_df中方向和根据SAR的组合(使用size更稳健,支持空分组与缺失值)
|
||
# 强制将分组键转为可哈希的标量类型,避免单元格为Series/列表导致的unhashable错误
|
||
if len(max_total_return_record_df) > 0:
|
||
|
||
def _to_hashable_scalar(v):
|
||
# 标量或None直接返回
|
||
if isinstance(v, (str, int, float, bool)) or v is None:
|
||
return v
|
||
try:
|
||
import numpy as _np
|
||
|
||
if _np.isscalar(v):
|
||
return v
|
||
except Exception:
|
||
pass
|
||
# 其它(如Series、list、dict、ndarray等)转字符串
|
||
return str(v)
|
||
|
||
for key_col in ["方向", "根据SAR"]:
|
||
if key_col in max_total_return_record_df.columns:
|
||
max_total_return_record_df[key_col] = max_total_return_record_df[
|
||
key_col
|
||
].apply(_to_hashable_scalar)
|
||
# 分组统计
|
||
max_total_return_record_df_grouped_count = (
|
||
max_total_return_record_df.groupby(["方向", "根据SAR"], dropna=False)
|
||
.size()
|
||
.reset_index(name="数量")
|
||
)
|
||
max_total_return_record_df_grouped_count.sort_values(
|
||
by="数量", ascending=False, inplace=True
|
||
)
|
||
max_total_return_record_df_grouped_count.reset_index(drop=True, inplace=True)
|
||
|
||
# 统计方向的记录数目
|
||
max_total_return_record_df_direction_count = (
|
||
max_total_return_record_df.groupby(["方向"], dropna=False)
|
||
.size()
|
||
.reset_index(name="数量")
|
||
)
|
||
max_total_return_record_df_direction_count.sort_values(
|
||
by="数量", ascending=False, inplace=True
|
||
)
|
||
max_total_return_record_df_direction_count.reset_index(drop=True, inplace=True)
|
||
|
||
# 统计根据SAR的记录数目
|
||
max_total_return_record_df_sar_count = (
|
||
max_total_return_record_df.groupby(["根据SAR"], dropna=False)
|
||
.size()
|
||
.reset_index(name="数量")
|
||
)
|
||
max_total_return_record_df_sar_count.sort_values(
|
||
by="数量", ascending=False, inplace=True
|
||
)
|
||
max_total_return_record_df_sar_count.reset_index(drop=True, inplace=True)
|
||
else:
|
||
# 构造空结果,保证下游写入Excel不报错
|
||
max_total_return_record_df_grouped_count = pd.DataFrame(
|
||
columns=["方向", "根据SAR", "数量"]
|
||
)
|
||
max_total_return_record_df_direction_count = pd.DataFrame(
|
||
columns=["方向", "数量"]
|
||
)
|
||
max_total_return_record_df_sar_count = pd.DataFrame(columns=["根据SAR", "数量"])
|
||
|
||
result = {
|
||
"statistics_summary_df": statistics_summary_df,
|
||
"max_total_return_record_df": max_total_return_record_df,
|
||
"max_total_return_record_df_grouped_count": max_total_return_record_df_grouped_count,
|
||
"max_total_return_record_df_direction_count": max_total_return_record_df_direction_count,
|
||
"max_total_return_record_df_sar_count": max_total_return_record_df_sar_count,
|
||
}
|
||
return result
|
||
|
||
|
||
def copy_chart_to_excel(chart_path: str, excel_file_path: str):
|
||
f"""
|
||
将chart图片复制到excel中
|
||
算法:
|
||
1. 读取chart_path
|
||
2. chart文件名开头是symbol,结尾是.png
|
||
3. 每个symbol创建一个Excel Sheet,Sheet名称为symbol_chart
|
||
4. 将chart图片插入到Sheet中
|
||
5. 要求每张图片大小为800x400
|
||
6. 要求两张图片左右并列显示
|
||
7. 要求上下图片间距为20px
|
||
"""
|
||
# 收集所有图片
|
||
if not os.path.isdir(chart_path):
|
||
return
|
||
chart_files = [f for f in os.listdir(chart_path) if f.lower().endswith(".png")]
|
||
if len(chart_files) == 0:
|
||
return
|
||
|
||
# 汇总需要处理的symbol列表(去重)
|
||
symbols = set(US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", ["QQQ"]))
|
||
symbols.update(OKX_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", ["BTC-USDT"]))
|
||
symbols = list(symbols)
|
||
symbols.sort()
|
||
# 每个symbol创建一个sheet并插图
|
||
for symbol in symbols:
|
||
logger.info(f"开始保存{symbol}的图表")
|
||
symbol_files = [f for f in chart_files if f.startswith(symbol)]
|
||
if len(symbol_files) == 0:
|
||
continue
|
||
# 排序以稳定显示顺序
|
||
symbol_files.sort()
|
||
copy_chart_to_excel_sheet(chart_path, symbol_files, excel_file_path, symbol)
|
||
|
||
|
||
def copy_chart_to_excel_sheet(
|
||
chart_path: str, chart_files: list, excel_file_path: str, symbol: str
|
||
):
|
||
"""
|
||
将chart图片复制到excel中
|
||
算法:
|
||
1. 读取chart_files
|
||
2. 创建一个Excel Sheet,Sheet名称为{symbol}_chart
|
||
3. 将chart_files中的图片插入到Sheet中
|
||
4. 要求每张图片大小为800x400
|
||
5. 要求两张图片左右并列显示, 如6张图片则图片行数为3,列数为2
|
||
6. 要求上下图片间距为20px
|
||
"""
|
||
# 打开已经存在的Excel文件
|
||
wb = openpyxl.load_workbook(excel_file_path)
|
||
# 如果sheet已存在,先删除,避免重复插入
|
||
sheet_name = f"{symbol}_chart"
|
||
if sheet_name in wb.sheetnames:
|
||
del wb[sheet_name]
|
||
ws = wb.create_sheet(title=sheet_name)
|
||
|
||
# 两列布局:左列A,右列L;行间距通过起始行步进控制
|
||
left_col = "A"
|
||
right_col = "L"
|
||
row_step = 26 # 行步进,控制上下间距
|
||
|
||
for idx, chart_file in enumerate(chart_files):
|
||
try:
|
||
img_path = os.path.join(chart_path, chart_file)
|
||
img = Image(img_path)
|
||
# 设置图片尺寸 800x400 像素
|
||
img.width = 800
|
||
img.height = 400
|
||
|
||
row_block = idx // 2
|
||
col_block = idx % 2
|
||
anchor_col = left_col if col_block == 0 else right_col
|
||
anchor_cell = f"{anchor_col}{1 + row_block * row_step}"
|
||
ws.add_image(img, anchor_cell)
|
||
except Exception:
|
||
continue
|
||
|
||
wb.save(excel_file_path)
|
||
logger.info(f"{symbol}的图表已输出到{excel_file_path}")
|
||
|
||
|
||
def test():
|
||
orb_strategy = ORBStrategy(
|
||
symbol="BTC-USDT",
|
||
bar="5m",
|
||
start_date="2024-01-01",
|
||
end_date="2025-09-02",
|
||
is_us_stock=False,
|
||
direction=None,
|
||
by_sar=True,
|
||
profit_target_multiple=10,
|
||
initial_capital=25000,
|
||
max_leverage=4,
|
||
risk_per_trade=0.01,
|
||
commission_per_share=0.0005,
|
||
)
|
||
orb_strategy.run()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# main()
|
||
|
||
chart_path = r"./output/trade_sandbox/orb_strategy/chart/"
|
||
excel_file_path = r"./output/trade_sandbox/orb_strategy/excel/summary/orb_strategy_summary_20250902174203.xlsx"
|
||
copy_chart_to_excel(chart_path, excel_file_path)
|
||
# test()
|