diff --git a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc index e041493..10d11b4 100644 Binary files a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc and b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc differ diff --git a/core/trade/ma_break_statistics.py b/core/trade/ma_break_statistics.py index 319541d..1a4a155 100644 --- a/core/trade/ma_break_statistics.py +++ b/core/trade/ma_break_statistics.py @@ -93,10 +93,13 @@ class MaBreakStatistics: f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%" ) market_data_pct_chg_list.append(market_data_pct_chg) + if len(ma_break_market_data_list) > 0: ma_break_market_data = pd.concat(ma_break_market_data_list) market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list) # 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count + ma_break_market_data.sort_values(by="begin_timestamp", ascending=True, inplace=True) + ma_break_market_data.reset_index(drop=True, inplace=True) pct_chg_df = ( ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"] .agg( @@ -127,7 +130,7 @@ class MaBreakStatistics: symbol_bar_data = ma_break_market_data[ (ma_break_market_data["symbol"] == symbol) & (ma_break_market_data["bar"] == bar) - ] + ].copy() # 创建副本避免SettingWithCopyWarning if len(symbol_bar_data) > 0: symbol_bar_data.sort_values( by="end_timestamp", ascending=True, inplace=True @@ -136,6 +139,14 @@ class MaBreakStatistics: symbol_bar_data["pct_chg_total"] = symbol_bar_data[ "pct_chg_total" ].cumprod() + + # 将更新后的pct_chg_total数据同步更新到ma_break_market_data的对应数据行中 + for idx, row in symbol_bar_data.iterrows(): + mask = (ma_break_market_data["symbol"] == symbol) & \ + (ma_break_market_data["bar"] == bar) & \ + (ma_break_market_data["end_timestamp"] == row["end_timestamp"]) + ma_break_market_data.loc[mask, "pct_chg_total"] = row["pct_chg_total"] + last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1] last_pct_chg_total = (last_pct_chg_total - 1) * 100 pct_chg_df.loc[ @@ -208,7 +219,9 @@ class MaBreakStatistics: writer, sheet_name="买卖时间间隔统计", index=False ) - chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name) + chart_dict = self.draw_quant_pct_chg_bar_chart(pct_chg_df, strategy_name) + self.output_chart_to_excel(output_file_path, chart_dict) + chart_dict = self.draw_quant_line_chart(ma_break_market_data, strategy_name) self.output_chart_to_excel(output_file_path, chart_dict) return pct_chg_df else: @@ -370,13 +383,17 @@ class MaBreakStatistics: if len(ma_break_market_data_pair_list) > 0: ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list) + # sort by end_timestamp + ma_break_market_data.sort_values(by="begin_timestamp", ascending=True, inplace=True) + ma_break_market_data.reset_index(drop=True, inplace=True) logger.info( f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}" ) - # 将market_data(最后一条数据的close - 第一条数据的open) / 第一条数据的open * 100 + # 量化期间,市场的波动率: + # ma_break_market_data(最后一条数据的end_close - 第一条数据的begin_close) / 第一条数据的begin_close * 100 pct_chg = ( - (market_data["close"].iloc[-1] - market_data["open"].iloc[0]) - / market_data["open"].iloc[0] + (ma_break_market_data["end_close"].iloc[-1] - ma_break_market_data["begin_close"].iloc[0]) + / ma_break_market_data["begin_close"].iloc[0] * 100 ) pct_chg = round(pct_chg, 4) @@ -511,7 +528,7 @@ class MaBreakStatistics: pass return condition - def draw_pct_chg_mean_chart( + def draw_quant_pct_chg_bar_chart( self, data: pd.DataFrame, strategy_name: str = "全均线策略" ): """ @@ -527,48 +544,204 @@ class MaBreakStatistics: plt.rcParams["font.size"] = 11 # 设置字体大小 plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题 chart_dict = {} - column_name_dict = {"pct_chg_total": "涨跌总和", "pct_chg_mean": "涨跌均值"} + column_name_dict = {"pct_chg_total": "量化策略涨跌", "pct_chg_mean": "量化策略涨跌均值"} for column_name, column_name_text in column_name_dict.items(): for bar in data["bar"].unique(): bar_data = data[data["bar"] == bar].copy() # 一次筛选即可 if bar_data.empty: continue bar_data.rename(columns={column_name: column_name_text}, inplace=True) + if column_name == "pct_chg_total": + bar_data.rename(columns={"market_pct_chg": "市场自然涨跌"}, inplace=True) # 可选:按均值排序 bar_data.sort_values(by=column_name_text, ascending=False, inplace=True) bar_data.reset_index(drop=True, inplace=True) - plt.figure(figsize=(10, 6)) - ax = sns.barplot( - x="symbol", y=column_name_text, data=bar_data, palette="Blues_d" - ) - plt.title(f"{bar}趋势{column_name_text}(%)") - plt.xlabel("symbol") - plt.ylabel(column_name_text) - plt.xticks(rotation=45, ha="right") + # 如果column_name_text是"量化策略涨跌",则柱状图,同时绘制量化策略涨跌与市场自然涨跌的柱状图,并绘制在同一个图表中 + if column_name == "pct_chg_total": + plt.figure(figsize=(12, 7)) + + # 设置x轴位置,为并列柱状图做准备 + x = np.arange(len(bar_data)) + width = 0.35 # 柱状图宽度 + + # 确保symbol列是字符串类型,避免matplotlib警告 + bar_data["symbol"] = bar_data["symbol"].astype(str) + + # 绘制量化策略涨跌柱状图(蓝色渐变色) + bars1 = plt.bar(x - width/2, bar_data[column_name_text], width, + label=column_name_text, + color=plt.cm.Blues(np.linspace(0.6, 0.9, len(bar_data)))) + + # 绘制市场自然涨跌柱状图(绿色渐变色) + bars2 = plt.bar(x + width/2, bar_data["市场自然涨跌"], width, + label="市场自然涨跌", + color=plt.cm.Greens(np.linspace(0.6, 0.9, len(bar_data)))) + + # 设置图表标题和标签 + plt.title(f"{bar}趋势{column_name_text}与市场自然涨跌对比(%)", fontsize=14, fontweight='bold') + plt.xlabel("Symbol", fontsize=12) + plt.ylabel("涨跌幅(%)", fontsize=12) + plt.xticks(x, bar_data['symbol'], rotation=45, ha='right') + plt.legend() + plt.grid(True, alpha=0.3) + + # 在量化策略涨跌柱状图上方添加数值标签 + for i, (bar1, value1) in enumerate(zip(bars1, bar_data[column_name_text])): + plt.text(bar1.get_x() + bar1.get_width()/2, value1 + (0.01 if value1 >= 0 else -0.01), + f'{value1:.3f}%', ha='center', va='bottom' if value1 >= 0 else 'top', + fontsize=9, fontweight='bold', color='darkblue') + + # 在市场自然涨跌柱状图上方添加数值标签 + for i, (bar2, value2) in enumerate(zip(bars2, bar_data["市场自然涨跌"])): + plt.text(bar2.get_x() + bar2.get_width()/2, value2 + (0.01 if value2 >= 0 else -0.01), + f'{value2:.3f}%', ha='center', va='bottom' if value2 >= 0 else 'top', + fontsize=9, fontweight='bold', color='darkgreen') - # 在柱状图上添加数值标签 - for i, v in enumerate(bar_data[column_name_text]): - ax.text( - i, - v, - f"{v:.3f}", - ha="center", - va="bottom", - fontsize=10, - fontweight="bold", + else: + plt.figure(figsize=(10, 6)) + + # 确保symbol列是字符串类型,避免matplotlib警告 + bar_data["symbol"] = bar_data["symbol"].astype(str) + + ax = sns.barplot( + x="symbol", y=column_name_text, data=bar_data, palette="Blues_d" ) + plt.title(f"{bar}趋势{column_name_text}(%)") + plt.xlabel("symbol") + plt.ylabel(column_name_text) + plt.xticks(rotation=45, ha="right") + + # 在柱状图上添加数值标签 + for i, v in enumerate(bar_data[column_name_text]): + ax.text( + i, + v, + f"{v:.3f}", + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + ) plt.tight_layout() save_path = os.path.join( self.stats_chart_dir, - f"{bar}_ma_break_{column_name}_{strategy_name}.png", + f"{bar}_bar_chart_{column_name}_{strategy_name}.png", ) plt.savefig(save_path, dpi=150) plt.close() - sheet_name = f"{bar}_趋势{column_name_text}分布图表_{strategy_name}" + sheet_name = f"{bar}_趋势{column_name_text}柱状图_{strategy_name}" + chart_dict[sheet_name] = save_path + return chart_dict + + def draw_quant_line_chart(self, data: pd.DataFrame, strategy_name: str = "全均线策略"): + """ + 根据量化策略买卖明细记录,绘制量化策略涨跌与市场自然涨跌的折线图 + :param data: 量化策略买卖明细记录 + :param strategy_name: 策略名称 + :return: None + """ + symbols = data["symbol"].unique() + bars = data["bar"].unique() + chart_dict = {} + for symbol in symbols: + for bar in bars: + symbol_bar_data = data[(data["symbol"] == symbol) & (data["bar"] == bar)] + if symbol_bar_data.empty: + continue + + # 获取第一行数据作为基准 + first_row = symbol_bar_data.iloc[0].copy() + + # 创建初始化行,设置基准值 + init_row = first_row.copy() + init_row.loc["pct_chg_total"] = 1.0 # 量化策略初始值为1 + init_row.loc["end_timestamp"] = first_row["begin_timestamp"] + init_row.loc["end_date_time"] = first_row["begin_date_time"] + init_row.loc["end_close"] = first_row["begin_close"] + init_row.loc["end_ma5"] = first_row["begin_ma5"] + init_row.loc["end_ma10"] = first_row["begin_ma10"] + init_row.loc["end_ma20"] = first_row["begin_ma20"] + init_row.loc["end_ma30"] = first_row["begin_ma30"] + init_row.loc["end_macd_diff"] = first_row["begin_macd_diff"] + init_row.loc["end_macd_dea"] = first_row["begin_macd_dea"] + init_row.loc["end_macd"] = first_row["begin_macd"] + init_row.loc["pct_chg"] = 0 + init_row.loc["interval_seconds"] = 0 + init_row.loc["interval_minutes"] = 0 + init_row.loc["interval_hours"] = 0 + init_row.loc["interval_days"] = 0 + + # 将初始化行添加到数据开头 + symbol_bar_data = pd.concat([pd.DataFrame([init_row]), symbol_bar_data]) + symbol_bar_data.sort_values(by="end_timestamp", ascending=True, inplace=True) + symbol_bar_data.reset_index(drop=True, inplace=True) + + # 确保时间列是datetime类型,避免matplotlib警告 + symbol_bar_data["end_date_time"] = pd.to_datetime(symbol_bar_data["end_date_time"]) + + # 计算市场价位归一化数据(相对于初始价格) + symbol_bar_data["end_close_to_1"] = symbol_bar_data["end_close"] / init_row["end_close"] + symbol_bar_data["end_close_to_1"] = symbol_bar_data["end_close_to_1"].round(4) + + # 绘制折线图 + plt.figure(figsize=(12, 7)) + + # 绘制量化策略涨跌线(蓝色) + plt.plot(symbol_bar_data["end_date_time"], symbol_bar_data["pct_chg_total"], + label="量化策略涨跌", color='blue', linewidth=2, marker='o', markersize=4) + + # 绘制市场自然涨跌线(绿色) + plt.plot(symbol_bar_data["end_date_time"], symbol_bar_data["end_close_to_1"], + label="市场自然涨跌", color='green', linewidth=2, marker='s', markersize=4) + + plt.title(f"{symbol} {bar} 量化与市场折线图_{strategy_name}", + fontsize=14, fontweight='bold') + plt.xlabel("时间", fontsize=12) + plt.ylabel("涨跌变化", fontsize=12) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + + # 设置x轴标签,避免matplotlib警告 + # 选择合适的时间间隔显示标签,避免过于密集 + if len(symbol_bar_data) > 30: + # 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示 + step = max(1, len(symbol_bar_data) // 30) + + # 创建标签索引列表,确保包含首尾数据 + label_indices = [0] # 第一条 + + # 添加中间间隔的标签 + for i in range(step, len(symbol_bar_data) - 1, step): + label_indices.append(i) + + # 添加最后一条(如果还没有包含的话) + if len(symbol_bar_data) - 1 not in label_indices: + label_indices.append(len(symbol_bar_data) - 1) + + # 设置x轴标签 + plt.xticks(symbol_bar_data["end_date_time"].iloc[label_indices], + symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%m-%d %H:%M'), + rotation=45, ha='right') + else: + # 如果数据点较少,全部显示 + plt.xticks(symbol_bar_data["end_date_time"], + symbol_bar_data["end_date_time"].dt.strftime('%m-%d %H:%M'), + rotation=45, ha='right') + + plt.tight_layout() + + save_path = os.path.join( + self.stats_chart_dir, + f"{symbol}_{bar}_line_chart_{strategy_name}.png", + ) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() + + sheet_name = f"{symbol}_{bar}_折线图_{strategy_name}" chart_dict[sheet_name] = save_path return chart_dict diff --git a/trade_ma_strategy_main.py b/trade_ma_strategy_main.py index 01174a6..7b815b8 100644 --- a/trade_ma_strategy_main.py +++ b/trade_ma_strategy_main.py @@ -38,6 +38,7 @@ class TradeMaStrategyMain: for strategy_name, strategy_info in strategy_dict.items(): pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name) pct_chg_df_list.append(pct_chg_df) + pct_chg_df = pd.concat(pct_chg_df_list) def statistics_pct_chg(self, pct_chg_df: pd.DataFrame):