combine trade strategy

This commit is contained in:
blade 2025-08-22 18:48:59 +08:00
parent b322aaa421
commit b1e7ddc261
9 changed files with 654 additions and 302 deletions

View File

@ -1,295 +0,0 @@
import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import re
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
from core.db.db_market_data import DBMarketData
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MaBreakStatistics:
"""
统计MA突破之后的涨跌幅
MA向上突破的点位周期K线5 > 10 > 20 > 30
统计MA向上突破的点位周期K线突破之后
下一个MA向下突破的点位周期K线30 > 20 > 10 > 5
之间的涨跌幅
"""
def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.symbols = MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
self.bars = MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.stats_output_dir = "./output/statistics/excel/"
os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = "./output/statistics/chart/"
os.makedirs(self.stats_chart_dir, exist_ok=True)
def batch_statistics(self, all_change: bool = True):
ma_break_market_data_list = []
for symbol in self.symbols:
for bar in self.bars:
logger.info(f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计")
ma_break_market_data = self.statistics(symbol, bar, all_change)
if ma_break_market_data is not None:
ma_break_market_data_list.append(ma_break_market_data)
if len(ma_break_market_data_list) > 0:
ma_break_market_data = pd.concat(ma_break_market_data_list)
# 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
pct_chg_df = (ma_break_market_data
.groupby(['symbol', 'bar'])['pct_chg']
.agg(pct_chg_max='max',
pct_chg_min='min',
pct_chg_mean='mean',
pct_chg_std='std',
pct_chg_median='median',
pct_chg_count='count')
.reset_index())
# 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
interval_minutes_df = (ma_break_market_data
.groupby(['symbol', 'bar'])['interval_minutes']
.agg(interval_minutes_max='max',
interval_minutes_min='min',
interval_minutes_mean='mean',
interval_minutes_std='std',
interval_minutes_median='median',
interval_minutes_count='count')
.reset_index())
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
earliest_market_date_time = re.sub(r"[\:\-\s]", "", str(earliest_market_date_time))
latest_market_date_time = ma_break_market_data["end_date_time"].max()
if latest_market_date_time is None:
latest_market_date_time = datetime.now().strftime("%Y%m%d")
latest_market_date_time = re.sub(r"[\:\-\s]", "", str(latest_market_date_time))
if all_change:
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_完全转势.xlsx"
else:
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_部分转势.xlsx"
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
logger.info(f"导出{output_file_path}")
with pd.ExcelWriter(output_file_path) as writer:
ma_break_market_data.to_excel(writer, sheet_name="ma_break_market_data", index=False)
pct_chg_df.to_excel(writer, sheet_name="pct_chg_stats", index=False)
interval_minutes_df.to_excel(writer, sheet_name="interval_minutes_stats", index=False)
chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, all_change)
self.output_chart_to_excel(output_file_path, chart_dict)
else:
return None
def statistics(self, symbol: str, bar: str, all_change: bool = False):
market_data = self.db_market_data.query_market_data_by_symbol_bar(symbol, bar, start=None, end=None)
if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败")
return
else:
market_data = pd.DataFrame(market_data)
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
market_data.reset_index(drop=True, inplace=True)
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
# 获得ma5, ma10, ma20, ma30不为空的行
market_data = market_data[(market_data["ma5"].notna()) &
(market_data["ma10"].notna()) &
(market_data["ma20"].notna()) &
(market_data["ma30"].notna())]
logger.info(f"ma5, ma10, ma20, ma30不为空的行数据条数: {len(market_data)}")
# 计算volume_ma5
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
market_data["volume_pct_chg"] = (market_data["volume"] - market_data["volume_ma5"]) / market_data["volume_ma5"]
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
if all_change:
long_market_data = market_data[(market_data["ma_cross"] == "5上穿10") & (market_data["ma5"] > market_data["ma10"]) &
(market_data["ma10"] > market_data["ma20"]) &
(market_data["ma20"] > market_data["ma30"]) &
(market_data["close"] > market_data["ma20"]) &
(market_data["volume_pct_chg"] > 0.2)]
logger.info(f"5上穿10, 且ma5 > ma10 > ma20 > ma30并且close > ma20并且成交量较前5日均量放大20%以上的行,数据条数: {len(long_market_data)}")
else:
long_market_data = market_data[(market_data["ma_cross"] == "5上穿10") & (market_data["ma5"] > market_data["ma10"]) &
(market_data["volume_pct_chg"] > 0.2)]
logger.info(f"5上穿10, 且ma5 > ma10并且成交量较前5日均量放大20%以上的行,数据条数: {len(long_market_data)}")
if len(long_market_data) == 0:
return None
if all_change:
# 获得ma5 < ma10 < ma20 < ma30的行
short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) &
(market_data["ma10"] < market_data["ma20"]) &
(market_data["ma20"] < market_data["ma30"])]
logger.info(f"ma5 < ma10 < ma20 < ma30的行数据条数: {len(short_market_data)}")
else:
# ma5 < ma10 or close < ma20
short_market_data = market_data[(market_data["ma5"] < market_data["ma10"]) |
(market_data["close"] < market_data["ma20"])]
logger.info(f"ma5 < ma10 or close < ma20的行数据条数: {len(short_market_data)}")
# concat long_market_data和short_market_data
ma_break_market_data = pd.concat([long_market_data, short_market_data])
# 按照timestamp排序
ma_break_market_data = ma_break_market_data.sort_values(by="timestamp", ascending=True)
# 获得ma_break_market_data的close列
ma_break_market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = []
ma_break_market_data_pair = {}
for index, row in ma_break_market_data.iterrows():
ma_cross = row["ma_cross"]
timestamp = row["timestamp"]
close = row["close"]
ma5 = row["ma5"]
ma10 = row["ma10"]
ma20 = row["ma20"]
ma30 = row["ma30"]
if pd.notna(ma_cross) and ma_cross is not None:
ma_cross = str(ma_cross)
buy_condition = False
if all_change:
buy_condition = (ma_cross == "5上穿10") and (ma5 > ma10 and ma10 > ma20 and ma20 > ma30) and (close > ma20)
else:
buy_condition = (ma_cross == "5上穿10") and (ma5 > ma10)
if buy_condition:
ma_break_market_data_pair = {}
ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp
ma_break_market_data_pair["begin_date_time"] = timestamp_to_datetime(timestamp)
ma_break_market_data_pair["begin_close"] = close
ma_break_market_data_pair["begin_ma5"] = ma5
ma_break_market_data_pair["begin_ma10"] = ma10
ma_break_market_data_pair["begin_ma20"] = ma20
ma_break_market_data_pair["begin_ma30"] = ma30
if all_change:
change_condition = (ma5 < ma10 and ma10 < ma20 and ma20 < ma30)
else:
# change_condition = (ma5 < ma10 or ma10 < ma20 or ma20 < ma30)
change_condition = (ma5 < ma10) or (close < ma20)
if change_condition:
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
continue
ma_break_market_data_pair["end_timestamp"] = timestamp
ma_break_market_data_pair["end_date_time"] = timestamp_to_datetime(timestamp)
ma_break_market_data_pair["end_close"] = close
ma_break_market_data_pair["end_ma5"] = ma5
ma_break_market_data_pair["end_ma10"] = ma10
ma_break_market_data_pair["end_ma20"] = ma20
ma_break_market_data_pair["end_ma30"] = ma30
ma_break_market_data_pair["pct_chg"] = (close - ma_break_market_data_pair["begin_close"]) / ma_break_market_data_pair["begin_close"]
ma_break_market_data_pair["pct_chg"] = round(ma_break_market_data_pair["pct_chg"] * 100, 4)
ma_break_market_data_pair["interval_seconds"] = (timestamp - ma_break_market_data_pair["begin_timestamp"]) / 1000
# 将interval转换为分钟
ma_break_market_data_pair["interval_minutes"] = ma_break_market_data_pair["interval_seconds"] / 60
ma_break_market_data_pair["interval_hours"] = ma_break_market_data_pair["interval_seconds"] / 3600
ma_break_market_data_pair["interval_days"] = ma_break_market_data_pair["interval_seconds"] / 86400
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
ma_break_market_data_pair = {}
if len(ma_break_market_data_pair_list) > 0:
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
return ma_break_market_data
else:
return None
def draw_pct_chg_mean_chart(self, data: pd.DataFrame, all_change: bool = True):
"""
绘制pct_chg mean的柱状图表美观保存到self.stats_chart_dir
:param data: 波段pct_chg_mean的数据
:return: None
"""
if data is None or data.empty:
return None
# seaborn风格设置
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
chart_dict = {}
for bar in data["bar"].unique():
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
if bar_data.empty:
continue
bar_data.rename(columns={"pct_chg_mean": "涨跌幅均值"}, inplace=True)
# 可选:按均值排序
bar_data.sort_values(by="涨跌幅均值", ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
plt.figure(figsize=(10, 6))
sns.barplot(x="symbol", y="涨跌幅均值", data=bar_data, palette="Blues_d")
plt.title(f"{bar}趋势涨跌幅均值分布")
plt.xlabel("symbol")
plt.ylabel("涨跌幅均值")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
if all_change:
save_path = os.path.join(self.stats_chart_dir, f"{bar}_ma_break_pct_chg_mean_all_change.png")
else:
save_path = os.path.join(self.stats_chart_dir, f"{bar}_ma_break_pct_chg_mean_part_change.png")
plt.savefig(save_path, dpi=150)
plt.close()
if all_change:
sheet_name = f"{bar}_趋势涨跌幅均值分布图表_完全转势"
else:
sheet_name = f"{bar}_趋势涨跌幅均值分布图表_部分转势"
chart_dict[sheet_name] = save_path
return chart_dict
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典格式为
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_path in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")

View File

@ -0,0 +1,506 @@
import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import re
import json
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
from core.db.db_market_data import DBMarketData
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MaBreakStatistics:
"""
统计MA突破之后的涨跌幅
MA向上突破的点位周期K线5 > 10 > 20 > 30
统计MA向上突破的点位周期K线突破之后
下一个MA向下突破的点位周期K线30 > 20 > 10 > 5
之间的涨跌幅
"""
def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.symbols = MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
self.bars = MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/"
os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = "./output/trade_sandbox/ma_strategy/chart/"
os.makedirs(self.stats_chart_dir, exist_ok=True)
self.trade_strategy_config = self.get_trade_strategy_config()
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
def get_trade_strategy_config(self):
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
trade_strategy_config = json.load(f)
return trade_strategy_config
def batch_statistics(self, strategy_name: str = "全均线策略"):
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
os.makedirs(self.stats_chart_dir, exist_ok=True)
ma_break_market_data_list = []
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
strategy_name = "全均线策略"
for symbol in self.symbols:
for bar in self.bars:
logger.info(f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name}")
ma_break_market_data = self.trade_simulate(symbol, bar, strategy_name)
if ma_break_market_data is not None:
ma_break_market_data_list.append(ma_break_market_data)
if len(ma_break_market_data_list) > 0:
ma_break_market_data = pd.concat(ma_break_market_data_list)
# 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
pct_chg_df = (
ma_break_market_data.groupby(["symbol", "bar"])["pct_chg"]
.agg(
pct_chg_sum="sum",
pct_chg_max="max",
pct_chg_min="min",
pct_chg_mean="mean",
pct_chg_std="std",
pct_chg_median="median",
pct_chg_count="count",
)
.reset_index()
)
# 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
interval_minutes_df = (
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
.agg(
interval_minutes_max="max",
interval_minutes_min="min",
interval_minutes_mean="mean",
interval_minutes_std="std",
interval_minutes_median="median",
interval_minutes_count="count",
)
.reset_index()
)
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
earliest_market_date_time = re.sub(
r"[\:\-\s]", "", str(earliest_market_date_time)
)
latest_market_date_time = ma_break_market_data["end_date_time"].max()
if latest_market_date_time is None:
latest_market_date_time = datetime.now().strftime("%Y%m%d")
latest_market_date_time = re.sub(
r"[\:\-\s]", "", str(latest_market_date_time)
)
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}.xlsx"
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
logger.info(f"导出{output_file_path}")
strategy_info_df = self.get_strategy_info(strategy_name)
with pd.ExcelWriter(output_file_path) as writer:
strategy_info_df.to_excel(writer, sheet_name="策略信息", index=False)
ma_break_market_data.to_excel(
writer, sheet_name="买卖记录明细", index=False
)
pct_chg_df.to_excel(writer, sheet_name="买卖涨跌幅统计", index=False)
interval_minutes_df.to_excel(
writer, sheet_name="买卖时间间隔统计", index=False
)
chart_dict = self.draw_pct_chg_mean_chart(pct_chg_df, strategy_name)
self.output_chart_to_excel(output_file_path, chart_dict)
else:
return None
def get_strategy_info(self, strategy_name: str = "全均线策略"):
strategy_config = self.main_strategy.get(strategy_name, None)
if strategy_config is None:
logger.error(f"策略{strategy_name}不存在")
return None
strategy_info = {"策略名称": strategy_name, "买入策略": "", "卖出策略": ""}
buy_dict = strategy_config.get("buy", {})
buy_and_list = buy_dict.get("and", [])
buy_or_list = buy_dict.get("or", [])
buy_and_text = ""
buy_or_text = ""
for and_condition in buy_and_list:
buy_and_text += f"{and_condition}, \n"
if len(buy_or_list) > 0:
for or_condition in buy_or_list:
buy_or_text += f"{or_condition}, \n"
if len(buy_or_text) > 0:
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
else:
strategy_info["买入策略"] = buy_and_text
sell_dict = strategy_config.get("sell", {})
sell_and_list = sell_dict.get("and", [])
sell_or_list = sell_dict.get("or", [])
sell_and_text = ""
sell_or_text = ""
for and_condition in sell_and_list:
sell_and_text += f"{and_condition}, \n"
if len(sell_or_list) > 0:
for or_condition in sell_or_list:
sell_or_text += f"{or_condition}, \n"
if len(sell_or_text) > 0:
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
else:
strategy_info["卖出策略"] = sell_and_text
# 将strategy_info转换为pd.DataFrame
strategy_info_df = pd.DataFrame([strategy_info])
return strategy_info_df
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
market_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, start=None, end=None
)
if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败")
return
else:
market_data = pd.DataFrame(market_data)
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
market_data.reset_index(drop=True, inplace=True)
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
# 获得ma5, ma10, ma20, ma30不为空的行
market_data = market_data[
(market_data["ma5"].notna())
& (market_data["ma10"].notna())
& (market_data["ma20"].notna())
& (market_data["ma30"].notna())
]
logger.info(
f"ma5, ma10, ma20, ma30不为空的行数据条数: {len(market_data)}"
)
# 计算volume_ma5
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
market_data["volume_pct_chg"] = (
market_data["volume"] - market_data["volume_ma5"]
) / market_data["volume_ma5"]
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
# 按照timestamp排序
market_data = market_data.sort_values(by="timestamp", ascending=True)
# 获得ma_break_market_data的close列
market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = []
ma_break_market_data_pair = {}
for index, row in market_data.iterrows():
ma_cross = row["ma_cross"]
timestamp = row["timestamp"]
close = row["close"]
ma5 = row["ma5"]
ma10 = row["ma10"]
ma20 = row["ma20"]
ma30 = row["ma30"]
macd_diff = float(row["dif"])
macd_dea = float(row["dea"])
macd = float(row["macd"])
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
buy_condition = self.fit_strategy(
strategy_name=strategy_name,
market_data=market_data,
row=row,
behavior="buy",
)
if buy_condition:
ma_break_market_data_pair = {}
ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp
ma_break_market_data_pair["begin_date_time"] = (
timestamp_to_datetime(timestamp)
)
ma_break_market_data_pair["begin_close"] = close
ma_break_market_data_pair["begin_ma5"] = ma5
ma_break_market_data_pair["begin_ma10"] = ma10
ma_break_market_data_pair["begin_ma20"] = ma20
ma_break_market_data_pair["begin_ma30"] = ma30
ma_break_market_data_pair["begin_macd_diff"] = macd_diff
ma_break_market_data_pair["begin_macd_dea"] = macd_dea
ma_break_market_data_pair["begin_macd"] = macd
continue
else:
sell_condition = self.fit_strategy(
strategy_name=strategy_name,
market_data=market_data,
row=row,
behavior="sell",
)
if sell_condition:
ma_break_market_data_pair["end_timestamp"] = timestamp
ma_break_market_data_pair["end_date_time"] = (
timestamp_to_datetime(timestamp)
)
ma_break_market_data_pair["end_close"] = close
ma_break_market_data_pair["end_ma5"] = ma5
ma_break_market_data_pair["end_ma10"] = ma10
ma_break_market_data_pair["end_ma20"] = ma20
ma_break_market_data_pair["end_ma30"] = ma30
ma_break_market_data_pair["end_macd_diff"] = macd_diff
ma_break_market_data_pair["end_macd_dea"] = macd_dea
ma_break_market_data_pair["end_macd"] = macd
ma_break_market_data_pair["pct_chg"] = (
close - ma_break_market_data_pair["begin_close"]
) / ma_break_market_data_pair["begin_close"]
ma_break_market_data_pair["pct_chg"] = round(
ma_break_market_data_pair["pct_chg"] * 100, 4
)
ma_break_market_data_pair["interval_seconds"] = (
timestamp - ma_break_market_data_pair["begin_timestamp"]
) / 1000
# 将interval转换为分钟
ma_break_market_data_pair["interval_minutes"] = (
ma_break_market_data_pair["interval_seconds"] / 60
)
ma_break_market_data_pair["interval_hours"] = (
ma_break_market_data_pair["interval_seconds"] / 3600
)
ma_break_market_data_pair["interval_days"] = (
ma_break_market_data_pair["interval_seconds"] / 86400
)
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
ma_break_market_data_pair = {}
if len(ma_break_market_data_pair_list) > 0:
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
logger.info(f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}")
return ma_break_market_data
else:
return None
def fit_strategy(
self,
strategy_name: str = "全均线策略",
market_data: pd.DataFrame = None,
row: pd.Series = None,
behavior: str = "buy",
):
strategy_config = self.main_strategy.get(strategy_name, None)
if strategy_config is None:
logger.error(f"策略{strategy_name}不存在")
return False
condition_dict = strategy_config.get(behavior, None)
if condition_dict is None:
logger.error(f"策略{strategy_name}{behavior}条件不存在")
return False
ma_cross = row["ma_cross"]
if pd.isna(ma_cross) or ma_cross is None:
ma_cross = ""
ma_cross = str(ma_cross)
ma5 = float(row["ma5"])
ma10 = float(row["ma10"])
ma20 = float(row["ma20"])
ma30 = float(row["ma30"])
close = float(row["close"])
volume_pct_chg = float(row["volume_pct_chg"])
macd_diff = float(row["dif"])
macd_dea = float(row["dea"])
macd = float(row["macd"])
and_list = condition_dict.get("and", [])
condition = True
for and_condition in and_list:
if and_condition == "5上穿10":
condition = condition and ("5上穿10" in ma_cross)
elif and_condition == "10上穿20":
condition = condition and ("10上穿20" in ma_cross)
elif and_condition == "20上穿30":
condition = condition and ("20上穿30" in ma_cross)
elif and_condition == "ma5>ma10":
condition = condition and (ma5 > ma10)
elif and_condition == "ma10>ma20":
condition = condition and (ma10 > ma20)
elif and_condition == "ma20>ma30":
condition = condition and (ma20 > ma30)
elif and_condition == "close>ma20":
condition = condition and (close > ma20)
elif and_condition == "volume_pct_chg>0.2":
condition = condition and (volume_pct_chg > 0.2)
elif and_condition == "macd_diff>0":
condition = condition and (macd_diff > 0)
elif and_condition == "macd_dea>0":
condition = condition and (macd_dea > 0)
elif and_condition == "macd>0":
condition = condition and (macd > 0)
elif and_condition == "10下穿5":
condition = condition and ("10下穿5" in ma_cross)
elif and_condition == "20下穿10":
condition = condition and ("20下穿10" in ma_cross)
elif and_condition == "30下穿20":
condition = condition and ("30下穿20" in ma_cross)
elif and_condition == "ma5<ma10":
condition = condition and (ma5 < ma10)
elif and_condition == "ma10<ma20":
condition = condition and (ma10 < ma20)
elif and_condition == "ma20<ma30":
condition = condition and (ma20 < ma30)
elif and_condition == "close<ma20":
condition = condition and (close < ma20)
elif and_condition == "macd_diff<0":
condition = condition and (macd_diff < 0)
elif and_condition == "macd_dea<0":
condition = condition and (macd_dea < 0)
elif and_condition == "macd<0":
condition = condition and (macd < 0)
else:
pass
if not condition:
or_list = condition_dict.get("or", [])
for or_condition in or_list:
if or_condition == "5上穿10":
condition = condition or ("5上穿10" in ma_cross)
elif or_condition == "10上穿20":
condition = condition or ("10上穿20" in ma_cross)
elif or_condition == "20上穿30":
condition = condition or ("20上穿30" in ma_cross)
elif or_condition == "ma5>ma10":
condition = condition or (ma5 > ma10)
elif or_condition == "ma10>ma20":
condition = condition or (ma10 > ma20)
elif or_condition == "ma20>ma30":
condition = condition or (ma20 > ma30)
elif or_condition == "close>ma20":
condition = condition or (close > ma20)
elif or_condition == "volume_pct_chg>0.2":
condition = condition or (volume_pct_chg > 0.2)
elif or_condition == "macd_diff>0":
condition = condition or (macd_diff > 0)
elif or_condition == "macd_dea>0":
condition = condition or (macd_dea > 0)
elif or_condition == "macd>0":
condition = condition or (macd > 0)
elif or_condition == "10下穿5":
condition = condition or ("10下穿5" in ma_cross)
elif or_condition == "20下穿10":
condition = condition or ("20下穿10" in ma_cross)
elif or_condition == "30下穿20":
condition = condition or ("30下穿20" in ma_cross)
elif or_condition == "ma5<ma10":
condition = condition or (ma5 < ma10)
elif or_condition == "ma10<ma20":
condition = condition or (ma10 < ma20)
elif or_condition == "ma20<ma30":
condition = condition or (ma20 < ma30)
elif or_condition == "close<ma20":
condition = condition or (close < ma20)
elif or_condition == "macd_diff<0":
condition = condition or (macd_diff < 0)
elif or_condition == "macd_dea<0":
condition = condition or (macd_dea < 0)
elif or_condition == "macd<0":
condition = condition or (macd < 0)
else:
pass
return condition
def draw_pct_chg_mean_chart(
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
):
"""
绘制pct_chg mean的柱状图表美观保存到self.stats_chart_dir
:param data: 波段pct_chg_mean的数据
:return: None
"""
if data is None or data.empty:
return None
# seaborn风格设置
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
chart_dict = {}
column_name_dict = {
"pct_chg_sum": "涨跌总和",
"pct_chg_mean": "涨跌均值",
}
for column_name, column_name_text in column_name_dict.items():
for bar in data["bar"].unique():
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
if bar_data.empty:
continue
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
# 可选:按均值排序
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
plt.figure(figsize=(10, 6))
ax = sns.barplot(x="symbol", y=column_name_text, data=bar_data, palette="Blues_d")
plt.title(f"{bar}趋势{column_name_text}(%)")
plt.xlabel("symbol")
plt.ylabel(column_name_text)
plt.xticks(rotation=45, ha="right")
# 在柱状图上添加数值标签
for i, v in enumerate(bar_data[column_name_text]):
ax.text(i, v, f'{v:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
plt.tight_layout()
save_path = os.path.join(
self.stats_chart_dir, f"{bar}_ma_break_{column_name}_{strategy_name}.png"
)
plt.savefig(save_path, dpi=150)
plt.close()
sheet_name = f"{bar}_趋势{column_name_text}分布图表_{strategy_name}"
chart_dict[sheet_name] = save_path
return chart_dict
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典格式为
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_path in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")

View File

@ -97,7 +97,7 @@ class MeanReversionSandbox:
os.makedirs("./json/", exist_ok=True) os.makedirs("./json/", exist_ok=True)
json_file_path = "./json/peak_valley_data.json" json_file_path = "./json/peak_valley_data.json"
if not os.path.exists(json_file_path): if not os.path.exists(json_file_path):
excel_file_path = "./output/statistics/excel/price_volume_stats_window_size_100_from_20250515000000_to_20250819110500.xlsx" excel_file_path = "./output/statistics/excel/price_volume_stats_window_size_100_from_20250515000000_to_20250822145000.xlsx"
if not os.path.exists(excel_file_path): if not os.path.exists(excel_file_path):
raise FileNotFoundError(f"Excel file not found: {excel_file_path}") raise FileNotFoundError(f"Excel file not found: {excel_file_path}")
sheet_name = "波峰波谷统计" sheet_name = "波峰波谷统计"

File diff suppressed because one or more lines are too long

102
json/trade_strategy.json Normal file
View File

@ -0,0 +1,102 @@
{
"均线系统策略": {
"全均线策略": {
"buy": {
"and": [
"5上穿10",
"ma5>ma10",
"ma10>ma20",
"ma20>ma30",
"close>ma20",
"volume_pct_chg>0.2"
]
},
"sell": {
"and": [
"ma5<ma10",
"ma10<ma20",
"ma20<ma30",
"volume_pct_chg>0.2"
]
}
},
"5上穿10策略": {
"buy": {
"and": [
"5上穿10",
"ma5>ma10",
"volume_pct_chg>0.2"
]
},
"sell": {
"and": [
"ma5<ma10",
"volume_pct_chg>0.2"
],
"or": [
"close < ma20"
]
}
},
"三均线策略": {
"buy": {
"and": [
"ma5>ma10",
"ma10>ma20",
"10上穿20",
"volume_pct_chg>0.2"
]
},
"sell": {
"and": [
"ma5<ma10",
"ma10<ma20"
]
}
},
"均线价格突破策略": {
"buy": {
"and": [
"ma5>ma10",
"close>ma20"
]
},
"sell": {
"and": [
"ma5<ma10",
"close<ma20"
]
}
},
"均线macd结合策略1": {
"buy": {
"and": [
"ma5>ma10",
"macd_diff>0"
]
},
"sell": {
"and": [
"ma5<ma10",
"macd_diff<0"
]
}
},
"均线macd结合策略2": {
"buy": {
"and": [
"ma5>ma10",
"macd_diff>0",
"macd>0"
]
},
"sell": {
"and": [
"ma5<ma10",
"macd_diff<0",
"macd<0"
]
}
}
}
}

View File

@ -3,7 +3,6 @@ from datetime import datetime
from time import sleep from time import sleep
import pandas as pd import pandas as pd
from core.biz.market_data import MarketData from core.biz.market_data import MarketData
from core.statistics.ma_break_statistics import MaBreakStatistics
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation from core.biz.metrics_calculation import MetricsCalculation
from core.utils import ( from core.utils import (
@ -53,7 +52,6 @@ class MarketDataMain:
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url) self.db_market_data = DBMarketData(self.db_url)
self.trade_data_main = TradeDataMain() self.trade_data_main = TradeDataMain()
self.ma_break_statistics = MaBreakStatistics()
def initial_data(self): def initial_data(self):
""" """
@ -408,5 +406,5 @@ if __name__ == "__main__":
market_data_main = MarketDataMain() market_data_main = MarketDataMain()
# market_data_main.batch_update_data() # market_data_main.batch_update_data()
# market_data_main.initial_data() # market_data_main.initial_data()
# market_data_main.batch_calculate_metrics() market_data_main.batch_calculate_metrics()
market_data_main.batch_ma_break_statistics() # market_data_main.batch_ma_break_statistics()

View File

@ -2,7 +2,7 @@ select * from crypto_market_monitor
order by date_time desc; order by date_time desc;
select symbol, bar, date_time, close, select symbol, bar, date_time, close,
pct_chg, kdj_k, kdj_d, kdj_k, kdj_pattern, pct_chg, ma_cross, ma5, ma10, ma20, ma30, dif, dea, macd, kdj_k, kdj_d, kdj_k, kdj_pattern,
rsi_14, rsi_signal, rsi_14, rsi_signal,
boll_upper, boll_middle, boll_lower, boll_pattern, boll_signal boll_upper, boll_middle, boll_lower, boll_pattern, boll_signal
from crypto_market_data from crypto_market_data

42
trade_ma_strategy_main.py Normal file
View File

@ -0,0 +1,42 @@
import core.logger as logging
from datetime import datetime
from time import sleep
import pandas as pd
from core.biz.market_data import MarketData
from core.trade.ma_break_statistics import MaBreakStatistics
from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation
from core.utils import (
datetime_to_timestamp,
timestamp_to_datetime,
transform_date_time_to_timestamp,
)
from trade_data_main import TradeDataMain
from config import (
API_KEY,
SECRET_KEY,
PASSPHRASE,
SANDBOX,
MONITOR_CONFIG,
MYSQL_CONFIG,
BAR_THRESHOLD,
)
logger = logging.logger
class TradeMaStrategyMain:
def __init__(self):
self.ma_break_statistics = MaBreakStatistics()
def batch_ma_break_statistics(self):
"""
批量计算MA突破统计
"""
logger.info("开始批量计算MA突破统计")
strategy_dict = self.ma_break_statistics.main_strategy
for strategy_name, strategy_info in strategy_dict.items():
self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
if __name__ == "__main__":
trade_ma_strategy_main = TradeMaStrategyMain()
trade_ma_strategy_main.batch_ma_break_statistics()