begin to research trade strategy

This commit is contained in:
blade 2025-08-20 11:33:13 +08:00
parent d4291c06a3
commit 1e204839a9
6 changed files with 822 additions and 12 deletions

View File

@ -0,0 +1,130 @@
import pandas as pd
from core.db.db_market_data import DBMarketData
from core.db.db_huge_volume_data import DBHugeVolumeData
class DBMergeMarketHugeVolume:
def __init__(self, db_url: str):
self.db_url = db_url
self.db_market_data = DBMarketData(self.db_url)
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
def merge_market_huge_volume(
self, symbol: str, bar: str, window_size: int, start: str, end: str
):
market_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, start, end
)
huge_volume_data = (
self.db_huge_volume_data.query_huge_volume_data_by_symbol_bar_window_size(
symbol, bar, window_size, start, end
)
)
if market_data is None or huge_volume_data is None:
return None
market_data = pd.DataFrame(market_data)
huge_volume_data = pd.DataFrame(huge_volume_data)
market_data = market_data.merge(
huge_volume_data, on=["symbol", "bar", "timestamp"], how="left"
)
# drop id_x, date_time_x, open_x, high_x, low_x, close_x,
# volume_x, volCcy_x, volCCyQuote_x, buy_sz, sell_sz, create_time_x
market_data.drop(
columns=[
"id_x",
"date_time_x",
"open_x",
"high_x",
"low_x",
"close_x",
"volume_x",
"volCcy_x",
"volCCyQuote_x",
"buy_sz",
"sell_sz",
"create_time_x",
],
inplace=True,
)
market_data.rename(
columns={
"id_y": "id",
"date_time_y": "date_time",
"open_y": "open",
"high_y": "high",
"low_y": "low",
"close_y": "close",
"volume_y": "volume",
"volCcy_y": "volCcy",
"volCCyQuote_y": "volCCyQuote",
"create_time_y": "create_time",
},
inplace=True,
)
# keep below columns: id, symbol, bar, timestamp, date_time, window_size, open, high, low, close, pct_chg, volume, volCcy, volCCyQuote, volume_ma, huge_volume, volume_ratio,
# macd, macd_signal, macd_divergence, kdj_k, kdj_d, kdj_j, kdj_signal, kdj_pattern, ma5, ma10, ma20, ma30, ma_cross, ma_long_short, ma_divergence, rsi_14, rsi_signal,
# boll_upper, boll_middle, boll_lower, boll_signal, boll_pattern, k_length, k_shape, k_up_down,
# close_80_high, close_90_high, close_20_low, close_10_low,
# high_80_high, high_90_high, low_20_low, low_10_low
# create_time
market_data = market_data[
[
"id",
"symbol",
"bar",
"timestamp",
"date_time",
"window_size",
"open",
"high",
"low",
"close",
"pct_chg",
"volume",
"volCcy",
"volCCyQuote",
"volume_ma",
"huge_volume",
"volume_ratio",
"macd",
"macd_signal",
"macd_divergence",
"kdj_k",
"kdj_d",
"kdj_j",
"kdj_signal",
"kdj_pattern",
"ma5",
"ma10",
"ma20",
"ma30",
"ma_cross",
"ma_long_short",
"ma_divergence",
"rsi_14",
"rsi_signal",
"boll_upper",
"boll_middle",
"boll_lower",
"boll_signal",
"boll_pattern",
"k_length",
"k_shape",
"k_up_down",
"close_80_high",
"close_90_high",
"close_20_low",
"close_10_low",
"high_80_high",
"high_90_high",
"low_20_low",
"low_10_low",
"create_time",
]
]
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
market_data.reset_index(drop=True, inplace=True)
return market_data

View File

@ -121,12 +121,25 @@ class MaBreakStatistics:
(market_data["ma20"].notna()) & (market_data["ma20"].notna()) &
(market_data["ma30"].notna())] (market_data["ma30"].notna())]
logger.info(f"ma5, ma10, ma20, ma30不为空的行数据条数: {len(market_data)}") logger.info(f"ma5, ma10, ma20, ma30不为空的行数据条数: {len(market_data)}")
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行 # 计算volume_ma5
long_market_data = market_data[(market_data["ma_cross"] == "5上穿10") & (market_data["ma5"] > market_data["ma10"]) & market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
(market_data["ma10"] > market_data["ma20"]) & # 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
(market_data["ma20"] > market_data["ma30"]) & market_data["volume_pct_chg"] = (market_data["volume"] - market_data["volume_ma5"]) / market_data["volume_ma5"]
(market_data["close"] > market_data["ma20"])] market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
logger.info(f"5上穿10, 且ma5 > ma10 > ma20 > ma30并且close > ma20的行数据条数: {len(long_market_data)}")
if all_change:
long_market_data = market_data[(market_data["ma_cross"] == "5上穿10") & (market_data["ma5"] > market_data["ma10"]) &
(market_data["ma10"] > market_data["ma20"]) &
(market_data["ma20"] > market_data["ma30"]) &
(market_data["close"] > market_data["ma20"]) &
(market_data["volume_pct_chg"] > 0.2)]
logger.info(f"5上穿10, 且ma5 > ma10 > ma20 > ma30并且close > ma20并且成交量较前5日均量放大20%以上的行,数据条数: {len(long_market_data)}")
else:
long_market_data = market_data[(market_data["ma_cross"] == "5上穿10") & (market_data["ma5"] > market_data["ma10"]) &
(market_data["volume_pct_chg"] > 0.2)]
logger.info(f"5上穿10, 且ma5 > ma10并且成交量较前5日均量放大20%以上的行,数据条数: {len(long_market_data)}")
if len(long_market_data) == 0:
return None
if all_change: if all_change:
# 获得ma5 < ma10 < ma20 < ma30的行 # 获得ma5 < ma10 < ma20 < ma30的行
short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) & short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) &
@ -134,10 +147,10 @@ class MaBreakStatistics:
(market_data["ma20"] < market_data["ma30"])] (market_data["ma20"] < market_data["ma30"])]
logger.info(f"ma5 < ma10 < ma20 < ma30的行数据条数: {len(short_market_data)}") logger.info(f"ma5 < ma10 < ma20 < ma30的行数据条数: {len(short_market_data)}")
else: else:
# ma5 < ma10 and close < ma20 # ma5 < ma10 or close < ma20
short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) & short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) |
(market_data["close"] < market_data["ma20"])] (market_data["close"] < market_data["ma20"])]
logger.info(f"ma5 < ma10 and close < ma20的行数据条数: {len(short_market_data)}") logger.info(f"ma5 < ma10 or close < ma20的行数据条数: {len(short_market_data)}")
# concat long_market_data和short_market_data # concat long_market_data和short_market_data
ma_break_market_data = pd.concat([long_market_data, short_market_data]) ma_break_market_data = pd.concat([long_market_data, short_market_data])
# 按照timestamp排序 # 按照timestamp排序
@ -156,7 +169,12 @@ class MaBreakStatistics:
ma30 = row["ma30"] ma30 = row["ma30"]
if pd.notna(ma_cross) and ma_cross is not None: if pd.notna(ma_cross) and ma_cross is not None:
ma_cross = str(ma_cross) ma_cross = str(ma_cross)
if ma_cross == "5上穿10" and (ma5 > ma10 and ma10 > ma20 and ma20 > ma30) and (close > ma20): buy_condition = False
if all_change:
buy_condition = (ma5 > ma10 and ma10 > ma20 and ma20 > ma30) and (close > ma20)
else:
buy_condition = ma_cross == "5上穿10" and (ma5 > ma10)
if buy_condition:
ma_break_market_data_pair = {} ma_break_market_data_pair = {}
ma_break_market_data_pair["symbol"] = symbol ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["bar"] = bar ma_break_market_data_pair["bar"] = bar
@ -172,7 +190,7 @@ class MaBreakStatistics:
change_condition = (ma5 < ma10 and ma10 < ma20 and ma20 < ma30) change_condition = (ma5 < ma10 and ma10 < ma20 and ma20 < ma30)
else: else:
# change_condition = (ma5 < ma10 or ma10 < ma20 or ma20 < ma30) # change_condition = (ma5 < ma10 or ma10 < ma20 or ma20 < ma30)
change_condition = (ma5 < ma10) and (close < ma20) change_condition = (ma5 < ma10) or (close < ma20)
if change_condition: if change_condition:
if ma_break_market_data_pair.get("begin_timestamp", None) is None: if ma_break_market_data_pair.get("begin_timestamp", None) is None:
@ -232,7 +250,10 @@ class MaBreakStatistics:
plt.xticks(rotation=45, ha="right") plt.xticks(rotation=45, ha="right")
plt.tight_layout() plt.tight_layout()
save_path = os.path.join(self.stats_chart_dir, f"{bar}_ma_break_pct_chg_mean.png") if all_change:
save_path = os.path.join(self.stats_chart_dir, f"{bar}_ma_break_pct_chg_mean_all_change.png")
else:
save_path = os.path.join(self.stats_chart_dir, f"{bar}_ma_break_pct_chg_mean_part_change.png")
plt.savefig(save_path, dpi=150) plt.savefig(save_path, dpi=150)
plt.close() plt.close()

View File

@ -0,0 +1,381 @@
import json
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import re
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
import core.logger as logging
from core.db.db_merge_market_huge_volume import DBMergeMarketHugeVolume
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MeanReversionSandbox:
def __init__(self, solution: str):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_merge_market_huge_volume = DBMergeMarketHugeVolume(self.db_url)
self.peak_valley_data = self.get_peak_valley_data()
self.solution = solution
self.save_path = f"./output/trade_sandbox/mean_reversion/{self.solution}/"
os.makedirs(self.save_path, exist_ok=True)
self.strategy_description = self.get_startegy_description()
def get_startegy_description(self):
desc_dict = {
"买入": [
"1. 窗口周期为100, 即100个K线",
"2. 当前low_10_low为1, 即当前最低价格在窗口周期的10分位以下",
"3. 之前2个K线与当前K线, 存在任意一个K线huge_volume为1, 即存在一个K线是巨量",
"4. 当前K线为阳线, 即close > open",
],
"止损": ["跌幅超过下跌周期跌幅中位数, 即down_median后卖出"],
"止盈": {
"solution_1": [
"高位放量止盈 - 简易版",
"1. 当前high_80_high为1或者high_90_high为1",
"2. 之前2个K线与当前K线, 存在任意一个K线huge_volume为1, 即存在一个K线是巨量",
],
"solution_2": [
"高位放量止盈 - 复杂版",
"前提条件"
"1. 当前high_80_high为1或者high_90_high为1",
"2. 之前2个K线与当前K线, 存在任意一个K线huge_volume为1, 即存在一个K线是巨量",
"以下两个条件, 任一满足即可",
"1. K线为阴线, 即close < open",
"2. K线为阳线, 即close >= open, 且k_shape满足:",
"一字, 长吊锤线, 吊锤线, 长倒T线, 倒T线, 长十字星, 十字星, 长上影线纺锤体, 长下影线纺锤体",
],
"solution_3": [
"上涨波段盈利中位数止盈法",
"1. 超过波段中位数涨幅, 即up_median后, 记录当前价格, 继续持仓",
"2. 之后一个周期, 如果价格上涨, 则记录该价格继续持仓",
"3. 之后一个周期, 如果价格跌到记录价格之下, 则卖出",
],
},
}
buy_list = desc_dict.get("买入", [])
stop_loss_list = desc_dict.get("止损", [])
take_profit_list = desc_dict.get("止盈", {}).get(self.solution, [])
if len(take_profit_list) == 0:
self.solution = "solution_1"
take_profit_list = desc_dict.get("止盈", {}).get(self.solution, [])
desc = f"策略名称: {self.solution}\n\n"
buy_desc = "\n".join(buy_list)
stop_loss_desc = "\n".join(stop_loss_list)
take_profit_desc = "\n".join(take_profit_list)
desc += f"买入策略\n {buy_desc}\n\n"
desc += f"止损策略\n {stop_loss_desc}\n\n"
desc += f"止盈策略\n {take_profit_desc}\n\n"
with open(f"{self.save_path}/策略描述.txt", "w", encoding="utf-8") as f:
f.write(desc)
return desc
def get_peak_valley_data(self):
os.makedirs("./json/", exist_ok=True)
json_file_path = "./json/peak_valley_data.json"
if not os.path.exists(json_file_path):
excel_file_path = "./output/statistics/excel/price_volume_stats_window_size_100_from_20250515000000_to_20250819110500.xlsx"
if not os.path.exists(excel_file_path):
raise FileNotFoundError(f"Excel file not found: {excel_file_path}")
sheet_name = "波峰波谷统计"
df = pd.read_excel(excel_file_path, sheet_name=sheet_name)
if df is None or len(df) == 0:
raise ValueError("Excel file is empty")
data_list = []
for index, row in df.iterrows():
data_list.append(
{
"symbol": row["symbol"],
"bar": row["bar"],
"up_mean": row["up_mean"],
"up_median": row["up_median"],
"down_mean": row["down_mean"],
"down_median": row["down_median"],
}
)
with open(json_file_path, "w", encoding="utf-8") as f:
json.dump(data_list, f, ensure_ascii=False, indent=4)
peak_valley_data = pd.DataFrame(data_list)
else:
with open(json_file_path, "r", encoding="utf-8") as f:
peak_valley_data = json.load(f)
return pd.DataFrame(peak_valley_data)
def trade_sandbox(
self, symbol: str, bar: str, window_size: int, start: str, end: str
):
logger.info(f"策略描述: {self.strategy_description}")
logger.info(
f"开始获取{symbol} {bar}{window_size}分钟窗口大小的数据, 开始时间: {start}, 结束时间: {end}"
)
market_data = self.db_merge_market_huge_volume.merge_market_huge_volume(
symbol, bar, window_size, start, end
)
if market_data is None:
return None
logger.info(f"数据条数: {len(market_data)}")
trade_list = []
trade_pair_dict = {}
for index, row in market_data.iterrows():
# check buy condition
if trade_pair_dict.get("buy_timestamp", None) is None:
buy_condition = self.check_buy_condition(market_data, row, index)
else:
buy_condition = False
if buy_condition:
trade_pair_dict = {}
trade_pair_dict["solution"] = self.solution
trade_pair_dict["symbol"] = symbol
trade_pair_dict["bar"] = bar
trade_pair_dict["window_size"] = window_size
trade_pair_dict["buy_timestamp"] = row["timestamp"]
trade_pair_dict["buy_date_time"] = timestamp_to_datetime(
row["timestamp"]
)
trade_pair_dict["buy_close"] = row["close"]
trade_pair_dict["buy_pct_chg"] = row["pct_chg"]
trade_pair_dict["buy_volume"] = row["volume"]
trade_pair_dict["buy_huge_volume"] = row["huge_volume"]
trade_pair_dict["buy_volume_ratio"] = row["volume_ratio"]
trade_pair_dict["buy_k_shape"] = row["k_shape"]
trade_pair_dict["buy_low_10_low"] = row["low_10_low"]
continue
if trade_pair_dict.get("buy_timestamp", None) is not None:
sell_condition = False
# check stop loss condition
sell_condition = self.check_stop_loss_condition(trade_pair_dict, row)
if sell_condition:
trade_pair_dict["sell_type"] = "止损"
else:
# check take profit condition
sell_condition = self.check_take_profit_condition(
trade_pair_dict, market_data, row, index
)
if sell_condition:
trade_pair_dict["sell_type"] = "止盈"
if sell_condition:
trade_pair_dict["sell_timestamp"] = row["timestamp"]
trade_pair_dict["sell_date_time"] = timestamp_to_datetime(
row["timestamp"]
)
trade_pair_dict["sell_close"] = row["close"]
trade_pair_dict["sell_pct_chg"] = row["pct_chg"]
trade_pair_dict["sell_volume"] = row["volume"]
trade_pair_dict["sell_huge_volume"] = row["huge_volume"]
trade_pair_dict["sell_volume_ratio"] = row["volume_ratio"]
trade_pair_dict["sell_k_shape"] = row["k_shape"]
trade_pair_dict["sell_high_80_high"] = row["high_80_high"]
trade_pair_dict["sell_high_90_high"] = row["high_90_high"]
trade_pair_dict["sell_low_10_low"] = row["low_10_low"]
trade_pair_dict["sell_low_20_low"] = row["low_20_low"]
trade_pair_dict["profit_pct"] = round(
(trade_pair_dict["sell_close"] - trade_pair_dict["buy_close"])
/ trade_pair_dict["buy_close"]
* 100,
4,
)
if trade_pair_dict["sell_type"] == "止盈" and trade_pair_dict["profit_pct"] < 0:
trade_pair_dict["sell_type"] = "止损"
if trade_pair_dict.get("last_max_close", None) is not None:
# remove last_max_close
trade_pair_dict.pop("last_max_close")
trade_list.append(trade_pair_dict)
trade_pair_dict = {}
if len(trade_list) == 0:
return None
trade_data = pd.DataFrame(trade_list)
trade_data.sort_values(by="buy_timestamp", inplace=True)
trade_data.reset_index(drop=True, inplace=True)
return trade_data
def check_buy_condition(
self, market_data: pd.DataFrame, row: pd.Series, index: int
):
"""
买入条件
1. 窗口周期为100, 即100个K线
2. 当前low_10_low为1, 即当前最低价格在窗口周期的10分位以下
3. 之前2个K线与当前K线, 存在任意一个K线huge_volume为1, 即存在一个K线是巨量
4. 当前K线为阳线, 即close > open
5. TODO: 考虑K线形态
"""
if index < 2:
return False
if row["close"] <= row["open"]:
return False
if row["low_10_low"] != 1:
return False
# 如果当前与前两个K线huge_volume都不为1则返回False
if (
row["huge_volume"] != 1
and market_data.loc[index - 1, "huge_volume"] != 1
and market_data.loc[index - 2, "huge_volume"] != 1
):
return False
logger.info(f"符合买入条件")
return True
def check_stop_loss_condition(self, trade_pair_dict: dict, row: pd.Series):
symbol = trade_pair_dict["symbol"]
bar = trade_pair_dict["bar"]
# 获取下跌周期跌幅中位数, 为百分比
down_median = (
self.peak_valley_data.loc[
(self.peak_valley_data["symbol"] == symbol)
& (self.peak_valley_data["bar"] == bar),
"down_median",
].values[0]
/ 100
)
buy_close = trade_pair_dict["buy_close"]
current_close = row["close"]
if (
current_close < buy_close
and (current_close - buy_close) / buy_close < down_median
):
logger.info(f"符合止损条件")
return True
return False
def check_take_profit_condition(
self,
trade_pair_dict: dict,
market_data: pd.DataFrame,
row: pd.Series,
index: int,
):
try:
if self.solution == "solution_1":
return self.check_take_profit_condition_solution_1(
market_data, row, index
)
elif self.solution == "solution_2":
return self.check_take_profit_condition_solution_2(
market_data, row, index
)
elif self.solution == "solution_3":
return self.check_take_profit_condition_solution_3(
trade_pair_dict, row
)
else:
raise ValueError(f"Invalid strategy name: {self.solution}")
except Exception as e:
logger.error(f"检查止盈条件时发生错误: {e}")
return False
def check_take_profit_condition_solution_1(
self,
market_data: pd.DataFrame,
row: pd.Series,
index: int,
):
"""
高位放量止盈 - 简易版
1. 当前high_80_high为1或者high_90_high为1
2. 之前2个K线与当前K线, 存在任意一个K线huge_volume为1, 即存在一个K线是巨量
"""
if row["high_80_high"] != 1 and row["high_90_high"] != 1:
return False
if (
row["huge_volume"] != 1
and market_data.loc[index - 1, "huge_volume"] != 1
and market_data.loc[index - 2, "huge_volume"] != 1
):
return False
logger.info(f"符合高位放量止盈 - 简易版条件")
return True
def check_take_profit_condition_solution_2(
self,
market_data: pd.DataFrame,
row: pd.Series,
index: int,
):
"""
高位放量止盈 - 复杂版
前提条件
1. 当前high_80_high为1或者high_90_high为1
2. 之前2个K线与当前K线, 存在任意一个K线huge_volume为1, 即存在一个K线是巨量
以下两个条件, 任一满足即可
1. K线为阴线, 即close < open
2. K线为阳线, 即close >= open, 且k_shape满足:
一字, 长吊锤线, 吊锤线, 长倒T线, 倒T线, 长十字星, 十字星, 长上影线纺锤体, 长下影线纺锤体
"""
if not self.check_take_profit_condition_solution_1(market_data, row, index):
return False
if row["close"] < row["open"]:
logger.info(f"符合高位放量止盈 - 复杂版条件")
return True
elif row["k_shape"] in ["一字", "长吊锤线", "吊锤线", "长倒T线", "倒T线", "长十字星", "十字星", "长上影线纺锤体", "长下影线纺锤体"]:
logger.info(f"符合高位放量止盈 - 复杂版条件")
return True
else:
return False
def check_take_profit_condition_solution_3(
self,
trade_pair_dict: dict,
row: pd.Series
):
"""
上涨波段盈利中位数止盈法
1. 超过波段中位数涨幅, 即up_median后, 记录当前价格, 继续持仓
2. 之后一个周期, 如果价格上涨, 则记录该价格继续持仓
3. 之后一个周期, 如果价格跌到记录价格之下, 则卖出
"""
current_close = row["close"]
last_max_close = trade_pair_dict.get("last_max_close", None)
if last_max_close is not None:
if current_close >= last_max_close:
logger.info(f"价格上涨, 继续持仓")
trade_pair_dict["last_max_close"] = current_close
return False
else:
logger.info(f"符合上涨波段盈利中位数止盈法条件")
return True
else:
symbol = trade_pair_dict["symbol"]
bar = trade_pair_dict["bar"]
up_median = (
self.peak_valley_data.loc[
(self.peak_valley_data["symbol"] == symbol)
& (self.peak_valley_data["bar"] == bar),
"up_median",
].values[0]
/ 100
)
buy_close = trade_pair_dict["buy_close"]
price_chg = (current_close - buy_close) / buy_close
if price_chg > up_median:
logger.info(f"当前价格上涨超过波段中位数涨幅, 记录当前价格")
trade_pair_dict["last_max_close"] = current_close
return False

File diff suppressed because one or more lines are too long

277
trade_sandbox_main.py Normal file
View File

@ -0,0 +1,277 @@
import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import PercentFormatter
from datetime import datetime
import re
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import MONITOR_CONFIG
from core.trade.mean_reversion_sandbox import MeanReversionSandbox
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MeanReversionSandboxMain:
def __init__(self, start_date: str, end_date: str, window_size: int):
self.symbols = MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
self.bars = MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.solution_list = ["solution_1", "solution_2", "solution_3"]
self.start_date = start_date
self.end_date = end_date
self.window_size = window_size
self.save_path = f"./output/trade_sandbox/mean_reversion/"
os.makedirs(self.save_path, exist_ok=True)
def batch_mean_reversion_sandbox(self):
"""
批量计算均值回归
"""
logger.info("开始批量计算均值回归交易策略")
logger.info(
f"开始时间: {self.start_date}, 结束时间: {self.end_date}, 窗口大小: {self.window_size}"
)
for solution in self.solution_list:
data_list = []
for symbol in self.symbols:
for bar in self.bars:
data = self.mean_reversion(symbol, bar, solution)
if data is not None and len(data) > 0:
data_list.append(data)
if len(data_list) == 0:
return None
total_data = pd.concat(data_list)
total_data.sort_values(by="buy_timestamp", ascending=True, inplace=True)
total_data.reset_index(drop=True, inplace=True)
stat_data = self.statistic_data(total_data)
excel_save_path = os.path.join(self.save_path, solution, "excel")
os.makedirs(excel_save_path, exist_ok=True)
date_time_str = datetime.now().strftime("%Y%m%d%H%M%S")
excel_file_path = os.path.join(
excel_save_path, f"{solution}_{date_time_str}.xlsx"
)
with pd.ExcelWriter(excel_file_path) as writer:
total_data.to_excel(writer, sheet_name="total_data", index=False)
stat_data.to_excel(writer, sheet_name="stat_data", index=False)
chart_dict = {}
self.draw_chart(stat_data, chart_dict)
self.output_chart_to_excel(excel_file_path, chart_dict)
def mean_reversion(self, symbol: str, bar: str, solution: str):
"""
均值回归交易策略
"""
mean_reversion_sandbox = MeanReversionSandbox(solution)
data = mean_reversion_sandbox.trade_sandbox(
symbol, bar, self.window_size, self.start_date, self.end_date
)
return data
def statistic_data(self, data: pd.DataFrame):
"""
统计数据
"""
data_list = []
# 以symbol, bar分组统计data的profit_pct>0的次数并且获得
# profit_pct的最大值最小值平均值profit_pct>0的平均值以及profit_pct<0的平均值
data_grouped = data.groupby(["symbol", "bar"])
for symbol, bar in data_grouped:
solution = bar["solution"].iloc[0]
# 止盈次数
take_profit_count = len(bar[bar["sell_type"] == "止盈"])
take_profit_ratio = round((take_profit_count / len(bar)) * 100, 4)
# 止损次数
stop_loss_count = len(bar[bar["sell_type"] == "止损"])
stop_loss_ratio = round((stop_loss_count / len(bar)) * 100, 4)
profit_pct_gt_0_count = len(bar[bar["profit_pct"] > 0])
profit_pct_gt_0_ratio = round((profit_pct_gt_0_count / len(bar)) * 100, 4)
profit_pct_lt_0_count = len(bar[bar["profit_pct"] < 0])
profit_pct_lt_0_ratio = round((profit_pct_lt_0_count / len(bar)) * 100, 4)
profit_pct_max = bar["profit_pct"].max()
profit_pct_min = bar["profit_pct"].min()
profit_pct_mean = bar["profit_pct"].mean()
profit_pct_gt_0_mean = bar[bar["profit_pct"] > 0]["profit_pct"].mean()
profit_pct_lt_0_mean = bar[bar["profit_pct"] < 0]["profit_pct"].mean()
symbol_name = bar["symbol"].iloc[0]
bar_name = bar["bar"].iloc[0]
logger.info(
f"策略: {solution}, symbol: {symbol_name}, bar: {bar_name}, profit_pct>0的次数: {profit_pct_gt_0_count}, profit_pct<0的次数: {profit_pct_lt_0_count}, profit_pct最大值: {profit_pct_max}, profit_pct最小值: {profit_pct_min}, profit_pct平均值: {profit_pct_mean}, profit_pct>0的平均值: {profit_pct_gt_0_mean}, profit_pct<0的平均值: {profit_pct_lt_0_mean}"
)
data_list.append(
{
"solution": solution,
"symbol": symbol_name,
"bar": bar_name,
"take_profit_count": take_profit_count,
"take_profit_ratio": take_profit_ratio,
"stop_loss_count": stop_loss_count,
"stop_loss_ratio": stop_loss_ratio,
"profit_pct_gt_0_count": profit_pct_gt_0_count,
"profit_pct_gt_0_ratio": profit_pct_gt_0_ratio,
"profit_pct_lt_0_count": profit_pct_lt_0_count,
"profit_pct_lt_0_ratio": profit_pct_lt_0_ratio,
"profit_pct_max": profit_pct_max,
"profit_pct_min": profit_pct_min,
"profit_pct_mean": profit_pct_mean,
"profit_pct_gt_0_mean": profit_pct_gt_0_mean,
"profit_pct_lt_0_mean": profit_pct_lt_0_mean,
}
)
stat_data = pd.DataFrame(data_list)
stat_data.sort_values(by=["bar", "symbol"], inplace=True)
stat_data.reset_index(drop=True, inplace=True)
return stat_data
def draw_chart(self, stat_data: pd.DataFrame, chart_dict: dict):
"""
绘制图表
"""
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150
plt.rcParams["savefig.dpi"] = 150
# 绘制各个solution的profit_pct_gt_0_ratio的柱状图
# bar为5m, 15, 30m, 1H共计四个分类
# 每一个bar为一张chart构成2x2的画布
# 要求y轴为百分比x轴为symbol
# 使用蓝色渐变色
# 每一个solution保存为一张chart图片保存到output/trade_sandbox/mean_reversion/chart/
solution = stat_data["solution"].iloc[0]
save_path = os.path.join(self.save_path, solution, "chart")
os.makedirs(save_path, exist_ok=True)
bars_in_order = [
b for b in getattr(self, "bars", []) if b in stat_data["bar"].unique()
]
if not bars_in_order:
bars_in_order = list(stat_data["bar"].unique())
palette_name = "Blues_d"
y_axis_fields = [
"take_profit_ratio",
"stop_loss_ratio",
"profit_pct_mean",
"profit_pct_gt_0_mean",
"profit_pct_lt_0_mean",
]
sheet_name = f"{solution}_chart"
chart_dict[sheet_name] = {}
for y_axis_field in y_axis_fields:
# 绘制2x2的画布
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for j, bar in enumerate(bars_in_order):
ax = axs[j // 2, j % 2]
bar_data = stat_data[stat_data["bar"] == bar].copy()
bar_data.sort_values(by=y_axis_field, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
colors = sns.color_palette(palette_name, n_colors=len(bar_data))
sns.barplot(
x="symbol",
y=y_axis_field,
data=bar_data,
palette=colors,
ax=ax,
)
ax.set_ylabel(y_axis_field)
ax.set_xlabel("symbol")
ax.set_title(f"{solution} {bar}")
if "ratio" in y_axis_field:
ax.yaxis.set_major_formatter(PercentFormatter(100))
ax.set_ylim(0, 100)
for label in ax.get_xticklabels():
label.set_rotation(45)
label.set_horizontalalignment("right")
# 隐藏未使用的subplot
total_used = len(bars_in_order)
for k in range(total_used, 4):
ax = axs[k // 2, k % 2]
ax.axis("off")
fig.tight_layout()
file_name = f"{solution}_{y_axis_field}.png"
fig.savefig(os.path.join(save_path, file_name))
plt.close(fig)
chart_dict[sheet_name][y_axis_field] = os.path.join(save_path, file_name)
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典格式为
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_data_dict in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
for chart_name, chart_path in chart_data_dict.items():
# Load image to get dimensions
with PILImage.open(chart_path) as img:
width_px, height_px = img.size
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
pixels_per_point = 1.333
points_per_row = 15 # Default row height in points
pixels_per_row = (
points_per_row * pixels_per_point
) # ≈ 20 pixels per row
chart_rows = max(
10, int(height_px / pixels_per_row)
) # Minimum 10 rows for small charts
# Add chart title
# 支持中文标题
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
row_offset += 2 # Add 2 rows for title and spacing
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
# Update row offset (chart height + padding)
row_offset += (
chart_rows + 5
) # Add 5 rows for padding between charts
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")
if __name__ == "__main__":
start_date = "2025-05-15 00:00:00"
end_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
mean_reversion_sandbox_main = MeanReversionSandboxMain(
start_date=start_date, end_date=end_date, window_size=100
)
mean_reversion_sandbox_main.batch_mean_reversion_sandbox()