diff --git a/core/statistics/ma_break_statistics.py b/core/statistics/ma_break_statistics.py new file mode 100644 index 0000000..d53dc58 --- /dev/null +++ b/core/statistics/ma_break_statistics.py @@ -0,0 +1,274 @@ +import core.logger as logging +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from datetime import datetime +import re +from openpyxl import Workbook +from openpyxl.drawing.image import Image +import openpyxl +from openpyxl.styles import Font +from PIL import Image as PILImage +from config import MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE +from core.db.db_market_data import DBMarketData +from core.db.db_huge_volume_data import DBHugeVolumeData +from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp + +# seaborn支持中文 +plt.rcParams["font.family"] = ["SimHei"] + +logger = logging.logger + +class MaBreakStatistics: + """ + 统计MA突破之后的涨跌幅 + MA向上突破的点位周期K线:5 > 10 > 20 > 30 + 统计MA向上突破的点位周期K线,突破之后,到: + 下一个MA向下突破的点位周期K线:30 > 20 > 10 > 5 + 之间的涨跌幅 + """ + def __init__(self): + mysql_user = MYSQL_CONFIG.get("user", "xch") + mysql_password = MYSQL_CONFIG.get("password", "") + if not mysql_password: + raise ValueError("MySQL password is not set") + mysql_host = MYSQL_CONFIG.get("host", "localhost") + mysql_port = MYSQL_CONFIG.get("port", 3306) + mysql_database = MYSQL_CONFIG.get("database", "okx") + + self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" + self.db_market_data = DBMarketData(self.db_url) + self.db_huge_volume_data = DBHugeVolumeData(self.db_url) + self.symbols = MONITOR_CONFIG.get("volume_monitor", {}).get( + "symbols", ["XCH-USDT"] + ) + self.bars = MONITOR_CONFIG.get("volume_monitor", {}).get( + "bars", ["5m", "15m", "30m", "1H"] + ) + self.stats_output_dir = "./output/statistics/excel/" + os.makedirs(self.stats_output_dir, exist_ok=True) + self.stats_chart_dir = "./output/statistics/chart/" + os.makedirs(self.stats_chart_dir, exist_ok=True) + + def batch_statistics(self, all_change: bool = True): + ma_break_market_data_list = [] + for symbol in self.symbols: + for bar in self.bars: + logger.info(f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计") + ma_break_market_data = self.statistics(symbol, bar, all_change) + if ma_break_market_data is not None: + ma_break_market_data_list.append(ma_break_market_data) + if len(ma_break_market_data_list) > 0: + ma_break_market_data = pd.concat(ma_break_market_data_list) + # 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count + pct_chg_df = (ma_break_market_data + .groupby(['symbol', 'bar'])['pct_chg'] + .agg(pct_chg_max='max', + pct_chg_min='min', + pct_chg_mean='mean', + pct_chg_std='std', + pct_chg_median='median', + pct_chg_count='count') + .reset_index()) + # 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count + interval_minutes_df = (ma_break_market_data + .groupby(['symbol', 'bar'])['interval_minutes'] + .agg(interval_minutes_max='max', + interval_minutes_min='min', + interval_minutes_mean='mean', + interval_minutes_std='std', + interval_minutes_median='median', + interval_minutes_count='count') + .reset_index()) + + earliest_market_date_time = ma_break_market_data["begin_date_time"].min() + earliest_market_date_time = re.sub(r"[\:\-\s]", "", str(earliest_market_date_time)) + latest_market_date_time = ma_break_market_data["end_date_time"].max() + if latest_market_date_time is None: + latest_market_date_time = datetime.now().strftime("%Y%m%d") + latest_market_date_time = re.sub(r"[\:\-\s]", "", str(latest_market_date_time)) + if all_change: + output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_完全转势.xlsx" + else: + output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_部分转势.xlsx" + output_file_path = os.path.join(self.stats_output_dir, output_file_name) + logger.info(f"导出{output_file_path}") + with pd.ExcelWriter(output_file_path) as writer: + ma_break_market_data.to_excel(writer, sheet_name="ma_break_market_data", index=False) + pct_chg_df.to_excel(writer, sheet_name="pct_chg_stats", index=False) + interval_minutes_df.to_excel(writer, sheet_name="interval_minutes_stats", index=False) + + chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, all_change) + self.output_chart_to_excel(output_file_path, chart_dict) + else: + return None + + def statistics(self, symbol: str, bar: str, all_change: bool = False): + market_data = self.db_market_data.query_market_data_by_symbol_bar(symbol, bar, start=None, end=None) + if market_data is None or len(market_data) == 0: + logger.warning(f"获取{symbol} {bar} 数据失败") + return + else: + market_data = pd.DataFrame(market_data) + market_data.sort_values(by="timestamp", ascending=True, inplace=True) + market_data.reset_index(drop=True, inplace=True) + logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}") + # 获得ma5, ma10, ma20, ma30不为空的行 + market_data = market_data[(market_data["ma5"].notna()) & + (market_data["ma10"].notna()) & + (market_data["ma20"].notna()) & + (market_data["ma30"].notna())] + logger.info(f"ma5, ma10, ma20, ma30不为空的行,数据条数: {len(market_data)}") + # 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行 + long_market_data = market_data[(market_data["ma_cross"] == "5上穿10") & (market_data["ma5"] > market_data["ma10"]) & + (market_data["ma10"] > market_data["ma20"]) & + (market_data["ma20"] > market_data["ma30"]) & + (market_data["close"] > market_data["ma20"])] + logger.info(f"5上穿10, 且ma5 > ma10 > ma20 > ma30,并且close > ma20的行,数据条数: {len(long_market_data)}") + if all_change: + # 获得ma5 < ma10 < ma20 < ma30的行 + short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) & + (market_data["ma10"] < market_data["ma20"]) & + (market_data["ma20"] < market_data["ma30"])] + logger.info(f"ma5 < ma10 < ma20 < ma30的行,数据条数: {len(short_market_data)}") + else: + # ma5 < ma10 and close < ma20 + short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) & + (market_data["close"] < market_data["ma20"])] + logger.info(f"ma5 < ma10 and close < ma20的行,数据条数: {len(short_market_data)}") + # concat long_market_data和short_market_data + ma_break_market_data = pd.concat([long_market_data, short_market_data]) + # 按照timestamp排序 + ma_break_market_data = ma_break_market_data.sort_values(by="timestamp", ascending=True) + # 获得ma_break_market_data的close列 + ma_break_market_data.reset_index(drop=True, inplace=True) + ma_break_market_data_pair_list = [] + ma_break_market_data_pair = {} + for index, row in ma_break_market_data.iterrows(): + ma_cross = row["ma_cross"] + timestamp = row["timestamp"] + close = row["close"] + ma5 = row["ma5"] + ma10 = row["ma10"] + ma20 = row["ma20"] + ma30 = row["ma30"] + if pd.notna(ma_cross) and ma_cross is not None: + ma_cross = str(ma_cross) + if ma_cross == "5上穿10" and (ma5 > ma10 and ma10 > ma20 and ma20 > ma30) and (close > ma20): + ma_break_market_data_pair = {} + ma_break_market_data_pair["symbol"] = symbol + ma_break_market_data_pair["bar"] = bar + ma_break_market_data_pair["begin_timestamp"] = timestamp + ma_break_market_data_pair["begin_date_time"] = timestamp_to_datetime(timestamp) + ma_break_market_data_pair["begin_close"] = close + ma_break_market_data_pair["begin_ma5"] = ma5 + ma_break_market_data_pair["begin_ma10"] = ma10 + ma_break_market_data_pair["begin_ma20"] = ma20 + ma_break_market_data_pair["begin_ma30"] = ma30 + + if all_change: + change_condition = (ma5 < ma10 and ma10 < ma20 and ma20 < ma30) + else: + # change_condition = (ma5 < ma10 or ma10 < ma20 or ma20 < ma30) + change_condition = (ma5 < ma10) and (close < ma20) + + if change_condition: + if ma_break_market_data_pair.get("begin_timestamp", None) is None: + continue + ma_break_market_data_pair["end_timestamp"] = timestamp + ma_break_market_data_pair["end_date_time"] = timestamp_to_datetime(timestamp) + ma_break_market_data_pair["end_close"] = close + ma_break_market_data_pair["end_ma5"] = ma5 + ma_break_market_data_pair["end_ma10"] = ma10 + ma_break_market_data_pair["end_ma20"] = ma20 + ma_break_market_data_pair["end_ma30"] = ma30 + ma_break_market_data_pair["pct_chg"] = (close - ma_break_market_data_pair["begin_close"]) / ma_break_market_data_pair["begin_close"] + ma_break_market_data_pair["pct_chg"] = round(ma_break_market_data_pair["pct_chg"] * 100, 4) + ma_break_market_data_pair["interval_seconds"] = (timestamp - ma_break_market_data_pair["begin_timestamp"]) / 1000 + # 将interval转换为分钟 + ma_break_market_data_pair["interval_minutes"] = ma_break_market_data_pair["interval_seconds"] / 60 + ma_break_market_data_pair["interval_hours"] = ma_break_market_data_pair["interval_seconds"] / 3600 + ma_break_market_data_pair["interval_days"] = ma_break_market_data_pair["interval_seconds"] / 86400 + ma_break_market_data_pair_list.append(ma_break_market_data_pair) + ma_break_market_data_pair = {} + if len(ma_break_market_data_pair_list) > 0: + ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list) + return ma_break_market_data + else: + return None + + def draw_pct_chg_mean_chart(self, data: pd.DataFrame, all_change: bool = True): + """ + 绘制pct_chg mean的柱状图表(美观,保存到self.stats_chart_dir) + :param data: 波段pct_chg_mean的数据 + :return: None + """ + if data is None or data.empty: + return None + # seaborn风格设置 + sns.set_theme(style="whitegrid") + plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名 + plt.rcParams["font.size"] = 11 # 设置字体大小 + plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题 + chart_dict = {} + + for bar in data["bar"].unique(): + bar_data = data[data["bar"] == bar].copy() # 一次筛选即可 + if bar_data.empty: + continue + + bar_data.rename(columns={"pct_chg_mean": "涨跌幅均值"}, inplace=True) + # 可选:按均值排序 + bar_data.sort_values(by="涨跌幅均值", ascending=False, inplace=True) + bar_data.reset_index(drop=True, inplace=True) + + plt.figure(figsize=(10, 6)) + sns.barplot(x="symbol", y="涨跌幅均值", data=bar_data, palette="Blues_d") + plt.title(f"{bar}趋势涨跌幅均值分布") + plt.xlabel("symbol") + plt.ylabel("涨跌幅均值") + plt.xticks(rotation=45, ha="right") + plt.tight_layout() + + save_path = os.path.join(self.stats_chart_dir, f"{bar}_ma_break_pct_chg_mean.png") + plt.savefig(save_path, dpi=150) + plt.close() + + if all_change: + sheet_name = f"{bar}_趋势涨跌幅均值分布图表_完全转势" + else: + sheet_name = f"{bar}_趋势涨跌幅均值分布图表_部分转势" + chart_dict[sheet_name] = save_path + return chart_dict + + def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict): + """ + 输出Excel文件,包含所有图表 + charts_dict: 图表数据字典,格式为: + { + "sheet_name": { + "chart_name": "chart_path" + } + } + """ + logger.info(f"将图表输出到{excel_file_path}") + + # 打开已经存在的Excel文件 + wb = openpyxl.load_workbook(excel_file_path) + + for sheet_name, chart_path in charts_dict.items(): + try: + ws = wb.create_sheet(title=sheet_name) + row_offset = 1 + # Insert chart image + img = Image(chart_path) + ws.add_image(img, f"A{row_offset}") + + except Exception as e: + logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}") + continue + # Save Excel file + wb.save(excel_file_path) + print(f"Chart saved as {excel_file_path}") \ No newline at end of file diff --git a/market_data_main.py b/market_data_main.py index 2aafc4c..21820e3 100644 --- a/market_data_main.py +++ b/market_data_main.py @@ -3,6 +3,7 @@ from datetime import datetime from time import sleep import pandas as pd from core.biz.market_data import MarketData +from core.statistics.ma_break_statistics import MaBreakStatistics from core.db.db_market_data import DBMarketData from core.biz.metrics_calculation import MetricsCalculation from core.utils import ( @@ -52,6 +53,7 @@ class MarketDataMain: self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_market_data = DBMarketData(self.db_url) self.trade_data_main = TradeDataMain() + self.ma_break_statistics = MaBreakStatistics() def initial_data(self): """ @@ -391,10 +393,20 @@ class MarketDataMain: data = self.calculate_metrics(data) logger.info(f"开始保存技术指标数据: {symbol} {bar}") self.db_market_data.insert_data_to_mysql(data) + + def batch_ma_break_statistics(self): + """ + 批量计算MA突破统计 + """ + logger.info("开始批量计算MA突破统计") + self.ma_break_statistics.batch_statistics(all_change=False) + self.ma_break_statistics.batch_statistics(all_change=True) + if __name__ == "__main__": market_data_main = MarketDataMain() # market_data_main.batch_update_data() # market_data_main.initial_data() - market_data_main.batch_calculate_metrics() \ No newline at end of file + # market_data_main.batch_calculate_metrics() + market_data_main.batch_ma_break_statistics() \ No newline at end of file