optimize ma_break quant logic

This commit is contained in:
blade 2025-09-15 14:12:47 +08:00
parent 5b7e95f4d9
commit ff2c35e1b3
5 changed files with 127 additions and 43 deletions

Binary file not shown.

View File

@ -4,7 +4,7 @@ import pandas as pd
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
from datetime import datetime from datetime import datetime, timedelta
import re import re
import json import json
from openpyxl import Workbook from openpyxl import Workbook
@ -12,8 +12,9 @@ from openpyxl.drawing.image import Image
import openpyxl import openpyxl
from openpyxl.styles import Font from openpyxl.styles import Font
from PIL import Image as PILImage from PIL import Image as PILImage
from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE from config import OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE, BINANCE_MONITOR_CONFIG
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.db.db_binance_data import DBBinanceData
from core.db.db_huge_volume_data import DBHugeVolumeData from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
@ -32,7 +33,7 @@ class MaBreakStatistics:
之间的涨跌幅 之间的涨跌幅
""" """
def __init__(self, is_us_stock: bool = False): def __init__(self, is_us_stock: bool = False, is_binance: bool = False):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
@ -42,28 +43,45 @@ class MaBreakStatistics:
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.db_huge_volume_data = DBHugeVolumeData(self.db_url) self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.is_us_stock = is_us_stock
self.is_binance = is_binance
if is_us_stock: if is_us_stock:
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"] "symbols", ["QQQ"]
) )
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
else:
if is_binance:
self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["BTC-USDT"]
)
self.bars = ["30m", "1H"]
self.initial_date = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2017-08-16 00:00:00"
)
self.db_market_data = DBBinanceData(self.db_url)
else: else:
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"] "symbols", ["XCH-USDT"]
) )
if is_us_stock:
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
else:
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"] "bars", ["5m", "15m", "30m", "1H"]
) )
self.stats_output_dir = "./output/trade_sandbox/ma_strategy/excel/" self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
os.makedirs(self.stats_output_dir, exist_ok=True) "initial_date", "2025-05-15 00:00:00"
self.stats_chart_dir = "./output/trade_sandbox/ma_strategy/chart/" )
os.makedirs(self.stats_chart_dir, exist_ok=True) self.db_market_data = DBMarketData(self.db_url)
if len(self.initial_date) > 10:
self.initial_date = self.initial_date[:10]
self.end_date = datetime.now().strftime("%Y-%m-%d")
self.trade_strategy_config = self.get_trade_strategy_config() self.trade_strategy_config = self.get_trade_strategy_config()
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None) self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
@ -73,14 +91,18 @@ class MaBreakStatistics:
return trade_strategy_config return trade_strategy_config
def batch_statistics(self, strategy_name: str = "全均线策略"): def batch_statistics(self, strategy_name: str = "全均线策略"):
self.stats_output_dir = ( if self.is_us_stock:
f"./output/trade_sandbox/ma_strategy/excel/{strategy_name}/" self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/us_stock/excel/{strategy_name}/"
) self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/us_stock/chart/{strategy_name}/"
elif self.is_binance:
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/binance/excel/{strategy_name}/"
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/"
else:
self.stats_output_dir = f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/"
self.stats_chart_dir = f"./output/trade_sandbox/ma_strategy/okx/chart/{strategy_name}/"
os.makedirs(self.stats_output_dir, exist_ok=True) os.makedirs(self.stats_output_dir, exist_ok=True)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/chart/{strategy_name}/"
)
os.makedirs(self.stats_chart_dir, exist_ok=True) os.makedirs(self.stats_chart_dir, exist_ok=True)
ma_break_market_data_list = [] ma_break_market_data_list = []
market_data_pct_chg_list = [] market_data_pct_chg_list = []
if strategy_name not in self.main_strategy.keys() or strategy_name is None: if strategy_name not in self.main_strategy.keys() or strategy_name is None:
@ -276,16 +298,11 @@ class MaBreakStatistics:
return strategy_info_df return strategy_info_df
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"): def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
market_data = self.db_market_data.query_market_data_by_symbol_bar( market_data = self.get_full_data(symbol, bar)
symbol, bar, start=None, end=None
)
if market_data is None or len(market_data) == 0: if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败") logger.warning(f"获取{symbol} {bar} 数据失败")
return None, None return None, None
else: else:
market_data = pd.DataFrame(market_data)
market_data.sort_values(by="timestamp", ascending=True, inplace=True)
market_data.reset_index(drop=True, inplace=True)
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}") logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
# 获得ma5, ma10, ma20, ma30不为空的行 # 获得ma5, ma10, ma20, ma30不为空的行
market_data = market_data[ market_data = market_data[
@ -299,7 +316,7 @@ class MaBreakStatistics:
) )
# 计算volume_ma5 # 计算volume_ma5
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean() market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
# 获得5上穿10且ma5 > ma10 > ma20 > ma30且close > ma20的行,成交量较前5日均量放大20%以上
market_data["volume_pct_chg"] = ( market_data["volume_pct_chg"] = (
market_data["volume"] - market_data["volume_ma5"] market_data["volume"] - market_data["volume_ma5"]
) / market_data["volume_ma5"] ) / market_data["volume_ma5"]
@ -311,9 +328,14 @@ class MaBreakStatistics:
market_data.reset_index(drop=True, inplace=True) market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = [] ma_break_market_data_pair_list = []
ma_break_market_data_pair = {} ma_break_market_data_pair = {}
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
for index, row in market_data.iterrows(): for index, row in market_data.iterrows():
ma_cross = row["ma_cross"] ma_cross = row["ma_cross"]
timestamp = row["timestamp"] timestamp = row["timestamp"]
date_time = row[date_time_field]
close = row["close"] close = row["close"]
ma5 = row["ma5"] ma5 = row["ma5"]
ma10 = row["ma10"] ma10 = row["ma10"]
@ -336,9 +358,7 @@ class MaBreakStatistics:
ma_break_market_data_pair["symbol"] = symbol ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["bar"] = bar ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp ma_break_market_data_pair["begin_timestamp"] = timestamp
ma_break_market_data_pair["begin_date_time"] = ( ma_break_market_data_pair["begin_date_time"] = date_time
timestamp_to_datetime(timestamp)
)
ma_break_market_data_pair["begin_close"] = close ma_break_market_data_pair["begin_close"] = close
ma_break_market_data_pair["begin_ma5"] = ma5 ma_break_market_data_pair["begin_ma5"] = ma5
ma_break_market_data_pair["begin_ma10"] = ma10 ma_break_market_data_pair["begin_ma10"] = ma10
@ -358,9 +378,7 @@ class MaBreakStatistics:
if sell_condition: if sell_condition:
ma_break_market_data_pair["end_timestamp"] = timestamp ma_break_market_data_pair["end_timestamp"] = timestamp
ma_break_market_data_pair["end_date_time"] = ( ma_break_market_data_pair["end_date_time"] = date_time
timestamp_to_datetime(timestamp)
)
ma_break_market_data_pair["end_close"] = close ma_break_market_data_pair["end_close"] = close
ma_break_market_data_pair["end_ma5"] = ma5 ma_break_market_data_pair["end_ma5"] = ma5
ma_break_market_data_pair["end_ma10"] = ma10 ma_break_market_data_pair["end_ma10"] = ma10
@ -412,6 +430,59 @@ class MaBreakStatistics:
else: else:
return None, None return None, None
def get_full_data(self, symbol: str, bar: str = "5m"):
"""
分段获取数据并将数据合并为完整数据
分段依据如果end_date与start_date相差超过一年则每次取一年数据
"""
data = pd.DataFrame()
start_date = datetime.strptime(self.initial_date, "%Y-%m-%d")
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
fields = [
"symbol",
"bar",
"timestamp",
"date_time",
"date_time_us",
"open",
"high",
"low",
"close",
"volume",
"sar_signal",
"ma5",
"ma10",
"ma20",
"ma30",
"ma_cross",
"dif",
"dea",
"macd",
]
while start_date < end_date:
current_end_date = min(start_date + timedelta(days=180), end_date)
start_date_str = start_date.strftime("%Y-%m-%d")
current_end_date_str = current_end_date.strftime("%Y-%m-%d")
logger.info(
f"获取{symbol}数据:{start_date_str}{current_end_date_str}"
)
current_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, fields, start=start_date_str, end=current_end_date_str
)
if current_data is not None and len(current_data) > 0:
current_data = pd.DataFrame(current_data)
data = pd.concat([data, current_data])
start_date = current_end_date
data.drop_duplicates(inplace=True)
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
data.sort_values(by=date_time_field, inplace=True)
data.reset_index(drop=True, inplace=True)
return data
def fit_strategy( def fit_strategy(
self, self,
strategy_name: str = "全均线策略", strategy_name: str = "全均线策略",
@ -734,12 +805,12 @@ class MaBreakStatistics:
# 设置x轴标签 # 设置x轴标签
plt.xticks(symbol_bar_data["end_date_time"].iloc[label_indices], plt.xticks(symbol_bar_data["end_date_time"].iloc[label_indices],
symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%m-%d %H:%M'), symbol_bar_data["end_date_time"].iloc[label_indices].dt.strftime('%Y%m%d %H:%M'),
rotation=45, ha='right') rotation=45, ha='right')
else: else:
# 如果数据点较少,全部显示 # 如果数据点较少,全部显示
plt.xticks(symbol_bar_data["end_date_time"], plt.xticks(symbol_bar_data["end_date_time"],
symbol_bar_data["end_date_time"].dt.strftime('%m-%d %H:%M'), symbol_bar_data["end_date_time"].dt.strftime('%Y%m%d %H:%M'),
rotation=45, ha='right') rotation=45, ha='right')
plt.tight_layout() plt.tight_layout()

View File

@ -28,11 +28,20 @@ def main():
else: else:
start_date = "2024-01-01" start_date = "2024-01-01"
end_date = datetime.now().strftime("%Y-%m-%d") end_date = datetime.now().strftime("%Y-%m-%d")
# 原值 盈利目标倍数默认10倍$R即10R
profit_target_multiple = 10 profit_target_multiple = 10
# 新值 盈利目标倍数默认20倍$R即10R -- 20250909
# profit_target_multiple = 20
initial_capital = 25000 initial_capital = 25000
max_leverage = 4 max_leverage = 4
risk_per_trade = 0.01 risk_per_trade = 0.01
commission_per_share = 0.0005 # if is_us_stock:
# commission_per_share = 0.0005
# else:
# commission_per_share = 0
# commission_per_share = 0
trades_df_list = [] trades_df_list = []
trades_summary_df_list = [] trades_summary_df_list = []
@ -47,6 +56,7 @@ def main():
symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"] "symbols", ["QQQ"]
) )
commission_per_share = 0.0005
else: else:
if is_binance: if is_binance:
symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get( symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
@ -56,6 +66,7 @@ def main():
symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["BTC-USDT"] "symbols", ["BTC-USDT"]
) )
commission_per_share = 0
for symbol in symbols: for symbol in symbols:
logger.info( logger.info(
f"开始回测 {symbol}, 交易周期:{bar}, 开始日期:{start_date}, 结束日期:{end_date}, 是否是美股:{is_us_stock}, 交易方向:{direction}, 是否使用SAR:{by_sar}, 是否使用R为entry减stop:{price_range_mean_as_R}, 是否使用K线实体过50%:{by_big_k}" f"开始回测 {symbol}, 交易周期:{bar}, 开始日期:{start_date}, 结束日期:{end_date}, 是否是美股:{is_us_stock}, 交易方向:{direction}, 是否使用SAR:{by_sar}, 是否使用R为entry减stop:{price_range_mean_as_R}, 是否使用K线实体过50%:{by_big_k}"

View File

@ -25,8 +25,8 @@ from config import (
logger = logging.logger logger = logging.logger
class TradeMaStrategyMain: class TradeMaStrategyMain:
def __init__(self, is_us_stock: bool = False): def __init__(self, is_us_stock: bool = False, is_binance: bool = False):
self.ma_break_statistics = MaBreakStatistics(is_us_stock=is_us_stock) self.ma_break_statistics = MaBreakStatistics(is_us_stock=is_us_stock, is_binance=is_binance)
def batch_ma_break_statistics(self): def batch_ma_break_statistics(self):
""" """
@ -36,6 +36,8 @@ class TradeMaStrategyMain:
strategy_dict = self.ma_break_statistics.main_strategy strategy_dict = self.ma_break_statistics.main_strategy
pct_chg_df_list = [] pct_chg_df_list = []
for strategy_name, strategy_info in strategy_dict.items(): for strategy_name, strategy_info in strategy_dict.items():
if "macd" in strategy_name:
# 只计算macd策略
pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name) pct_chg_df = self.ma_break_statistics.batch_statistics(strategy_name=strategy_name)
pct_chg_df_list.append(pct_chg_df) pct_chg_df_list.append(pct_chg_df)
@ -59,5 +61,5 @@ class TradeMaStrategyMain:
if __name__ == "__main__": if __name__ == "__main__":
trade_ma_strategy_main = TradeMaStrategyMain(is_us_stock=True) trade_ma_strategy_main = TradeMaStrategyMain(is_us_stock=False, is_binance=True)
trade_ma_strategy_main.batch_ma_break_statistics() trade_ma_strategy_main.batch_ma_break_statistics()