diff --git a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc index 22ba36c..e041493 100644 Binary files a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc and b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc differ diff --git a/core/trade/ma_break_statistics.py b/core/trade/ma_break_statistics.py index 82936f8..319541d 100644 --- a/core/trade/ma_break_statistics.py +++ b/core/trade/ma_break_statistics.py @@ -63,21 +63,39 @@ class MaBreakStatistics: return trade_strategy_config def batch_statistics(self, strategy_name: str = "全均线策略"): - self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/" + self.stats_output_dir = ( + f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/" + ) os.makedirs(self.stats_output_dir, exist_ok=True) - self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/" + self.stats_chart_dir = ( + f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/" + ) os.makedirs(self.stats_chart_dir, exist_ok=True) ma_break_market_data_list = [] + market_data_pct_chg_list = [] if strategy_name not in self.main_strategy.keys() or strategy_name is None: strategy_name = "全均线策略" for symbol in self.symbols: for bar in self.bars: - logger.info(f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}") - ma_break_market_data = self.trade_simulate(symbol, bar, strategy_name) - if ma_break_market_data is not None: + logger.info( + f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}" + ) + ma_break_market_data, market_data_pct_chg = self.trade_simulate( + symbol, bar, strategy_name + ) + if ( + ma_break_market_data is not None + and len(ma_break_market_data) > 0 + and market_data_pct_chg is not None + ): ma_break_market_data_list.append(ma_break_market_data) + logger.info( + f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%" + ) + market_data_pct_chg_list.append(market_data_pct_chg) if len(ma_break_market_data_list) > 0: ma_break_market_data = pd.concat(ma_break_market_data_list) + market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list) # 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count pct_chg_df = ( ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"] @@ -92,6 +110,66 @@ class MaBreakStatistics: ) .reset_index() ) + pct_chg_df["strategy_name"] = strategy_name + pct_chg_df["pct_chg_total"] = 0 + pct_chg_df["market_pct_chg"] = 0 + # 将pct_chg_total与market_pct_chg的值类型转换为float + pct_chg_df["pct_chg_total"] = pct_chg_df["pct_chg_total"].astype(float) + pct_chg_df["market_pct_chg"] = pct_chg_df["market_pct_chg"].astype(float) + # 统计pct_chg_total + # 算法要求,ma_break_market_data,然后pct_chg/100 + 1 + ma_break_market_data["pct_chg_total"] = ( + ma_break_market_data["pct_chg"] / 100 + 1 + ) + # 遍历symbol和bar,按照end_timestamp排序,计算pct_chg_total的值,然后相乘 + for symbol in pct_chg_df["symbol"].unique(): + for bar in pct_chg_df["bar"].unique(): + symbol_bar_data = ma_break_market_data[ + (ma_break_market_data["symbol"] == symbol) + & (ma_break_market_data["bar"] == bar) + ] + if len(symbol_bar_data) > 0: + symbol_bar_data.sort_values( + by="end_timestamp", ascending=True, inplace=True + ) + symbol_bar_data.reset_index(drop=True, inplace=True) + symbol_bar_data["pct_chg_total"] = symbol_bar_data[ + "pct_chg_total" + ].cumprod() + last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1] + last_pct_chg_total = (last_pct_chg_total - 1) * 100 + pct_chg_df.loc[ + (pct_chg_df["symbol"] == symbol) + & (pct_chg_df["bar"] == bar), + "pct_chg_total", + ] = last_pct_chg_total + market_pct_chg = market_data_pct_chg_df.loc[ + (market_data_pct_chg_df["symbol"] == symbol) + & (market_data_pct_chg_df["bar"] == bar), + "pct_chg", + ].values[0] + pct_chg_df.loc[ + (pct_chg_df["symbol"] == symbol) + & (pct_chg_df["bar"] == bar), + "market_pct_chg", + ] = market_pct_chg + + pct_chg_df = pct_chg_df[ + [ + "strategy_name", + "symbol", + "bar", + "market_pct_chg", + "pct_chg_total", + "pct_chg_sum", + "pct_chg_max", + "pct_chg_min", + "pct_chg_mean", + "pct_chg_std", + "pct_chg_median", + "pct_chg_count", + ] + ] # 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count interval_minutes_df = ( ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"] @@ -132,9 +210,10 @@ class MaBreakStatistics: chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name) self.output_chart_to_excel(output_file_path, chart_dict) + return pct_chg_df else: return None - + def get_strategy_info(self, strategy_name: str = "全均线策略"): strategy_config = self.main_strategy.get(strategy_name, None) if strategy_config is None: @@ -179,7 +258,7 @@ class MaBreakStatistics: ) if market_data is None or len(market_data) == 0: logger.warning(f"获取{symbol} {bar} 数据失败") - return + return None, None else: market_data = pd.DataFrame(market_data) market_data.sort_values(by="timestamp", ascending=True, inplace=True) @@ -291,10 +370,20 @@ class MaBreakStatistics: if len(ma_break_market_data_pair_list) > 0: ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list) - logger.info(f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}") - return ma_break_market_data + logger.info( + f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}" + ) + # 将market_data(最后一条数据的close - 第一条数据的open) / 第一条数据的open * 100 + pct_chg = ( + (market_data["close"].iloc[-1] - market_data["open"].iloc[0]) + / market_data["open"].iloc[0] + * 100 + ) + pct_chg = round(pct_chg, 4) + market_data_pct_chg = {"symbol": symbol, "bar": bar, "pct_chg": pct_chg} + return ma_break_market_data, market_data_pct_chg else: - return None + return None, None def fit_strategy( self, @@ -438,10 +527,7 @@ class MaBreakStatistics: plt.rcParams["font.size"] = 11 # 设置字体大小 plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题 chart_dict = {} - column_name_dict = { - "pct_chg_sum": "涨跌总和", - "pct_chg_mean": "涨跌均值", - } + column_name_dict = {"pct_chg_total": "涨跌总和", "pct_chg_mean": "涨跌均值"} for column_name, column_name_text in column_name_dict.items(): for bar in data["bar"].unique(): bar_data = data[data["bar"] == bar].copy() # 一次筛选即可 @@ -453,20 +539,31 @@ class MaBreakStatistics: bar_data.reset_index(drop=True, inplace=True) plt.figure(figsize=(10, 6)) - ax = sns.barplot(x="symbol", y=column_name_text, data=bar_data, palette="Blues_d") + ax = sns.barplot( + x="symbol", y=column_name_text, data=bar_data, palette="Blues_d" + ) plt.title(f"{bar}趋势{column_name_text}(%)") plt.xlabel("symbol") plt.ylabel(column_name_text) plt.xticks(rotation=45, ha="right") - + # 在柱状图上添加数值标签 for i, v in enumerate(bar_data[column_name_text]): - ax.text(i, v, f'{v:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold') - + ax.text( + i, + v, + f"{v:.3f}", + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + ) + plt.tight_layout() save_path = os.path.join( - self.stats_chart_dir, f"{bar}_ma_break_{column_name}_{strategy_name}.png" + self.stats_chart_dir, + f"{bar}_ma_break_{column_name}_{strategy_name}.png", ) plt.savefig(save_path, dpi=150) plt.close() diff --git a/trade_ma_strategy_main.py b/trade_ma_strategy_main.py index 430d5a2..01174a6 100644 --- a/trade_ma_strategy_main.py +++ b/trade_ma_strategy_main.py @@ -34,8 +34,28 @@ class TradeMaStrategyMain: """ logger.info("开始批量计算MA突破统计") strategy_dict = self.ma_break_statistics.main_strategy + pct_chg_df_list = [] for strategy_name, strategy_info in strategy_dict.items(): - self.ma_break_statistics.batch_statistics(strategy_name=strategy_name) + pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name) + pct_chg_df_list.append(pct_chg_df) + pct_chg_df = pd.concat(pct_chg_df_list) + + def statistics_pct_chg(self, pct_chg_df: pd.DataFrame): + """ + 1. 将各个symbol, 各个bar, 各个策略的pct_chg_total构建为新的数据结构,如: + symbol, bar, stratege_name_1, stratege_name_2, stratege_name_3, ... + stratege_name_1的值, 为该策略的pct_chg_total的值 + 2. 构建新的数据结构: symbol, bar, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name + 如: BCT-USDT, 15m, 均线macd结合策略2, 全均线策略 + 3. 构建新的数据结构, bar, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name + 如: 15m, 均线macd结合策略2, 全均线策略 + 4. 构建新的数据结构, symbol, max_pct_chg_total_strategy_name, min_pct_chg_total_strategy_name + 如: BCT-USDT, 均线macd结合策略2, 全均线策略 + """ + logger.info("开始统计pct_chg") + + + if __name__ == "__main__": trade_ma_strategy_main = TradeMaStrategyMain()