crypto_quant/core/biz/huge_volume_chart.py

459 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
import matplotlib.pyplot as plt
import seaborn as sns
from openpyxl import Workbook
from openpyxl.drawing.image import Image
from PIL import Image as PILImage
import core.logger as logging
from datetime import datetime, timezone, timedelta
from core.utils import get_current_date_time
import pandas as pd
import os
import re
import openpyxl
from openpyxl.styles import Font
logger = logging.logger
sns.set_theme(style="whitegrid")
# 设置中文
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
class HugeVolumeChart:
def __init__(
self,
data: pd.DataFrame,
output_folder: str = "./output/huge_volume_statistics/",
):
"""
初始化
data: 数据
data中的列名为
symbol: 币种
bar: 周期
window_size: 窗口大小
huge_volume: 是否巨量
volume_ratio_percentile_10: 10%分位数
volume_ratio_percentile_10_mean: 10%分位数平均值
price_type: 价格类型
next_period: 下一个周期
average_return: 平均回报
max_return: 最大回报
min_return: 最小回报
rise_count: 上涨次数
rise_ratio: 上涨比例
fall_count: 下跌次数
fall_ratio: 下跌比例
draw_count: 持平次数
draw_ratio: 持平比例
total_count: 总次数
output_folder: 输出文件夹
"""
self.data = data
# remove 1D bar
self.data = self.data[self.data["bar"] != "1D"]
self.data.reset_index(drop=True, inplace=True)
self.output_folder = output_folder
os.makedirs(self.output_folder, exist_ok=True)
self.temp_dir = os.path.join(self.output_folder, "temp")
os.makedirs(self.temp_dir, exist_ok=True)
self.symbol_list = self.data["symbol"].unique().tolist()
# sort symbol_list
self.symbol_list.sort()
self.bar_list = self.data["bar"].unique().tolist()
self.bar_list.sort()
self.window_size_list = self.data["window_size"].unique().tolist()
self.window_size_list.sort()
self.next_period_list = self.data["next_period"].unique().tolist()
self.next_period_list.sort()
self.volume_ratio_percentile_10_list = (
self.data["volume_ratio_percentile_10"].unique().tolist()
)
self.volume_ratio_percentile_10_list.sort()
self.price_type_list = self.data["price_type"].unique().tolist()
self.price_type_list.sort()
def plot_entrance(self, include_heatmap: bool = True, include_line: bool = True):
"""
绘制上涨下跌图入口
"""
charts_dict = {}
if include_heatmap:
heatmap_plot_dict = self.plot_heatmap_entrance()
if include_line:
line_plot_dict = self.plot_line_chart_entrance()
if include_line:
charts_dict.update(line_plot_dict)
if include_heatmap:
charts_dict.update(heatmap_plot_dict)
return charts_dict
def plot_line_chart_entrance(self):
"""
绘制折线图入口
"""
charts_dict = {}
# 根据price_type_list得到各个price_type的平均rise_ratio平均fall_ratio平均draw_ratio, 平均average_return
total_chart_path = self.plot_pice_rise_fall(data=self.data, prefix="总体")
charts_dict["总体"] = {"总体": total_chart_path}
self.plot_window_size_rise_fall(charts_dict=charts_dict)
self.plot_window_size_bar_rise_fall(charts_dict=charts_dict)
self.plot_window_size_bar_next_period_rise_fall(charts_dict=charts_dict)
self.plot_symbol_rise_fall(charts_dict=charts_dict)
self.plot_symbol_bar_rise_fall(charts_dict=charts_dict)
self.plot_symbol_bar_window_size_rise_fall(charts_dict=charts_dict)
self.plot_symbol_bar_window_size_next_period_rise_fall(charts_dict=charts_dict)
# self.plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_rise_fall(charts_dict=charts_dict)
# self.plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_next_period_rise_fall(charts_dict=charts_dict)
self.output_excel(chart_type="line_chart", charts_dict=charts_dict)
return charts_dict
def plot_heatmap_entrance(self):
"""
绘制热力图入口
"""
charts_dict = {}
self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="rise_ratio", title=f"Rise Ratio Heatmap by Window Size and Bar")
self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="fall_ratio", title=f"Fall Ratio Heatmap by Window Size and Bar")
self.plot_symbol_heatmap(charts_dict=charts_dict, ratio_column="average_return", title=f"Average Return Heatmap by Window Size and Bar")
self.output_excel(chart_type="heatmap_chart", charts_dict=charts_dict)
return charts_dict
def plot_symbol_heatmap(self,
charts_dict: dict,
ratio_column: str = "rise_ratio",
title: str = "Rise Ratio Heatmap by Window Size and Bar"
):
"""
绘制symbol热力图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_{ratio_column}_heatmap"] = {}
for price_type in self.price_type_list:
logger.info(f"绘制{symbol} {price_type} {ratio_column}热力图")
df = self.data[(self.data["symbol"] == symbol) & (self.data["price_type"] == price_type)]
pivot_table = df.pivot_table(values=ratio_column, index='window_size', columns='bar', aggfunc='mean')
plt.figure(figsize=(10, 6))
# 热力图以红色渐变为主,红色表示高,绿色表示低
sns.heatmap(pivot_table, annot=True, cmap='RdYlGn_r', fmt='.3f')
plt.xlabel('Period')
plt.ylabel('Window Size')
plt.title(f"{title} {price_type}")
# plt.show()
chart_path = os.path.join(self.temp_dir, f'{symbol}_{price_type}_{ratio_column}_heatmap.png')
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
plt.close()
charts_dict[f"{symbol}_{ratio_column}_heatmap"][f"{symbol}_{price_type}_{ratio_column}_heatmap"] = chart_path
def plot_window_size_rise_fall(self, charts_dict: dict):
"""
不区分symbol, 绘制window_size上涨下跌图
"""
charts_dict["window_size"] = {}
for window_size in self.window_size_list:
data = self.data[self.data["window_size"] == window_size]
chart_path = self.plot_pice_rise_fall(data, prefix=f"window_size_{window_size}")
charts_dict["window_size"][f"window_size_{window_size}"] = chart_path
def plot_window_size_bar_rise_fall(self, charts_dict: dict):
"""
不区分symbol, 根据window_size绘制bar上涨下跌图
"""
charts_dict["window_size_bar"] = {}
for window_size in self.window_size_list:
for bar in self.bar_list:
data = self.data[
(self.data["window_size"] == window_size)
& (self.data["bar"] == bar)
]
chart_path = self.plot_pice_rise_fall(
data, prefix=f"window_size_{window_size}_bar_{bar}"
)
charts_dict["window_size_bar"][f"window_size_{window_size}_bar_{bar}"] = chart_path
def plot_window_size_bar_next_period_rise_fall(self, charts_dict: dict):
"""
不区分symbol, 根据window_size, bar, next_period上涨下跌图
"""
charts_dict["window_size_bar_period"] = {}
for window_size in self.window_size_list:
for bar in self.bar_list:
for next_period in self.next_period_list:
data = self.data[
(self.data["window_size"] == window_size)
& (self.data["bar"] == bar)
& (self.data["next_period"] == next_period)
]
chart_path = self.plot_pice_rise_fall(
data,
prefix=f"window_size_{window_size}_bar_{bar}_next_period_{next_period}"
)
charts_dict["window_size_bar_period"][f"window_size_{window_size}_bar_{bar}_next_period_{next_period}"] = chart_path
def plot_symbol_rise_fall(self, charts_dict: dict):
"""
区分symbol, 绘制symbol上涨下跌图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_总体"] = {}
data = self.data[self.data["symbol"] == symbol]
chart_path = self.plot_pice_rise_fall(data, prefix=symbol)
charts_dict[f"{symbol}_总体"][symbol] = chart_path
def plot_symbol_bar_rise_fall(self, charts_dict: dict):
"""
区分symbol, bar, 绘制symbol上涨下跌图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_bar"] = {}
for bar in self.bar_list:
data = self.data[
(self.data["symbol"] == symbol) & (self.data["bar"] == bar)
]
chart_path = self.plot_pice_rise_fall(data, prefix=f"{symbol}_{bar}")
charts_dict[f"{symbol}_bar"][f"{symbol}_{bar}"] = chart_path
def plot_symbol_bar_window_size_rise_fall(self, charts_dict: dict):
"""
区分symbol, bar, window_size, 绘制symbol上涨下跌图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_bar_window_size"] = {}
for bar in self.bar_list:
for window_size in self.window_size_list:
data = self.data[
(self.data["symbol"] == symbol)
& (self.data["bar"] == bar)
& (self.data["window_size"] == window_size)
]
chart_path = self.plot_pice_rise_fall(
data, prefix=f"{symbol}_{bar}_ws_{window_size}"
)
charts_dict[f"{symbol}_bar_window_size"][f"{symbol}_{bar}_ws_{window_size}"] = chart_path
def plot_symbol_bar_window_size_next_period_rise_fall(self, charts_dict: dict):
"""
区分symbol, bar, window_size, next_period, 绘制symbol上涨下跌图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_bar_ws_period"] = {}
for bar in self.bar_list:
for window_size in self.window_size_list:
for next_period in self.next_period_list:
data = self.data[
(self.data["symbol"] == symbol)
& (self.data["bar"] == bar)
& (self.data["window_size"] == window_size)
& (self.data["next_period"] == next_period)
]
chart_path = self.plot_pice_rise_fall(
data,
prefix=f"{symbol}_{bar}_ws_{window_size}_next_period_{next_period}"
)
charts_dict[f"{symbol}_bar_ws_period"][f"{symbol}_{bar}_ws_{window_size}_next_period_{next_period}"] = chart_path
def plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_rise_fall(self, charts_dict: dict):
"""
区分symbol, bar, window_size, volume_ratio_percentile_10_mean, 绘制symbol上涨下跌图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_bar_window_size_vol_per"] = {}
for bar in self.bar_list:
for window_size in self.window_size_list:
for (
volume_ratio_percentile_10
) in self.volume_ratio_percentile_10_list:
data = self.data[
(self.data["symbol"] == symbol)
& (self.data["bar"] == bar)
& (self.data["window_size"] == window_size)
& (
self.data["volume_ratio_percentile_10"]
== volume_ratio_percentile_10
)
]
chart_path = self.plot_pice_rise_fall(
data,
prefix=f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}"
)
charts_dict[f"{symbol}_bar_window_size_vol_per"][f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}"] = chart_path
def plot_symbol_bar_window_size_volume_ratio_percentile_10_mean_next_period_rise_fall(
self,
charts_dict: dict,
):
"""
区分symbol, bar, window_size, volume_ratio_percentile_10_mean, next_period, 绘制symbol上涨下跌图
"""
for symbol in self.symbol_list:
charts_dict[f"{symbol}_bar_ws_vol_period"] = {}
for bar in self.bar_list:
for window_size in self.window_size_list:
for (
volume_ratio_percentile_10
) in self.volume_ratio_percentile_10_list:
for next_period in self.next_period_list:
data = self.data[
(self.data["symbol"] == symbol)
& (self.data["bar"] == bar)
& (self.data["window_size"] == window_size)
& (
self.data["volume_ratio_percentile_10"]
== volume_ratio_percentile_10
)
& (self.data["next_period"] == next_period)
]
chart_path = self.plot_pice_rise_fall(
data,
prefix=f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}_next_period_{next_period}"
)
charts_dict[f"{symbol}_bar_ws_vol_period"][f"{symbol}_{bar}_ws_{window_size}_vol_per_{volume_ratio_percentile_10}_next_period_{next_period}"] = chart_path
def plot_pice_rise_fall(self, data: pd.DataFrame, prefix: str = ""):
"""
绘制价格上涨下跌图
"""
logger.info(f"绘制价格上涨下跌图: {prefix}")
# 根据price_type_list得到各个price_type的平均rise_ratio平均fall_ratio平均draw_ratio, 平均average_return
price_type_data_dict = {}
for price_type in self.price_type_list:
filtered_data = data[data["price_type"] == price_type]
average_rise_ratio = filtered_data["rise_ratio"].mean()
average_fall_ratio = filtered_data["fall_ratio"].mean()
average_draw_ratio = filtered_data["draw_ratio"].mean()
average_average_return = filtered_data["average_return"].mean()
price_type_data_dict[price_type] = {
"average_rise_ratio": average_rise_ratio,
"average_fall_ratio": average_fall_ratio,
"average_draw_ratio": average_draw_ratio,
"average_average_return": average_average_return,
}
# 准备数据用于绘图
price_types = list(price_type_data_dict.keys())
rise_ratios = [
price_type_data_dict[pt]["average_rise_ratio"] for pt in price_types
]
fall_ratios = [
price_type_data_dict[pt]["average_fall_ratio"] for pt in price_types
]
draw_ratios = [
price_type_data_dict[pt]["average_draw_ratio"] for pt in price_types
]
avg_returns = [
price_type_data_dict[pt]["average_average_return"] for pt in price_types
]
# 创建子图保持2x2布局
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
# 绘制上涨比例
bars1 = ax1.bar(price_types, rise_ratios, color="green", alpha=0.7)
ax1.set_title(f"{prefix}总体平均上涨比例")
ax1.set_ylabel("上涨比例")
ax1.tick_params(axis="x", rotation=0)
ax1.bar_label(bars1, fmt="%.2f")
# 绘制下跌比例
bars2 = ax2.bar(price_types, fall_ratios, color="red", alpha=0.7)
ax2.set_title(f"{prefix}平均下跌比例")
ax2.set_ylabel("下跌比例")
ax2.tick_params(axis="x", rotation=0)
ax2.bar_label(bars2, fmt="%.2f")
# 绘制持平比例
bars3 = ax3.bar(price_types, draw_ratios, color="gray", alpha=0.7)
ax3.set_title(f"{prefix}平均持平比例")
ax3.set_ylabel("持平比例")
ax3.tick_params(axis="x", rotation=0)
ax3.bar_label(bars3, fmt="%.2f")
# 绘制平均回报
bars4 = ax4.bar(price_types, avg_returns, color="blue", alpha=0.7)
ax4.set_title(f"{prefix}平均回报")
ax4.set_ylabel("平均回报")
ax4.tick_params(axis="x", rotation=0)
ax4.bar_label(bars4, fmt="%.2f")
# 调整布局增加底部空间和垂直间距以显示完整的x轴标签
plt.tight_layout()
plt.subplots_adjust(bottom=0.15, hspace=0.4)
# plt.show()
chart_path = os.path.join(self.temp_dir, f'{prefix}.png')
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
plt.close()
return chart_path
def output_excel(self, chart_type: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典,格式为:
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"输出Excel文件包含所有{chart_type}图表")
file_name = f"huge_volume_{chart_type}_{get_current_date_time()}.xlsx"
file_path = os.path.join(self.output_folder, file_name)
# Create Excel file and worksheet
wb = Workbook()
wb.remove(wb.active) # Remove default sheet
for sheet_name, chart_data_dict in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
for chart_name, chart_path in chart_data_dict.items():
# Load image to get dimensions
with PILImage.open(chart_path) as img:
width_px, height_px = img.size
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
pixels_per_point = 1.333
points_per_row = 15 # Default row height in points
pixels_per_row = (
points_per_row * pixels_per_point
) # ≈ 20 pixels per row
chart_rows = max(
10, int(height_px / pixels_per_row)
) # Minimum 10 rows for small charts
# Add chart title
# 支持中文标题
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
row_offset += 2 # Add 2 rows for title and spacing
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
# Update row offset (chart height + padding)
row_offset += chart_rows + 5 # Add 5 rows for padding between charts
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(file_path)
print(f"Excel file saved as {file_path}")
for sheet_name, chart_data_dict in charts_dict.items():
for chart_name, chart_path in chart_data_dict.items():
try:
os.remove(chart_path)
except Exception as e:
logger.error(f"删除临时文件失败: {e}")