optimize strategy statistics
This commit is contained in:
parent
b1e7ddc261
commit
5700e8e7e1
Binary file not shown.
|
|
@ -63,21 +63,39 @@ class MaBreakStatistics:
|
|||
return trade_strategy_config
|
||||
|
||||
def batch_statistics(self, strategy_name: str = "全均线策略"):
|
||||
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
|
||||
self.stats_output_dir = (
|
||||
f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
|
||||
)
|
||||
os.makedirs(self.stats_output_dir, exist_ok=True)
|
||||
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
|
||||
self.stats_chart_dir = (
|
||||
f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
|
||||
)
|
||||
os.makedirs(self.stats_chart_dir, exist_ok=True)
|
||||
ma_break_market_data_list = []
|
||||
market_data_pct_chg_list = []
|
||||
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
|
||||
strategy_name = "全均线策略"
|
||||
for symbol in self.symbols:
|
||||
for bar in self.bars:
|
||||
logger.info(f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}")
|
||||
ma_break_market_data = self.trade_simulate(symbol, bar, strategy_name)
|
||||
if ma_break_market_data is not None:
|
||||
logger.info(
|
||||
f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}"
|
||||
)
|
||||
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
|
||||
symbol, bar, strategy_name
|
||||
)
|
||||
if (
|
||||
ma_break_market_data is not None
|
||||
and len(ma_break_market_data) > 0
|
||||
and market_data_pct_chg is not None
|
||||
):
|
||||
ma_break_market_data_list.append(ma_break_market_data)
|
||||
logger.info(
|
||||
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
|
||||
)
|
||||
market_data_pct_chg_list.append(market_data_pct_chg)
|
||||
if len(ma_break_market_data_list) > 0:
|
||||
ma_break_market_data = pd.concat(ma_break_market_data_list)
|
||||
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
|
||||
# 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
|
||||
pct_chg_df = (
|
||||
ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"]
|
||||
|
|
@ -92,6 +110,66 @@ class MaBreakStatistics:
|
|||
)
|
||||
.reset_index()
|
||||
)
|
||||
pct_chg_df["strategy_name"] = strategy_name
|
||||
pct_chg_df["pct_chg_total"] = 0
|
||||
pct_chg_df["market_pct_chg"] = 0
|
||||
# 将pct_chg_total与market_pct_chg的值类型转换为float
|
||||
pct_chg_df["pct_chg_total"] = pct_chg_df["pct_chg_total"].astype(float)
|
||||
pct_chg_df["market_pct_chg"] = pct_chg_df["market_pct_chg"].astype(float)
|
||||
# 统计pct_chg_total
|
||||
# 算法要求,ma_break_market_data,然后pct_chg/100 + 1
|
||||
ma_break_market_data["pct_chg_total"] = (
|
||||
ma_break_market_data["pct_chg"] / 100 + 1
|
||||
)
|
||||
# 遍历symbol和bar,按照end_timestamp排序,计算pct_chg_total的值,然后相乘
|
||||
for symbol in pct_chg_df["symbol"].unique():
|
||||
for bar in pct_chg_df["bar"].unique():
|
||||
symbol_bar_data = ma_break_market_data[
|
||||
(ma_break_market_data["symbol"] == symbol)
|
||||
& (ma_break_market_data["bar"] == bar)
|
||||
]
|
||||
if len(symbol_bar_data) > 0:
|
||||
symbol_bar_data.sort_values(
|
||||
by="end_timestamp", ascending=True, inplace=True
|
||||
)
|
||||
symbol_bar_data.reset_index(drop=True, inplace=True)
|
||||
symbol_bar_data["pct_chg_total"] = symbol_bar_data[
|
||||
"pct_chg_total"
|
||||
].cumprod()
|
||||
last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1]
|
||||
last_pct_chg_total = (last_pct_chg_total - 1) * 100
|
||||
pct_chg_df.loc[
|
||||
(pct_chg_df["symbol"] == symbol)
|
||||
& (pct_chg_df["bar"] == bar),
|
||||
"pct_chg_total",
|
||||
] = last_pct_chg_total
|
||||
market_pct_chg = market_data_pct_chg_df.loc[
|
||||
(market_data_pct_chg_df["symbol"] == symbol)
|
||||
& (market_data_pct_chg_df["bar"] == bar),
|
||||
"pct_chg",
|
||||
].values[0]
|
||||
pct_chg_df.loc[
|
||||
(pct_chg_df["symbol"] == symbol)
|
||||
& (pct_chg_df["bar"] == bar),
|
||||
"market_pct_chg",
|
||||
] = market_pct_chg
|
||||
|
||||
pct_chg_df = pct_chg_df[
|
||||
[
|
||||
"strategy_name",
|
||||
"symbol",
|
||||
"bar",
|
||||
"market_pct_chg",
|
||||
"pct_chg_total",
|
||||
"pct_chg_sum",
|
||||
"pct_chg_max",
|
||||
"pct_chg_min",
|
||||
"pct_chg_mean",
|
||||
"pct_chg_std",
|
||||
"pct_chg_median",
|
||||
"pct_chg_count",
|
||||
]
|
||||
]
|
||||
# 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
|
||||
interval_minutes_df = (
|
||||
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
|
||||
|
|
@ -132,9 +210,10 @@ class MaBreakStatistics:
|
|||
|
||||
chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name)
|
||||
self.output_chart_to_excel(output_file_path, chart_dict)
|
||||
return pct_chg_df
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_strategy_info(self, strategy_name: str = "全均线策略"):
|
||||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||||
if strategy_config is None:
|
||||
|
|
@ -179,7 +258,7 @@ class MaBreakStatistics:
|
|||
)
|
||||
if market_data is None or len(market_data) == 0:
|
||||
logger.warning(f"获取{symbol} {bar} 数据失败")
|
||||
return
|
||||
return None, None
|
||||
else:
|
||||
market_data = pd.DataFrame(market_data)
|
||||
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
|
||||
|
|
@ -291,10 +370,20 @@ class MaBreakStatistics:
|
|||
|
||||
if len(ma_break_market_data_pair_list) > 0:
|
||||
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
|
||||
logger.info(f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}")
|
||||
return ma_break_market_data
|
||||
logger.info(
|
||||
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
|
||||
)
|
||||
# 将market_data(最后一条数据的close - 第一条数据的open) / 第一条数据的open * 100
|
||||
pct_chg = (
|
||||
(market_data["close"].iloc[-1] - market_data["open"].iloc[0])
|
||||
/ market_data["open"].iloc[0]
|
||||
* 100
|
||||
)
|
||||
pct_chg = round(pct_chg, 4)
|
||||
market_data_pct_chg = {"symbol": symbol, "bar": bar, "pct_chg": pct_chg}
|
||||
return ma_break_market_data, market_data_pct_chg
|
||||
else:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
def fit_strategy(
|
||||
self,
|
||||
|
|
@ -438,10 +527,7 @@ class MaBreakStatistics:
|
|||
plt.rcParams["font.size"] = 11 # 设置字体大小
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
|
||||
chart_dict = {}
|
||||
column_name_dict = {
|
||||
"pct_chg_sum": "涨跌总和",
|
||||
"pct_chg_mean": "涨跌均值",
|
||||
}
|
||||
column_name_dict = {"pct_chg_total": "涨跌总和", "pct_chg_mean": "涨跌均值"}
|
||||
for column_name, column_name_text in column_name_dict.items():
|
||||
for bar in data["bar"].unique():
|
||||
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
|
||||
|
|
@ -453,20 +539,31 @@ class MaBreakStatistics:
|
|||
bar_data.reset_index(drop=True, inplace=True)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
ax = sns.barplot(x="symbol", y=column_name_text, data=bar_data, palette="Blues_d")
|
||||
ax = sns.barplot(
|
||||
x="symbol", y=column_name_text, data=bar_data, palette="Blues_d"
|
||||
)
|
||||
plt.title(f"{bar}趋势{column_name_text}(%)")
|
||||
plt.xlabel("symbol")
|
||||
plt.ylabel(column_name_text)
|
||||
plt.xticks(rotation=45, ha="right")
|
||||
|
||||
|
||||
# 在柱状图上添加数值标签
|
||||
for i, v in enumerate(bar_data[column_name_text]):
|
||||
ax.text(i, v, f'{v:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
|
||||
|
||||
ax.text(
|
||||
i,
|
||||
v,
|
||||
f"{v:.3f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
save_path = os.path.join(
|
||||
self.stats_chart_dir, f"{bar}_ma_break_{column_name}_{strategy_name}.png"
|
||||
self.stats_chart_dir,
|
||||
f"{bar}_ma_break_{column_name}_{strategy_name}.png",
|
||||
)
|
||||
plt.savefig(save_path, dpi=150)
|
||||
plt.close()
|
||||
|
|
|
|||
|
|
@ -34,8 +34,28 @@ class TradeMaStrategyMain:
|
|||
"""
|
||||
logger.info("开始批量计算MA突破统计")
|
||||
strategy_dict = self.ma_break_statistics.main_strategy
|
||||
pct_chg_df_list = []
|
||||
for strategy_name, strategy_info in strategy_dict.items():
|
||||
self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
|
||||
pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
|
||||
pct_chg_df_list.append(pct_chg_df)
|
||||
pct_chg_df = pd.concat(pct_chg_df_list)
|
||||
|
||||
def statistics_pct_chg(self, pct_chg_df: pd.DataFrame):
|
||||
"""
|
||||
1. 将各个symbol, 各个bar, 各个策略的pct_chg_total构建为新的数据结构,如:
|
||||
symbol, bar, stratege_name_1, stratege_name_2, stratege_name_3, ...
|
||||
stratege_name_1的值, 为该策略的pct_chg_total的值
|
||||
2. 构建新的数据结构: symbol, bar, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name
|
||||
如: BCT-USDT, 15m, 均线macd结合策略2, 全均线策略
|
||||
3. 构建新的数据结构, bar, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name
|
||||
如: 15m, 均线macd结合策略2, 全均线策略
|
||||
4. 构建新的数据结构, symbol, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name
|
||||
如: BCT-USDT, 均线macd结合策略2, 全均线策略
|
||||
"""
|
||||
logger.info("开始统计pct_chg")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trade_ma_strategy_main = TradeMaStrategyMain()
|
||||
|
|
|
|||
Loading…
Reference in New Issue