optimize chart output

This commit is contained in:
blade 2025-08-25 16:58:38 +08:00
parent 5700e8e7e1
commit 69aff37c24
3 changed files with 201 additions and 27 deletions

View File

@ -93,10 +93,13 @@ class MaBreakStatistics:
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
)
market_data_pct_chg_list.append(market_data_pct_chg)
if len(ma_break_market_data_list) > 0:
ma_break_market_data = pd.concat(ma_break_market_data_list)
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
# 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
ma_break_market_data.sort_values(by="begin_timestamp", ascending=True, inplace=True)
ma_break_market_data.reset_index(drop=True, inplace=True)
pct_chg_df = (
ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"]
.agg(
@ -127,7 +130,7 @@ class MaBreakStatistics:
symbol_bar_data = ma_break_market_data[
(ma_break_market_data["symbol"] == symbol)
& (ma_break_market_data["bar"] == bar)
]
].copy() # 创建副本避免SettingWithCopyWarning
if len(symbol_bar_data) > 0:
symbol_bar_data.sort_values(
by="end_timestamp", ascending=True, inplace=True
@ -136,6 +139,14 @@ class MaBreakStatistics:
symbol_bar_data["pct_chg_total"] = symbol_bar_data[
"pct_chg_total"
].cumprod()
# 将更新后的pct_chg_total数据同步更新到ma_break_market_data的对应数据行中
for idx, row in symbol_bar_data.iterrows():
mask = (ma_break_market_data["symbol"] == symbol) & \
(ma_break_market_data["bar"] == bar) & \
(ma_break_market_data["end_timestamp"] == row["end_timestamp"])
ma_break_market_data.loc[mask, "pct_chg_total"] = row["pct_chg_total"]
last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1]
last_pct_chg_total = (last_pct_chg_total - 1) * 100
pct_chg_df.loc[
@ -208,7 +219,9 @@ class MaBreakStatistics:
writer, sheet_name="买卖时间间隔统计", index=False
)
chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name)
chart_dict = self.draw_quant_pct_chg_bar_chart(pct_chg_df, strategy_name)
self.output_chart_to_excel(output_file_path, chart_dict)
chart_dict = self.draw_quant_line_chart(ma_break_market_data, strategy_name)
self.output_chart_to_excel(output_file_path, chart_dict)
return pct_chg_df
else:
@ -370,13 +383,17 @@ class MaBreakStatistics:
if len(ma_break_market_data_pair_list) > 0:
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
# sort by end_timestamp
ma_break_market_data.sort_values(by="begin_timestamp", ascending=True, inplace=True)
ma_break_market_data.reset_index(drop=True, inplace=True)
logger.info(
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
)
# 将market_data(最后一条数据的close - 第一条数据的open) / 第一条数据的open * 100
# 量化期间,市场的波动率:
# ma_break_market_data(最后一条数据的end_close - 第一条数据的begin_close) / 第一条数据的begin_close * 100
pct_chg = (
(market_data["close"].iloc[-1] - market_data["open"].iloc[0])
/ market_data["open"].iloc[0]
(ma_break_market_data["end_close"].iloc[-1] - ma_break_market_data["begin_close"].iloc[0])
/ ma_break_market_data["begin_close"].iloc[0]
* 100
)
pct_chg = round(pct_chg, 4)
@ -511,7 +528,7 @@ class MaBreakStatistics:
pass
return condition
def draw_pct_chg_mean_chart(
def draw_quant_pct_chg_bar_chart(
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
):
"""
@ -527,18 +544,66 @@ class MaBreakStatistics:
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
chart_dict = {}
column_name_dict = {"pct_chg_total": "涨跌总和", "pct_chg_mean": "涨跌均值"}
column_name_dict = {"pct_chg_total": "量化策略涨跌", "pct_chg_mean": "量化策略涨跌均值"}
for column_name, column_name_text in column_name_dict.items():
for bar in data["bar"].unique():
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
if bar_data.empty:
continue
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
if column_name == "pct_chg_total":
bar_data.rename(columns={"market_pct_chg": "市场自然涨跌"}, inplace=True)
# 可选:按均值排序
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
# 如果column_name_text是"量化策略涨跌",则柱状图,同时绘制量化策略涨跌与市场自然涨跌的柱状图,并绘制在同一个图表中
if column_name == "pct_chg_total":
plt.figure(figsize=(12, 7))
# 设置x轴位置为并列柱状图做准备
x = np.arange(len(bar_data))
width = 0.35 # 柱状图宽度
# 确保symbol列是字符串类型避免matplotlib警告
bar_data["symbol"] = bar_data["symbol"].astype(str)
# 绘制量化策略涨跌柱状图(蓝色渐变色)
bars1 = plt.bar(x - width/2, bar_data[column_name_text], width,
label=column_name_text,
color=plt.cm.Blues(np.linspace(0.6, 0.9, len(bar_data))))
# 绘制市场自然涨跌柱状图(绿色渐变色)
bars2 = plt.bar(x + width/2, bar_data["市场自然涨跌"], width,
label="市场自然涨跌",
color=plt.cm.Greens(np.linspace(0.6, 0.9, len(bar_data))))
# 设置图表标题和标签
plt.title(f"{bar}趋势{column_name_text}与市场自然涨跌对比(%)", fontsize=14, fontweight='bold')
plt.xlabel("Symbol", fontsize=12)
plt.ylabel("涨跌幅(%)", fontsize=12)
plt.xticks(x, bar_data['symbol'], rotation=45, ha='right')
plt.legend()
plt.grid(True, alpha=0.3)
# 在量化策略涨跌柱状图上方添加数值标签
for i, (bar1, value1) in enumerate(zip(bars1, bar_data[column_name_text])):
plt.text(bar1.get_x() + bar1.get_width()/2, value1 + (0.01 if value1 >= 0 else -0.01),
f'{value1:.3f}%', ha='center', va='bottom' if value1 >= 0 else 'top',
fontsize=9, fontweight='bold', color='darkblue')
# 在市场自然涨跌柱状图上方添加数值标签
for i, (bar2, value2) in enumerate(zip(bars2, bar_data["市场自然涨跌"])):
plt.text(bar2.get_x() + bar2.get_width()/2, value2 + (0.01 if value2 >= 0 else -0.01),
f'{value2:.3f}%', ha='center', va='bottom' if value2 >= 0 else 'top',
fontsize=9, fontweight='bold', color='darkgreen')
else:
plt.figure(figsize=(10, 6))
# 确保symbol列是字符串类型避免matplotlib警告
bar_data["symbol"] = bar_data["symbol"].astype(str)
ax = sns.barplot(
x="symbol", y=column_name_text, data=bar_data, palette="Blues_d"
)
@ -563,12 +628,120 @@ class MaBreakStatistics:
save_path = os.path.join(
self.stats_chart_dir,
f"{bar}_ma_break_{column_name}_{strategy_name}.png",
f"{bar}_bar_chart_{column_name}_{strategy_name}.png",
)
plt.savefig(save_path, dpi=150)
plt.close()
sheet_name = f"{bar}_趋势{column_name_text}分布图表_{strategy_name}"
sheet_name = f"{bar}_趋势{column_name_text}柱状图_{strategy_name}"
chart_dict[sheet_name] = save_path
return chart_dict
def draw_quant_line_chart(self, data: pd.DataFrame, strategy_name: str = "全均线策略"):
"""
根据量化策略买卖明细记录绘制量化策略涨跌与市场自然涨跌的折线图
:param data: 量化策略买卖明细记录
:param strategy_name: 策略名称
:return: None
"""
symbols = data["symbol"].unique()
bars = data["bar"].unique()
chart_dict = {}
for symbol in symbols:
for bar in bars:
symbol_bar_data = data[(data["symbol"] == symbol) & (data["bar"] == bar)]
if symbol_bar_data.empty:
continue
# 获取第一行数据作为基准
first_row = symbol_bar_data.iloc[0].copy()
# 创建初始化行,设置基准值
init_row = first_row.copy()
init_row.loc["pct_chg_total"] = 1.0 # 量化策略初始值为1
init_row.loc["end_timestamp"] = first_row["begin_timestamp"]
init_row.loc["end_date_time"] = first_row["begin_date_time"]
init_row.loc["end_close"] = first_row["begin_close"]
init_row.loc["end_ma5"] = first_row["begin_ma5"]
init_row.loc["end_ma10"] = first_row["begin_ma10"]
init_row.loc["end_ma20"] = first_row["begin_ma20"]
init_row.loc["end_ma30"] = first_row["begin_ma30"]
init_row.loc["end_macd_diff"] = first_row["begin_macd_diff"]
init_row.loc["end_macd_dea"] = first_row["begin_macd_dea"]
init_row.loc["end_macd"] = first_row["begin_macd"]
init_row.loc["pct_chg"] = 0
init_row.loc["interval_seconds"] = 0
init_row.loc["interval_minutes"] = 0
init_row.loc["interval_hours"] = 0
init_row.loc["interval_days"] = 0
# 将初始化行添加到数据开头
symbol_bar_data = pd.concat([pd.DataFrame([init_row]), symbol_bar_data])
symbol_bar_data.sort_values(by="end_timestamp", ascending=True, inplace=True)
symbol_bar_data.reset_index(drop=True, inplace=True)
# 确保时间列是datetime类型避免matplotlib警告
symbol_bar_data["end_date_time"] = pd.to_datetime(symbol_bar_data["end_date_time"])
# 计算市场价位归一化数据(相对于初始价格)
symbol_bar_data["end_close_to_1"] = symbol_bar_data["end_close"] / init_row["end_close"]
symbol_bar_data["end_close_to_1"] = symbol_bar_data["end_close_to_1"].round(4)
# 绘制折线图
plt.figure(figsize=(12, 7))
# 绘制量化策略涨跌线(蓝色)
plt.plot(symbol_bar_data["end_date_time"], symbol_bar_data["pct_chg_total"],
label="量化策略涨跌", color='blue', linewidth=2, marker='o', markersize=4)
# 绘制市场自然涨跌线(绿色)
plt.plot(symbol_bar_data["end_date_time"], symbol_bar_data["end_close_to_1"],
label="市场自然涨跌", color='green', linewidth=2, marker='s', markersize=4)
plt.title(f"{symbol} {bar} 量化与市场折线图_{strategy_name}",
fontsize=14, fontweight='bold')
plt.xlabel("时间", fontsize=12)
plt.ylabel("涨跌变化", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
# 设置x轴标签避免matplotlib警告
# 选择合适的时间间隔显示标签,避免过于密集
if len(symbol_bar_data) > 30:
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
step = max(1, len(symbol_bar_data) // 30)
# 创建标签索引列表,确保包含首尾数据
label_indices = [0] # 第一条
# 添加中间间隔的标签
for i in range(step, len(symbol_bar_data) - 1, step):
label_indices.append(i)
# 添加最后一条(如果还没有包含的话)
if len(symbol_bar_data) - 1 not in label_indices:
label_indices.append(len(symbol_bar_data) - 1)
# 设置x轴标签
plt.xticks(symbol_bar_data["end_date_time"].iloc[label_indices],
symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%m-%d %H:%M'),
rotation=45, ha='right')
else:
# 如果数据点较少,全部显示
plt.xticks(symbol_bar_data["end_date_time"],
symbol_bar_data["end_date_time"].dt.strftime('%m-%d %H:%M'),
rotation=45, ha='right')
plt.tight_layout()
save_path = os.path.join(
self.stats_chart_dir,
f"{symbol}_{bar}_line_chart_{strategy_name}.png",
)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
sheet_name = f"{symbol}_{bar}_折线图_{strategy_name}"
chart_dict[sheet_name] = save_path
return chart_dict

View File

@ -38,6 +38,7 @@ class TradeMaStrategyMain:
for strategy_name, strategy_info in strategy_dict.items():
pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
pct_chg_df_list.append(pct_chg_df)
pct_chg_df = pd.concat(pct_chg_df_list)
def statistics_pct_chg(self, pct_chg_df: pd.DataFrame):