optimize ma_break quant logic

This commit is contained in:
blade 2025-09-15 14:12:47 +08:00
parent 5b7e95f4d9
commit ff2c35e1b3
5 changed files with 127 additions and 43 deletions

Binary file not shown.

View File

@ -4,7 +4,7 @@ import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from datetime import datetime, timedelta
import re
import json
from openpyxl import Workbook
@ -12,8 +12,9 @@ from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE
from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE, BINANCE_MONITOR_CONFIG
from core.db.db_market_data import DBMarketData
from core.db.db_binance_data import DBBinanceData
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
@ -32,7 +33,7 @@ class MaBreakStatistics:
之间的涨跌幅
"""
def __init__(self, is_us_stock: bool = False):
def __init__(self, is_us_stock: bool = False, is_binance: bool = False):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
@ -42,28 +43,45 @@ class MaBreakStatistics:
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.is_us_stock = is_us_stock
self.is_binance = is_binance
if is_us_stock:
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"]
)
else:
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
if is_us_stock:
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
else:
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/"
os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = "./output/trade_sandbox/ma_strategy/chart/"
os.makedirs(self.stats_chart_dir, exist_ok=True)
if is_binance:
self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["BTC-USDT"]
)
self.bars = ["30m", "1H"]
self.initial_date = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2017-08-16 00:00:00"
)
self.db_market_data = DBBinanceData(self.db_url)
else:
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2025-05-15 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
if len(self.initial_date) > 10:
self.initial_date = self.initial_date[:10]
self.end_date = datetime.now().strftime("%Y-%m-%d")
self.trade_strategy_config = self.get_trade_strategy_config()
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
@ -73,14 +91,18 @@ class MaBreakStatistics:
return trade_strategy_config
def batch_statistics(self, strategy_name: str = "全均线策略"):
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/"
)
if self.is_us_stock:
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/us_stock/excel/{strategy_name}/"
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/us_stock/chart/{strategy_name}/"
elif self.is_binance:
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/binance/excel/{strategy_name}/"
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/"
else:
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/"
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/okx/chart/{strategy_name}/"
os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
)
os.makedirs(self.stats_chart_dir, exist_ok=True)
ma_break_market_data_list = []
market_data_pct_chg_list = []
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
@ -276,16 +298,11 @@ class MaBreakStatistics:
return strategy_info_df
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
market_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, start=None, end=None
)
market_data = self.get_full_data(symbol, bar)
if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败")
return None, None
else:
market_data = pd.DataFrame(market_data)
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
market_data.reset_index(drop=True, inplace=True)
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
# 获得ma5, ma10, ma20, ma30不为空的行
market_data = market_data[
@ -299,7 +316,7 @@ class MaBreakStatistics:
)
# 计算volume_ma5
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
market_data["volume_pct_chg"] = (
market_data["volume"] - market_data["volume_ma5"]
) / market_data["volume_ma5"]
@ -311,9 +328,14 @@ class MaBreakStatistics:
market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = []
ma_break_market_data_pair = {}
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
for index, row in market_data.iterrows():
ma_cross = row["ma_cross"]
timestamp = row["timestamp"]
date_time = row[date_time_field]
close = row["close"]
ma5 = row["ma5"]
ma10 = row["ma10"]
@ -336,9 +358,7 @@ class MaBreakStatistics:
ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp
ma_break_market_data_pair["begin_date_time"] = (
timestamp_to_datetime(timestamp)
)
ma_break_market_data_pair["begin_date_time"] = date_time
ma_break_market_data_pair["begin_close"] = close
ma_break_market_data_pair["begin_ma5"] = ma5
ma_break_market_data_pair["begin_ma10"] = ma10
@ -358,9 +378,7 @@ class MaBreakStatistics:
if sell_condition:
ma_break_market_data_pair["end_timestamp"] = timestamp
ma_break_market_data_pair["end_date_time"] = (
timestamp_to_datetime(timestamp)
)
ma_break_market_data_pair["end_date_time"] = date_time
ma_break_market_data_pair["end_close"] = close
ma_break_market_data_pair["end_ma5"] = ma5
ma_break_market_data_pair["end_ma10"] = ma10
@ -411,6 +429,59 @@ class MaBreakStatistics:
return ma_break_market_data, market_data_pct_chg
else:
return None, None
def get_full_data(self, symbol: str, bar: str = "5m"):
"""
分段获取数据并将数据合并为完整数据
分段依据如果end_date与start_date相差超过一年则每次取一年数据
"""
data = pd.DataFrame()
start_date = datetime.strptime(self.initial_date, "%Y-%m-%d")
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
fields = [
"symbol",
"bar",
"timestamp",
"date_time",
"date_time_us",
"open",
"high",
"low",
"close",
"volume",
"sar_signal",
"ma5",
"ma10",
"ma20",
"ma30",
"ma_cross",
"dif",
"dea",
"macd",
]
while start_date < end_date:
current_end_date = min(start_date + timedelta(days=180), end_date)
start_date_str = start_date.strftime("%Y-%m-%d")
current_end_date_str = current_end_date.strftime("%Y-%m-%d")
logger.info(
f"获取{symbol}数据:{start_date_str}{current_end_date_str}"
)
current_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, fields, start=start_date_str, end=current_end_date_str
)
if current_data is not None and len(current_data) > 0:
current_data = pd.DataFrame(current_data)
data = pd.concat([data, current_data])
start_date = current_end_date
data.drop_duplicates(inplace=True)
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
data.sort_values(by=date_time_field, inplace=True)
data.reset_index(drop=True, inplace=True)
return data
def fit_strategy(
self,
@ -734,12 +805,12 @@ class MaBreakStatistics:
# 设置x轴标签
plt.xticks(symbol_bar_data["end_date_time"].iloc[label_indices],
symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%m-%d %H:%M'),
symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%Y%m%d %H:%M'),
rotation=45, ha='right')
else:
# 如果数据点较少,全部显示
plt.xticks(symbol_bar_data["end_date_time"],
symbol_bar_data["end_date_time"].dt.strftime('%m-%d %H:%M'),
symbol_bar_data["end_date_time"].dt.strftime('%Y%m%d %H:%M'),
rotation=45, ha='right')
plt.tight_layout()

View File

@ -28,11 +28,20 @@ def main():
else:
start_date = "2024-01-01"
end_date = datetime.now().strftime("%Y-%m-%d")
# 原值 盈利目标倍数默认10倍$R即10R
profit_target_multiple = 10
# 新值 盈利目标倍数默认20倍$R即10R -- 20250909
# profit_target_multiple = 20
initial_capital = 25000
max_leverage = 4
risk_per_trade = 0.01
commission_per_share = 0.0005
# if is_us_stock:
# commission_per_share = 0.0005
# else:
# commission_per_share = 0
# commission_per_share = 0
trades_df_list = []
trades_summary_df_list = []
@ -47,6 +56,7 @@ def main():
symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"]
)
commission_per_share = 0.0005
else:
if is_binance:
symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
@ -56,6 +66,7 @@ def main():
symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["BTC-USDT"]
)
commission_per_share = 0
for symbol in symbols:
logger.info(
f"开始回测 {symbol}, 交易周期:{bar}, 开始日期:{start_date}, 结束日期:{end_date}, 是否是美股:{is_us_stock}, 交易方向:{direction}, 是否使用SAR:{by_sar}, 是否使用R为entry减stop:{price_range_mean_as_R}, 是否使用K线实体过50%:{by_big_k}"

View File

@ -25,8 +25,8 @@ from config import (
logger = logging.logger
class TradeMaStrategyMain:
def __init__(self, is_us_stock: bool = False):
self.ma_break_statistics = MaBreakStatistics(is_us_stock=is_us_stock)
def __init__(self, is_us_stock: bool = False, is_binance: bool = False):
self.ma_break_statistics = MaBreakStatistics(is_us_stock=is_us_stock, is_binance=is_binance)
def batch_ma_break_statistics(self):
"""
@ -36,8 +36,10 @@ class TradeMaStrategyMain:
strategy_dict = self.ma_break_statistics.main_strategy
pct_chg_df_list = []
for strategy_name, strategy_info in strategy_dict.items():
pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
pct_chg_df_list.append(pct_chg_df)
if "macd" in strategy_name:
# 只计算macd策略
pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
pct_chg_df_list.append(pct_chg_df)
pct_chg_df = pd.concat(pct_chg_df_list)
@ -59,5 +61,5 @@ class TradeMaStrategyMain:
if __name__ == "__main__":
trade_ma_strategy_main = TradeMaStrategyMain(is_us_stock=True)
trade_ma_strategy_main = TradeMaStrategyMain(is_us_stock=False, is_binance=True)
trade_ma_strategy_main.batch_ma_break_statistics()