optimize strategy statistics

This commit is contained in:
blade 2025-08-24 01:44:33 +08:00
parent b1e7ddc261
commit 5700e8e7e1
3 changed files with 137 additions and 20 deletions

View File

@ -63,21 +63,39 @@ class MaBreakStatistics:
return trade_strategy_config return trade_strategy_config
def batch_statistics(self, strategy_name: str = "全均线策略"): def batch_statistics(self, strategy_name: str = "全均线策略"):
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/" self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
)
os.makedirs(self.stats_output_dir, exist_ok=True) os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/" self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
)
os.makedirs(self.stats_chart_dir, exist_ok=True) os.makedirs(self.stats_chart_dir, exist_ok=True)
ma_break_market_data_list = [] ma_break_market_data_list = []
market_data_pct_chg_list = []
if strategy_name not in self.main_strategy.keys() or strategy_name is None: if strategy_name not in self.main_strategy.keys() or strategy_name is None:
strategy_name = "全均线策略" strategy_name = "全均线策略"
for symbol in self.symbols: for symbol in self.symbols:
for bar in self.bars: for bar in self.bars:
logger.info(f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}") logger.info(
ma_break_market_data = self.trade_simulate(symbol, bar, strategy_name) f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}"
if ma_break_market_data is not None: )
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
symbol, bar, strategy_name
)
if (
ma_break_market_data is not None
and len(ma_break_market_data) > 0
and market_data_pct_chg is not None
):
ma_break_market_data_list.append(ma_break_market_data) ma_break_market_data_list.append(ma_break_market_data)
logger.info(
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
)
market_data_pct_chg_list.append(market_data_pct_chg)
if len(ma_break_market_data_list) > 0: if len(ma_break_market_data_list) > 0:
ma_break_market_data = pd.concat(ma_break_market_data_list) ma_break_market_data = pd.concat(ma_break_market_data_list)
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
# 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count # 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
pct_chg_df = ( pct_chg_df = (
ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"] ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"]
@ -92,6 +110,66 @@ class MaBreakStatistics:
) )
.reset_index() .reset_index()
) )
pct_chg_df["strategy_name"] = strategy_name
pct_chg_df["pct_chg_total"] = 0
pct_chg_df["market_pct_chg"] = 0
# 将pct_chg_total与market_pct_chg的值类型转换为float
pct_chg_df["pct_chg_total"] = pct_chg_df["pct_chg_total"].astype(float)
pct_chg_df["market_pct_chg"] = pct_chg_df["market_pct_chg"].astype(float)
# 统计pct_chg_total
# 算法要求ma_break_market_data然后pct_chg/100 + 1
ma_break_market_data["pct_chg_total"] = (
ma_break_market_data["pct_chg"] / 100 + 1
)
# 遍历symbol和bar按照end_timestamp排序计算pct_chg_total的值然后相乘
for symbol in pct_chg_df["symbol"].unique():
for bar in pct_chg_df["bar"].unique():
symbol_bar_data = ma_break_market_data[
(ma_break_market_data["symbol"] == symbol)
& (ma_break_market_data["bar"] == bar)
]
if len(symbol_bar_data) > 0:
symbol_bar_data.sort_values(
by="end_timestamp", ascending=True, inplace=True
)
symbol_bar_data.reset_index(drop=True, inplace=True)
symbol_bar_data["pct_chg_total"] = symbol_bar_data[
"pct_chg_total"
].cumprod()
last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1]
last_pct_chg_total = (last_pct_chg_total - 1) * 100
pct_chg_df.loc[
(pct_chg_df["symbol"] == symbol)
& (pct_chg_df["bar"] == bar),
"pct_chg_total",
] = last_pct_chg_total
market_pct_chg = market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"pct_chg",
].values[0]
pct_chg_df.loc[
(pct_chg_df["symbol"] == symbol)
& (pct_chg_df["bar"] == bar),
"market_pct_chg",
] = market_pct_chg
pct_chg_df = pct_chg_df[
[
"strategy_name",
"symbol",
"bar",
"market_pct_chg",
"pct_chg_total",
"pct_chg_sum",
"pct_chg_max",
"pct_chg_min",
"pct_chg_mean",
"pct_chg_std",
"pct_chg_median",
"pct_chg_count",
]
]
# 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count # 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
interval_minutes_df = ( interval_minutes_df = (
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"] ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
@ -132,6 +210,7 @@ class MaBreakStatistics:
chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name) chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name)
self.output_chart_to_excel(output_file_path, chart_dict) self.output_chart_to_excel(output_file_path, chart_dict)
return pct_chg_df
else: else:
return None return None
@ -179,7 +258,7 @@ class MaBreakStatistics:
) )
if market_data is None or len(market_data) == 0: if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败") logger.warning(f"获取{symbol} {bar} 数据失败")
return return None, None
else: else:
market_data = pd.DataFrame(market_data) market_data = pd.DataFrame(market_data)
market_data.sort_values(by="timestamp", ascending=True, inplace=True) market_data.sort_values(by="timestamp", ascending=True, inplace=True)
@ -291,10 +370,20 @@ class MaBreakStatistics:
if len(ma_break_market_data_pair_list) > 0: if len(ma_break_market_data_pair_list) > 0:
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list) ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
logger.info(f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}") logger.info(
return ma_break_market_data f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
)
# 将market_data(最后一条数据的close - 第一条数据的open) / 第一条数据的open * 100
pct_chg = (
(market_data["close"].iloc[-1] - market_data["open"].iloc[0])
/ market_data["open"].iloc[0]
* 100
)
pct_chg = round(pct_chg, 4)
market_data_pct_chg = {"symbol": symbol, "bar": bar, "pct_chg": pct_chg}
return ma_break_market_data, market_data_pct_chg
else: else:
return None return None, None
def fit_strategy( def fit_strategy(
self, self,
@ -438,10 +527,7 @@ class MaBreakStatistics:
plt.rcParams["font.size"] = 11 # 设置字体大小 plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题 plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
chart_dict = {} chart_dict = {}
column_name_dict = { column_name_dict = {"pct_chg_total": "涨跌总和", "pct_chg_mean": "涨跌均值"}
"pct_chg_sum": "涨跌总和",
"pct_chg_mean": "涨跌均值",
}
for column_name, column_name_text in column_name_dict.items(): for column_name, column_name_text in column_name_dict.items():
for bar in data["bar"].unique(): for bar in data["bar"].unique():
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可 bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
@ -453,7 +539,9 @@ class MaBreakStatistics:
bar_data.reset_index(drop=True, inplace=True) bar_data.reset_index(drop=True, inplace=True)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
ax = sns.barplot(x="symbol", y=column_name_text, data=bar_data, palette="Blues_d") ax = sns.barplot(
x="symbol", y=column_name_text, data=bar_data, palette="Blues_d"
)
plt.title(f"{bar}趋势{column_name_text}(%)") plt.title(f"{bar}趋势{column_name_text}(%)")
plt.xlabel("symbol") plt.xlabel("symbol")
plt.ylabel(column_name_text) plt.ylabel(column_name_text)
@ -461,12 +549,21 @@ class MaBreakStatistics:
# 在柱状图上添加数值标签 # 在柱状图上添加数值标签
for i, v in enumerate(bar_data[column_name_text]): for i, v in enumerate(bar_data[column_name_text]):
ax.text(i, v, f'{v:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold') ax.text(
i,
v,
f"{v:.3f}",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
plt.tight_layout() plt.tight_layout()
save_path = os.path.join( save_path = os.path.join(
self.stats_chart_dir, f"{bar}_ma_break_{column_name}_{strategy_name}.png" self.stats_chart_dir,
f"{bar}_ma_break_{column_name}_{strategy_name}.png",
) )
plt.savefig(save_path, dpi=150) plt.savefig(save_path, dpi=150)
plt.close() plt.close()

View File

@ -34,8 +34,28 @@ class TradeMaStrategyMain:
""" """
logger.info("开始批量计算MA突破统计") logger.info("开始批量计算MA突破统计")
strategy_dict = self.ma_break_statistics.main_strategy strategy_dict = self.ma_break_statistics.main_strategy
pct_chg_df_list = []
for strategy_name, strategy_info in strategy_dict.items(): for strategy_name, strategy_info in strategy_dict.items():
self.ma_break_statistics.batch_statistics(strategy_name=strategy_name) pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
pct_chg_df_list.append(pct_chg_df)
pct_chg_df = pd.concat(pct_chg_df_list)
def statistics_pct_chg(self, pct_chg_df: pd.DataFrame):
"""
1. 将各个symbol, 各个bar, 各个策略的pct_chg_total构建为新的数据结构,
symbol, bar, stratege_name_1, stratege_name_2, stratege_name_3, ...
stratege_name_1的值, 为该策略的pct_chg_total的值
2. 构建新的数据结构: symbol, bar, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name
: BCT-USDT, 15m, 均线macd结合策略2, 全均线策略
3. 构建新的数据结构, bar, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name
: 15m, 均线macd结合策略2, 全均线策略
4. 构建新的数据结构, symbol, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name
: BCT-USDT, 均线macd结合策略2, 全均线策略
"""
logger.info("开始统计pct_chg")
if __name__ == "__main__": if __name__ == "__main__":
trade_ma_strategy_main = TradeMaStrategyMain() trade_ma_strategy_main = TradeMaStrategyMain()