crypto_quant/trade_sandbox_main.py

307 lines
13 KiB
Python
Raw Normal View History

2025-08-20 03:33:13 +00:00
import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import PercentFormatter
from datetime import datetime
import re
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import MONITOR_CONFIG
from core.trade.mean_reversion_sandbox import MeanReversionSandbox
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MeanReversionSandboxMain:
2025-08-20 08:40:33 +00:00
def __init__(self, start_date: str, end_date: str, window_size: int, only_5m: bool = False, solution_list: list = None):
2025-08-20 03:33:13 +00:00
self.symbols = MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
2025-08-20 08:40:33 +00:00
self.only_5m = only_5m
if only_5m:
self.bars = ["5m"]
else:
self.bars = MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
if solution_list is None:
self.solution_list = ["solution_1", "solution_2", "solution_3"]
else:
self.solution_list = solution_list
2025-08-20 03:33:13 +00:00
self.start_date = start_date
self.end_date = end_date
self.window_size = window_size
self.save_path = f"./output/trade_sandbox/mean_reversion/"
os.makedirs(self.save_path, exist_ok=True)
def batch_mean_reversion_sandbox(self):
"""
批量计算均值回归
"""
logger.info("开始批量计算均值回归交易策略")
logger.info(
f"开始时间: {self.start_date}, 结束时间: {self.end_date}, 窗口大小: {self.window_size}"
)
for solution in self.solution_list:
data_list = []
for symbol in self.symbols:
for bar in self.bars:
data = self.mean_reversion(symbol, bar, solution)
if data is not None and len(data) > 0:
data_list.append(data)
if len(data_list) == 0:
return None
total_data = pd.concat(data_list)
total_data.sort_values(by="buy_timestamp", ascending=True, inplace=True)
total_data.reset_index(drop=True, inplace=True)
stat_data = self.statistic_data(total_data)
excel_save_path = os.path.join(self.save_path, solution, "excel")
os.makedirs(excel_save_path, exist_ok=True)
date_time_str = datetime.now().strftime("%Y%m%d%H%M%S")
excel_file_path = os.path.join(
excel_save_path, f"{solution}_{date_time_str}.xlsx"
)
with pd.ExcelWriter(excel_file_path) as writer:
total_data.to_excel(writer, sheet_name="total_data", index=False)
stat_data.to_excel(writer, sheet_name="stat_data", index=False)
chart_dict = {}
self.draw_chart(stat_data, chart_dict)
self.output_chart_to_excel(excel_file_path, chart_dict)
def mean_reversion(self, symbol: str, bar: str, solution: str):
"""
均值回归交易策略
"""
mean_reversion_sandbox = MeanReversionSandbox(solution)
data = mean_reversion_sandbox.trade_sandbox(
symbol, bar, self.window_size, self.start_date, self.end_date
)
return data
def statistic_data(self, data: pd.DataFrame):
"""
统计数据
"""
data_list = []
# 以symbol, bar分组统计data的profit_pct>0的次数并且获得
# profit_pct的最大值最小值平均值profit_pct>0的平均值以及profit_pct<0的平均值
data_grouped = data.groupby(["symbol", "bar"])
for symbol, bar in data_grouped:
solution = bar["solution"].iloc[0]
# 止盈次数
take_profit_count = len(bar[bar["sell_type"] == "止盈"])
take_profit_ratio = round((take_profit_count / len(bar)) * 100, 4)
# 止损次数
stop_loss_count = len(bar[bar["sell_type"] == "止损"])
stop_loss_ratio = round((stop_loss_count / len(bar)) * 100, 4)
profit_pct_gt_0_count = len(bar[bar["profit_pct"] > 0])
profit_pct_gt_0_ratio = round((profit_pct_gt_0_count / len(bar)) * 100, 4)
profit_pct_lt_0_count = len(bar[bar["profit_pct"] < 0])
profit_pct_lt_0_ratio = round((profit_pct_lt_0_count / len(bar)) * 100, 4)
profit_pct_max = bar["profit_pct"].max()
profit_pct_min = bar["profit_pct"].min()
profit_pct_mean = bar["profit_pct"].mean()
profit_pct_gt_0_mean = bar[bar["profit_pct"] > 0]["profit_pct"].mean()
profit_pct_lt_0_mean = bar[bar["profit_pct"] < 0]["profit_pct"].mean()
symbol_name = bar["symbol"].iloc[0]
bar_name = bar["bar"].iloc[0]
logger.info(
f"策略: {solution}, symbol: {symbol_name}, bar: {bar_name}, profit_pct>0的次数: {profit_pct_gt_0_count}, profit_pct<0的次数: {profit_pct_lt_0_count}, profit_pct最大值: {profit_pct_max}, profit_pct最小值: {profit_pct_min}, profit_pct平均值: {profit_pct_mean}, profit_pct>0的平均值: {profit_pct_gt_0_mean}, profit_pct<0的平均值: {profit_pct_lt_0_mean}"
)
data_list.append(
{
"solution": solution,
"symbol": symbol_name,
"bar": bar_name,
"take_profit_count": take_profit_count,
"take_profit_ratio": take_profit_ratio,
"stop_loss_count": stop_loss_count,
"stop_loss_ratio": stop_loss_ratio,
"profit_pct_gt_0_count": profit_pct_gt_0_count,
"profit_pct_gt_0_ratio": profit_pct_gt_0_ratio,
"profit_pct_lt_0_count": profit_pct_lt_0_count,
"profit_pct_lt_0_ratio": profit_pct_lt_0_ratio,
"profit_pct_max": profit_pct_max,
"profit_pct_min": profit_pct_min,
"profit_pct_mean": profit_pct_mean,
"profit_pct_gt_0_mean": profit_pct_gt_0_mean,
"profit_pct_lt_0_mean": profit_pct_lt_0_mean,
}
)
stat_data = pd.DataFrame(data_list)
stat_data.sort_values(by=["bar", "symbol"], inplace=True)
stat_data.reset_index(drop=True, inplace=True)
return stat_data
def draw_chart(self, stat_data: pd.DataFrame, chart_dict: dict):
"""
绘制图表
"""
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
plt.rcParams["savefig.dpi"] = 150
# 绘制各个solution的profit_pct_gt_0_ratio的柱状图
# bar为5m, 15, 30m, 1H共计四个分类
# 每一个bar为一张chart构成2x2的画布
# 要求y轴为百分比x轴为symbol
# 使用蓝色渐变色
# 每一个solution保存为一张chart图片保存到output/trade_sandbox/mean_reversion/chart/
solution = stat_data["solution"].iloc[0]
save_path = os.path.join(self.save_path, solution, "chart")
os.makedirs(save_path, exist_ok=True)
bars_in_order = [
b for b in getattr(self, "bars", []) if b in stat_data["bar"].unique()
]
if not bars_in_order:
bars_in_order = list(stat_data["bar"].unique())
palette_name = "Blues_d"
y_axis_fields = [
"take_profit_ratio",
"stop_loss_ratio",
"profit_pct_mean",
"profit_pct_gt_0_mean",
"profit_pct_lt_0_mean",
]
sheet_name = f"{solution}_chart"
chart_dict[sheet_name] = {}
for y_axis_field in y_axis_fields:
2025-08-20 08:40:33 +00:00
if self.only_5m:
fig, axs = plt.subplots(1, 1, figsize=(10, 10))
# 当只有一个子图时将axs包装成数组以便统一处理
axs = np.array([[axs]])
else:
# 绘制2x2的画布
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
2025-08-20 03:33:13 +00:00
for j, bar in enumerate(bars_in_order):
ax = axs[j // 2, j % 2]
bar_data = stat_data[stat_data["bar"] == bar].copy()
bar_data.sort_values(by=y_axis_field, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
colors = sns.color_palette(palette_name, n_colors=len(bar_data))
sns.barplot(
x="symbol",
y=y_axis_field,
data=bar_data,
palette=colors,
ax=ax,
)
2025-08-20 08:40:33 +00:00
# 在柱子上方添加数值标签
for i, (idx, row) in enumerate(bar_data.iterrows()):
value = row[y_axis_field]
# 根据数值类型格式化标签
if "ratio" in y_axis_field:
label = f"{value:.2f}%"
else:
label = f"{value:.4f}"
# 在柱子上方显示数值
ax.text(i, value, label,
ha='center', va='bottom',
fontsize=9, fontweight='bold')
2025-08-20 03:33:13 +00:00
ax.set_ylabel(y_axis_field)
ax.set_xlabel("symbol")
ax.set_title(f"{solution} {bar}")
if "ratio" in y_axis_field:
ax.yaxis.set_major_formatter(PercentFormatter(100))
ax.set_ylim(0, 100)
for label in ax.get_xticklabels():
label.set_rotation(45)
label.set_horizontalalignment("right")
# 隐藏未使用的subplot
total_used = len(bars_in_order)
2025-08-20 08:40:33 +00:00
if not self.only_5m:
for k in range(total_used, 4):
ax = axs[k // 2, k % 2]
ax.axis("off")
2025-08-20 03:33:13 +00:00
fig.tight_layout()
file_name = f"{solution}_{y_axis_field}.png"
fig.savefig(os.path.join(save_path, file_name))
plt.close(fig)
chart_dict[sheet_name][y_axis_field] = os.path.join(save_path, file_name)
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典格式为
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_data_dict in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
for chart_name, chart_path in chart_data_dict.items():
# Load image to get dimensions
with PILImage.open(chart_path) as img:
width_px, height_px = img.size
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
pixels_per_point = 1.333
points_per_row = 15 # Default row height in points
pixels_per_row = (
points_per_row * pixels_per_point
) # ≈ 20 pixels per row
chart_rows = max(
10, int(height_px / pixels_per_row)
) # Minimum 10 rows for small charts
# Add chart title
# 支持中文标题
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
row_offset += 2 # Add 2 rows for title and spacing
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
# Update row offset (chart height + padding)
row_offset += (
chart_rows + 5
) # Add 5 rows for padding between charts
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")
if __name__ == "__main__":
start_date = "2025-05-15 00:00:00"
end_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
2025-08-20 08:40:33 +00:00
solution_list = ["solution_3"]
2025-08-20 03:33:13 +00:00
mean_reversion_sandbox_main = MeanReversionSandboxMain(
2025-08-20 08:40:33 +00:00
start_date=start_date, end_date=end_date, window_size=100, only_5m=True, solution_list=solution_list
2025-08-20 03:33:13 +00:00
)
mean_reversion_sandbox_main.batch_mean_reversion_sandbox()