from core.db.db_huge_volume_data import DBHugeVolumeData from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp import matplotlib.pyplot as plt import seaborn as sns from openpyxl import Workbook from openpyxl.drawing.image import Image from PIL import Image as PILImage import logging from datetime import datetime import pandas as pd import os import re import openpyxl from openpyxl.styles import Font logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) sns.set_theme(style="whitegrid") # 设置中文 plt.rcParams["font.sans-serif"] = ["SimHei"] plt.rcParams["axes.unicode_minus"] = False class HugeVolumeChart: def __init__( self, data: pd.DataFrame, output_folder: str = "./output/huge_volume_statistics/", ): """ 初始化 data: 数据 data中的列名为: symbol: 币种 bar: 周期 window_size: 窗口大小 huge_volume: 是否巨量 volume_ratio_percentile_10: 10%分位数 volume_ratio_percentile_10_mean: 10%分位数平均值 price_type: 价格类型 next_period: 下一个周期 average_return: 平均回报 max_return: 最大回报 min_return: 最小回报 rise_count: 上涨次数 rise_ratio: 上涨比例 fall_count: 下跌次数 fall_ratio: 下跌比例 draw_count: 持平次数 draw_ratio: 持平比例 total_count: 总次数 output_folder: 输出文件夹 """ self.data = data # remove 1D bar self.data = self.data[self.data["bar"] != "1D"] self.data.reset_index(drop=True, inplace=True) self.output_folder = output_folder os.makedirs(self.output_folder, exist_ok=True) self.temp_dir = os.path.join(self.output_folder, "temp") os.makedirs(self.temp_dir, exist_ok=True) self.symbol_list = self.data["symbol"].unique().tolist() # sort symbol_list self.symbol_list.sort() self.bar_list = self.data["bar"].unique().tolist() self.bar_list.sort() self.window_size_list = self.data["window_size"].unique().tolist() self.window_size_list.sort() self.next_period_list = self.data["next_period"].unique().tolist() self.next_period_list.sort() self.volume_ratio_percentile_10_list = ( self.data["volume_ratio_percentile_10"].unique().tolist() ) self.volume_ratio_percentile_10_list.sort() self.price_type_list = self.data["price_type"].unique().tolist() self.price_type_list.sort() def plot_entrance(self, include_heatmap: bool = True, include_line: bool = True): """ 绘制上涨下跌图入口 """ charts_dict = {} if include_heatmap: heatmap_plot_dict = self.plot_heatmap_entrance() if include_line: line_plot_dict = self.plot_line_chart_entrance() if include_line: charts_dict.update(line_plot_dict) if include_heatmap: charts_dict.update(heatmap_plot_dict) return charts_dict def plot_line_chart_entrance(self): """ 绘制折线图入口 """ charts_dict = {} # 根据price_type_list,得到各个price_type的平均rise_ratio,平均fall_ratio,平均draw_ratio, 平均average_return total_chart_path = self.plot_pice_rise_fall(data=self.data, prefix="总体") charts_dict["总体"] = {"总体": total_chart_path} self.plot_window_size_rise_fall(charts_dict=charts_dict) self.plot_window_size_bar_rise_fall(charts_dict=charts_dict) self.plot_window_size_bar_next_period_rise_fall(charts_dict=charts_dict) self.plot_symbol_rise_fall(charts_dict=charts_dict) self.plot_symbol_bar_rise_fall(charts_dict=charts_dict) self.plot_symbol_bar_window_size_rise_fall(charts_dict=charts_dict) self.plot_symbol_bar_window_size_next_period_rise_fall(charts_dict=charts_dict) # self.plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_rise_fall(charts_dict=charts_dict) # self.plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_next_period_rise_fall(charts_dict=charts_dict) self.output_excel(chart_type="line_chart", charts_dict=charts_dict) return charts_dict def plot_heatmap_entrance(self): """ 绘制热力图入口 """ charts_dict = {} self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="rise_ratio", title=f"Rise Ratio Heatmap by Window Size and Bar") self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="fall_ratio", title=f"Fall Ratio Heatmap by Window Size and Bar") self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="average_return", title=f"Average Return Heatmap by Window Size and Bar") self.output_excel(chart_type="heatmap_chart", charts_dict=charts_dict) return charts_dict def plot_symbol_heatmap(self, charts_dict: dict, ratio_column: str = "rise_ratio", title: str = "Rise Ratio Heatmap by Window Size and Bar" ): """ 绘制symbol热力图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_{ratio_column}_heatmap"] = {} for price_type in self.price_type_list: logging.info(f"绘制{symbol} {price_type} {ratio_column}热力图") df = self.data[(self.data["symbol"] == symbol) & (self.data["price_type"] == price_type)] pivot_table = df.pivot_table(values=ratio_column, index='window_size', columns='bar', aggfunc='mean') plt.figure(figsize=(10, 6)) # 热力图以红色渐变为主,红色表示高,绿色表示低 sns.heatmap(pivot_table, annot=True, cmap='RdYlGn_r', fmt='.3f') plt.xlabel('Period') plt.ylabel('Window Size') plt.title(f"{title} {price_type}") # plt.show() chart_path = os.path.join(self.temp_dir, f'{symbol}_{price_type}_{ratio_column}_heatmap.png') plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.close() charts_dict[f"{symbol}_{ratio_column}_heatmap"][f"{symbol}_{price_type}_{ratio_column}_heatmap"] = chart_path def plot_window_size_rise_fall(self, charts_dict: dict): """ 不区分symbol, 绘制window_size上涨下跌图 """ charts_dict["window_size"] = {} for window_size in self.window_size_list: data = self.data[self.data["window_size"] == window_size] chart_path = self.plot_pice_rise_fall(data, prefix=f"window_size_{window_size}") charts_dict["window_size"][f"window_size_{window_size}"] = chart_path def plot_window_size_bar_rise_fall(self, charts_dict: dict): """ 不区分symbol, 根据window_size绘制bar上涨下跌图 """ charts_dict["window_size_bar"] = {} for window_size in self.window_size_list: for bar in self.bar_list: data = self.data[ (self.data["window_size"] == window_size) & (self.data["bar"] == bar) ] chart_path = self.plot_pice_rise_fall( data, prefix=f"window_size_{window_size}_bar_{bar}" ) charts_dict["window_size_bar"][f"window_size_{window_size}_bar_{bar}"] = chart_path def plot_window_size_bar_next_period_rise_fall(self, charts_dict: dict): """ 不区分symbol, 根据window_size, bar, next_period上涨下跌图 """ charts_dict["window_size_bar_period"] = {} for window_size in self.window_size_list: for bar in self.bar_list: for next_period in self.next_period_list: data = self.data[ (self.data["window_size"] == window_size) & (self.data["bar"] == bar) & (self.data["next_period"] == next_period) ] chart_path = self.plot_pice_rise_fall( data, prefix=f"window_size_{window_size}_bar_{bar}_next_period_{next_period}" ) charts_dict["window_size_bar_period"][f"window_size_{window_size}_bar_{bar}_next_period_{next_period}"] = chart_path def plot_symbol_rise_fall(self, charts_dict: dict): """ 区分symbol, 绘制symbol上涨下跌图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_总体"] = {} data = self.data[self.data["symbol"] == symbol] chart_path = self.plot_pice_rise_fall(data, prefix=symbol) charts_dict[f"{symbol}_总体"][symbol] = chart_path def plot_symbol_bar_rise_fall(self, charts_dict: dict): """ 区分symbol, bar, 绘制symbol上涨下跌图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_bar"] = {} for bar in self.bar_list: data = self.data[ (self.data["symbol"] == symbol) & (self.data["bar"] == bar) ] chart_path = self.plot_pice_rise_fall(data, prefix=f"{symbol}_{bar}") charts_dict[f"{symbol}_bar"][f"{symbol}_{bar}"] = chart_path def plot_symbol_bar_window_size_rise_fall(self, charts_dict: dict): """ 区分symbol, bar, window_size, 绘制symbol上涨下跌图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_bar_window_size"] = {} for bar in self.bar_list: for window_size in self.window_size_list: data = self.data[ (self.data["symbol"] == symbol) & (self.data["bar"] == bar) & (self.data["window_size"] == window_size) ] chart_path = self.plot_pice_rise_fall( data, prefix=f"{symbol}_{bar}_ws_{window_size}" ) charts_dict[f"{symbol}_bar_window_size"][f"{symbol}_{bar}_ws_{window_size}"] = chart_path def plot_symbol_bar_window_size_next_period_rise_fall(self, charts_dict: dict): """ 区分symbol, bar, window_size, next_period, 绘制symbol上涨下跌图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_bar_ws_period"] = {} for bar in self.bar_list: for window_size in self.window_size_list: for next_period in self.next_period_list: data = self.data[ (self.data["symbol"] == symbol) & (self.data["bar"] == bar) & (self.data["window_size"] == window_size) & (self.data["next_period"] == next_period) ] chart_path = self.plot_pice_rise_fall( data, prefix=f"{symbol}_{bar}_ws_{window_size}_next_period_{next_period}" ) charts_dict[f"{symbol}_bar_ws_period"][f"{symbol}_{bar}_ws_{window_size}_next_period_{next_period}"] = chart_path def plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_rise_fall(self, charts_dict: dict): """ 区分symbol, bar, window_size, volume_ratio_percentile_10_mean, 绘制symbol上涨下跌图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_bar_window_size_vol_per"] = {} for bar in self.bar_list: for window_size in self.window_size_list: for ( volume_ratio_percentile_10 ) in self.volume_ratio_percentile_10_list: data = self.data[ (self.data["symbol"] == symbol) & (self.data["bar"] == bar) & (self.data["window_size"] == window_size) & ( self.data["volume_ratio_percentile_10"] == volume_ratio_percentile_10 ) ] chart_path = self.plot_pice_rise_fall( data, prefix=f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}" ) charts_dict[f"{symbol}_bar_window_size_vol_per"][f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}"] = chart_path def plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_next_period_rise_fall( self, charts_dict: dict, ): """ 区分symbol, bar, window_size, volume_ratio_percentile_10_mean, next_period, 绘制symbol上涨下跌图 """ for symbol in self.symbol_list: charts_dict[f"{symbol}_bar_ws_vol_period"] = {} for bar in self.bar_list: for window_size in self.window_size_list: for ( volume_ratio_percentile_10 ) in self.volume_ratio_percentile_10_list: for next_period in self.next_period_list: data = self.data[ (self.data["symbol"] == symbol) & (self.data["bar"] == bar) & (self.data["window_size"] == window_size) & ( self.data["volume_ratio_percentile_10"] == volume_ratio_percentile_10 ) & (self.data["next_period"] == next_period) ] chart_path = self.plot_pice_rise_fall( data, prefix=f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}_next_period_{next_period}" ) charts_dict[f"{symbol}_bar_ws_vol_period"][f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}_next_period_{next_period}"] = chart_path def plot_pice_rise_fall(self, data: pd.DataFrame, prefix: str = ""): """ 绘制价格上涨下跌图 """ logging.info(f"绘制价格上涨下跌图: {prefix}") # 根据price_type_list,得到各个price_type的平均rise_ratio,平均fall_ratio,平均draw_ratio, 平均average_return price_type_data_dict = {} for price_type in self.price_type_list: filtered_data = data[data["price_type"] == price_type] average_rise_ratio = filtered_data["rise_ratio"].mean() average_fall_ratio = filtered_data["fall_ratio"].mean() average_draw_ratio = filtered_data["draw_ratio"].mean() average_average_return = filtered_data["average_return"].mean() price_type_data_dict[price_type] = { "average_rise_ratio": average_rise_ratio, "average_fall_ratio": average_fall_ratio, "average_draw_ratio": average_draw_ratio, "average_average_return": average_average_return, } # 准备数据用于绘图 price_types = list(price_type_data_dict.keys()) rise_ratios = [ price_type_data_dict[pt]["average_rise_ratio"] for pt in price_types ] fall_ratios = [ price_type_data_dict[pt]["average_fall_ratio"] for pt in price_types ] draw_ratios = [ price_type_data_dict[pt]["average_draw_ratio"] for pt in price_types ] avg_returns = [ price_type_data_dict[pt]["average_average_return"] for pt in price_types ] # 创建子图,保持2x2布局 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) # 绘制上涨比例 bars1 = ax1.bar(price_types, rise_ratios, color="green", alpha=0.7) ax1.set_title(f"{prefix}总体平均上涨比例") ax1.set_ylabel("上涨比例") ax1.tick_params(axis="x", rotation=0) ax1.bar_label(bars1, fmt="%.2f") # 绘制下跌比例 bars2 = ax2.bar(price_types, fall_ratios, color="red", alpha=0.7) ax2.set_title(f"{prefix}平均下跌比例") ax2.set_ylabel("下跌比例") ax2.tick_params(axis="x", rotation=0) ax2.bar_label(bars2, fmt="%.2f") # 绘制持平比例 bars3 = ax3.bar(price_types, draw_ratios, color="gray", alpha=0.7) ax3.set_title(f"{prefix}平均持平比例") ax3.set_ylabel("持平比例") ax3.tick_params(axis="x", rotation=0) ax3.bar_label(bars3, fmt="%.2f") # 绘制平均回报 bars4 = ax4.bar(price_types, avg_returns, color="blue", alpha=0.7) ax4.set_title(f"{prefix}平均回报") ax4.set_ylabel("平均回报") ax4.tick_params(axis="x", rotation=0) ax4.bar_label(bars4, fmt="%.2f") # 调整布局,增加底部空间和垂直间距以显示完整的x轴标签 plt.tight_layout() plt.subplots_adjust(bottom=0.15, hspace=0.4) # plt.show() chart_path = os.path.join(self.temp_dir, f'{prefix}.png') plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.close() return chart_path def output_excel(self, chart_type: str, charts_dict: dict): """ 输出Excel文件,包含所有图表 charts_dict: 图表数据字典,格式为: { "sheet_name": { "chart_name": "chart_path" } } """ logging.info(f"输出Excel文件,包含所有{chart_type}图表") file_name = f"huge_volume_{chart_type}_{datetime.now().strftime('%Y%m%d%H%M%S')}.xlsx" file_path = os.path.join(self.output_folder, file_name) # Create Excel file and worksheet wb = Workbook() wb.remove(wb.active) # Remove default sheet for sheet_name, chart_data_dict in charts_dict.items(): try: ws = wb.create_sheet(title=sheet_name) row_offset = 1 for chart_name, chart_path in chart_data_dict.items(): # Load image to get dimensions with PILImage.open(chart_path) as img: width_px, height_px = img.size # Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels) pixels_per_point = 1.333 points_per_row = 15 # Default row height in points pixels_per_row = ( points_per_row * pixels_per_point ) # ≈ 20 pixels per row chart_rows = max( 10, int(height_px / pixels_per_row) ) # Minimum 10 rows for small charts # Add chart title # 支持中文标题 ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8") ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12) row_offset += 2 # Add 2 rows for title and spacing # Insert chart image img = Image(chart_path) ws.add_image(img, f"A{row_offset}") # Update row offset (chart height + padding) row_offset += chart_rows + 5 # Add 5 rows for padding between charts except Exception as e: logging.error(f"输出Excel Sheet {sheet_name} 失败: {e}") continue # Save Excel file wb.save(file_path) print(f"Excel file saved as {file_path}") for sheet_name, chart_data_dict in charts_dict.items(): for chart_name, chart_path in chart_data_dict.items(): try: os.remove(chart_path) except Exception as e: logging.error(f"删除临时文件失败: {e}")