459 lines
20 KiB
Python
459 lines
20 KiB
Python
from core.db.db_huge_volume_data import DBHugeVolumeData
|
||
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
from PIL import Image as PILImage
|
||
import core.logger as logging
|
||
from datetime import datetime, timezone, timedelta
|
||
from core.utils import get_current_date_time
|
||
import pandas as pd
|
||
import os
|
||
import re
|
||
import openpyxl
|
||
from openpyxl.styles import Font
|
||
|
||
logger = logging.logger
|
||
|
||
sns.set_theme(style="whitegrid")
|
||
# 设置中文
|
||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
|
||
|
||
class HugeVolumeChart:
|
||
def __init__(
|
||
self,
|
||
data: pd.DataFrame,
|
||
output_folder: str = "./output/huge_volume_statistics/",
|
||
):
|
||
"""
|
||
初始化
|
||
data: 数据
|
||
data中的列名为:
|
||
symbol: 币种
|
||
bar: 周期
|
||
window_size: 窗口大小
|
||
huge_volume: 是否巨量
|
||
volume_ratio_percentile_10: 10%分位数
|
||
volume_ratio_percentile_10_mean: 10%分位数平均值
|
||
price_type: 价格类型
|
||
next_period: 下一个周期
|
||
average_return: 平均回报
|
||
max_return: 最大回报
|
||
min_return: 最小回报
|
||
rise_count: 上涨次数
|
||
rise_ratio: 上涨比例
|
||
fall_count: 下跌次数
|
||
fall_ratio: 下跌比例
|
||
draw_count: 持平次数
|
||
draw_ratio: 持平比例
|
||
total_count: 总次数
|
||
output_folder: 输出文件夹
|
||
"""
|
||
self.data = data
|
||
# remove 1D bar
|
||
self.data = self.data[self.data["bar"] != "1D"]
|
||
self.data.reset_index(drop=True, inplace=True)
|
||
self.output_folder = output_folder
|
||
os.makedirs(self.output_folder, exist_ok=True)
|
||
self.temp_dir = os.path.join(self.output_folder, "temp")
|
||
os.makedirs(self.temp_dir, exist_ok=True)
|
||
self.symbol_list = self.data["symbol"].unique().tolist()
|
||
# sort symbol_list
|
||
self.symbol_list.sort()
|
||
self.bar_list = self.data["bar"].unique().tolist()
|
||
self.bar_list.sort()
|
||
self.window_size_list = self.data["window_size"].unique().tolist()
|
||
self.window_size_list.sort()
|
||
self.next_period_list = self.data["next_period"].unique().tolist()
|
||
self.next_period_list.sort()
|
||
self.volume_ratio_percentile_10_list = (
|
||
self.data["volume_ratio_percentile_10"].unique().tolist()
|
||
)
|
||
self.volume_ratio_percentile_10_list.sort()
|
||
self.price_type_list = self.data["price_type"].unique().tolist()
|
||
self.price_type_list.sort()
|
||
|
||
def plot_entrance(self, include_heatmap: bool = True, include_line: bool = True):
|
||
"""
|
||
绘制上涨下跌图入口
|
||
"""
|
||
charts_dict = {}
|
||
if include_heatmap:
|
||
heatmap_plot_dict = self.plot_heatmap_entrance()
|
||
if include_line:
|
||
line_plot_dict = self.plot_line_chart_entrance()
|
||
|
||
if include_line:
|
||
charts_dict.update(line_plot_dict)
|
||
if include_heatmap:
|
||
charts_dict.update(heatmap_plot_dict)
|
||
return charts_dict
|
||
|
||
def plot_line_chart_entrance(self):
|
||
"""
|
||
绘制折线图入口
|
||
"""
|
||
charts_dict = {}
|
||
# 根据price_type_list,得到各个price_type的平均rise_ratio,平均fall_ratio,平均draw_ratio, 平均average_return
|
||
total_chart_path = self.plot_pice_rise_fall(data=self.data, prefix="总体")
|
||
charts_dict["总体"] = {"总体": total_chart_path}
|
||
self.plot_window_size_rise_fall(charts_dict=charts_dict)
|
||
self.plot_window_size_bar_rise_fall(charts_dict=charts_dict)
|
||
self.plot_window_size_bar_next_period_rise_fall(charts_dict=charts_dict)
|
||
self.plot_symbol_rise_fall(charts_dict=charts_dict)
|
||
self.plot_symbol_bar_rise_fall(charts_dict=charts_dict)
|
||
self.plot_symbol_bar_window_size_rise_fall(charts_dict=charts_dict)
|
||
self.plot_symbol_bar_window_size_next_period_rise_fall(charts_dict=charts_dict)
|
||
|
||
# self.plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_rise_fall(charts_dict=charts_dict)
|
||
# self.plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_next_period_rise_fall(charts_dict=charts_dict)
|
||
|
||
self.output_excel(chart_type="line_chart", charts_dict=charts_dict)
|
||
return charts_dict
|
||
|
||
def plot_heatmap_entrance(self):
|
||
"""
|
||
绘制热力图入口
|
||
"""
|
||
charts_dict = {}
|
||
|
||
self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="rise_ratio", title=f"Rise Ratio Heatmap by Window Size and Bar")
|
||
self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="fall_ratio", title=f"Fall Ratio Heatmap by Window Size and Bar")
|
||
self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="average_return", title=f"Average Return Heatmap by Window Size and Bar")
|
||
|
||
self.output_excel(chart_type="heatmap_chart", charts_dict=charts_dict)
|
||
return charts_dict
|
||
|
||
def plot_symbol_heatmap(self,
|
||
charts_dict: dict,
|
||
ratio_column: str = "rise_ratio",
|
||
title: str = "Rise Ratio Heatmap by Window Size and Bar"
|
||
):
|
||
"""
|
||
绘制symbol热力图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_{ratio_column}_heatmap"] = {}
|
||
for price_type in self.price_type_list:
|
||
logger.info(f"绘制{symbol} {price_type} {ratio_column}热力图")
|
||
df = self.data[(self.data["symbol"] == symbol) & (self.data["price_type"] == price_type)]
|
||
pivot_table = df.pivot_table(values=ratio_column, index='window_size', columns='bar', aggfunc='mean')
|
||
plt.figure(figsize=(10, 6))
|
||
# 热力图以红色渐变为主,红色表示高,绿色表示低
|
||
sns.heatmap(pivot_table, annot=True, cmap='RdYlGn_r', fmt='.3f')
|
||
plt.xlabel('Period')
|
||
plt.ylabel('Window Size')
|
||
plt.title(f"{title} {price_type}")
|
||
# plt.show()
|
||
chart_path = os.path.join(self.temp_dir, f'{symbol}_{price_type}_{ratio_column}_heatmap.png')
|
||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||
plt.close()
|
||
charts_dict[f"{symbol}_{ratio_column}_heatmap"][f"{symbol}_{price_type}_{ratio_column}_heatmap"] = chart_path
|
||
|
||
def plot_window_size_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
不区分symbol, 绘制window_size上涨下跌图
|
||
"""
|
||
charts_dict["window_size"] = {}
|
||
for window_size in self.window_size_list:
|
||
data = self.data[self.data["window_size"] == window_size]
|
||
chart_path = self.plot_pice_rise_fall(data, prefix=f"window_size_{window_size}")
|
||
charts_dict["window_size"][f"window_size_{window_size}"] = chart_path
|
||
|
||
def plot_window_size_bar_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
不区分symbol, 根据window_size绘制bar上涨下跌图
|
||
"""
|
||
charts_dict["window_size_bar"] = {}
|
||
for window_size in self.window_size_list:
|
||
for bar in self.bar_list:
|
||
data = self.data[
|
||
(self.data["window_size"] == window_size)
|
||
& (self.data["bar"] == bar)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(
|
||
data, prefix=f"window_size_{window_size}_bar_{bar}"
|
||
)
|
||
charts_dict["window_size_bar"][f"window_size_{window_size}_bar_{bar}"] = chart_path
|
||
|
||
def plot_window_size_bar_next_period_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
不区分symbol, 根据window_size, bar, next_period上涨下跌图
|
||
"""
|
||
charts_dict["window_size_bar_period"] = {}
|
||
for window_size in self.window_size_list:
|
||
for bar in self.bar_list:
|
||
for next_period in self.next_period_list:
|
||
data = self.data[
|
||
(self.data["window_size"] == window_size)
|
||
& (self.data["bar"] == bar)
|
||
& (self.data["next_period"] == next_period)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(
|
||
data,
|
||
prefix=f"window_size_{window_size}_bar_{bar}_next_period_{next_period}"
|
||
)
|
||
charts_dict["window_size_bar_period"][f"window_size_{window_size}_bar_{bar}_next_period_{next_period}"] = chart_path
|
||
|
||
def plot_symbol_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
区分symbol, 绘制symbol上涨下跌图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_总体"] = {}
|
||
data = self.data[self.data["symbol"] == symbol]
|
||
chart_path = self.plot_pice_rise_fall(data, prefix=symbol)
|
||
charts_dict[f"{symbol}_总体"][symbol] = chart_path
|
||
|
||
def plot_symbol_bar_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
区分symbol, bar, 绘制symbol上涨下跌图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_bar"] = {}
|
||
for bar in self.bar_list:
|
||
data = self.data[
|
||
(self.data["symbol"] == symbol) & (self.data["bar"] == bar)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(data, prefix=f"{symbol}_{bar}")
|
||
charts_dict[f"{symbol}_bar"][f"{symbol}_{bar}"] = chart_path
|
||
|
||
def plot_symbol_bar_window_size_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
区分symbol, bar, window_size, 绘制symbol上涨下跌图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_bar_window_size"] = {}
|
||
for bar in self.bar_list:
|
||
for window_size in self.window_size_list:
|
||
data = self.data[
|
||
(self.data["symbol"] == symbol)
|
||
& (self.data["bar"] == bar)
|
||
& (self.data["window_size"] == window_size)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(
|
||
data, prefix=f"{symbol}_{bar}_ws_{window_size}"
|
||
)
|
||
charts_dict[f"{symbol}_bar_window_size"][f"{symbol}_{bar}_ws_{window_size}"] = chart_path
|
||
|
||
def plot_symbol_bar_window_size_next_period_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
区分symbol, bar, window_size, next_period, 绘制symbol上涨下跌图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_bar_ws_period"] = {}
|
||
for bar in self.bar_list:
|
||
for window_size in self.window_size_list:
|
||
for next_period in self.next_period_list:
|
||
data = self.data[
|
||
(self.data["symbol"] == symbol)
|
||
& (self.data["bar"] == bar)
|
||
& (self.data["window_size"] == window_size)
|
||
& (self.data["next_period"] == next_period)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(
|
||
data,
|
||
prefix=f"{symbol}_{bar}_ws_{window_size}_next_period_{next_period}"
|
||
)
|
||
charts_dict[f"{symbol}_bar_ws_period"][f"{symbol}_{bar}_ws_{window_size}_next_period_{next_period}"] = chart_path
|
||
|
||
def plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_rise_fall(self, charts_dict: dict):
|
||
"""
|
||
区分symbol, bar, window_size, volume_ratio_percentile_10_mean, 绘制symbol上涨下跌图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_bar_window_size_vol_per"] = {}
|
||
for bar in self.bar_list:
|
||
for window_size in self.window_size_list:
|
||
for (
|
||
volume_ratio_percentile_10
|
||
) in self.volume_ratio_percentile_10_list:
|
||
data = self.data[
|
||
(self.data["symbol"] == symbol)
|
||
& (self.data["bar"] == bar)
|
||
& (self.data["window_size"] == window_size)
|
||
& (
|
||
self.data["volume_ratio_percentile_10"]
|
||
== volume_ratio_percentile_10
|
||
)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(
|
||
data,
|
||
prefix=f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}"
|
||
)
|
||
charts_dict[f"{symbol}_bar_window_size_vol_per"][f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}"] = chart_path
|
||
|
||
def plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_next_period_rise_fall(
|
||
self,
|
||
charts_dict: dict,
|
||
):
|
||
"""
|
||
区分symbol, bar, window_size, volume_ratio_percentile_10_mean, next_period, 绘制symbol上涨下跌图
|
||
"""
|
||
for symbol in self.symbol_list:
|
||
charts_dict[f"{symbol}_bar_ws_vol_period"] = {}
|
||
for bar in self.bar_list:
|
||
for window_size in self.window_size_list:
|
||
for (
|
||
volume_ratio_percentile_10
|
||
) in self.volume_ratio_percentile_10_list:
|
||
for next_period in self.next_period_list:
|
||
data = self.data[
|
||
(self.data["symbol"] == symbol)
|
||
& (self.data["bar"] == bar)
|
||
& (self.data["window_size"] == window_size)
|
||
& (
|
||
self.data["volume_ratio_percentile_10"]
|
||
== volume_ratio_percentile_10
|
||
)
|
||
& (self.data["next_period"] == next_period)
|
||
]
|
||
chart_path = self.plot_pice_rise_fall(
|
||
data,
|
||
prefix=f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}_next_period_{next_period}"
|
||
)
|
||
charts_dict[f"{symbol}_bar_ws_vol_period"][f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}_next_period_{next_period}"] = chart_path
|
||
|
||
def plot_pice_rise_fall(self, data: pd.DataFrame, prefix: str = ""):
|
||
"""
|
||
绘制价格上涨下跌图
|
||
"""
|
||
logger.info(f"绘制价格上涨下跌图: {prefix}")
|
||
# 根据price_type_list,得到各个price_type的平均rise_ratio,平均fall_ratio,平均draw_ratio, 平均average_return
|
||
price_type_data_dict = {}
|
||
for price_type in self.price_type_list:
|
||
filtered_data = data[data["price_type"] == price_type]
|
||
average_rise_ratio = filtered_data["rise_ratio"].mean()
|
||
average_fall_ratio = filtered_data["fall_ratio"].mean()
|
||
average_draw_ratio = filtered_data["draw_ratio"].mean()
|
||
average_average_return = filtered_data["average_return"].mean()
|
||
price_type_data_dict[price_type] = {
|
||
"average_rise_ratio": average_rise_ratio,
|
||
"average_fall_ratio": average_fall_ratio,
|
||
"average_draw_ratio": average_draw_ratio,
|
||
"average_average_return": average_average_return,
|
||
}
|
||
|
||
# 准备数据用于绘图
|
||
price_types = list(price_type_data_dict.keys())
|
||
rise_ratios = [
|
||
price_type_data_dict[pt]["average_rise_ratio"] for pt in price_types
|
||
]
|
||
fall_ratios = [
|
||
price_type_data_dict[pt]["average_fall_ratio"] for pt in price_types
|
||
]
|
||
draw_ratios = [
|
||
price_type_data_dict[pt]["average_draw_ratio"] for pt in price_types
|
||
]
|
||
avg_returns = [
|
||
price_type_data_dict[pt]["average_average_return"] for pt in price_types
|
||
]
|
||
|
||
# 创建子图,保持2x2布局
|
||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
|
||
|
||
# 绘制上涨比例
|
||
bars1 = ax1.bar(price_types, rise_ratios, color="green", alpha=0.7)
|
||
ax1.set_title(f"{prefix}总体平均上涨比例")
|
||
ax1.set_ylabel("上涨比例")
|
||
ax1.tick_params(axis="x", rotation=0)
|
||
ax1.bar_label(bars1, fmt="%.2f")
|
||
|
||
# 绘制下跌比例
|
||
bars2 = ax2.bar(price_types, fall_ratios, color="red", alpha=0.7)
|
||
ax2.set_title(f"{prefix}平均下跌比例")
|
||
ax2.set_ylabel("下跌比例")
|
||
ax2.tick_params(axis="x", rotation=0)
|
||
ax2.bar_label(bars2, fmt="%.2f")
|
||
|
||
# 绘制持平比例
|
||
bars3 = ax3.bar(price_types, draw_ratios, color="gray", alpha=0.7)
|
||
ax3.set_title(f"{prefix}平均持平比例")
|
||
ax3.set_ylabel("持平比例")
|
||
ax3.tick_params(axis="x", rotation=0)
|
||
ax3.bar_label(bars3, fmt="%.2f")
|
||
|
||
# 绘制平均回报
|
||
bars4 = ax4.bar(price_types, avg_returns, color="blue", alpha=0.7)
|
||
ax4.set_title(f"{prefix}平均回报")
|
||
ax4.set_ylabel("平均回报")
|
||
ax4.tick_params(axis="x", rotation=0)
|
||
ax4.bar_label(bars4, fmt="%.2f")
|
||
|
||
# 调整布局,增加底部空间和垂直间距以显示完整的x轴标签
|
||
plt.tight_layout()
|
||
plt.subplots_adjust(bottom=0.15, hspace=0.4)
|
||
# plt.show()
|
||
chart_path = os.path.join(self.temp_dir, f'{prefix}.png')
|
||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||
plt.close()
|
||
return chart_path
|
||
|
||
def output_excel(self, chart_type: str, charts_dict: dict):
|
||
"""
|
||
输出Excel文件,包含所有图表
|
||
charts_dict: 图表数据字典,格式为:
|
||
{
|
||
"sheet_name": {
|
||
"chart_name": "chart_path"
|
||
}
|
||
}
|
||
"""
|
||
logger.info(f"输出Excel文件,包含所有{chart_type}图表")
|
||
file_name = f"huge_volume_{chart_type}_{get_current_date_time(format="%Y%m%d%H%M%S")}.xlsx"
|
||
file_path = os.path.join(self.output_folder, file_name)
|
||
|
||
# Create Excel file and worksheet
|
||
wb = Workbook()
|
||
wb.remove(wb.active) # Remove default sheet
|
||
|
||
for sheet_name, chart_data_dict in charts_dict.items():
|
||
try:
|
||
ws = wb.create_sheet(title=sheet_name)
|
||
row_offset = 1
|
||
for chart_name, chart_path in chart_data_dict.items():
|
||
# Load image to get dimensions
|
||
with PILImage.open(chart_path) as img:
|
||
width_px, height_px = img.size
|
||
|
||
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
|
||
pixels_per_point = 1.333
|
||
points_per_row = 15 # Default row height in points
|
||
pixels_per_row = (
|
||
points_per_row * pixels_per_point
|
||
) # ≈ 20 pixels per row
|
||
chart_rows = max(
|
||
10, int(height_px / pixels_per_row)
|
||
) # Minimum 10 rows for small charts
|
||
|
||
# Add chart title
|
||
# 支持中文标题
|
||
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
|
||
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
|
||
row_offset += 2 # Add 2 rows for title and spacing
|
||
|
||
# Insert chart image
|
||
img = Image(chart_path)
|
||
ws.add_image(img, f"A{row_offset}")
|
||
|
||
# Update row offset (chart height + padding)
|
||
row_offset += chart_rows + 5 # Add 5 rows for padding between charts
|
||
except Exception as e:
|
||
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
|
||
continue
|
||
|
||
# Save Excel file
|
||
wb.save(file_path)
|
||
print(f"Excel file saved as {file_path}")
|
||
|
||
for sheet_name, chart_data_dict in charts_dict.items():
|
||
for chart_name, chart_path in chart_data_dict.items():
|
||
try:
|
||
os.remove(chart_path)
|
||
except Exception as e:
|
||
logger.error(f"删除临时文件失败: {e}")
|