crypto_quant/trade_sandbox_main.py

311 lines
13 KiB
Python
Raw Normal View History

2025-08-20 03:33:13 +00:00
import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import PercentFormatter
2025-09-16 06:31:15 +00:00
from datetime import datetime, timezone, timedelta
from core.utils import get_current_date_time
2025-08-20 03:33:13 +00:00
import re
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
2025-08-31 03:20:59 +00:00
from config import OKX_MONITOR_CONFIG
2025-08-20 03:33:13 +00:00
from core.trade.mean_reversion_sandbox import MeanReversionSandbox
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MeanReversionSandboxMain:
2025-08-20 08:40:33 +00:00
def __init__(self, start_date: str, end_date: str, window_size: int, only_5m: bool = False, solution_list: list = None):
2025-08-31 03:20:59 +00:00
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
2025-08-20 03:33:13 +00:00
"symbols", ["XCH-USDT"]
)
2025-08-20 08:40:33 +00:00
self.only_5m = only_5m
if only_5m:
self.bars = ["5m"]
else:
2025-08-31 03:20:59 +00:00
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
2025-08-20 08:40:33 +00:00
"bars", ["5m", "15m", "30m", "1H"]
)
if solution_list is None:
self.solution_list = ["solution_1", "solution_2", "solution_3"]
else:
self.solution_list = solution_list
2025-08-20 03:33:13 +00:00
self.start_date = start_date
self.end_date = end_date
self.window_size = window_size
self.save_path = f"./output/trade_sandbox/mean_reversion/"
os.makedirs(self.save_path, exist_ok=True)
def batch_mean_reversion_sandbox(self):
"""
批量计算均值回归
"""
logger.info("开始批量计算均值回归交易策略")
logger.info(
f"开始时间: {self.start_date}, 结束时间: {self.end_date}, 窗口大小: {self.window_size}"
)
for solution in self.solution_list:
data_list = []
for symbol in self.symbols:
for bar in self.bars:
data = self.mean_reversion(symbol, bar, solution)
if data is not None and len(data) > 0:
data_list.append(data)
if len(data_list) == 0:
return None
total_data = pd.concat(data_list)
total_data.sort_values(by="buy_timestamp", ascending=True, inplace=True)
total_data.reset_index(drop=True, inplace=True)
stat_data = self.statistic_data(total_data)
excel_save_path = os.path.join(self.save_path, solution, "excel")
os.makedirs(excel_save_path, exist_ok=True)
2025-09-16 06:31:15 +00:00
date_time_str = get_current_date_time()
2025-08-20 03:33:13 +00:00
excel_file_path = os.path.join(
excel_save_path, f"{solution}_{date_time_str}.xlsx"
)
with pd.ExcelWriter(excel_file_path) as writer:
total_data.to_excel(writer, sheet_name="total_data", index=False)
stat_data.to_excel(writer, sheet_name="stat_data", index=False)
chart_dict = {}
self.draw_chart(stat_data, chart_dict)
self.output_chart_to_excel(excel_file_path, chart_dict)
def mean_reversion(self, symbol: str, bar: str, solution: str):
"""
均值回归交易策略
"""
mean_reversion_sandbox = MeanReversionSandbox(solution)
data = mean_reversion_sandbox.trade_sandbox(
symbol, bar, self.window_size, self.start_date, self.end_date
)
return data
def statistic_data(self, data: pd.DataFrame):
"""
统计数据
"""
data_list = []
# 以symbol, bar分组统计data的profit_pct>0的次数并且获得
# profit_pct的最大值最小值平均值profit_pct>0的平均值以及profit_pct<0的平均值
data_grouped = data.groupby(["symbol", "bar"])
for symbol, bar in data_grouped:
solution = bar["solution"].iloc[0]
# 止盈次数
take_profit_count = len(bar[bar["sell_type"] == "止盈"])
take_profit_ratio = round((take_profit_count / len(bar)) * 100, 4)
# 止损次数
stop_loss_count = len(bar[bar["sell_type"] == "止损"])
stop_loss_ratio = round((stop_loss_count / len(bar)) * 100, 4)
profit_pct_gt_0_count = len(bar[bar["profit_pct"] > 0])
profit_pct_gt_0_ratio = round((profit_pct_gt_0_count / len(bar)) * 100, 4)
profit_pct_lt_0_count = len(bar[bar["profit_pct"] < 0])
profit_pct_lt_0_ratio = round((profit_pct_lt_0_count / len(bar)) * 100, 4)
profit_pct_max = bar["profit_pct"].max()
profit_pct_min = bar["profit_pct"].min()
profit_pct_mean = bar["profit_pct"].mean()
2025-08-21 10:37:33 +00:00
profit_pct_sum = bar["profit_pct"].sum()
2025-08-20 03:33:13 +00:00
profit_pct_gt_0_mean = bar[bar["profit_pct"] > 0]["profit_pct"].mean()
profit_pct_lt_0_mean = bar[bar["profit_pct"] < 0]["profit_pct"].mean()
symbol_name = bar["symbol"].iloc[0]
bar_name = bar["bar"].iloc[0]
logger.info(
f"策略: {solution}, symbol: {symbol_name}, bar: {bar_name}, profit_pct>0的次数: {profit_pct_gt_0_count}, profit_pct<0的次数: {profit_pct_lt_0_count}, profit_pct最大值: {profit_pct_max}, profit_pct最小值: {profit_pct_min}, profit_pct平均值: {profit_pct_mean}, profit_pct>0的平均值: {profit_pct_gt_0_mean}, profit_pct<0的平均值: {profit_pct_lt_0_mean}"
)
data_list.append(
{
"solution": solution,
"symbol": symbol_name,
"bar": bar_name,
2025-08-21 10:37:33 +00:00
"profit_pct_sum": profit_pct_sum,
2025-08-20 03:33:13 +00:00
"take_profit_count": take_profit_count,
"take_profit_ratio": take_profit_ratio,
"stop_loss_count": stop_loss_count,
"stop_loss_ratio": stop_loss_ratio,
"profit_pct_gt_0_count": profit_pct_gt_0_count,
"profit_pct_gt_0_ratio": profit_pct_gt_0_ratio,
"profit_pct_lt_0_count": profit_pct_lt_0_count,
"profit_pct_lt_0_ratio": profit_pct_lt_0_ratio,
2025-08-21 10:37:33 +00:00
"profit_pct_mean": profit_pct_mean,
2025-08-20 03:33:13 +00:00
"profit_pct_max": profit_pct_max,
"profit_pct_min": profit_pct_min,
"profit_pct_gt_0_mean": profit_pct_gt_0_mean,
"profit_pct_lt_0_mean": profit_pct_lt_0_mean,
}
)
stat_data = pd.DataFrame(data_list)
stat_data.sort_values(by=["bar", "symbol"], inplace=True)
stat_data.reset_index(drop=True, inplace=True)
return stat_data
def draw_chart(self, stat_data: pd.DataFrame, chart_dict: dict):
"""
绘制图表
"""
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
plt.rcParams["savefig.dpi"] = 150
# 绘制各个solution的profit_pct_gt_0_ratio的柱状图
# bar为5m, 15, 30m, 1H共计四个分类
# 每一个bar为一张chart构成2x2的画布
# 要求y轴为百分比x轴为symbol
# 使用蓝色渐变色
# 每一个solution保存为一张chart图片保存到output/trade_sandbox/mean_reversion/chart/
solution = stat_data["solution"].iloc[0]
save_path = os.path.join(self.save_path, solution, "chart")
os.makedirs(save_path, exist_ok=True)
bars_in_order = [
b for b in getattr(self, "bars", []) if b in stat_data["bar"].unique()
]
if not bars_in_order:
bars_in_order = list(stat_data["bar"].unique())
palette_name = "Blues_d"
y_axis_fields = [
"take_profit_ratio",
"stop_loss_ratio",
2025-08-21 10:37:33 +00:00
"profit_pct_sum",
2025-08-20 03:33:13 +00:00
"profit_pct_mean",
"profit_pct_gt_0_mean",
"profit_pct_lt_0_mean",
]
sheet_name = f"{solution}_chart"
chart_dict[sheet_name] = {}
for y_axis_field in y_axis_fields:
2025-08-20 08:40:33 +00:00
if self.only_5m:
fig, axs = plt.subplots(1, 1, figsize=(10, 10))
# 当只有一个子图时将axs包装成数组以便统一处理
axs = np.array([[axs]])
else:
# 绘制2x2的画布
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
2025-08-20 03:33:13 +00:00
for j, bar in enumerate(bars_in_order):
ax = axs[j // 2, j % 2]
bar_data = stat_data[stat_data["bar"] == bar].copy()
bar_data.sort_values(by=y_axis_field, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
colors = sns.color_palette(palette_name, n_colors=len(bar_data))
sns.barplot(
x="symbol",
y=y_axis_field,
data=bar_data,
palette=colors,
ax=ax,
)
2025-08-20 08:40:33 +00:00
# 在柱子上方添加数值标签
for i, (idx, row) in enumerate(bar_data.iterrows()):
value = row[y_axis_field]
# 根据数值类型格式化标签
if "ratio" in y_axis_field:
label = f"{value:.2f}%"
else:
label = f"{value:.4f}"
# 在柱子上方显示数值
ax.text(i, value, label,
ha='center', va='bottom',
fontsize=9, fontweight='bold')
2025-08-20 03:33:13 +00:00
ax.set_ylabel(y_axis_field)
ax.set_xlabel("symbol")
ax.set_title(f"{solution} {bar}")
if "ratio" in y_axis_field:
ax.yaxis.set_major_formatter(PercentFormatter(100))
ax.set_ylim(0, 100)
for label in ax.get_xticklabels():
label.set_rotation(45)
label.set_horizontalalignment("right")
# 隐藏未使用的subplot
total_used = len(bars_in_order)
2025-08-20 08:40:33 +00:00
if not self.only_5m:
for k in range(total_used, 4):
ax = axs[k // 2, k % 2]
ax.axis("off")
2025-08-20 03:33:13 +00:00
fig.tight_layout()
file_name = f"{solution}_{y_axis_field}.png"
fig.savefig(os.path.join(save_path, file_name))
plt.close(fig)
chart_dict[sheet_name][y_axis_field] = os.path.join(save_path, file_name)
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典格式为
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_data_dict in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
for chart_name, chart_path in chart_data_dict.items():
# Load image to get dimensions
with PILImage.open(chart_path) as img:
width_px, height_px = img.size
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
pixels_per_point = 1.333
points_per_row = 15 # Default row height in points
pixels_per_row = (
points_per_row * pixels_per_point
) # ≈ 20 pixels per row
chart_rows = max(
10, int(height_px / pixels_per_row)
) # Minimum 10 rows for small charts
# Add chart title
# 支持中文标题
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
row_offset += 2 # Add 2 rows for title and spacing
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
# Update row offset (chart height + padding)
row_offset += (
chart_rows + 5
) # Add 5 rows for padding between charts
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")
if __name__ == "__main__":
start_date = "2025-05-15 00:00:00"
2025-09-16 06:31:15 +00:00
end_date = get_current_date_time()
2025-08-20 08:40:33 +00:00
solution_list = ["solution_3"]
2025-08-20 03:33:13 +00:00
mean_reversion_sandbox_main = MeanReversionSandboxMain(
2025-08-20 08:40:33 +00:00
start_date=start_date, end_date=end_date, window_size=100, only_5m=True, solution_list=solution_list
2025-08-20 03:33:13 +00:00
)
mean_reversion_sandbox_main.batch_mean_reversion_sandbox()