diff --git a/config.py b/config.py index 29ed08b..522012e 100644 --- a/config.py +++ b/config.py @@ -134,6 +134,43 @@ US_STOCK_MONITOR_CONFIG = { } } +A_STOCK_MONITOR_CONFIG = { + "volume_monitor": { + "symbols": [ + "600276.SH", + "002714.SZ", + "600111.SH", + "603019.SH", + "600036.SH", + "300474.SZ", + "600519.SH", + "300750.SZ", + "000858.SZ", + "000651.SZ", + "000333.SZ", + "002230.SZ", + "300308.SZ", + "002475.SZ" + ], + "bars": ["1D", "1W", "1M"], + "initial_date": "2015-01-01 00:00:00", + }, +} + +A_INDEX_MONITOR_CONFIG = { + "volume_monitor": { + "symbols": [ + "000001.SH", + "399006.SZ", + "000300.SH", + "399001.SZ", + "000852.SH", + ], + "bars": ["1D", "1W", "1M"], + "initial_date": "2015-01-01 00:00:00", + }, +} + WINDOW_SIZE = {"window_sizes": [50, 80, 100, 120]} BAR_THRESHOLD = { @@ -146,7 +183,7 @@ BAR_THRESHOLD = { "1D": 1000 * 60 * 60 * 24, } -# MYSQL_CONFIG = { +# COIN_MYSQL_CONFIG = { # "host": "localhost", # "port": 3306, # "user": "xch", @@ -154,7 +191,7 @@ BAR_THRESHOLD = { # "database": "okx", # } -MYSQL_CONFIG = { +COIN_MYSQL_CONFIG = { "host": "218.17.89.43", "port": 11013, "user": "xch", @@ -162,6 +199,14 @@ MYSQL_CONFIG = { "database": "okx", } +A_MYSQL_CONFIG = { + "host": "43.139.95.249", + "port": 3306, + "user": "root", + "password": "bengbu_200!", + "database": "astock", +} + WECHAT_CONFIG = { "general_key": "11e6f7ac-efa9-418a-904c-9325a9f5d324", "btc_key": "529e135d-843b-43dc-8aca-677a860f4b4b", diff --git a/core/db/__pycache__/db_market_data.cpython-312.pyc b/core/db/__pycache__/db_market_data.cpython-312.pyc index 8d27039..41805d5 100644 Binary files a/core/db/__pycache__/db_market_data.cpython-312.pyc and b/core/db/__pycache__/db_market_data.cpython-312.pyc differ diff --git a/core/db/db_astock.py b/core/db/db_astock.py new file mode 100644 index 0000000..26eed5a --- /dev/null +++ b/core/db/db_astock.py @@ -0,0 +1,125 @@ +import pandas as pd +from sqlalchemy import create_engine, exc, text +import re +from core.utils import get_current_date_time +import core.logger as logging +from core.utils import transform_data_type + +logger = logging.logger + + +class DBAStockData: + def __init__( + self, + db_url: str, + ): + self.db_url = db_url + self.db_engine = create_engine( + self.db_url, + pool_size=25, # 连接池大小 + max_overflow=10, # 允许的最大溢出连接 + pool_timeout=30, # 连接超时时间(秒) + pool_recycle=60, # 连接回收时间(秒),避免长时间闲置 + ) + + def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True): + """ + 查询数据 + :param sql: 查询SQL + :param db_url: 数据库连接URL + """ + try: + with self.db_engine.connect() as conn: + result = conn.execute(text(sql), condition_dict) + if return_multi: + result = result.fetchall() + if result: + result_list = [ + transform_data_type(dict(row._mapping)) for row in result + ] + return result_list + else: + return None + else: + result = result.fetchone() + if result: + result_dict = transform_data_type(dict(result._mapping)) + return result_dict + else: + return None + except Exception as e: + logger.error(f"查询数据出错: {e}") + return None + + def query_market_data_by_symbol_bar( + self, + symbol: str, + bar: str, + fields: list = None, + start: str = None, + end: str = None, + table_name: str = "index_daily_price_from_2021", + ): + """ + 根据交易对和K线周期查询数据 + :param symbol: 交易对 + :param bar: K线周期 + :param fields: 字段列表 + :param start: 开始时间 + :param end: 结束时间 + """ + if fields is None: + fields = ["*"] + fields_str = ", ".join(fields) + if table_name is None: + table_name = "index_daily_price_from_2021" + join_table = "all_index" + if table_name.startswith("index"): + join_table = "all_index" + else: + join_table = "all_stock" + + if start is None and end is None: + sql = f""" + SELECT {fields_str} FROM {table_name} a + INNER JOIN {join_table} b ON a.ts_code = b.ts_code + WHERE a.ts_code = :symbol + ORDER BY a.trade_date ASC + """ + condition_dict = {"symbol": symbol} + else: + if start is not None and end is not None: + start = start.replace("-", "") + end = end.replace("-", "") + if start > end: + start, end = end, start + sql = f""" + SELECT {fields_str} FROM {table_name} a + INNER JOIN {join_table} b ON a.ts_code = b.ts_code + WHERE a.ts_code = :symbol AND a.trade_date BETWEEN :start AND :end + ORDER BY a.trade_date ASC + """ + condition_dict = { + "symbol": symbol, + "start": start, + "end": end, + } + elif start is not None: + start = start.replace("-", "") + sql = f""" + SELECT {fields_str} FROM {table_name} a + INNER JOIN {join_table} b ON a.ts_code = b.ts_code + WHERE a.ts_code = :symbol AND a.trade_date >= :start + ORDER BY a.trade_date ASC + """ + condition_dict = {"symbol": symbol, "start": start} + elif end is not None: + end = end.replace("-", "") + sql = f""" + SELECT {fields_str} FROM {table_name} a + INNER JOIN {join_table} b ON a.ts_code = b.ts_code + WHERE a.ts_code = :symbol AND a.trade_date <= :end + ORDER BY a.trade_date ASC + """ + condition_dict = {"symbol": symbol, "end": end} + return self.query_data(sql, condition_dict, return_multi=True) diff --git a/core/db/db_binance_data.py b/core/db/db_binance_data.py index cf967c2..daf66b9 100644 --- a/core/db/db_binance_data.py +++ b/core/db/db_binance_data.py @@ -480,6 +480,7 @@ class DBBinanceData: fields: list = None, start: str = None, end: str = None, + table_name: str = "crypto_binance_data", ): """ 根据交易对和K线周期查询数据 @@ -494,7 +495,7 @@ class DBBinanceData: fields_str = ", ".join(fields) if start is None and end is None: sql = f""" - SELECT {fields_str} FROM crypto_binance_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar ORDER BY timestamp ASC """ @@ -514,7 +515,7 @@ class DBBinanceData: if start > end: start, end = end, start sql = f""" - SELECT {fields_str} FROM crypto_binance_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar AND timestamp BETWEEN :start AND :end ORDER BY timestamp ASC """ @@ -526,14 +527,14 @@ class DBBinanceData: } elif start is not None: sql = f""" - SELECT {fields_str} FROM crypto_binance_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar AND timestamp >= :start ORDER BY timestamp ASC """ condition_dict = {"symbol": symbol, "bar": bar, "start": start} elif end is not None: sql = f""" - SELECT {fields_str} FROM crypto_binance_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar AND timestamp <= :end ORDER BY timestamp ASC """ diff --git a/core/db/db_market_data.py b/core/db/db_market_data.py index adff462..646f62e 100644 --- a/core/db/db_market_data.py +++ b/core/db/db_market_data.py @@ -7,68 +7,65 @@ logger = logging.logger class DBMarketData: - def __init__( - self, - db_url: str - ): + def __init__(self, db_url: str): self.db_url = db_url self.table_name = "crypto_market_data" self.columns = [ - "symbol", - "bar", - "timestamp", - "date_time", - "date_time_us", - "open", - "high", - "low", - "close", - "pre_close", - "close_change", - "pct_chg", - "volume", - "volCcy", - "volCCyQuote", - "buy_sz", - "sell_sz", - # 技术指标字段 - "ma1", - "ma2", - "dif", - "dea", - "macd", - "macd_signal", - "macd_divergence", - "kdj_k", - "kdj_d", - "kdj_j", - "kdj_signal", - "kdj_pattern", - "sar", - "sar_signal", - "ma5", - "ma10", - "ma20", - "ma30", - "ma_cross", - "ma5_close_diff", - "ma10_close_diff", - "ma20_close_diff", - "ma30_close_diff", - "ma_close_avg", - "ma_long_short", - "ma_divergence", - "rsi_14", - "rsi_signal", - "boll_upper", - "boll_middle", - "boll_lower", - "boll_signal", - "boll_pattern", - "k_length", - "k_shape", - "k_up_down", - "create_time", + "symbol", + "bar", + "timestamp", + "date_time", + "date_time_us", + "open", + "high", + "low", + "close", + "pre_close", + "close_change", + "pct_chg", + "volume", + "volCcy", + "volCCyQuote", + "buy_sz", + "sell_sz", + # 技术指标字段 + "ma1", + "ma2", + "dif", + "dea", + "macd", + "macd_signal", + "macd_divergence", + "kdj_k", + "kdj_d", + "kdj_j", + "kdj_signal", + "kdj_pattern", + "sar", + "sar_signal", + "ma5", + "ma10", + "ma20", + "ma30", + "ma_cross", + "ma5_close_diff", + "ma10_close_diff", + "ma20_close_diff", + "ma30_close_diff", + "ma_close_avg", + "ma_long_short", + "ma_divergence", + "rsi_14", + "rsi_signal", + "boll_upper", + "boll_middle", + "boll_lower", + "boll_signal", + "boll_pattern", + "k_length", + "k_shape", + "k_up_down", + "create_time", ] self.db_manager = DBData(db_url, self.table_name, self.columns) @@ -85,7 +82,7 @@ class DBMarketData: return self.db_manager.insert_data_to_mysql(df) - + def insert_data_to_mysql_fast(self, df: pd.DataFrame): """ 快速插入K线行情数据(方案2:使用executemany批量插入) @@ -97,9 +94,9 @@ class DBMarketData: if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return - + self.db_manager.insert_data_to_mysql_fast(df) - + def insert_data_to_mysql_chunk(self, df: pd.DataFrame, chunk_size: int = 1000): """ 分块插入K线行情数据(方案3:适合大数据量) @@ -112,9 +109,9 @@ class DBMarketData: if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return - + self.db_manager.insert_data_to_mysql_chunk(df, chunk_size) - + def insert_data_to_mysql_simple(self, df: pd.DataFrame): """ 简单插入K线行情数据(方案4:直接使用to_sql,忽略重复) @@ -125,9 +122,9 @@ class DBMarketData: if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return - + self.db_manager.insert_data_to_mysql_simple(df) - + def query_latest_data(self, symbol: str, bar: str): """ 查询最新数据 @@ -142,8 +139,10 @@ class DBMarketData: """ condition_dict = {"symbol": symbol, "bar": bar} return self.db_manager.query_data(sql, condition_dict, return_multi=False) - - def query_data_before_timestamp(self, symbol: str, bar: str, timestamp: int, limit: int = 100): + + def query_data_before_timestamp( + self, symbol: str, bar: str, timestamp: int, limit: int = 100 + ): """ 根据时间戳查询之前的数据 :param symbol: 交易对 @@ -157,20 +156,25 @@ class DBMarketData: ORDER BY timestamp DESC LIMIT :limit """ - condition_dict = {"symbol": symbol, "bar": bar, "timestamp": timestamp, "limit": limit} + condition_dict = { + "symbol": symbol, + "bar": bar, + "timestamp": timestamp, + "limit": limit, + } return self.db_manager.query_data(sql, condition_dict, return_multi=True) - + def query_data_by_technical_indicators( - self, - symbol: str, - bar: str, - start: str = None, + self, + symbol: str, + bar: str, + start: str = None, end: str = None, macd_signal: str = None, kdj_signal: str = None, rsi_signal: str = None, boll_signal: str = None, - ma_cross: str = None + ma_cross: str = None, ): """ 根据技术指标查询数据 @@ -186,7 +190,7 @@ class DBMarketData: """ conditions = ["symbol = :symbol", "bar = :bar"] condition_dict = {"symbol": symbol, "bar": bar} - + if macd_signal: conditions.append("macd_signal = :macd_signal") condition_dict["macd_signal"] = macd_signal @@ -202,7 +206,7 @@ class DBMarketData: if ma_cross: conditions.append("ma_cross = :ma_cross") condition_dict["ma_cross"] = ma_cross - + # 处理时间范围 if start: start_timestamp = transform_date_time_to_timestamp(start) @@ -214,23 +218,23 @@ class DBMarketData: if end_timestamp: conditions.append("timestamp <= :end") condition_dict["end"] = end_timestamp - + where_clause = " AND ".join(conditions) sql = f""" SELECT * FROM crypto_market_data WHERE {where_clause} ORDER BY timestamp DESC """ - + return self.db_manager.query_data(sql, condition_dict, return_multi=True) - + def query_macd_signals( - self, - symbol: str, - bar: str, + self, + symbol: str, + bar: str, signal: str = None, - start: str = None, - end: str = None + start: str = None, + end: str = None, ): """ 查询MACD信号数据 @@ -242,11 +246,11 @@ class DBMarketData: """ conditions = ["symbol = :symbol", "bar = :bar"] condition_dict = {"symbol": symbol, "bar": bar} - + if signal: conditions.append("macd_signal = :signal") condition_dict["signal"] = signal - + # 处理时间范围 if start: start_timestamp = transform_date_time_to_timestamp(start) @@ -258,24 +262,24 @@ class DBMarketData: if end_timestamp: conditions.append("timestamp <= :end") condition_dict["end"] = end_timestamp - + where_clause = " AND ".join(conditions) sql = f""" SELECT * FROM crypto_market_data WHERE {where_clause} ORDER BY timestamp DESC """ - + return self.db_manager.query_data(sql, condition_dict, return_multi=True) - + def query_kdj_signals( - self, - symbol: str, - bar: str, + self, + symbol: str, + bar: str, signal: str = None, pattern: str = None, - start: str = None, - end: str = None + start: str = None, + end: str = None, ): """ 查询KDJ信号数据 @@ -288,14 +292,14 @@ class DBMarketData: """ conditions = ["symbol = :symbol", "bar = :bar"] condition_dict = {"symbol": symbol, "bar": bar} - + if signal: conditions.append("kdj_signal = :signal") condition_dict["signal"] = signal if pattern: conditions.append("kdj_pattern = :pattern") condition_dict["pattern"] = pattern - + # 处理时间范围 if start: start_timestamp = transform_date_time_to_timestamp(start) @@ -307,25 +311,25 @@ class DBMarketData: if end_timestamp: conditions.append("timestamp <= :end") condition_dict["end"] = end_timestamp - + where_clause = " AND ".join(conditions) sql = f""" SELECT * FROM crypto_market_data WHERE {where_clause} ORDER BY timestamp DESC """ - + return self.db_manager.query_data(sql, condition_dict, return_multi=True) - + def query_ma_signals( - self, - symbol: str, - bar: str, + self, + symbol: str, + bar: str, cross: str = None, long_short: str = None, divergence: str = None, - start: str = None, - end: str = None + start: str = None, + end: str = None, ): """ 查询均线信号数据 @@ -339,7 +343,7 @@ class DBMarketData: """ conditions = ["symbol = :symbol", "bar = :bar"] condition_dict = {"symbol": symbol, "bar": bar} - + if cross: conditions.append("ma_cross = :cross") condition_dict["cross"] = cross @@ -349,7 +353,7 @@ class DBMarketData: if divergence: conditions.append("ma_divergence = :divergence") condition_dict["divergence"] = divergence - + # 处理时间范围 if start: start_timestamp = transform_date_time_to_timestamp(start) @@ -361,24 +365,24 @@ class DBMarketData: if end_timestamp: conditions.append("timestamp <= :end") condition_dict["end"] = end_timestamp - + where_clause = " AND ".join(conditions) sql = f""" SELECT * FROM crypto_market_data WHERE {where_clause} ORDER BY timestamp DESC """ - + return self.db_manager.query_data(sql, condition_dict, return_multi=True) - + def query_bollinger_signals( - self, - symbol: str, - bar: str, + self, + symbol: str, + bar: str, signal: str = None, pattern: str = None, - start: str = None, - end: str = None + start: str = None, + end: str = None, ): """ 查询布林带信号数据 @@ -391,14 +395,14 @@ class DBMarketData: """ conditions = ["symbol = :symbol", "bar = :bar"] condition_dict = {"symbol": symbol, "bar": bar} - + if signal: conditions.append("boll_signal = :signal") condition_dict["signal"] = signal if pattern: conditions.append("boll_pattern = :pattern") condition_dict["pattern"] = pattern - + # 处理时间范围 if start: start_timestamp = transform_date_time_to_timestamp(start) @@ -410,22 +414,18 @@ class DBMarketData: if end_timestamp: conditions.append("timestamp <= :end") condition_dict["end"] = end_timestamp - + where_clause = " AND ".join(conditions) sql = f""" SELECT * FROM crypto_market_data WHERE {where_clause} ORDER BY timestamp DESC """ - + return self.db_manager.query_data(sql, condition_dict, return_multi=True) - + def get_technical_statistics( - self, - symbol: str, - bar: str, - start: str = None, - end: str = None + self, symbol: str, bar: str, start: str = None, end: str = None ): """ 获取技术指标统计信息 @@ -436,7 +436,7 @@ class DBMarketData: """ conditions = ["symbol = :symbol", "bar = :bar"] condition_dict = {"symbol": symbol, "bar": bar} - + # 处理时间范围 if start: start_timestamp = transform_date_time_to_timestamp(start) @@ -448,7 +448,7 @@ class DBMarketData: if end_timestamp: conditions.append("timestamp <= :end") condition_dict["end"] = end_timestamp - + where_clause = " AND ".join(conditions) sql = f""" SELECT @@ -470,10 +470,18 @@ class DBMarketData: FROM crypto_market_data WHERE {where_clause} """ - + return self.db_manager.query_data(sql, condition_dict, return_multi=False) - - def query_market_data_by_symbol_bar(self, symbol: str, bar: str, fields: list = None, start: str = None, end: str = None): + + def query_market_data_by_symbol_bar( + self, + symbol: str, + bar: str, + fields: list = None, + start: str = None, + end: str = None, + table_name: str = "crypto_market_data", + ): """ 根据交易对和K线周期查询数据 :param symbol: 交易对 @@ -487,7 +495,7 @@ class DBMarketData: fields_str = ", ".join(fields) if start is None and end is None: sql = f""" - SELECT {fields_str} FROM crypto_market_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar ORDER BY timestamp ASC """ @@ -507,23 +515,28 @@ class DBMarketData: if start > end: start, end = end, start sql = f""" - SELECT {fields_str} FROM crypto_market_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar AND timestamp BETWEEN :start AND :end ORDER BY timestamp ASC """ - condition_dict = {"symbol": symbol, "bar": bar, "start": start, "end": end} + condition_dict = { + "symbol": symbol, + "bar": bar, + "start": start, + "end": end, + } elif start is not None: sql = f""" - SELECT {fields_str} FROM crypto_market_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar AND timestamp >= :start ORDER BY timestamp ASC """ condition_dict = {"symbol": symbol, "bar": bar, "start": start} elif end is not None: sql = f""" - SELECT {fields_str} FROM crypto_market_data + SELECT {fields_str} FROM {table_name} WHERE symbol = :symbol AND bar = :bar AND timestamp <= :end ORDER BY timestamp ASC """ condition_dict = {"symbol": symbol, "bar": bar, "end": end} - return self.db_manager.query_data(sql, condition_dict, return_multi=True) \ No newline at end of file + return self.db_manager.query_data(sql, condition_dict, return_multi=True) diff --git a/core/statistics/price_volume_stats.py b/core/statistics/price_volume_stats.py index 9d6d5fa..c9a0aea 100644 --- a/core/statistics/price_volume_stats.py +++ b/core/statistics/price_volume_stats.py @@ -12,7 +12,7 @@ from openpyxl.drawing.image import Image import openpyxl from openpyxl.styles import Font from PIL import Image as PILImage -from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE +from config import OKX_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WINDOW_SIZE from core.db.db_market_data import DBMarketData from core.db.db_huge_volume_data import DBHugeVolumeData from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp @@ -25,13 +25,13 @@ logger = logging.logger class PriceVolumeStats: def __init__(self): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_market_data = DBMarketData(self.db_url) diff --git a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc index 4501d00..3c88a56 100644 Binary files a/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc and b/core/trade/__pycache__/ma_break_statistics.cpython-312.pyc differ diff --git a/core/trade/ma_break_statistics.py b/core/trade/ma_break_statistics.py index accba00..140b733 100644 --- a/core/trade/ma_break_statistics.py +++ b/core/trade/ma_break_statistics.py @@ -17,11 +17,16 @@ from PIL import Image as PILImage from config import ( OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, - MYSQL_CONFIG, + COIN_MYSQL_CONFIG, + A_MYSQL_CONFIG, WINDOW_SIZE, BINANCE_MONITOR_CONFIG, + A_STOCK_MONITOR_CONFIG, + A_INDEX_MONITOR_CONFIG, ) +from core.biz.metrics_calculation import MetricsCalculation from core.db.db_market_data import DBMarketData +from core.db.db_astock import DBAStockData from core.db.db_binance_data import DBBinanceData from core.db.db_huge_volume_data import DBHugeVolumeData from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp @@ -44,21 +49,40 @@ class MaBreakStatistics: def __init__( self, is_us_stock: bool = False, - is_binance: bool = False, + is_astock: bool = False, + is_aindex: bool = False, + is_binance: bool = True, + buy_by_long_period: dict = {"by_week": False, "by_month": False}, + long_period_condition: dict = {"ma5>ma10": True, "ma10>ma20": False, "macd_diff>0": True, "macd>0": True}, commission_per_share: float = 0.0008, ): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") - if not mysql_password: - raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + if is_astock or is_aindex: + mysql_user = A_MYSQL_CONFIG.get("user", "root") + mysql_password = A_MYSQL_CONFIG.get("password", "") + if not mysql_password: + raise ValueError("MySQL password is not set") + mysql_host = A_MYSQL_CONFIG.get("host", "localhost") + mysql_port = A_MYSQL_CONFIG.get("port", 3306) + mysql_database = A_MYSQL_CONFIG.get("database", "astock") + else: + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") + if not mysql_password: + raise ValueError("MySQL password is not set") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_huge_volume_data = DBHugeVolumeData(self.db_url) self.is_us_stock = is_us_stock + self.is_astock = is_astock + self.is_aindex = is_aindex + if self.is_us_stock: + self.date_time_field = "date_time_us" + else: + self.date_time_field = "date_time" self.is_binance = is_binance if is_us_stock: self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( @@ -71,6 +95,28 @@ class MaBreakStatistics: "initial_date", "2014-11-30 00:00:00" ) self.db_market_data = DBMarketData(self.db_url) + elif is_astock: + self.symbols = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( + "symbols", ["000001.SH"] + ) + self.bars = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( + "bars", ["5m"] + ) + self.initial_date = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( + "initial_date", "2014-11-30 00:00:00" + ) + self.db_market_data = DBAStockData(self.db_url) + elif is_aindex: + self.symbols = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get( + "symbols", ["000001.SH"] + ) + self.bars = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get( + "bars", ["5m"] + ) + self.initial_date = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get( + "initial_date", "2014-11-30 00:00:00" + ) + self.db_market_data = DBAStockData(self.db_url) else: if is_binance: self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get( @@ -98,12 +144,36 @@ class MaBreakStatistics: self.commission_per_share = commission_per_share self.trade_strategy_config = self.get_trade_strategy_config() self.main_strategy = self.trade_strategy_config.get("均线系统策略", None) + self.buy_by_long_period = buy_by_long_period + self.long_period_condition = long_period_condition def get_trade_strategy_config(self): with open("./json/trade_strategy.json", "r", encoding="utf-8") as f: trade_strategy_config = json.load(f) return trade_strategy_config + def get_by_long_period_desc(self): + by_week = self.buy_by_long_period.get("by_week", False) + by_month = self.buy_by_long_period.get("by_month", False) + by_long_period = "" + if by_week: + by_long_period += "1W" + if by_month: + by_long_period += "1M" + if by_long_period == "": + return "no_long_period_judge" + by_condition = "" + if self.long_period_condition.get("ma5>ma10", False): + by_condition += "ma5gtma10" + if self.long_period_condition.get("ma10>ma20", False): + by_condition += "_ma10gtma20" + if self.long_period_condition.get("macd_diff>0", False): + by_condition += "_macd_diffgt0" + if self.long_period_condition.get("macd>0", False): + by_condition += "_macdgt0" + return by_long_period + "_" + by_condition + + def batch_statistics(self, strategy_name: str = "全均线策略"): if self.is_us_stock: self.stats_output_dir = ( @@ -119,6 +189,38 @@ class MaBreakStatistics: self.stats_chart_dir = ( f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/" ) + elif self.is_astock: + long_period_desc = self.get_by_long_period_desc() + if len(long_period_desc) > 0: + self.stats_output_dir = ( + f"./output/trade_sandbox/ma_strategy/astock/{long_period_desc}/excel/{strategy_name}/" + ) + self.stats_chart_dir = ( + f"./output/trade_sandbox/ma_strategy/astock/{long_period_desc}/chart/{strategy_name}/" + ) + else: + self.stats_output_dir = ( + f"./output/trade_sandbox/ma_strategy/astock/excel/{strategy_name}/" + ) + self.stats_chart_dir = ( + f"./output/trade_sandbox/ma_strategy/astock/chart/{strategy_name}/" + ) + elif self.is_aindex: + long_period_desc = self.get_by_long_period_desc() + if len(long_period_desc) > 0: + self.stats_output_dir = ( + f"./output/trade_sandbox/ma_strategy/aindex/{long_period_desc}/excel/{strategy_name}/" + ) + self.stats_chart_dir = ( + f"./output/trade_sandbox/ma_strategy/aindex/{long_period_desc}/chart/{strategy_name}/" + ) + else: + self.stats_output_dir = ( + f"./output/trade_sandbox/ma_strategy/aindex/excel/{strategy_name}/" + ) + self.stats_chart_dir = ( + f"./output/trade_sandbox/ma_strategy/aindex/chart/{strategy_name}/" + ) else: self.stats_output_dir = ( f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/" @@ -192,9 +294,11 @@ class MaBreakStatistics: total_commission = round(total_commission, 4) total_buy_commission = round(total_buy_commission, 4) total_sell_commission = round(total_sell_commission, 4) + symbol_name = str(symbol_bar_data["symbol_name"].iloc[0]) account_value_chg_list.append({ "strategy_name": strategy_name, "symbol": symbol, + "symbol_name": symbol_name, "bar": bar, "total_buy_commission": total_buy_commission, "total_sell_commission": total_sell_commission, @@ -209,6 +313,7 @@ class MaBreakStatistics: [ "strategy_name", "symbol", + "symbol_name", "bar", "total_buy_commission", "total_sell_commission", @@ -221,7 +326,7 @@ class MaBreakStatistics: ] account_value_statistics_df = ( - ma_break_market_data.groupby(["symbol", "bar"])["end_account_value"] + ma_break_market_data.groupby(["symbol", "symbol_name", "bar"])["end_account_value"] .agg( account_value_max="max", account_value_min="min", @@ -237,6 +342,7 @@ class MaBreakStatistics: [ "strategy_name", "symbol", + "symbol_name", "bar", "account_value_max", "account_value_min", @@ -249,7 +355,7 @@ class MaBreakStatistics: # 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count interval_minutes_df = ( - ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"] + ma_break_market_data.groupby(["symbol", "symbol_name", "bar"])["interval_minutes"] .agg( interval_minutes_max="max", interval_minutes_min="min", @@ -265,6 +371,7 @@ class MaBreakStatistics: [ "strategy_name", "symbol", + "symbol_name", "bar", "interval_minutes_max", "interval_minutes_min", @@ -335,6 +442,23 @@ class MaBreakStatistics: strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text else: strategy_info["买入策略"] = buy_and_text + + # 假如根据长周期判断买入,则需要设置长周期策略 + by_week = self.buy_by_long_period.get("by_week", False) + by_month = self.buy_by_long_period.get("by_month", False) + if by_week: + strategy_info["买入策略"] += "根据周线指标,\n" + if by_month: + strategy_info["买入策略"] += "根据月线指标,\n" + if self.long_period_condition.get("ma5>ma10", False): + strategy_info["买入策略"] += "ma5>ma10, \n" + if self.long_period_condition.get("ma10>ma20", False): + strategy_info["买入策略"] += "ma10>ma20, \n" + if self.long_period_condition.get("macd_diff>0", False): + strategy_info["买入策略"] += "macd_diff>0, \n" + if self.long_period_condition.get("macd>0", False): + strategy_info["买入策略"] += "macd>0, \n" + strategy_info["买入策略"] = strategy_info["买入策略"].strip() sell_dict = strategy_config.get("sell", {}) sell_and_list = sell_dict.get("and", []) sell_or_list = sell_dict.get("or", []) @@ -349,6 +473,7 @@ class MaBreakStatistics: strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text else: strategy_info["卖出策略"] = sell_and_text + strategy_info["卖出策略"] = strategy_info["卖出策略"].strip() # 将strategy_info转换为pd.DataFrame strategy_info_df = pd.DataFrame([strategy_info]) return strategy_info_df @@ -384,22 +509,24 @@ class MaBreakStatistics: market_data.reset_index(drop=True, inplace=True) ma_break_market_data_pair_list = [] ma_break_market_data_pair = {} - if self.is_us_stock: - date_time_field = "date_time_us" - else: - date_time_field = "date_time" close_mean = market_data["close"].mean() self.update_initial_capital(close_mean) logger.info( - f"成功获取{symbol}数据:{len(market_data)}根{bar}K线,开始日期={market_data[date_time_field].min()},结束日期={market_data[date_time_field].max()}" + f"成功获取{symbol}数据:{len(market_data)}根{bar}K线,开始日期={market_data[self.date_time_field].min()},结束日期={market_data[self.date_time_field].max()}" ) account_value = self.initial_capital for index, row in market_data.iterrows(): + if self.is_astock: + symbol_name = row["symbol_name"] + elif self.is_aindex: + symbol_name = row["symbol_name"] + else: + symbol_name = row["symbol"] ma_cross = row["ma_cross"] timestamp = row["timestamp"] - date_time = row[date_time_field] + date_time = row[self.date_time_field] close = row["close"] ma5 = row["ma5"] ma10 = row["ma10"] @@ -411,7 +538,6 @@ class MaBreakStatistics: if ma_break_market_data_pair.get("begin_timestamp", None) is None: buy_condition = self.fit_strategy( strategy_name=strategy_name, - market_data=market_data, row=row, behavior="buy", ) @@ -431,6 +557,7 @@ class MaBreakStatistics: ma_break_market_data_pair = {} ma_break_market_data_pair["symbol"] = symbol + ma_break_market_data_pair["symbol_name"] = symbol_name ma_break_market_data_pair["bar"] = bar ma_break_market_data_pair["begin_timestamp"] = timestamp ma_break_market_data_pair["begin_date_time"] = date_time @@ -449,12 +576,12 @@ class MaBreakStatistics: else: sell_condition = self.fit_strategy( strategy_name=strategy_name, - market_data=market_data, row=row, behavior="sell", ) - if sell_condition: + if sell_condition or index == len(market_data) - 1: + # 达到卖出条件或者最后一条数据,则卖出 shares = ma_break_market_data_pair["shares"] entry_price = ma_break_market_data_pair["begin_close"] exit_price = close @@ -525,8 +652,10 @@ class MaBreakStatistics: * 100 ) pct_chg = round(pct_chg, 4) + symbol_name = ma_break_market_data["symbol_name"].iloc[0] market_data_pct_chg = { "symbol": symbol, + "symbol_name": symbol_name, "bar": bar, "pct_chg": pct_chg, "initial_capital": self.initial_capital, @@ -604,52 +733,191 @@ class MaBreakStatistics: data = pd.DataFrame() start_date = datetime.strptime(self.initial_date, "%Y-%m-%d") end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1) - fields = [ - "symbol", - "bar", - "timestamp", - "date_time", - "date_time_us", - "open", - "high", - "low", - "close", - "volume", - "sar_signal", - "ma5", - "ma10", - "ma20", - "ma30", - "ma_cross", - "dif", - "dea", - "macd", - ] + table_name = "" + if self.is_astock: + if bar == "1D": + table_name = "stock_daily_price_from_2021" + elif bar == "1W": + table_name = "stock_weekly_price_from_2020" + elif bar == "1M": + table_name = "stock_monthly_price_from_2015" + elif self.is_aindex: + if bar == "1D": + table_name = "index_daily_price_from_2021" + elif bar == "1W": + table_name = "index_weekly_price_from_2020" + elif bar == "1M": + table_name = "index_monthly_price_from_2015" + elif self.is_us_stock: + table_name = "crypto_market_data" + elif self.is_binance: + table_name = "crypto_binance_data" + else: + table_name = "crypto_binance_data" + + if self.is_astock or self.is_aindex: + fields = [ + "a.ts_code as symbol", + "b.name as symbol_name", + f"'{bar}' as bar", + "0 as timestamp", + "trade_date as date_time", + "open", + "high", + "low", + "close", + "vol as volume", + "MA5 as ma5", + "MA10 as ma10", + "MA20 as ma20", + "MA30 as ma30", + "均线交叉 as ma_cross", + "DIF as dif", + "DEA as dea", + "MACD as macd", + ] + else: + fields = [ + "symbol", + "bar", + "timestamp", + "date_time", + "date_time_us", + "open", + "high", + "low", + "close", + "volume", + "sar_signal", + "ma5", + "ma10", + "ma20", + "ma30", + "ma_cross", + "dif", + "dea", + "macd", + ] while start_date < end_date: current_end_date = min(start_date + timedelta(days=180), end_date) start_date_str = start_date.strftime("%Y-%m-%d") current_end_date_str = current_end_date.strftime("%Y-%m-%d") logger.info(f"获取{symbol}数据:{start_date_str}至{current_end_date_str}") current_data = self.db_market_data.query_market_data_by_symbol_bar( - symbol, bar, fields, start=start_date_str, end=current_end_date_str + symbol, bar, fields, start=start_date_str, end=current_end_date_str, table_name=table_name ) if current_data is not None and len(current_data) > 0: current_data = pd.DataFrame(current_data) data = pd.concat([data, current_data]) start_date = current_end_date data.drop_duplicates(inplace=True) - if self.is_us_stock: - date_time_field = "date_time_us" - else: - date_time_field = "date_time" - data.sort_values(by=date_time_field, inplace=True) + data.sort_values(by=self.date_time_field, inplace=True) data.reset_index(drop=True, inplace=True) + if self.is_astock or self.is_aindex: + data = self.update_data(data) return data + + def get_long_period_data(self, symbol: str, bar: str, end_date: str): + """ + 获取长周期数据 + :param data: 数据 + :return: 长周期数据 + """ + if not (self.is_astock or self.is_aindex): + return None + table_name = "" + if self.is_astock: + if bar == "1M": + table_name = "stock_monthly_price_from_2015" + elif bar == "1W": + table_name = "stock_weekly_price_from_2020" + else: + pass + elif self.is_aindex: + if bar == "1M": + table_name = "index_monthly_price_from_2015" + elif bar == "1W": + table_name = "index_weekly_price_from_2020" + else: + pass + if len(end_date) != 10: + end_date = self.change_date_format(end_date) + if bar == "1M": + # 获取上两个月的日期 + last_date = datetime.strptime(end_date, "%Y-%m-%d") - timedelta(days=60) + last_date = last_date.strftime("%Y-%m-%d") + elif bar == "1W": + # 获取上两周的日期 + last_date = datetime.strptime(end_date, "%Y-%m-%d") - timedelta(days=14) + last_date = last_date.strftime("%Y-%m-%d") + else: + last_date = None + + if len(table_name) == 0 or last_date is None: + return None + fields = [ + "a.ts_code as symbol", + "b.name as symbol_name", + f"'{bar}' as bar", + "0 as timestamp", + "trade_date as date_time", + "open", + "high", + "low", + "close", + "vol as volume", + "MA5 as ma5", + "MA10 as ma10", + "MA20 as ma20", + "MA30 as ma30", + "均线交叉 as ma_cross", + "DIF as dif", + "DEA as dea", + "MACD as macd", + ] + data = self.db_market_data.query_market_data_by_symbol_bar( + symbol, bar, fields, start=last_date, end=end_date, table_name=table_name + ) + if data is not None and len(data) > 0: + data = pd.DataFrame(data) + data.sort_values(by="date_time", inplace=True) + latest_row = data.iloc[-1] + if (latest_row["ma5"] is None or + latest_row["ma10"] is None or + latest_row["ma20"] is None or + latest_row["dif"] is None or + latest_row["macd"] is None): + return None + return latest_row + else: + return None + + + def update_data(self, data: pd.DataFrame): + """ + 更新数据 + 1. 将date_time列中的20210104这种格式,替换为2021-01-04的格式 + 2. 将date_time转换为timestamp,并更新timestamp列 + 3. 通过MetricsCalculation的ma5102030方法更新ma_cross列 + :param data: 数据 + :return: 更新后的数据 + """ + data["date_time"] = data["date_time"].apply(lambda x: self.change_date_format(x)) + data["timestamp"] = data["date_time"].apply(lambda x: transform_date_time_to_timestamp(x)) + metrics_calculation = MetricsCalculation() + data = metrics_calculation.ma5102030(data) + return data + + def change_date_format(self, date_text: str): + # 将20210104这种格式,替换为2021-01-04的格式 + if len(date_text) == 8: + return date_text[0:4] + "-" + date_text[4:6] + "-" + date_text[6:8] + else: + return date_text def fit_strategy( self, strategy_name: str = "全均线策略", - market_data: pd.DataFrame = None, row: pd.Series = None, behavior: str = "buy", ): @@ -661,6 +929,45 @@ class MaBreakStatistics: if condition_dict is None: logger.error(f"策略{strategy_name}的{behavior}条件不存在") return False + + and_list = condition_dict.get("and", []) + + condition = True + condition = self.get_judge_result(row, and_list, "and", condition) + or_list = condition_dict.get("or", []) + condition = self.get_judge_result(row, or_list, "or", condition) + + if behavior == "buy" and condition: + # 如果满足条件,则判断是否根据长周期指标买入 + bar = row["bar"] + if (self.is_astock or self.is_aindex) and bar == "1D": + date_time = row["date_time"] + long_period_condition_list = [] + if self.long_period_condition.get("ma5>ma10", False): + long_period_condition_list.append("ma5>ma10") + if self.long_period_condition.get("ma10>ma20", False): + long_period_condition_list.append("ma10>ma20") + if self.long_period_condition.get("macd_diff>0", False): + long_period_condition_list.append("macd_diff>0") + if self.long_period_condition.get("macd>0", False): + long_period_condition_list.append("macd>0") + if len(long_period_condition_list) > 0: + if self.buy_by_long_period.get("by_week", False): + long_period_data = self.get_long_period_data(row["symbol"], "1W", date_time) + if long_period_data is not None: + condition = self.get_judge_result(long_period_data, long_period_condition_list, "and", condition) + if not condition: + logger.info(f"根据周线指标,{row['symbol']}不满足买入条件") + if self.buy_by_long_period.get("by_month", False): + long_period_data = self.get_long_period_data(row["symbol"], "1M", date_time) + if long_period_data is not None: + condition = self.get_judge_result(long_period_data, long_period_condition_list, "and", condition) + if not condition: + logger.info(f"根据月线指标,{row['symbol']}不满足买入条件") + + return condition + + def get_judge_result(self, row: pd.Series, condition_list: list, and_or: str = "and", raw_condition: bool = True): ma_cross = row["ma_cross"] if pd.isna(ma_cross) or ma_cross is None: ma_cross = "" @@ -670,107 +977,107 @@ class MaBreakStatistics: ma20 = float(row["ma20"]) ma30 = float(row["ma30"]) close = float(row["close"]) - volume_pct_chg = float(row["volume_pct_chg"]) + if "volume_pct_chg" in list(row.index) and row["volume_pct_chg"] is not None: + volume_pct_chg = float(row["volume_pct_chg"]) + else: + volume_pct_chg = None macd_diff = float(row["dif"]) macd_dea = float(row["dea"]) macd = float(row["macd"]) - - and_list = condition_dict.get("and", []) - - condition = True - for and_condition in and_list: + if and_or == "and": + for and_condition in condition_list: if and_condition == "5上穿10": - condition = condition and ("5上穿10" in ma_cross) + raw_condition = raw_condition and ("5上穿10" in ma_cross) elif and_condition == "10上穿20": - condition = condition and ("10上穿20" in ma_cross) + raw_condition = raw_condition and ("10上穿20" in ma_cross) elif and_condition == "20上穿30": - condition = condition and ("20上穿30" in ma_cross) + raw_condition = raw_condition and ("20上穿30" in ma_cross) elif and_condition == "ma5>ma10": - condition = condition and (ma5 > ma10) + raw_condition = raw_condition and (ma5 > ma10) elif and_condition == "ma10>ma20": - condition = condition and (ma10 > ma20) + raw_condition = raw_condition and (ma10 > ma20) elif and_condition == "ma20>ma30": - condition = condition and (ma20 > ma30) + raw_condition = raw_condition and (ma20 > ma30) elif and_condition == "close>ma20": - condition = condition and (close > ma20) - elif and_condition == "volume_pct_chg>0.2": - condition = condition and (volume_pct_chg > 0.2) + raw_condition = raw_condition and (close > ma20) + elif and_condition == "volume_pct_chg>0.2" and volume_pct_chg is not None: + raw_condition = raw_condition and (volume_pct_chg > 0.2) elif and_condition == "macd_diff>0": - condition = condition and (macd_diff > 0) + raw_condition = raw_condition and (macd_diff > 0) elif and_condition == "macd_dea>0": - condition = condition and (macd_dea > 0) + raw_condition = raw_condition and (macd_dea > 0) elif and_condition == "macd>0": - condition = condition and (macd > 0) + raw_condition = raw_condition and (macd > 0) elif and_condition == "10下穿5": - condition = condition and ("10下穿5" in ma_cross) + raw_condition = raw_condition and ("10下穿5" in ma_cross) elif and_condition == "20下穿10": - condition = condition and ("20下穿10" in ma_cross) + raw_condition = raw_condition and ("20下穿10" in ma_cross) elif and_condition == "30下穿20": - condition = condition and ("30下穿20" in ma_cross) + raw_condition = raw_condition and ("30下穿20" in ma_cross) elif and_condition == "ma5ma10": - condition = condition or (ma5 > ma10) + raw_condition = raw_condition or (ma5 > ma10) elif or_condition == "ma10>ma20": - condition = condition or (ma10 > ma20) + raw_condition = raw_condition or (ma10 > ma20) elif or_condition == "ma20>ma30": - condition = condition or (ma20 > ma30) + raw_condition = raw_condition or (ma20 > ma30) elif or_condition == "close>ma20": - condition = condition or (close > ma20) - elif or_condition == "volume_pct_chg>0.2": - condition = condition or (volume_pct_chg > 0.2) + raw_condition = raw_condition or (close > ma20) + elif or_condition == "volume_pct_chg>0.2" and volume_pct_chg is not None: + raw_condition = raw_condition or (volume_pct_chg > 0.2) elif or_condition == "macd_diff>0": - condition = condition or (macd_diff > 0) + raw_condition = raw_condition or (macd_diff > 0) elif or_condition == "macd_dea>0": - condition = condition or (macd_dea > 0) + raw_condition = raw_condition or (macd_dea > 0) elif or_condition == "macd>0": - condition = condition or (macd > 0) + raw_condition = raw_condition or (macd > 0) elif or_condition == "10下穿5": - condition = condition or ("10下穿5" in ma_cross) + raw_condition = raw_condition or ("10下穿5" in ma_cross) elif or_condition == "20下穿10": - condition = condition or ("20下穿10" in ma_cross) + raw_condition = raw_condition or ("20下穿10" in ma_cross) elif or_condition == "30下穿20": - condition = condition or ("30下穿20" in ma_cross) + raw_condition = raw_condition or ("30下穿20" in ma_cross) elif or_condition == "ma5 MaBreakStatistics 批量统计。 + - 聚合多个策略结果输出合并的资金曲线/收益数据。 + - 可配置是否美股/是否Binance、佣金参数等。 + +- trade_sandbox_main.py + - 功能: 均值回归策略沙盒回测,批量跑不同方案并输出Excel和图表。 + - 要点: + - 批量维度:symbols × bars × solutions。 + - 统计指标:止盈/止损次数与占比、收益分布、均值等,按 `symbol, bar` 分组。 + - 自动绘制2×2面板图(不同bar),嵌入Excel文件。 + - 可仅跑5m,也可多周期。 + +- market_data_from_itick_main.py + - 功能: 按时间段分片下载美股/ETF数据(示例使用AlphaVantage类命名),处理并保存CSV,展示统计。 + - 要点: + - 配置symbol/interval/分段天数,降低单次请求压力。 + - 下载→处理→保存→打印统计的串行流程。 + - 作为离线数据拉取脚本模版使用。 + +- auto_schedule.py + - 功能: 简易定时调度器,周期性运行 `huge_volume_main.py`。 + - 要点: + - 使用 schedule 每小时执行一次;记录执行时间与耗时。 + - 兼容不同当前工作目录定位脚本路径。 + - 适合本地常驻调度。 + +- auto_update_market_data.py + - 功能: 同上,定时运行 `huge_volume_main.py`(与 auto_schedule.py 功能基本一致)。 + - 要点: + - 同样是每小时执行,日志与输出一致。 + - 可按需要二选一保留,避免重复。 + +- update_data_main.py + - 功能: 批量更新数据库中行情数据的技术指标与美东时间字段。 + - 要点: + - 从DB读取全量数据→按timestamp排序→更新 `date_time_us` 与 `SAR` 指标→回写DB。 + - 支持美股/加密两套symbols与bars配置。 + - 严格校验MySQL配置是否存在密码。 + +- trade_data_main.py + - 功能: 交易明细拉取与补齐(API与DB结合),返回时间段内整理过的交易数据。 + - 要点: + - 依据DB现有最早/最新时间决定是否调用API补齐前段或后段。 + - 默认时间范围从配置初始时间到当前;最终结果从DB聚合、排序、去重。 + - 依赖 TradeData 实现API交互与DB写入。 + +- statistics_main.py + - 功能: 批量价格/成交量统计。 + - 要点: + - 调用 PriceVolumeStats 批处理,返回价格统计、成交量统计与联动统计结果。 + - 用作一次性统计入口,便于离线分析。 + +- trade_main.py + - 功能: 交易流程演示脚本(三段式示例:开空→现货卖出→平空)。 + - 要点: + - 封装 QuantTrader,对接实盘/模拟(由配置SANDBOX决定)。 + - 展示下单参数:逐仓/全仓、张数、杠杆、缓冲比例;以及余额检查、下单、平仓流程。 + - 以日志形式串联完整交易生命周期,适合作为交易API联通性验证与流程Demo。 + +- huge_volume_main.py + - 功能: 巨量成交检测与统计分析的核心入口(OKX与Binance均支持)。 + - 要点: + - 从行情表读取K线,计算滑窗放量、价格分位等;支持初始化、按窗口增量更新。 + - Binance 支持CSV历史导入,导入后联动更新巨量表。 + - 提供后续N周期涨跌统计、Excel导出与可视化,以及企业微信推送过滤(如volume_ratio>10且极值价位)。 + - 多窗口(50/80/100/120)与多周期(1m~1D)批量处理能力,MySQL落库去重。 \ No newline at end of file diff --git a/doc/trade_code_file_brief.md b/doc/trade_code_file_brief.md new file mode 100644 index 0000000..e33b203 --- /dev/null +++ b/doc/trade_code_file_brief.md @@ -0,0 +1,23 @@ +- ma_break_statistics.py + - 功能: 统计“均线突破”后的收益表现,批量跑不同标的/周期,生成统计与图表/Excel。 + - 要点: + - 数据源可切换:OKX/Binance/美股;时间范围从配置读取。 + - 关注 MA5/10/20/30 多组“上穿/下穿”组合,度量突破后的区间收益。 + - 可配置手续费率,输出多策略合并结果,落地到指定 output 目录。 + - 与数据库类 `DBMarketData/DBBinanceData/DBHugeVolumeData` 协同,读取K线、过滤区间。 + +- mean_reversion_sandbox.py + - 功能: 均值回归策略沙盒(回测器),按多种“买入/止损/止盈方案”跑批评估,生成统计+图表+Excel。 + - 要点: + - 条件以“价格分位+巨量”触发为主:如 close_10_low=1 或 close_80/90_high=1 且近2根K线任一巨量。 + - 多个方案(solution_1/2/3)策略化定义(止盈可用波段中位数、分位、高位形态等)。 + - 读自合并视图 `DBMergeMarketHugeVolume`(含价量与巨量事件),统一回测窗口与分组汇总。 + - 自动出图(seaborn/matplotlib)并将图贴入 Excel,结果分 symbol×bar 汇总。 + +- orb_trade.py + - 功能: ORB(Opening Range Breakout)日内策略回测与可视化。 + - 要点: + - 以开盘第一根5分钟K线的高低(High1/Low1)作为区间,第二根K线产生多空信号;入场价=第二根开盘价,止损价=第一根极值;盈亏基于 $R(entry-stop)。 + - 支持参数:账户初始资金、最大杠杆、单笔风险比例、佣金、盈利目标倍数、仅做多/仅做空/双向、是否参考 SAR、是否参考 1H 形态等。 + - 数据获取两路:优先本地DB(OKX/Binance),也提供 yfinance 拉取美股数据的流程;自动调整初始资金规模以适配价格量级。 + - 回测输出交易清单、资金曲线,生成图表与Excel摘要到 output 目录。 \ No newline at end of file diff --git a/huge_volume_main.py b/huge_volume_main.py index 035f67d..52b4f0e 100644 --- a/huge_volume_main.py +++ b/huge_volume_main.py @@ -11,7 +11,7 @@ import core.logger as logging from config import ( OKX_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, - MYSQL_CONFIG, + COIN_MYSQL_CONFIG, WINDOW_SIZE, BINANCE_MONITOR_CONFIG, ) @@ -30,13 +30,13 @@ class HugeVolumeMain: is_us_stock: bool = False, is_binance: bool = False, ): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.huge_volume = HugeVolume() diff --git a/market_data_main.py b/market_data_main.py index 2278d68..afea5f4 100644 --- a/market_data_main.py +++ b/market_data_main.py @@ -21,7 +21,7 @@ from config import ( OKX_MONITOR_CONFIG, BINANCE_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, - MYSQL_CONFIG, + COIN_MYSQL_CONFIG, BAR_THRESHOLD, ) @@ -67,13 +67,13 @@ class MarketDataMain: self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "initial_date", "2025-07-01 00:00:00" ) - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" if is_binance: diff --git a/market_monitor_main.py b/market_monitor_main.py index 0d65f38..05b1bd5 100644 --- a/market_monitor_main.py +++ b/market_monitor_main.py @@ -4,7 +4,7 @@ from huge_volume_main import HugeVolumeMain from core.biz.market_monitor import create_metrics_report from core.db.db_market_monitor import DBMarketMonitor from core.wechat import Wechat -from config import OKX_MONITOR_CONFIG, OKX_REALTIME_MONITOR_CONFIG, MYSQL_CONFIG, WECHAT_CONFIG +from config import OKX_MONITOR_CONFIG, OKX_REALTIME_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WECHAT_CONFIG from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp import core.logger as logging @@ -31,13 +31,13 @@ class MarketMonitorMain: self.output_folder = "./output/report/market_monitor/" os.makedirs(self.output_folder, exist_ok=True) - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" diff --git a/play.py b/play.py index 621dce7..689f07b 100644 --- a/play.py +++ b/play.py @@ -2,7 +2,7 @@ import logging from core.biz.quant_trader import QuantTrader from core.biz.strategy import QuantStrategy -from config import MYSQL_CONFIG +from config import COIN_MYSQL_CONFIG from sqlalchemy import create_engine, exc, text import pandas as pd @@ -100,13 +100,13 @@ def main() -> None: def test_query(): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" db_engine = create_engine( db_url, diff --git a/test_ma_methods.py b/test_ma_methods.py index e399344..add27b9 100644 --- a/test_ma_methods.py +++ b/test_ma_methods.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt from core.db.db_market_data import DBMarketData from core.biz.metrics_calculation import MetricsCalculation import logging -from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE +from config import OKX_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WINDOW_SIZE # plt支持中文 plt.rcParams['font.family'] = ['SimHei'] @@ -18,13 +18,13 @@ plt.rcParams['font.family'] = ['SimHei'] logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def get_real_data(symbol, bar, start, end): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" db_market_data = DBMarketData(db_url) diff --git a/trade_data_main.py b/trade_data_main.py index 5a30ebe..c3a447c 100644 --- a/trade_data_main.py +++ b/trade_data_main.py @@ -10,7 +10,7 @@ from config import ( PASSPHRASE, SANDBOX, OKX_MONITOR_CONFIG, - MYSQL_CONFIG, + COIN_MYSQL_CONFIG, ) logger = logging.logger @@ -18,13 +18,13 @@ logger = logging.logger class TradeDataMain: def __init__(self): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.trade_data = TradeData( diff --git a/trade_ma_strategy_main.py b/trade_ma_strategy_main.py index b7df31c..dd7c2bc 100644 --- a/trade_ma_strategy_main.py +++ b/trade_ma_strategy_main.py @@ -18,7 +18,6 @@ from config import ( PASSPHRASE, SANDBOX, OKX_MONITOR_CONFIG, - MYSQL_CONFIG, BAR_THRESHOLD, ) @@ -29,13 +28,21 @@ class TradeMaStrategyMain: def __init__( self, is_us_stock: bool = False, + is_astock: bool = False, + is_aindex: bool = True, is_binance: bool = False, commission_per_share: float = 0, + buy_by_long_period: dict = {"by_week": False, "by_month": False}, + long_period_condition: dict = {"ma5>ma10": False, "ma10>ma20": False, "macd_diff>0": False, "macd>0": False}, ): self.ma_break_statistics = MaBreakStatistics( is_us_stock=is_us_stock, + is_astock=is_astock, + is_aindex=is_aindex, is_binance=is_binance, commission_per_share=commission_per_share, + buy_by_long_period=buy_by_long_period, + long_period_condition=long_period_condition, ) def batch_ma_break_statistics(self): @@ -60,12 +67,51 @@ class TradeMaStrategyMain: logger.info("开始统计account_value_chg") +def test_single_symbol(): + ma_break_statistics = MaBreakStatistics( + is_us_stock=False, + is_astock=True, + is_aindex=False, + is_binance=False, + commission_per_share=0, + ) + symbol = "600111.SH" + bar = "1D" + ma_break_statistics.trade_simulate(symbol=symbol, bar=bar, strategy_name="均线macd结合策略2") + + if __name__ == "__main__": commission_per_share_list = [0, 0.0008] + buy_by_long_period_list = [{"by_week": True, "by_month": True}, + {"by_week": True, "by_month": False}, + {"by_week": False, "by_month": True}, + {"by_week": False, "by_month": False}] + long_period_condition_list = [{"ma5>ma10": True, "ma10>ma20": True, "macd_diff>0": True, "macd>0": True}, + {"ma5>ma10": True, "ma10>ma20": False, "macd_diff>0": True, "macd>0": True}, + {"ma5>ma10": False, "ma10>ma20": True, "macd_diff>0": True, "macd>0": True}] + for commission_per_share in commission_per_share_list: - trade_ma_strategy_main = TradeMaStrategyMain( - is_us_stock=False, - is_binance=True, - commission_per_share=commission_per_share, - ) - trade_ma_strategy_main.batch_ma_break_statistics() + for buy_by_long_period in buy_by_long_period_list: + for long_period_condition in long_period_condition_list: + logger.info(f"开始计算, 主要参数:commission_per_share: {commission_per_share}, buy_by_long_period: {buy_by_long_period}, long_period_condition: {long_period_condition}") + trade_ma_strategy_main = TradeMaStrategyMain( + is_us_stock=False, + is_astock=False, + is_aindex=True, + is_binance=False, + commission_per_share=commission_per_share, + buy_by_long_period=buy_by_long_period, + long_period_condition=long_period_condition, + ) + trade_ma_strategy_main.batch_ma_break_statistics() + + trade_ma_strategy_main = TradeMaStrategyMain( + is_us_stock=False, + is_astock=True, + is_aindex=False, + is_binance=False, + commission_per_share=commission_per_share, + buy_by_long_period=buy_by_long_period, + long_period_condition=long_period_condition, + ) + trade_ma_strategy_main.batch_ma_break_statistics() diff --git a/update_data_main.py b/update_data_main.py index a6a5b37..bd7de01 100644 --- a/update_data_main.py +++ b/update_data_main.py @@ -2,20 +2,20 @@ import pandas as pd from core.db.db_market_data import DBMarketData from core.biz.metrics_calculation import MetricsCalculation -from config import MYSQL_CONFIG, US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG +from config import COIN_MYSQL_CONFIG, US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG import core.logger as logging logger = logging.logger class UpdateDataMain: def __init__(self): - mysql_user = MYSQL_CONFIG.get("user", "xch") - mysql_password = MYSQL_CONFIG.get("password", "") + mysql_user = COIN_MYSQL_CONFIG.get("user", "xch") + mysql_password = COIN_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") - mysql_host = MYSQL_CONFIG.get("host", "localhost") - mysql_port = MYSQL_CONFIG.get("port", 3306) - mysql_database = MYSQL_CONFIG.get("database", "okx") + mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost") + mysql_port = COIN_MYSQL_CONFIG.get("port", 3306) + mysql_database = COIN_MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_market_data = DBMarketData(self.db_url)