787 lines
38 KiB
Python
787 lines
38 KiB
Python
import core.logger as logging
|
||
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from datetime import datetime
|
||
import re
|
||
import json
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
import openpyxl
|
||
from openpyxl.styles import Font
|
||
from PIL import Image as PILImage
|
||
from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
|
||
from core.db.db_market_data import DBMarketData
|
||
from core.db.db_huge_volume_data import DBHugeVolumeData
|
||
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
|
||
|
||
# seaborn支持中文
|
||
plt.rcParams["font.family"] = ["SimHei"]
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
class MaBreakStatistics:
|
||
"""
|
||
统计MA突破之后的涨跌幅
|
||
MA向上突破的点位周期K线:5 > 10 > 20 > 30
|
||
统计MA向上突破的点位周期K线,突破之后,到:
|
||
下一个MA向下突破的点位周期K线:30 > 20 > 10 > 5
|
||
之间的涨跌幅
|
||
"""
|
||
|
||
def __init__(self, is_us_stock: bool = False):
|
||
mysql_user = MYSQL_CONFIG.get("user", "xch")
|
||
mysql_password = MYSQL_CONFIG.get("password", "")
|
||
if not mysql_password:
|
||
raise ValueError("MySQL password is not set")
|
||
mysql_host = MYSQL_CONFIG.get("host", "localhost")
|
||
mysql_port = MYSQL_CONFIG.get("port", 3306)
|
||
mysql_database = MYSQL_CONFIG.get("database", "okx")
|
||
|
||
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||
self.db_market_data = DBMarketData(self.db_url)
|
||
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
|
||
if is_us_stock:
|
||
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["QQQ"]
|
||
)
|
||
else:
|
||
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["XCH-USDT"]
|
||
)
|
||
if is_us_stock:
|
||
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"bars", ["5m"]
|
||
)
|
||
else:
|
||
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"bars", ["5m", "15m", "30m", "1H"]
|
||
)
|
||
self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/"
|
||
os.makedirs(self.stats_output_dir, exist_ok=True)
|
||
self.stats_chart_dir = "./output/trade_sandbox/ma_strategy/chart/"
|
||
os.makedirs(self.stats_chart_dir, exist_ok=True)
|
||
self.trade_strategy_config = self.get_trade_strategy_config()
|
||
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
|
||
|
||
def get_trade_strategy_config(self):
|
||
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
|
||
trade_strategy_config = json.load(f)
|
||
return trade_strategy_config
|
||
|
||
def batch_statistics(self, strategy_name: str = "全均线策略"):
|
||
self.stats_output_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
|
||
)
|
||
os.makedirs(self.stats_output_dir, exist_ok=True)
|
||
self.stats_chart_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
|
||
)
|
||
os.makedirs(self.stats_chart_dir, exist_ok=True)
|
||
ma_break_market_data_list = []
|
||
market_data_pct_chg_list = []
|
||
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
|
||
strategy_name = "全均线策略"
|
||
for symbol in self.symbols:
|
||
for bar in self.bars:
|
||
logger.info(
|
||
f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}"
|
||
)
|
||
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
|
||
symbol, bar, strategy_name
|
||
)
|
||
if (
|
||
ma_break_market_data is not None
|
||
and len(ma_break_market_data) > 0
|
||
and market_data_pct_chg is not None
|
||
):
|
||
ma_break_market_data_list.append(ma_break_market_data)
|
||
logger.info(
|
||
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
|
||
)
|
||
market_data_pct_chg_list.append(market_data_pct_chg)
|
||
|
||
if len(ma_break_market_data_list) > 0:
|
||
ma_break_market_data = pd.concat(ma_break_market_data_list)
|
||
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
|
||
# 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
|
||
ma_break_market_data.sort_values(by="begin_timestamp", ascending=True, inplace=True)
|
||
ma_break_market_data.reset_index(drop=True, inplace=True)
|
||
pct_chg_df = (
|
||
ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"]
|
||
.agg(
|
||
pct_chg_sum="sum",
|
||
pct_chg_max="max",
|
||
pct_chg_min="min",
|
||
pct_chg_mean="mean",
|
||
pct_chg_std="std",
|
||
pct_chg_median="median",
|
||
pct_chg_count="count",
|
||
)
|
||
.reset_index()
|
||
)
|
||
pct_chg_df["strategy_name"] = strategy_name
|
||
pct_chg_df["pct_chg_total"] = 0
|
||
pct_chg_df["market_pct_chg"] = 0
|
||
# 将pct_chg_total与market_pct_chg的值类型转换为float
|
||
pct_chg_df["pct_chg_total"] = pct_chg_df["pct_chg_total"].astype(float)
|
||
pct_chg_df["market_pct_chg"] = pct_chg_df["market_pct_chg"].astype(float)
|
||
# 统计pct_chg_total
|
||
# 算法要求,ma_break_market_data,然后pct_chg/100 + 1
|
||
ma_break_market_data["pct_chg_total"] = (
|
||
ma_break_market_data["pct_chg"] / 100 + 1
|
||
)
|
||
# 遍历symbol和bar,按照end_timestamp排序,计算pct_chg_total的值,然后相乘
|
||
for symbol in pct_chg_df["symbol"].unique():
|
||
for bar in pct_chg_df["bar"].unique():
|
||
symbol_bar_data = ma_break_market_data[
|
||
(ma_break_market_data["symbol"] == symbol)
|
||
& (ma_break_market_data["bar"] == bar)
|
||
].copy() # 创建副本避免SettingWithCopyWarning
|
||
if len(symbol_bar_data) > 0:
|
||
symbol_bar_data.sort_values(
|
||
by="end_timestamp", ascending=True, inplace=True
|
||
)
|
||
symbol_bar_data.reset_index(drop=True, inplace=True)
|
||
symbol_bar_data["pct_chg_total"] = symbol_bar_data[
|
||
"pct_chg_total"
|
||
].cumprod()
|
||
|
||
# 将更新后的pct_chg_total数据同步更新到ma_break_market_data的对应数据行中
|
||
for idx, row in symbol_bar_data.iterrows():
|
||
mask = (ma_break_market_data["symbol"] == symbol) & \
|
||
(ma_break_market_data["bar"] == bar) & \
|
||
(ma_break_market_data["end_timestamp"] == row["end_timestamp"])
|
||
ma_break_market_data.loc[mask, "pct_chg_total"] = row["pct_chg_total"]
|
||
|
||
last_pct_chg_total = symbol_bar_data["pct_chg_total"].iloc[-1]
|
||
last_pct_chg_total = (last_pct_chg_total - 1) * 100
|
||
pct_chg_df.loc[
|
||
(pct_chg_df["symbol"] == symbol)
|
||
& (pct_chg_df["bar"] == bar),
|
||
"pct_chg_total",
|
||
] = last_pct_chg_total
|
||
market_pct_chg = market_data_pct_chg_df.loc[
|
||
(market_data_pct_chg_df["symbol"] == symbol)
|
||
& (market_data_pct_chg_df["bar"] == bar),
|
||
"pct_chg",
|
||
].values[0]
|
||
pct_chg_df.loc[
|
||
(pct_chg_df["symbol"] == symbol)
|
||
& (pct_chg_df["bar"] == bar),
|
||
"market_pct_chg",
|
||
] = market_pct_chg
|
||
|
||
pct_chg_df = pct_chg_df[
|
||
[
|
||
"strategy_name",
|
||
"symbol",
|
||
"bar",
|
||
"market_pct_chg",
|
||
"pct_chg_total",
|
||
"pct_chg_sum",
|
||
"pct_chg_max",
|
||
"pct_chg_min",
|
||
"pct_chg_mean",
|
||
"pct_chg_std",
|
||
"pct_chg_median",
|
||
"pct_chg_count",
|
||
]
|
||
]
|
||
# 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
|
||
interval_minutes_df = (
|
||
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
|
||
.agg(
|
||
interval_minutes_max="max",
|
||
interval_minutes_min="min",
|
||
interval_minutes_mean="mean",
|
||
interval_minutes_std="std",
|
||
interval_minutes_median="median",
|
||
interval_minutes_count="count",
|
||
)
|
||
.reset_index()
|
||
)
|
||
|
||
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
|
||
earliest_market_date_time = re.sub(
|
||
r"[\:\-\s]", "", str(earliest_market_date_time)
|
||
)
|
||
latest_market_date_time = ma_break_market_data["end_date_time"].max()
|
||
if latest_market_date_time is None:
|
||
latest_market_date_time = datetime.now().strftime("%Y%m%d")
|
||
latest_market_date_time = re.sub(
|
||
r"[\:\-\s]", "", str(latest_market_date_time)
|
||
)
|
||
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}.xlsx"
|
||
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
|
||
logger.info(f"导出{output_file_path}")
|
||
strategy_info_df = self.get_strategy_info(strategy_name)
|
||
with pd.ExcelWriter(output_file_path) as writer:
|
||
strategy_info_df.to_excel(writer, sheet_name="策略信息", index=False)
|
||
ma_break_market_data.to_excel(
|
||
writer, sheet_name="买卖记录明细", index=False
|
||
)
|
||
pct_chg_df.to_excel(writer, sheet_name="买卖涨跌幅统计", index=False)
|
||
interval_minutes_df.to_excel(
|
||
writer, sheet_name="买卖时间间隔统计", index=False
|
||
)
|
||
|
||
chart_dict = self.draw_quant_pct_chg_bar_chart(pct_chg_df, strategy_name)
|
||
self.output_chart_to_excel(output_file_path, chart_dict)
|
||
chart_dict = self.draw_quant_line_chart(ma_break_market_data, strategy_name)
|
||
self.output_chart_to_excel(output_file_path, chart_dict)
|
||
return pct_chg_df
|
||
else:
|
||
return None
|
||
|
||
def get_strategy_info(self, strategy_name: str = "全均线策略"):
|
||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||
if strategy_config is None:
|
||
logger.error(f"策略{strategy_name}不存在")
|
||
return None
|
||
strategy_info = {"策略名称": strategy_name, "买入策略": "", "卖出策略": ""}
|
||
buy_dict = strategy_config.get("buy", {})
|
||
buy_and_list = buy_dict.get("and", [])
|
||
buy_or_list = buy_dict.get("or", [])
|
||
buy_and_text = ""
|
||
buy_or_text = ""
|
||
for and_condition in buy_and_list:
|
||
buy_and_text += f"{and_condition}, \n"
|
||
if len(buy_or_list) > 0:
|
||
for or_condition in buy_or_list:
|
||
buy_or_text += f"{or_condition}, \n"
|
||
if len(buy_or_text) > 0:
|
||
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
|
||
else:
|
||
strategy_info["买入策略"] = buy_and_text
|
||
sell_dict = strategy_config.get("sell", {})
|
||
sell_and_list = sell_dict.get("and", [])
|
||
sell_or_list = sell_dict.get("or", [])
|
||
sell_and_text = ""
|
||
sell_or_text = ""
|
||
for and_condition in sell_and_list:
|
||
sell_and_text += f"{and_condition}, \n"
|
||
if len(sell_or_list) > 0:
|
||
for or_condition in sell_or_list:
|
||
sell_or_text += f"{or_condition}, \n"
|
||
if len(sell_or_text) > 0:
|
||
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
|
||
else:
|
||
strategy_info["卖出策略"] = sell_and_text
|
||
# 将strategy_info转换为pd.DataFrame
|
||
strategy_info_df = pd.DataFrame([strategy_info])
|
||
return strategy_info_df
|
||
|
||
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
|
||
market_data = self.db_market_data.query_market_data_by_symbol_bar(
|
||
symbol, bar, start=None, end=None
|
||
)
|
||
if market_data is None or len(market_data) == 0:
|
||
logger.warning(f"获取{symbol} {bar} 数据失败")
|
||
return None, None
|
||
else:
|
||
market_data = pd.DataFrame(market_data)
|
||
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
|
||
market_data.reset_index(drop=True, inplace=True)
|
||
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
|
||
# 获得ma5, ma10, ma20, ma30不为空的行
|
||
market_data = market_data[
|
||
(market_data["ma5"].notna())
|
||
& (market_data["ma10"].notna())
|
||
& (market_data["ma20"].notna())
|
||
& (market_data["ma30"].notna())
|
||
]
|
||
logger.info(
|
||
f"ma5, ma10, ma20, ma30不为空的行,数据条数: {len(market_data)}"
|
||
)
|
||
# 计算volume_ma5
|
||
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
|
||
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
|
||
market_data["volume_pct_chg"] = (
|
||
market_data["volume"] - market_data["volume_ma5"]
|
||
) / market_data["volume_ma5"]
|
||
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
|
||
|
||
# 按照timestamp排序
|
||
market_data = market_data.sort_values(by="timestamp", ascending=True)
|
||
# 获得ma_break_market_data的close列
|
||
market_data.reset_index(drop=True, inplace=True)
|
||
ma_break_market_data_pair_list = []
|
||
ma_break_market_data_pair = {}
|
||
for index, row in market_data.iterrows():
|
||
ma_cross = row["ma_cross"]
|
||
timestamp = row["timestamp"]
|
||
close = row["close"]
|
||
ma5 = row["ma5"]
|
||
ma10 = row["ma10"]
|
||
ma20 = row["ma20"]
|
||
ma30 = row["ma30"]
|
||
macd_diff = float(row["dif"])
|
||
macd_dea = float(row["dea"])
|
||
macd = float(row["macd"])
|
||
|
||
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
|
||
buy_condition = self.fit_strategy(
|
||
strategy_name=strategy_name,
|
||
market_data=market_data,
|
||
row=row,
|
||
behavior="buy",
|
||
)
|
||
|
||
if buy_condition:
|
||
ma_break_market_data_pair = {}
|
||
ma_break_market_data_pair["symbol"] = symbol
|
||
ma_break_market_data_pair["bar"] = bar
|
||
ma_break_market_data_pair["begin_timestamp"] = timestamp
|
||
ma_break_market_data_pair["begin_date_time"] = (
|
||
timestamp_to_datetime(timestamp)
|
||
)
|
||
ma_break_market_data_pair["begin_close"] = close
|
||
ma_break_market_data_pair["begin_ma5"] = ma5
|
||
ma_break_market_data_pair["begin_ma10"] = ma10
|
||
ma_break_market_data_pair["begin_ma20"] = ma20
|
||
ma_break_market_data_pair["begin_ma30"] = ma30
|
||
ma_break_market_data_pair["begin_macd_diff"] = macd_diff
|
||
ma_break_market_data_pair["begin_macd_dea"] = macd_dea
|
||
ma_break_market_data_pair["begin_macd"] = macd
|
||
continue
|
||
else:
|
||
sell_condition = self.fit_strategy(
|
||
strategy_name=strategy_name,
|
||
market_data=market_data,
|
||
row=row,
|
||
behavior="sell",
|
||
)
|
||
|
||
if sell_condition:
|
||
ma_break_market_data_pair["end_timestamp"] = timestamp
|
||
ma_break_market_data_pair["end_date_time"] = (
|
||
timestamp_to_datetime(timestamp)
|
||
)
|
||
ma_break_market_data_pair["end_close"] = close
|
||
ma_break_market_data_pair["end_ma5"] = ma5
|
||
ma_break_market_data_pair["end_ma10"] = ma10
|
||
ma_break_market_data_pair["end_ma20"] = ma20
|
||
ma_break_market_data_pair["end_ma30"] = ma30
|
||
ma_break_market_data_pair["end_macd_diff"] = macd_diff
|
||
ma_break_market_data_pair["end_macd_dea"] = macd_dea
|
||
ma_break_market_data_pair["end_macd"] = macd
|
||
ma_break_market_data_pair["pct_chg"] = (
|
||
close - ma_break_market_data_pair["begin_close"]
|
||
) / ma_break_market_data_pair["begin_close"]
|
||
ma_break_market_data_pair["pct_chg"] = round(
|
||
ma_break_market_data_pair["pct_chg"] * 100, 4
|
||
)
|
||
ma_break_market_data_pair["interval_seconds"] = (
|
||
timestamp - ma_break_market_data_pair["begin_timestamp"]
|
||
) / 1000
|
||
# 将interval转换为分钟
|
||
ma_break_market_data_pair["interval_minutes"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 60
|
||
)
|
||
ma_break_market_data_pair["interval_hours"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 3600
|
||
)
|
||
ma_break_market_data_pair["interval_days"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 86400
|
||
)
|
||
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
|
||
ma_break_market_data_pair = {}
|
||
|
||
if len(ma_break_market_data_pair_list) > 0:
|
||
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
|
||
# sort by end_timestamp
|
||
ma_break_market_data.sort_values(by="begin_timestamp", ascending=True, inplace=True)
|
||
ma_break_market_data.reset_index(drop=True, inplace=True)
|
||
logger.info(
|
||
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
|
||
)
|
||
# 量化期间,市场的波动率:
|
||
# ma_break_market_data(最后一条数据的end_close - 第一条数据的begin_close) / 第一条数据的begin_close * 100
|
||
pct_chg = (
|
||
(ma_break_market_data["end_close"].iloc[-1] - ma_break_market_data["begin_close"].iloc[0])
|
||
/ ma_break_market_data["begin_close"].iloc[0]
|
||
* 100
|
||
)
|
||
pct_chg = round(pct_chg, 4)
|
||
market_data_pct_chg = {"symbol": symbol, "bar": bar, "pct_chg": pct_chg}
|
||
return ma_break_market_data, market_data_pct_chg
|
||
else:
|
||
return None, None
|
||
|
||
def fit_strategy(
|
||
self,
|
||
strategy_name: str = "全均线策略",
|
||
market_data: pd.DataFrame = None,
|
||
row: pd.Series = None,
|
||
behavior: str = "buy",
|
||
):
|
||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||
if strategy_config is None:
|
||
logger.error(f"策略{strategy_name}不存在")
|
||
return False
|
||
condition_dict = strategy_config.get(behavior, None)
|
||
if condition_dict is None:
|
||
logger.error(f"策略{strategy_name}的{behavior}条件不存在")
|
||
return False
|
||
ma_cross = row["ma_cross"]
|
||
if pd.isna(ma_cross) or ma_cross is None:
|
||
ma_cross = ""
|
||
ma_cross = str(ma_cross)
|
||
ma5 = float(row["ma5"])
|
||
ma10 = float(row["ma10"])
|
||
ma20 = float(row["ma20"])
|
||
ma30 = float(row["ma30"])
|
||
close = float(row["close"])
|
||
volume_pct_chg = float(row["volume_pct_chg"])
|
||
macd_diff = float(row["dif"])
|
||
macd_dea = float(row["dea"])
|
||
macd = float(row["macd"])
|
||
|
||
and_list = condition_dict.get("and", [])
|
||
|
||
condition = True
|
||
for and_condition in and_list:
|
||
if and_condition == "5上穿10":
|
||
condition = condition and ("5上穿10" in ma_cross)
|
||
elif and_condition == "10上穿20":
|
||
condition = condition and ("10上穿20" in ma_cross)
|
||
elif and_condition == "20上穿30":
|
||
condition = condition and ("20上穿30" in ma_cross)
|
||
elif and_condition == "ma5>ma10":
|
||
condition = condition and (ma5 > ma10)
|
||
elif and_condition == "ma10>ma20":
|
||
condition = condition and (ma10 > ma20)
|
||
elif and_condition == "ma20>ma30":
|
||
condition = condition and (ma20 > ma30)
|
||
elif and_condition == "close>ma20":
|
||
condition = condition and (close > ma20)
|
||
elif and_condition == "volume_pct_chg>0.2":
|
||
condition = condition and (volume_pct_chg > 0.2)
|
||
elif and_condition == "macd_diff>0":
|
||
condition = condition and (macd_diff > 0)
|
||
elif and_condition == "macd_dea>0":
|
||
condition = condition and (macd_dea > 0)
|
||
elif and_condition == "macd>0":
|
||
condition = condition and (macd > 0)
|
||
elif and_condition == "10下穿5":
|
||
condition = condition and ("10下穿5" in ma_cross)
|
||
elif and_condition == "20下穿10":
|
||
condition = condition and ("20下穿10" in ma_cross)
|
||
elif and_condition == "30下穿20":
|
||
condition = condition and ("30下穿20" in ma_cross)
|
||
elif and_condition == "ma5<ma10":
|
||
condition = condition and (ma5 < ma10)
|
||
elif and_condition == "ma10<ma20":
|
||
condition = condition and (ma10 < ma20)
|
||
elif and_condition == "ma20<ma30":
|
||
condition = condition and (ma20 < ma30)
|
||
elif and_condition == "close<ma20":
|
||
condition = condition and (close < ma20)
|
||
elif and_condition == "macd_diff<0":
|
||
condition = condition and (macd_diff < 0)
|
||
elif and_condition == "macd_dea<0":
|
||
condition = condition and (macd_dea < 0)
|
||
elif and_condition == "macd<0":
|
||
condition = condition and (macd < 0)
|
||
else:
|
||
pass
|
||
if not condition:
|
||
or_list = condition_dict.get("or", [])
|
||
for or_condition in or_list:
|
||
if or_condition == "5上穿10":
|
||
condition = condition or ("5上穿10" in ma_cross)
|
||
elif or_condition == "10上穿20":
|
||
condition = condition or ("10上穿20" in ma_cross)
|
||
elif or_condition == "20上穿30":
|
||
condition = condition or ("20上穿30" in ma_cross)
|
||
elif or_condition == "ma5>ma10":
|
||
condition = condition or (ma5 > ma10)
|
||
elif or_condition == "ma10>ma20":
|
||
condition = condition or (ma10 > ma20)
|
||
elif or_condition == "ma20>ma30":
|
||
condition = condition or (ma20 > ma30)
|
||
elif or_condition == "close>ma20":
|
||
condition = condition or (close > ma20)
|
||
elif or_condition == "volume_pct_chg>0.2":
|
||
condition = condition or (volume_pct_chg > 0.2)
|
||
elif or_condition == "macd_diff>0":
|
||
condition = condition or (macd_diff > 0)
|
||
elif or_condition == "macd_dea>0":
|
||
condition = condition or (macd_dea > 0)
|
||
elif or_condition == "macd>0":
|
||
condition = condition or (macd > 0)
|
||
elif or_condition == "10下穿5":
|
||
condition = condition or ("10下穿5" in ma_cross)
|
||
elif or_condition == "20下穿10":
|
||
condition = condition or ("20下穿10" in ma_cross)
|
||
elif or_condition == "30下穿20":
|
||
condition = condition or ("30下穿20" in ma_cross)
|
||
elif or_condition == "ma5<ma10":
|
||
condition = condition or (ma5 < ma10)
|
||
elif or_condition == "ma10<ma20":
|
||
condition = condition or (ma10 < ma20)
|
||
elif or_condition == "ma20<ma30":
|
||
condition = condition or (ma20 < ma30)
|
||
elif or_condition == "close<ma20":
|
||
condition = condition or (close < ma20)
|
||
elif or_condition == "macd_diff<0":
|
||
condition = condition or (macd_diff < 0)
|
||
elif or_condition == "macd_dea<0":
|
||
condition = condition or (macd_dea < 0)
|
||
elif or_condition == "macd<0":
|
||
condition = condition or (macd < 0)
|
||
else:
|
||
pass
|
||
return condition
|
||
|
||
def draw_quant_pct_chg_bar_chart(
|
||
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
|
||
):
|
||
"""
|
||
绘制pct_chg mean的柱状图表(美观,保存到self.stats_chart_dir)
|
||
:param data: 波段pct_chg_mean的数据
|
||
:return: None
|
||
"""
|
||
if data is None or data.empty:
|
||
return None
|
||
# seaborn风格设置
|
||
sns.set_theme(style="whitegrid")
|
||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
|
||
plt.rcParams["font.size"] = 11 # 设置字体大小
|
||
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
|
||
chart_dict = {}
|
||
column_name_dict = {"pct_chg_total": "量化策略涨跌", "pct_chg_mean": "量化策略涨跌均值"}
|
||
for column_name, column_name_text in column_name_dict.items():
|
||
for bar in data["bar"].unique():
|
||
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
|
||
if bar_data.empty:
|
||
continue
|
||
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
|
||
if column_name == "pct_chg_total":
|
||
bar_data.rename(columns={"market_pct_chg": "市场自然涨跌"}, inplace=True)
|
||
# 可选:按均值排序
|
||
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
|
||
bar_data.reset_index(drop=True, inplace=True)
|
||
|
||
# 如果column_name_text是"量化策略涨跌",则柱状图,同时绘制量化策略涨跌与市场自然涨跌的柱状图,并绘制在同一个图表中
|
||
if column_name == "pct_chg_total":
|
||
plt.figure(figsize=(12, 7))
|
||
|
||
# 设置x轴位置,为并列柱状图做准备
|
||
x = np.arange(len(bar_data))
|
||
width = 0.35 # 柱状图宽度
|
||
|
||
# 确保symbol列是字符串类型,避免matplotlib警告
|
||
bar_data["symbol"] = bar_data["symbol"].astype(str)
|
||
|
||
# 绘制量化策略涨跌柱状图(蓝色渐变色)
|
||
bars1 = plt.bar(x - width/2, bar_data[column_name_text], width,
|
||
label=column_name_text,
|
||
color=plt.cm.Blues(np.linspace(0.6, 0.9, len(bar_data))))
|
||
|
||
# 绘制市场自然涨跌柱状图(绿色渐变色)
|
||
bars2 = plt.bar(x + width/2, bar_data["市场自然涨跌"], width,
|
||
label="市场自然涨跌",
|
||
color=plt.cm.Greens(np.linspace(0.6, 0.9, len(bar_data))))
|
||
|
||
# 设置图表标题和标签
|
||
plt.title(f"{bar}趋势{column_name_text}与市场自然涨跌对比(%)", fontsize=14, fontweight='bold')
|
||
plt.xlabel("Symbol", fontsize=12)
|
||
plt.ylabel("涨跌幅(%)", fontsize=12)
|
||
plt.xticks(x, bar_data['symbol'], rotation=45, ha='right')
|
||
plt.legend()
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# 在量化策略涨跌柱状图上方添加数值标签
|
||
for i, (bar1, value1) in enumerate(zip(bars1, bar_data[column_name_text])):
|
||
plt.text(bar1.get_x() + bar1.get_width()/2, value1 + (0.01 if value1 >= 0 else -0.01),
|
||
f'{value1:.3f}%', ha='center', va='bottom' if value1 >= 0 else 'top',
|
||
fontsize=9, fontweight='bold', color='darkblue')
|
||
|
||
# 在市场自然涨跌柱状图上方添加数值标签
|
||
for i, (bar2, value2) in enumerate(zip(bars2, bar_data["市场自然涨跌"])):
|
||
plt.text(bar2.get_x() + bar2.get_width()/2, value2 + (0.01 if value2 >= 0 else -0.01),
|
||
f'{value2:.3f}%', ha='center', va='bottom' if value2 >= 0 else 'top',
|
||
fontsize=9, fontweight='bold', color='darkgreen')
|
||
|
||
else:
|
||
plt.figure(figsize=(10, 6))
|
||
|
||
# 确保symbol列是字符串类型,避免matplotlib警告
|
||
bar_data["symbol"] = bar_data["symbol"].astype(str)
|
||
|
||
ax = sns.barplot(
|
||
x="symbol", y=column_name_text, data=bar_data, palette="Blues_d"
|
||
)
|
||
plt.title(f"{bar}趋势{column_name_text}(%)")
|
||
plt.xlabel("symbol")
|
||
plt.ylabel(column_name_text)
|
||
plt.xticks(rotation=45, ha="right")
|
||
|
||
# 在柱状图上添加数值标签
|
||
for i, v in enumerate(bar_data[column_name_text]):
|
||
ax.text(
|
||
i,
|
||
v,
|
||
f"{v:.3f}",
|
||
ha="center",
|
||
va="bottom",
|
||
fontsize=10,
|
||
fontweight="bold",
|
||
)
|
||
|
||
plt.tight_layout()
|
||
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{bar}_bar_chart_{column_name}_{strategy_name}.png",
|
||
)
|
||
plt.savefig(save_path, dpi=150)
|
||
plt.close()
|
||
|
||
sheet_name = f"{bar}_趋势{column_name_text}柱状图_{strategy_name}"
|
||
chart_dict[sheet_name] = save_path
|
||
return chart_dict
|
||
|
||
def draw_quant_line_chart(self, data: pd.DataFrame, strategy_name: str = "全均线策略"):
|
||
"""
|
||
根据量化策略买卖明细记录,绘制量化策略涨跌与市场自然涨跌的折线图
|
||
:param data: 量化策略买卖明细记录
|
||
:param strategy_name: 策略名称
|
||
:return: None
|
||
"""
|
||
symbols = data["symbol"].unique()
|
||
bars = data["bar"].unique()
|
||
chart_dict = {}
|
||
for symbol in symbols:
|
||
for bar in bars:
|
||
symbol_bar_data = data[(data["symbol"] == symbol) & (data["bar"] == bar)]
|
||
if symbol_bar_data.empty:
|
||
continue
|
||
|
||
# 获取第一行数据作为基准
|
||
first_row = symbol_bar_data.iloc[0].copy()
|
||
|
||
# 创建初始化行,设置基准值
|
||
init_row = first_row.copy()
|
||
init_row.loc["pct_chg_total"] = 1.0 # 量化策略初始值为1
|
||
init_row.loc["end_timestamp"] = first_row["begin_timestamp"]
|
||
init_row.loc["end_date_time"] = first_row["begin_date_time"]
|
||
init_row.loc["end_close"] = first_row["begin_close"]
|
||
init_row.loc["end_ma5"] = first_row["begin_ma5"]
|
||
init_row.loc["end_ma10"] = first_row["begin_ma10"]
|
||
init_row.loc["end_ma20"] = first_row["begin_ma20"]
|
||
init_row.loc["end_ma30"] = first_row["begin_ma30"]
|
||
init_row.loc["end_macd_diff"] = first_row["begin_macd_diff"]
|
||
init_row.loc["end_macd_dea"] = first_row["begin_macd_dea"]
|
||
init_row.loc["end_macd"] = first_row["begin_macd"]
|
||
init_row.loc["pct_chg"] = 0
|
||
init_row.loc["interval_seconds"] = 0
|
||
init_row.loc["interval_minutes"] = 0
|
||
init_row.loc["interval_hours"] = 0
|
||
init_row.loc["interval_days"] = 0
|
||
|
||
# 将初始化行添加到数据开头
|
||
symbol_bar_data = pd.concat([pd.DataFrame([init_row]), symbol_bar_data])
|
||
symbol_bar_data.sort_values(by="end_timestamp", ascending=True, inplace=True)
|
||
symbol_bar_data.reset_index(drop=True, inplace=True)
|
||
|
||
# 确保时间列是datetime类型,避免matplotlib警告
|
||
symbol_bar_data["end_date_time"] = pd.to_datetime(symbol_bar_data["end_date_time"])
|
||
|
||
# 计算市场价位归一化数据(相对于初始价格)
|
||
symbol_bar_data["end_close_to_1"] = symbol_bar_data["end_close"] / init_row["end_close"]
|
||
symbol_bar_data["end_close_to_1"] = symbol_bar_data["end_close_to_1"].round(4)
|
||
|
||
# 绘制折线图
|
||
plt.figure(figsize=(12, 7))
|
||
|
||
# 绘制量化策略涨跌线(蓝色)
|
||
plt.plot(symbol_bar_data["end_date_time"], symbol_bar_data["pct_chg_total"],
|
||
label="量化策略涨跌", color='blue', linewidth=2, marker='o', markersize=4)
|
||
|
||
# 绘制市场自然涨跌线(绿色)
|
||
plt.plot(symbol_bar_data["end_date_time"], symbol_bar_data["end_close_to_1"],
|
||
label="市场自然涨跌", color='green', linewidth=2, marker='s', markersize=4)
|
||
|
||
plt.title(f"{symbol} {bar} 量化与市场折线图_{strategy_name}",
|
||
fontsize=14, fontweight='bold')
|
||
plt.xlabel("时间", fontsize=12)
|
||
plt.ylabel("涨跌变化", fontsize=12)
|
||
plt.legend(fontsize=11)
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# 设置x轴标签,避免matplotlib警告
|
||
# 选择合适的时间间隔显示标签,避免过于密集
|
||
if len(symbol_bar_data) > 30:
|
||
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
|
||
step = max(1, len(symbol_bar_data) // 30)
|
||
|
||
# 创建标签索引列表,确保包含首尾数据
|
||
label_indices = [0] # 第一条
|
||
|
||
# 添加中间间隔的标签
|
||
for i in range(step, len(symbol_bar_data) - 1, step):
|
||
label_indices.append(i)
|
||
|
||
# 添加最后一条(如果还没有包含的话)
|
||
if len(symbol_bar_data) - 1 not in label_indices:
|
||
label_indices.append(len(symbol_bar_data) - 1)
|
||
|
||
# 设置x轴标签
|
||
plt.xticks(symbol_bar_data["end_date_time"].iloc[label_indices],
|
||
symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%m-%d %H:%M'),
|
||
rotation=45, ha='right')
|
||
else:
|
||
# 如果数据点较少,全部显示
|
||
plt.xticks(symbol_bar_data["end_date_time"],
|
||
symbol_bar_data["end_date_time"].dt.strftime('%m-%d %H:%M'),
|
||
rotation=45, ha='right')
|
||
|
||
plt.tight_layout()
|
||
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{symbol}_{bar}_line_chart_{strategy_name}.png",
|
||
)
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
sheet_name = f"{symbol}_{bar}_折线图_{strategy_name}"
|
||
chart_dict[sheet_name] = save_path
|
||
return chart_dict
|
||
|
||
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
|
||
"""
|
||
输出Excel文件,包含所有图表
|
||
charts_dict: 图表数据字典,格式为:
|
||
{
|
||
"sheet_name": {
|
||
"chart_name": "chart_path"
|
||
}
|
||
}
|
||
"""
|
||
logger.info(f"将图表输出到{excel_file_path}")
|
||
|
||
# 打开已经存在的Excel文件
|
||
wb = openpyxl.load_workbook(excel_file_path)
|
||
|
||
for sheet_name, chart_path in charts_dict.items():
|
||
try:
|
||
ws = wb.create_sheet(title=sheet_name)
|
||
row_offset = 1
|
||
# Insert chart image
|
||
img = Image(chart_path)
|
||
ws.add_image(img, f"A{row_offset}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
|
||
continue
|
||
# Save Excel file
|
||
wb.save(excel_file_path)
|
||
print(f"Chart saved as {excel_file_path}")
|