604 lines
28 KiB
Python
604 lines
28 KiB
Python
import core.logger as logging
|
||
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from datetime import datetime
|
||
import re
|
||
import json
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
import openpyxl
|
||
from openpyxl.styles import Font
|
||
from PIL import Image as PILImage
|
||
from config import MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
|
||
from core.db.db_market_data import DBMarketData
|
||
from core.db.db_huge_volume_data import DBHugeVolumeData
|
||
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
|
||
|
||
# seaborn支持中文
|
||
plt.rcParams["font.family"] = ["SimHei"]
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
class MaBreakStatistics:
|
||
"""
|
||
统计MA突破之后的涨跌幅
|
||
MA向上突破的点位周期K线:5 > 10 > 20 > 30
|
||
统计MA向上突破的点位周期K线,突破之后,到:
|
||
下一个MA向下突破的点位周期K线:30 > 20 > 10 > 5
|
||
之间的涨跌幅
|
||
"""
|
||
|
||
def __init__(self):
|
||
mysql_user = MYSQL_CONFIG.get("user", "xch")
|
||
mysql_password = MYSQL_CONFIG.get("password", "")
|
||
if not mysql_password:
|
||
raise ValueError("MySQL password is not set")
|
||
mysql_host = MYSQL_CONFIG.get("host", "localhost")
|
||
mysql_port = MYSQL_CONFIG.get("port", 3306)
|
||
mysql_database = MYSQL_CONFIG.get("database", "okx")
|
||
|
||
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||
self.db_market_data = DBMarketData(self.db_url)
|
||
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
|
||
self.symbols = MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["XCH-USDT"]
|
||
)
|
||
self.bars = MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"bars", ["5m", "15m", "30m", "1H"]
|
||
)
|
||
self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/"
|
||
os.makedirs(self.stats_output_dir, exist_ok=True)
|
||
self.stats_chart_dir = "./output/trade_sandbox/ma_strategy/chart/"
|
||
os.makedirs(self.stats_chart_dir, exist_ok=True)
|
||
self.trade_strategy_config = self.get_trade_strategy_config()
|
||
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
|
||
|
||
def get_trade_strategy_config(self):
|
||
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
|
||
trade_strategy_config = json.load(f)
|
||
return trade_strategy_config
|
||
|
||
def batch_statistics(self, strategy_name: str = "全均线策略"):
|
||
self.stats_output_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
|
||
)
|
||
os.makedirs(self.stats_output_dir, exist_ok=True)
|
||
self.stats_chart_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
|
||
)
|
||
os.makedirs(self.stats_chart_dir, exist_ok=True)
|
||
ma_break_market_data_list = []
|
||
market_data_pct_chg_list = []
|
||
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
|
||
strategy_name = "全均线策略"
|
||
for symbol in self.symbols:
|
||
for bar in self.bars:
|
||
logger.info(
|
||
f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}"
|
||
)
|
||
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
|
||
symbol, bar, strategy_name
|
||
)
|
||
if (
|
||
ma_break_market_data is not None
|
||
and len(ma_break_market_data) > 0
|
||
and market_data_pct_chg is not None
|
||
):
|
||
ma_break_market_data_list.append(ma_break_market_data)
|
||
logger.info(
|
||
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
|
||
)
|
||
market_data_pct_chg_list.append(market_data_pct_chg)
|
||
if len(ma_break_market_data_list) > 0:
|
||
ma_break_market_data = pd.concat(ma_break_market_data_list)
|
||
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
|
||
# 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
|
||
pct_chg_df = (
|
||
ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"]
|
||
.agg(
|
||
pct_chg_sum="sum",
|
||
pct_chg_max="max",
|
||
pct_chg_min="min",
|
||
pct_chg_mean="mean",
|
||
pct_chg_std="std",
|
||
pct_chg_median="median",
|
||
pct_chg_count="count",
|
||
)
|
||
.reset_index()
|
||
)
|
||
pct_chg_df["strategy_name"] = strategy_name
|
||
pct_chg_df["pct_chg_total"] = 0
|
||
pct_chg_df["market_pct_chg"] = 0
|
||
# 将pct_chg_total与market_pct_chg的值类型转换为float
|
||
pct_chg_df["pct_chg_total"] = pct_chg_df["pct_chg_total"].astype(float)
|
||
pct_chg_df["market_pct_chg"] = pct_chg_df["market_pct_chg"].astype(float)
|
||
# 统计pct_chg_total
|
||
# 算法要求,ma_break_market_data,然后pct_chg/100 + 1
|
||
ma_break_market_data["pct_chg_total"] = (
|
||
ma_break_market_data["pct_chg"] / 100 + 1
|
||
)
|
||
# 遍历symbol和bar,按照end_timestamp排序,计算pct_chg_total的值,然后相乘
|
||
for symbol in pct_chg_df["symbol"].unique():
|
||
for bar in pct_chg_df["bar"].unique():
|
||
symbol_bar_data = ma_break_market_data[
|
||
(ma_break_market_data["symbol"] == symbol)
|
||
& (ma_break_market_data["bar"] == bar)
|
||
]
|
||
if len(symbol_bar_data) > 0:
|
||
symbol_bar_data.sort_values(
|
||
by="end_timestamp", ascending=True, inplace=True
|
||
)
|
||
symbol_bar_data.reset_index(drop=True, inplace=True)
|
||
symbol_bar_data["pct_chg_total"] = symbol_bar_data[
|
||
"pct_chg_total"
|
||
].cumprod()
|
||
last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1]
|
||
last_pct_chg_total = (last_pct_chg_total - 1) * 100
|
||
pct_chg_df.loc[
|
||
(pct_chg_df["symbol"] == symbol)
|
||
& (pct_chg_df["bar"] == bar),
|
||
"pct_chg_total",
|
||
] = last_pct_chg_total
|
||
market_pct_chg = market_data_pct_chg_df.loc[
|
||
(market_data_pct_chg_df["symbol"] == symbol)
|
||
& (market_data_pct_chg_df["bar"] == bar),
|
||
"pct_chg",
|
||
].values[0]
|
||
pct_chg_df.loc[
|
||
(pct_chg_df["symbol"] == symbol)
|
||
& (pct_chg_df["bar"] == bar),
|
||
"market_pct_chg",
|
||
] = market_pct_chg
|
||
|
||
pct_chg_df = pct_chg_df[
|
||
[
|
||
"strategy_name",
|
||
"symbol",
|
||
"bar",
|
||
"market_pct_chg",
|
||
"pct_chg_total",
|
||
"pct_chg_sum",
|
||
"pct_chg_max",
|
||
"pct_chg_min",
|
||
"pct_chg_mean",
|
||
"pct_chg_std",
|
||
"pct_chg_median",
|
||
"pct_chg_count",
|
||
]
|
||
]
|
||
# 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
|
||
interval_minutes_df = (
|
||
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
|
||
.agg(
|
||
interval_minutes_max="max",
|
||
interval_minutes_min="min",
|
||
interval_minutes_mean="mean",
|
||
interval_minutes_std="std",
|
||
interval_minutes_median="median",
|
||
interval_minutes_count="count",
|
||
)
|
||
.reset_index()
|
||
)
|
||
|
||
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
|
||
earliest_market_date_time = re.sub(
|
||
r"[\:\-\s]", "", str(earliest_market_date_time)
|
||
)
|
||
latest_market_date_time = ma_break_market_data["end_date_time"].max()
|
||
if latest_market_date_time is None:
|
||
latest_market_date_time = datetime.now().strftime("%Y%m%d")
|
||
latest_market_date_time = re.sub(
|
||
r"[\:\-\s]", "", str(latest_market_date_time)
|
||
)
|
||
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}.xlsx"
|
||
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
|
||
logger.info(f"导出{output_file_path}")
|
||
strategy_info_df = self.get_strategy_info(strategy_name)
|
||
with pd.ExcelWriter(output_file_path) as writer:
|
||
strategy_info_df.to_excel(writer, sheet_name="策略信息", index=False)
|
||
ma_break_market_data.to_excel(
|
||
writer, sheet_name="买卖记录明细", index=False
|
||
)
|
||
pct_chg_df.to_excel(writer, sheet_name="买卖涨跌幅统计", index=False)
|
||
interval_minutes_df.to_excel(
|
||
writer, sheet_name="买卖时间间隔统计", index=False
|
||
)
|
||
|
||
chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name)
|
||
self.output_chart_to_excel(output_file_path, chart_dict)
|
||
return pct_chg_df
|
||
else:
|
||
return None
|
||
|
||
def get_strategy_info(self, strategy_name: str = "全均线策略"):
|
||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||
if strategy_config is None:
|
||
logger.error(f"策略{strategy_name}不存在")
|
||
return None
|
||
strategy_info = {"策略名称": strategy_name, "买入策略": "", "卖出策略": ""}
|
||
buy_dict = strategy_config.get("buy", {})
|
||
buy_and_list = buy_dict.get("and", [])
|
||
buy_or_list = buy_dict.get("or", [])
|
||
buy_and_text = ""
|
||
buy_or_text = ""
|
||
for and_condition in buy_and_list:
|
||
buy_and_text += f"{and_condition}, \n"
|
||
if len(buy_or_list) > 0:
|
||
for or_condition in buy_or_list:
|
||
buy_or_text += f"{or_condition}, \n"
|
||
if len(buy_or_text) > 0:
|
||
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
|
||
else:
|
||
strategy_info["买入策略"] = buy_and_text
|
||
sell_dict = strategy_config.get("sell", {})
|
||
sell_and_list = sell_dict.get("and", [])
|
||
sell_or_list = sell_dict.get("or", [])
|
||
sell_and_text = ""
|
||
sell_or_text = ""
|
||
for and_condition in sell_and_list:
|
||
sell_and_text += f"{and_condition}, \n"
|
||
if len(sell_or_list) > 0:
|
||
for or_condition in sell_or_list:
|
||
sell_or_text += f"{or_condition}, \n"
|
||
if len(sell_or_text) > 0:
|
||
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
|
||
else:
|
||
strategy_info["卖出策略"] = sell_and_text
|
||
# 将strategy_info转换为pd.DataFrame
|
||
strategy_info_df = pd.DataFrame([strategy_info])
|
||
return strategy_info_df
|
||
|
||
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
|
||
market_data = self.db_market_data.query_market_data_by_symbol_bar(
|
||
symbol, bar, start=None, end=None
|
||
)
|
||
if market_data is None or len(market_data) == 0:
|
||
logger.warning(f"获取{symbol} {bar} 数据失败")
|
||
return None, None
|
||
else:
|
||
market_data = pd.DataFrame(market_data)
|
||
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
|
||
market_data.reset_index(drop=True, inplace=True)
|
||
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
|
||
# 获得ma5, ma10, ma20, ma30不为空的行
|
||
market_data = market_data[
|
||
(market_data["ma5"].notna())
|
||
& (market_data["ma10"].notna())
|
||
& (market_data["ma20"].notna())
|
||
& (market_data["ma30"].notna())
|
||
]
|
||
logger.info(
|
||
f"ma5, ma10, ma20, ma30不为空的行,数据条数: {len(market_data)}"
|
||
)
|
||
# 计算volume_ma5
|
||
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
|
||
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
|
||
market_data["volume_pct_chg"] = (
|
||
market_data["volume"] - market_data["volume_ma5"]
|
||
) / market_data["volume_ma5"]
|
||
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
|
||
|
||
# 按照timestamp排序
|
||
market_data = market_data.sort_values(by="timestamp", ascending=True)
|
||
# 获得ma_break_market_data的close列
|
||
market_data.reset_index(drop=True, inplace=True)
|
||
ma_break_market_data_pair_list = []
|
||
ma_break_market_data_pair = {}
|
||
for index, row in market_data.iterrows():
|
||
ma_cross = row["ma_cross"]
|
||
timestamp = row["timestamp"]
|
||
close = row["close"]
|
||
ma5 = row["ma5"]
|
||
ma10 = row["ma10"]
|
||
ma20 = row["ma20"]
|
||
ma30 = row["ma30"]
|
||
macd_diff = float(row["dif"])
|
||
macd_dea = float(row["dea"])
|
||
macd = float(row["macd"])
|
||
|
||
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
|
||
buy_condition = self.fit_strategy(
|
||
strategy_name=strategy_name,
|
||
market_data=market_data,
|
||
row=row,
|
||
behavior="buy",
|
||
)
|
||
|
||
if buy_condition:
|
||
ma_break_market_data_pair = {}
|
||
ma_break_market_data_pair["symbol"] = symbol
|
||
ma_break_market_data_pair["bar"] = bar
|
||
ma_break_market_data_pair["begin_timestamp"] = timestamp
|
||
ma_break_market_data_pair["begin_date_time"] = (
|
||
timestamp_to_datetime(timestamp)
|
||
)
|
||
ma_break_market_data_pair["begin_close"] = close
|
||
ma_break_market_data_pair["begin_ma5"] = ma5
|
||
ma_break_market_data_pair["begin_ma10"] = ma10
|
||
ma_break_market_data_pair["begin_ma20"] = ma20
|
||
ma_break_market_data_pair["begin_ma30"] = ma30
|
||
ma_break_market_data_pair["begin_macd_diff"] = macd_diff
|
||
ma_break_market_data_pair["begin_macd_dea"] = macd_dea
|
||
ma_break_market_data_pair["begin_macd"] = macd
|
||
continue
|
||
else:
|
||
sell_condition = self.fit_strategy(
|
||
strategy_name=strategy_name,
|
||
market_data=market_data,
|
||
row=row,
|
||
behavior="sell",
|
||
)
|
||
|
||
if sell_condition:
|
||
ma_break_market_data_pair["end_timestamp"] = timestamp
|
||
ma_break_market_data_pair["end_date_time"] = (
|
||
timestamp_to_datetime(timestamp)
|
||
)
|
||
ma_break_market_data_pair["end_close"] = close
|
||
ma_break_market_data_pair["end_ma5"] = ma5
|
||
ma_break_market_data_pair["end_ma10"] = ma10
|
||
ma_break_market_data_pair["end_ma20"] = ma20
|
||
ma_break_market_data_pair["end_ma30"] = ma30
|
||
ma_break_market_data_pair["end_macd_diff"] = macd_diff
|
||
ma_break_market_data_pair["end_macd_dea"] = macd_dea
|
||
ma_break_market_data_pair["end_macd"] = macd
|
||
ma_break_market_data_pair["pct_chg"] = (
|
||
close - ma_break_market_data_pair["begin_close"]
|
||
) / ma_break_market_data_pair["begin_close"]
|
||
ma_break_market_data_pair["pct_chg"] = round(
|
||
ma_break_market_data_pair["pct_chg"] * 100, 4
|
||
)
|
||
ma_break_market_data_pair["interval_seconds"] = (
|
||
timestamp - ma_break_market_data_pair["begin_timestamp"]
|
||
) / 1000
|
||
# 将interval转换为分钟
|
||
ma_break_market_data_pair["interval_minutes"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 60
|
||
)
|
||
ma_break_market_data_pair["interval_hours"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 3600
|
||
)
|
||
ma_break_market_data_pair["interval_days"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 86400
|
||
)
|
||
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
|
||
ma_break_market_data_pair = {}
|
||
|
||
if len(ma_break_market_data_pair_list) > 0:
|
||
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
|
||
logger.info(
|
||
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
|
||
)
|
||
# 将market_data(最后一条数据的close - 第一条数据的open) / 第一条数据的open * 100
|
||
pct_chg = (
|
||
(market_data["close"].iloc[-1] - market_data["open"].iloc[0])
|
||
/ market_data["open"].iloc[0]
|
||
* 100
|
||
)
|
||
pct_chg = round(pct_chg, 4)
|
||
market_data_pct_chg = {"symbol": symbol, "bar": bar, "pct_chg": pct_chg}
|
||
return ma_break_market_data, market_data_pct_chg
|
||
else:
|
||
return None, None
|
||
|
||
def fit_strategy(
|
||
self,
|
||
strategy_name: str = "全均线策略",
|
||
market_data: pd.DataFrame = None,
|
||
row: pd.Series = None,
|
||
behavior: str = "buy",
|
||
):
|
||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||
if strategy_config is None:
|
||
logger.error(f"策略{strategy_name}不存在")
|
||
return False
|
||
condition_dict = strategy_config.get(behavior, None)
|
||
if condition_dict is None:
|
||
logger.error(f"策略{strategy_name}的{behavior}条件不存在")
|
||
return False
|
||
ma_cross = row["ma_cross"]
|
||
if pd.isna(ma_cross) or ma_cross is None:
|
||
ma_cross = ""
|
||
ma_cross = str(ma_cross)
|
||
ma5 = float(row["ma5"])
|
||
ma10 = float(row["ma10"])
|
||
ma20 = float(row["ma20"])
|
||
ma30 = float(row["ma30"])
|
||
close = float(row["close"])
|
||
volume_pct_chg = float(row["volume_pct_chg"])
|
||
macd_diff = float(row["dif"])
|
||
macd_dea = float(row["dea"])
|
||
macd = float(row["macd"])
|
||
|
||
and_list = condition_dict.get("and", [])
|
||
|
||
condition = True
|
||
for and_condition in and_list:
|
||
if and_condition == "5上穿10":
|
||
condition = condition and ("5上穿10" in ma_cross)
|
||
elif and_condition == "10上穿20":
|
||
condition = condition and ("10上穿20" in ma_cross)
|
||
elif and_condition == "20上穿30":
|
||
condition = condition and ("20上穿30" in ma_cross)
|
||
elif and_condition == "ma5>ma10":
|
||
condition = condition and (ma5 > ma10)
|
||
elif and_condition == "ma10>ma20":
|
||
condition = condition and (ma10 > ma20)
|
||
elif and_condition == "ma20>ma30":
|
||
condition = condition and (ma20 > ma30)
|
||
elif and_condition == "close>ma20":
|
||
condition = condition and (close > ma20)
|
||
elif and_condition == "volume_pct_chg>0.2":
|
||
condition = condition and (volume_pct_chg > 0.2)
|
||
elif and_condition == "macd_diff>0":
|
||
condition = condition and (macd_diff > 0)
|
||
elif and_condition == "macd_dea>0":
|
||
condition = condition and (macd_dea > 0)
|
||
elif and_condition == "macd>0":
|
||
condition = condition and (macd > 0)
|
||
elif and_condition == "10下穿5":
|
||
condition = condition and ("10下穿5" in ma_cross)
|
||
elif and_condition == "20下穿10":
|
||
condition = condition and ("20下穿10" in ma_cross)
|
||
elif and_condition == "30下穿20":
|
||
condition = condition and ("30下穿20" in ma_cross)
|
||
elif and_condition == "ma5<ma10":
|
||
condition = condition and (ma5 < ma10)
|
||
elif and_condition == "ma10<ma20":
|
||
condition = condition and (ma10 < ma20)
|
||
elif and_condition == "ma20<ma30":
|
||
condition = condition and (ma20 < ma30)
|
||
elif and_condition == "close<ma20":
|
||
condition = condition and (close < ma20)
|
||
elif and_condition == "macd_diff<0":
|
||
condition = condition and (macd_diff < 0)
|
||
elif and_condition == "macd_dea<0":
|
||
condition = condition and (macd_dea < 0)
|
||
elif and_condition == "macd<0":
|
||
condition = condition and (macd < 0)
|
||
else:
|
||
pass
|
||
if not condition:
|
||
or_list = condition_dict.get("or", [])
|
||
for or_condition in or_list:
|
||
if or_condition == "5上穿10":
|
||
condition = condition or ("5上穿10" in ma_cross)
|
||
elif or_condition == "10上穿20":
|
||
condition = condition or ("10上穿20" in ma_cross)
|
||
elif or_condition == "20上穿30":
|
||
condition = condition or ("20上穿30" in ma_cross)
|
||
elif or_condition == "ma5>ma10":
|
||
condition = condition or (ma5 > ma10)
|
||
elif or_condition == "ma10>ma20":
|
||
condition = condition or (ma10 > ma20)
|
||
elif or_condition == "ma20>ma30":
|
||
condition = condition or (ma20 > ma30)
|
||
elif or_condition == "close>ma20":
|
||
condition = condition or (close > ma20)
|
||
elif or_condition == "volume_pct_chg>0.2":
|
||
condition = condition or (volume_pct_chg > 0.2)
|
||
elif or_condition == "macd_diff>0":
|
||
condition = condition or (macd_diff > 0)
|
||
elif or_condition == "macd_dea>0":
|
||
condition = condition or (macd_dea > 0)
|
||
elif or_condition == "macd>0":
|
||
condition = condition or (macd > 0)
|
||
elif or_condition == "10下穿5":
|
||
condition = condition or ("10下穿5" in ma_cross)
|
||
elif or_condition == "20下穿10":
|
||
condition = condition or ("20下穿10" in ma_cross)
|
||
elif or_condition == "30下穿20":
|
||
condition = condition or ("30下穿20" in ma_cross)
|
||
elif or_condition == "ma5<ma10":
|
||
condition = condition or (ma5 < ma10)
|
||
elif or_condition == "ma10<ma20":
|
||
condition = condition or (ma10 < ma20)
|
||
elif or_condition == "ma20<ma30":
|
||
condition = condition or (ma20 < ma30)
|
||
elif or_condition == "close<ma20":
|
||
condition = condition or (close < ma20)
|
||
elif or_condition == "macd_diff<0":
|
||
condition = condition or (macd_diff < 0)
|
||
elif or_condition == "macd_dea<0":
|
||
condition = condition or (macd_dea < 0)
|
||
elif or_condition == "macd<0":
|
||
condition = condition or (macd < 0)
|
||
else:
|
||
pass
|
||
return condition
|
||
|
||
def draw_pct_chg_mean_chart(
|
||
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
|
||
):
|
||
"""
|
||
绘制pct_chg mean的柱状图表(美观,保存到self.stats_chart_dir)
|
||
:param data: 波段pct_chg_mean的数据
|
||
:return: None
|
||
"""
|
||
if data is None or data.empty:
|
||
return None
|
||
# seaborn风格设置
|
||
sns.set_theme(style="whitegrid")
|
||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
|
||
plt.rcParams["font.size"] = 11 # 设置字体大小
|
||
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
|
||
chart_dict = {}
|
||
column_name_dict = {"pct_chg_total": "涨跌总和", "pct_chg_mean": "涨跌均值"}
|
||
for column_name, column_name_text in column_name_dict.items():
|
||
for bar in data["bar"].unique():
|
||
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
|
||
if bar_data.empty:
|
||
continue
|
||
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
|
||
# 可选:按均值排序
|
||
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
|
||
bar_data.reset_index(drop=True, inplace=True)
|
||
|
||
plt.figure(figsize=(10, 6))
|
||
ax = sns.barplot(
|
||
x="symbol", y=column_name_text, data=bar_data, palette="Blues_d"
|
||
)
|
||
plt.title(f"{bar}趋势{column_name_text}(%)")
|
||
plt.xlabel("symbol")
|
||
plt.ylabel(column_name_text)
|
||
plt.xticks(rotation=45, ha="right")
|
||
|
||
# 在柱状图上添加数值标签
|
||
for i, v in enumerate(bar_data[column_name_text]):
|
||
ax.text(
|
||
i,
|
||
v,
|
||
f"{v:.3f}",
|
||
ha="center",
|
||
va="bottom",
|
||
fontsize=10,
|
||
fontweight="bold",
|
||
)
|
||
|
||
plt.tight_layout()
|
||
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{bar}_ma_break_{column_name}_{strategy_name}.png",
|
||
)
|
||
plt.savefig(save_path, dpi=150)
|
||
plt.close()
|
||
|
||
sheet_name = f"{bar}_趋势{column_name_text}分布图表_{strategy_name}"
|
||
chart_dict[sheet_name] = save_path
|
||
return chart_dict
|
||
|
||
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
|
||
"""
|
||
输出Excel文件,包含所有图表
|
||
charts_dict: 图表数据字典,格式为:
|
||
{
|
||
"sheet_name": {
|
||
"chart_name": "chart_path"
|
||
}
|
||
}
|
||
"""
|
||
logger.info(f"将图表输出到{excel_file_path}")
|
||
|
||
# 打开已经存在的Excel文件
|
||
wb = openpyxl.load_workbook(excel_file_path)
|
||
|
||
for sheet_name, chart_path in charts_dict.items():
|
||
try:
|
||
ws = wb.create_sheet(title=sheet_name)
|
||
row_offset = 1
|
||
# Insert chart image
|
||
img = Image(chart_path)
|
||
ws.add_image(img, f"A{row_offset}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
|
||
continue
|
||
# Save Excel file
|
||
wb.save(excel_file_path)
|
||
print(f"Chart saved as {excel_file_path}")
|