import core.logger as logging import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from matplotlib.ticker import PercentFormatter from datetime import datetime, timezone, timedelta from core.utils import get_current_date_time import re from openpyxl import Workbook from openpyxl.drawing.image import Image import openpyxl from openpyxl.styles import Font from PIL import Image as PILImage from config import OKX_MONITOR_CONFIG from core.trade.mean_reversion_sandbox import MeanReversionSandbox from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp # seaborn支持中文 plt.rcParams["font.family"] = ["SimHei"] logger = logging.logger class MeanReversionSandboxMain: def __init__(self, start_date: str, end_date: str, window_size: int, only_5m: bool = False, solution_list: list = None): self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "symbols", ["XCH-USDT"] ) self.only_5m = only_5m if only_5m: self.bars = ["5m"] else: self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "bars", ["5m", "15m", "30m", "1H"] ) if solution_list is None: self.solution_list = ["solution_1", "solution_2", "solution_3"] else: self.solution_list = solution_list self.start_date = start_date self.end_date = end_date self.window_size = window_size self.save_path = f"./output/trade_sandbox/mean_reversion/" os.makedirs(self.save_path, exist_ok=True) def batch_mean_reversion_sandbox(self): """ 批量计算均值回归 """ logger.info("开始批量计算均值回归交易策略") logger.info( f"开始时间: {self.start_date}, 结束时间: {self.end_date}, 窗口大小: {self.window_size}" ) for solution in self.solution_list: data_list = [] for symbol in self.symbols: for bar in self.bars: data = self.mean_reversion(symbol, bar, solution) if data is not None and len(data) > 0: data_list.append(data) if len(data_list) == 0: return None total_data = pd.concat(data_list) total_data.sort_values(by="buy_timestamp", ascending=True, inplace=True) total_data.reset_index(drop=True, inplace=True) stat_data = self.statistic_data(total_data) excel_save_path = os.path.join(self.save_path, solution, "excel") os.makedirs(excel_save_path, exist_ok=True) date_time_str = get_current_date_time() excel_file_path = os.path.join( excel_save_path, f"{solution}_{date_time_str}.xlsx" ) with pd.ExcelWriter(excel_file_path) as writer: total_data.to_excel(writer, sheet_name="total_data", index=False) stat_data.to_excel(writer, sheet_name="stat_data", index=False) chart_dict = {} self.draw_chart(stat_data, chart_dict) self.output_chart_to_excel(excel_file_path, chart_dict) def mean_reversion(self, symbol: str, bar: str, solution: str): """ 均值回归交易策略 """ mean_reversion_sandbox = MeanReversionSandbox(solution) data = mean_reversion_sandbox.trade_sandbox( symbol, bar, self.window_size, self.start_date, self.end_date ) return data def statistic_data(self, data: pd.DataFrame): """ 统计数据 """ data_list = [] # 以symbol, bar分组,统计data的profit_pct>0的次数,并且获得: # profit_pct的最大值,最小值,平均值,profit_pct>0的平均值,以及profit_pct<0的平均值 data_grouped = data.groupby(["symbol", "bar"]) for symbol, bar in data_grouped: solution = bar["solution"].iloc[0] # 止盈次数 take_profit_count = len(bar[bar["sell_type"] == "止盈"]) take_profit_ratio = round((take_profit_count / len(bar)) * 100, 4) # 止损次数 stop_loss_count = len(bar[bar["sell_type"] == "止损"]) stop_loss_ratio = round((stop_loss_count / len(bar)) * 100, 4) profit_pct_gt_0_count = len(bar[bar["profit_pct"] > 0]) profit_pct_gt_0_ratio = round((profit_pct_gt_0_count / len(bar)) * 100, 4) profit_pct_lt_0_count = len(bar[bar["profit_pct"] < 0]) profit_pct_lt_0_ratio = round((profit_pct_lt_0_count / len(bar)) * 100, 4) profit_pct_max = bar["profit_pct"].max() profit_pct_min = bar["profit_pct"].min() profit_pct_mean = bar["profit_pct"].mean() profit_pct_sum = bar["profit_pct"].sum() profit_pct_gt_0_mean = bar[bar["profit_pct"] > 0]["profit_pct"].mean() profit_pct_lt_0_mean = bar[bar["profit_pct"] < 0]["profit_pct"].mean() symbol_name = bar["symbol"].iloc[0] bar_name = bar["bar"].iloc[0] logger.info( f"策略: {solution}, symbol: {symbol_name}, bar: {bar_name}, profit_pct>0的次数: {profit_pct_gt_0_count}, profit_pct<0的次数: {profit_pct_lt_0_count}, profit_pct最大值: {profit_pct_max}, profit_pct最小值: {profit_pct_min}, profit_pct平均值: {profit_pct_mean}, profit_pct>0的平均值: {profit_pct_gt_0_mean}, profit_pct<0的平均值: {profit_pct_lt_0_mean}" ) data_list.append( { "solution": solution, "symbol": symbol_name, "bar": bar_name, "profit_pct_sum": profit_pct_sum, "take_profit_count": take_profit_count, "take_profit_ratio": take_profit_ratio, "stop_loss_count": stop_loss_count, "stop_loss_ratio": stop_loss_ratio, "profit_pct_gt_0_count": profit_pct_gt_0_count, "profit_pct_gt_0_ratio": profit_pct_gt_0_ratio, "profit_pct_lt_0_count": profit_pct_lt_0_count, "profit_pct_lt_0_ratio": profit_pct_lt_0_ratio, "profit_pct_mean": profit_pct_mean, "profit_pct_max": profit_pct_max, "profit_pct_min": profit_pct_min, "profit_pct_gt_0_mean": profit_pct_gt_0_mean, "profit_pct_lt_0_mean": profit_pct_lt_0_mean, } ) stat_data = pd.DataFrame(data_list) stat_data.sort_values(by=["bar", "symbol"], inplace=True) stat_data.reset_index(drop=True, inplace=True) return stat_data def draw_chart(self, stat_data: pd.DataFrame, chart_dict: dict): """ 绘制图表 """ sns.set_theme(style="whitegrid") plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名 plt.rcParams["font.size"] = 11 # 设置字体大小 plt.rcParams["axes.unicode_minus"] = False plt.rcParams["figure.dpi"] = 150 plt.rcParams["savefig.dpi"] = 150 # 绘制各个solution的profit_pct_gt_0_ratio的柱状图 # bar为5m, 15, 30m, 1H,共计四个分类, # 每一个bar为一张chart,构成2x2的画布 # 要求y轴为百分比,x轴为symbol # 使用蓝色渐变色 # 每一个solution保存为一张chart图片,保存到output/trade_sandbox/mean_reversion/chart/ solution = stat_data["solution"].iloc[0] save_path = os.path.join(self.save_path, solution, "chart") os.makedirs(save_path, exist_ok=True) bars_in_order = [ b for b in getattr(self, "bars", []) if b in stat_data["bar"].unique() ] if not bars_in_order: bars_in_order = list(stat_data["bar"].unique()) palette_name = "Blues_d" y_axis_fields = [ "take_profit_ratio", "stop_loss_ratio", "profit_pct_sum", "profit_pct_mean", "profit_pct_gt_0_mean", "profit_pct_lt_0_mean", ] sheet_name = f"{solution}_chart" chart_dict[sheet_name] = {} for y_axis_field in y_axis_fields: if self.only_5m: fig, axs = plt.subplots(1, 1, figsize=(10, 10)) # 当只有一个子图时,将axs包装成数组以便统一处理 axs = np.array([[axs]]) else: # 绘制2x2的画布 fig, axs = plt.subplots(2, 2, figsize=(10, 10)) for j, bar in enumerate(bars_in_order): ax = axs[j // 2, j % 2] bar_data = stat_data[stat_data["bar"] == bar].copy() bar_data.sort_values(by=y_axis_field, ascending=False, inplace=True) bar_data.reset_index(drop=True, inplace=True) colors = sns.color_palette(palette_name, n_colors=len(bar_data)) sns.barplot( x="symbol", y=y_axis_field, data=bar_data, palette=colors, ax=ax, ) # 在柱子上方添加数值标签 for i, (idx, row) in enumerate(bar_data.iterrows()): value = row[y_axis_field] # 根据数值类型格式化标签 if "ratio" in y_axis_field: label = f"{value:.2f}%" else: label = f"{value:.4f}" # 在柱子上方显示数值 ax.text(i, value, label, ha='center', va='bottom', fontsize=9, fontweight='bold') ax.set_ylabel(y_axis_field) ax.set_xlabel("symbol") ax.set_title(f"{solution} {bar}") if "ratio" in y_axis_field: ax.yaxis.set_major_formatter(PercentFormatter(100)) ax.set_ylim(0, 100) for label in ax.get_xticklabels(): label.set_rotation(45) label.set_horizontalalignment("right") # 隐藏未使用的subplot total_used = len(bars_in_order) if not self.only_5m: for k in range(total_used, 4): ax = axs[k // 2, k % 2] ax.axis("off") fig.tight_layout() file_name = f"{solution}_{y_axis_field}.png" fig.savefig(os.path.join(save_path, file_name)) plt.close(fig) chart_dict[sheet_name][y_axis_field] = os.path.join(save_path, file_name) def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict): """ 输出Excel文件,包含所有图表 charts_dict: 图表数据字典,格式为: { "sheet_name": { "chart_name": "chart_path" } } """ logger.info(f"将图表输出到{excel_file_path}") # 打开已经存在的Excel文件 wb = openpyxl.load_workbook(excel_file_path) for sheet_name, chart_data_dict in charts_dict.items(): try: ws = wb.create_sheet(title=sheet_name) row_offset = 1 for chart_name, chart_path in chart_data_dict.items(): # Load image to get dimensions with PILImage.open(chart_path) as img: width_px, height_px = img.size # Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels) pixels_per_point = 1.333 points_per_row = 15 # Default row height in points pixels_per_row = ( points_per_row * pixels_per_point ) # ≈ 20 pixels per row chart_rows = max( 10, int(height_px / pixels_per_row) ) # Minimum 10 rows for small charts # Add chart title # 支持中文标题 ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8") ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12) row_offset += 2 # Add 2 rows for title and spacing # Insert chart image img = Image(chart_path) ws.add_image(img, f"A{row_offset}") # Update row offset (chart height + padding) row_offset += ( chart_rows + 5 ) # Add 5 rows for padding between charts except Exception as e: logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}") continue # Save Excel file wb.save(excel_file_path) print(f"Chart saved as {excel_file_path}") if __name__ == "__main__": start_date = "2025-05-15 00:00:00" end_date = get_current_date_time() solution_list = ["solution_3"] mean_reversion_sandbox_main = MeanReversionSandboxMain( start_date=start_date, end_date=end_date, window_size=100, only_5m=True, solution_list=solution_list ) mean_reversion_sandbox_main.batch_mean_reversion_sandbox()