310 lines
13 KiB
Python
310 lines
13 KiB
Python
import core.logger as logging
|
||
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from matplotlib.ticker import PercentFormatter
|
||
from datetime import datetime
|
||
import re
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
import openpyxl
|
||
from openpyxl.styles import Font
|
||
from PIL import Image as PILImage
|
||
from config import OKX_MONITOR_CONFIG
|
||
from core.trade.mean_reversion_sandbox import MeanReversionSandbox
|
||
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
|
||
|
||
# seaborn支持中文
|
||
plt.rcParams["font.family"] = ["SimHei"]
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
class MeanReversionSandboxMain:
|
||
def __init__(self, start_date: str, end_date: str, window_size: int, only_5m: bool = False, solution_list: list = None):
|
||
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["XCH-USDT"]
|
||
)
|
||
self.only_5m = only_5m
|
||
if only_5m:
|
||
self.bars = ["5m"]
|
||
else:
|
||
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"bars", ["5m", "15m", "30m", "1H"]
|
||
)
|
||
if solution_list is None:
|
||
self.solution_list = ["solution_1", "solution_2", "solution_3"]
|
||
else:
|
||
self.solution_list = solution_list
|
||
self.start_date = start_date
|
||
self.end_date = end_date
|
||
self.window_size = window_size
|
||
self.save_path = f"./output/trade_sandbox/mean_reversion/"
|
||
os.makedirs(self.save_path, exist_ok=True)
|
||
|
||
def batch_mean_reversion_sandbox(self):
|
||
"""
|
||
批量计算均值回归
|
||
"""
|
||
logger.info("开始批量计算均值回归交易策略")
|
||
logger.info(
|
||
f"开始时间: {self.start_date}, 结束时间: {self.end_date}, 窗口大小: {self.window_size}"
|
||
)
|
||
|
||
for solution in self.solution_list:
|
||
data_list = []
|
||
for symbol in self.symbols:
|
||
for bar in self.bars:
|
||
data = self.mean_reversion(symbol, bar, solution)
|
||
if data is not None and len(data) > 0:
|
||
data_list.append(data)
|
||
if len(data_list) == 0:
|
||
return None
|
||
total_data = pd.concat(data_list)
|
||
total_data.sort_values(by="buy_timestamp", ascending=True, inplace=True)
|
||
total_data.reset_index(drop=True, inplace=True)
|
||
stat_data = self.statistic_data(total_data)
|
||
excel_save_path = os.path.join(self.save_path, solution, "excel")
|
||
os.makedirs(excel_save_path, exist_ok=True)
|
||
date_time_str = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
excel_file_path = os.path.join(
|
||
excel_save_path, f"{solution}_{date_time_str}.xlsx"
|
||
)
|
||
with pd.ExcelWriter(excel_file_path) as writer:
|
||
total_data.to_excel(writer, sheet_name="total_data", index=False)
|
||
stat_data.to_excel(writer, sheet_name="stat_data", index=False)
|
||
chart_dict = {}
|
||
self.draw_chart(stat_data, chart_dict)
|
||
self.output_chart_to_excel(excel_file_path, chart_dict)
|
||
|
||
def mean_reversion(self, symbol: str, bar: str, solution: str):
|
||
"""
|
||
均值回归交易策略
|
||
"""
|
||
mean_reversion_sandbox = MeanReversionSandbox(solution)
|
||
data = mean_reversion_sandbox.trade_sandbox(
|
||
symbol, bar, self.window_size, self.start_date, self.end_date
|
||
)
|
||
return data
|
||
|
||
def statistic_data(self, data: pd.DataFrame):
|
||
"""
|
||
统计数据
|
||
"""
|
||
data_list = []
|
||
# 以symbol, bar分组,统计data的profit_pct>0的次数,并且获得:
|
||
# profit_pct的最大值,最小值,平均值,profit_pct>0的平均值,以及profit_pct<0的平均值
|
||
data_grouped = data.groupby(["symbol", "bar"])
|
||
for symbol, bar in data_grouped:
|
||
solution = bar["solution"].iloc[0]
|
||
# 止盈次数
|
||
take_profit_count = len(bar[bar["sell_type"] == "止盈"])
|
||
take_profit_ratio = round((take_profit_count / len(bar)) * 100, 4)
|
||
# 止损次数
|
||
stop_loss_count = len(bar[bar["sell_type"] == "止损"])
|
||
stop_loss_ratio = round((stop_loss_count / len(bar)) * 100, 4)
|
||
profit_pct_gt_0_count = len(bar[bar["profit_pct"] > 0])
|
||
profit_pct_gt_0_ratio = round((profit_pct_gt_0_count / len(bar)) * 100, 4)
|
||
profit_pct_lt_0_count = len(bar[bar["profit_pct"] < 0])
|
||
profit_pct_lt_0_ratio = round((profit_pct_lt_0_count / len(bar)) * 100, 4)
|
||
profit_pct_max = bar["profit_pct"].max()
|
||
profit_pct_min = bar["profit_pct"].min()
|
||
profit_pct_mean = bar["profit_pct"].mean()
|
||
profit_pct_sum = bar["profit_pct"].sum()
|
||
profit_pct_gt_0_mean = bar[bar["profit_pct"] > 0]["profit_pct"].mean()
|
||
profit_pct_lt_0_mean = bar[bar["profit_pct"] < 0]["profit_pct"].mean()
|
||
|
||
symbol_name = bar["symbol"].iloc[0]
|
||
bar_name = bar["bar"].iloc[0]
|
||
logger.info(
|
||
f"策略: {solution}, symbol: {symbol_name}, bar: {bar_name}, profit_pct>0的次数: {profit_pct_gt_0_count}, profit_pct<0的次数: {profit_pct_lt_0_count}, profit_pct最大值: {profit_pct_max}, profit_pct最小值: {profit_pct_min}, profit_pct平均值: {profit_pct_mean}, profit_pct>0的平均值: {profit_pct_gt_0_mean}, profit_pct<0的平均值: {profit_pct_lt_0_mean}"
|
||
)
|
||
data_list.append(
|
||
{
|
||
"solution": solution,
|
||
"symbol": symbol_name,
|
||
"bar": bar_name,
|
||
"profit_pct_sum": profit_pct_sum,
|
||
"take_profit_count": take_profit_count,
|
||
"take_profit_ratio": take_profit_ratio,
|
||
"stop_loss_count": stop_loss_count,
|
||
"stop_loss_ratio": stop_loss_ratio,
|
||
"profit_pct_gt_0_count": profit_pct_gt_0_count,
|
||
"profit_pct_gt_0_ratio": profit_pct_gt_0_ratio,
|
||
"profit_pct_lt_0_count": profit_pct_lt_0_count,
|
||
"profit_pct_lt_0_ratio": profit_pct_lt_0_ratio,
|
||
"profit_pct_mean": profit_pct_mean,
|
||
"profit_pct_max": profit_pct_max,
|
||
"profit_pct_min": profit_pct_min,
|
||
"profit_pct_gt_0_mean": profit_pct_gt_0_mean,
|
||
"profit_pct_lt_0_mean": profit_pct_lt_0_mean,
|
||
}
|
||
)
|
||
stat_data = pd.DataFrame(data_list)
|
||
stat_data.sort_values(by=["bar", "symbol"], inplace=True)
|
||
stat_data.reset_index(drop=True, inplace=True)
|
||
return stat_data
|
||
|
||
def draw_chart(self, stat_data: pd.DataFrame, chart_dict: dict):
|
||
"""
|
||
绘制图表
|
||
"""
|
||
sns.set_theme(style="whitegrid")
|
||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
|
||
plt.rcParams["font.size"] = 11 # 设置字体大小
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
plt.rcParams["figure.dpi"] = 150
|
||
plt.rcParams["savefig.dpi"] = 150
|
||
# 绘制各个solution的profit_pct_gt_0_ratio的柱状图
|
||
# bar为5m, 15, 30m, 1H,共计四个分类,
|
||
# 每一个bar为一张chart,构成2x2的画布
|
||
# 要求y轴为百分比,x轴为symbol
|
||
# 使用蓝色渐变色
|
||
# 每一个solution保存为一张chart图片,保存到output/trade_sandbox/mean_reversion/chart/
|
||
|
||
solution = stat_data["solution"].iloc[0]
|
||
save_path = os.path.join(self.save_path, solution, "chart")
|
||
os.makedirs(save_path, exist_ok=True)
|
||
|
||
bars_in_order = [
|
||
b for b in getattr(self, "bars", []) if b in stat_data["bar"].unique()
|
||
]
|
||
if not bars_in_order:
|
||
bars_in_order = list(stat_data["bar"].unique())
|
||
palette_name = "Blues_d"
|
||
|
||
y_axis_fields = [
|
||
"take_profit_ratio",
|
||
"stop_loss_ratio",
|
||
"profit_pct_sum",
|
||
"profit_pct_mean",
|
||
"profit_pct_gt_0_mean",
|
||
"profit_pct_lt_0_mean",
|
||
]
|
||
sheet_name = f"{solution}_chart"
|
||
chart_dict[sheet_name] = {}
|
||
for y_axis_field in y_axis_fields:
|
||
if self.only_5m:
|
||
fig, axs = plt.subplots(1, 1, figsize=(10, 10))
|
||
# 当只有一个子图时,将axs包装成数组以便统一处理
|
||
axs = np.array([[axs]])
|
||
else:
|
||
# 绘制2x2的画布
|
||
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
|
||
for j, bar in enumerate(bars_in_order):
|
||
ax = axs[j // 2, j % 2]
|
||
bar_data = stat_data[stat_data["bar"] == bar].copy()
|
||
bar_data.sort_values(by=y_axis_field, ascending=False, inplace=True)
|
||
bar_data.reset_index(drop=True, inplace=True)
|
||
colors = sns.color_palette(palette_name, n_colors=len(bar_data))
|
||
sns.barplot(
|
||
x="symbol",
|
||
y=y_axis_field,
|
||
data=bar_data,
|
||
palette=colors,
|
||
ax=ax,
|
||
)
|
||
|
||
# 在柱子上方添加数值标签
|
||
for i, (idx, row) in enumerate(bar_data.iterrows()):
|
||
value = row[y_axis_field]
|
||
# 根据数值类型格式化标签
|
||
if "ratio" in y_axis_field:
|
||
label = f"{value:.2f}%"
|
||
else:
|
||
label = f"{value:.4f}"
|
||
|
||
# 在柱子上方显示数值
|
||
ax.text(i, value, label,
|
||
ha='center', va='bottom',
|
||
fontsize=9, fontweight='bold')
|
||
|
||
ax.set_ylabel(y_axis_field)
|
||
ax.set_xlabel("symbol")
|
||
ax.set_title(f"{solution} {bar}")
|
||
if "ratio" in y_axis_field:
|
||
ax.yaxis.set_major_formatter(PercentFormatter(100))
|
||
ax.set_ylim(0, 100)
|
||
|
||
for label in ax.get_xticklabels():
|
||
label.set_rotation(45)
|
||
label.set_horizontalalignment("right")
|
||
# 隐藏未使用的subplot
|
||
total_used = len(bars_in_order)
|
||
if not self.only_5m:
|
||
for k in range(total_used, 4):
|
||
ax = axs[k // 2, k % 2]
|
||
ax.axis("off")
|
||
fig.tight_layout()
|
||
file_name = f"{solution}_{y_axis_field}.png"
|
||
fig.savefig(os.path.join(save_path, file_name))
|
||
plt.close(fig)
|
||
chart_dict[sheet_name][y_axis_field] = os.path.join(save_path, file_name)
|
||
|
||
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
|
||
"""
|
||
输出Excel文件,包含所有图表
|
||
charts_dict: 图表数据字典,格式为:
|
||
{
|
||
"sheet_name": {
|
||
"chart_name": "chart_path"
|
||
}
|
||
}
|
||
"""
|
||
logger.info(f"将图表输出到{excel_file_path}")
|
||
|
||
# 打开已经存在的Excel文件
|
||
wb = openpyxl.load_workbook(excel_file_path)
|
||
|
||
for sheet_name, chart_data_dict in charts_dict.items():
|
||
try:
|
||
ws = wb.create_sheet(title=sheet_name)
|
||
row_offset = 1
|
||
for chart_name, chart_path in chart_data_dict.items():
|
||
# Load image to get dimensions
|
||
with PILImage.open(chart_path) as img:
|
||
width_px, height_px = img.size
|
||
|
||
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
|
||
pixels_per_point = 1.333
|
||
points_per_row = 15 # Default row height in points
|
||
pixels_per_row = (
|
||
points_per_row * pixels_per_point
|
||
) # ≈ 20 pixels per row
|
||
chart_rows = max(
|
||
10, int(height_px / pixels_per_row)
|
||
) # Minimum 10 rows for small charts
|
||
|
||
# Add chart title
|
||
# 支持中文标题
|
||
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
|
||
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
|
||
row_offset += 2 # Add 2 rows for title and spacing
|
||
|
||
# Insert chart image
|
||
img = Image(chart_path)
|
||
ws.add_image(img, f"A{row_offset}")
|
||
|
||
# Update row offset (chart height + padding)
|
||
row_offset += (
|
||
chart_rows + 5
|
||
) # Add 5 rows for padding between charts
|
||
except Exception as e:
|
||
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
|
||
continue
|
||
# Save Excel file
|
||
wb.save(excel_file_path)
|
||
print(f"Chart saved as {excel_file_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
start_date = "2025-05-15 00:00:00"
|
||
end_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
solution_list = ["solution_3"]
|
||
mean_reversion_sandbox_main = MeanReversionSandboxMain(
|
||
start_date=start_date, end_date=end_date, window_size=100, only_5m=True, solution_list=solution_list
|
||
)
|
||
mean_reversion_sandbox_main.batch_mean_reversion_sandbox()
|