support quant by A market

This commit is contained in:
blade 2025-09-25 12:28:43 +08:00
parent f3b98bcc22
commit 11c6e25490
21 changed files with 1017 additions and 317 deletions

View File

@ -134,6 +134,43 @@ US_STOCK_MONITOR_CONFIG = {
} }
} }
A_STOCK_MONITOR_CONFIG = {
"volume_monitor": {
"symbols": [
"600276.SH",
"002714.SZ",
"600111.SH",
"603019.SH",
"600036.SH",
"300474.SZ",
"600519.SH",
"300750.SZ",
"000858.SZ",
"000651.SZ",
"000333.SZ",
"002230.SZ",
"300308.SZ",
"002475.SZ"
],
"bars": ["1D", "1W", "1M"],
"initial_date": "2015-01-01 00:00:00",
},
}
A_INDEX_MONITOR_CONFIG = {
"volume_monitor": {
"symbols": [
"000001.SH",
"399006.SZ",
"000300.SH",
"399001.SZ",
"000852.SH",
],
"bars": ["1D", "1W", "1M"],
"initial_date": "2015-01-01 00:00:00",
},
}
WINDOW_SIZE = {"window_sizes": [50, 80, 100, 120]} WINDOW_SIZE = {"window_sizes": [50, 80, 100, 120]}
BAR_THRESHOLD = { BAR_THRESHOLD = {
@ -146,7 +183,7 @@ BAR_THRESHOLD = {
"1D": 1000 * 60 * 60 * 24, "1D": 1000 * 60 * 60 * 24,
} }
# MYSQL_CONFIG = { # COIN_MYSQL_CONFIG = {
# "host": "localhost", # "host": "localhost",
# "port": 3306, # "port": 3306,
# "user": "xch", # "user": "xch",
@ -154,7 +191,7 @@ BAR_THRESHOLD = {
# "database": "okx", # "database": "okx",
# } # }
MYSQL_CONFIG = { COIN_MYSQL_CONFIG = {
"host": "218.17.89.43", "host": "218.17.89.43",
"port": 11013, "port": 11013,
"user": "xch", "user": "xch",
@ -162,6 +199,14 @@ MYSQL_CONFIG = {
"database": "okx", "database": "okx",
} }
A_MYSQL_CONFIG = {
"host": "43.139.95.249",
"port": 3306,
"user": "root",
"password": "bengbu_200!",
"database": "astock",
}
WECHAT_CONFIG = { WECHAT_CONFIG = {
"general_key": "11e6f7ac-efa9-418a-904c-9325a9f5d324", "general_key": "11e6f7ac-efa9-418a-904c-9325a9f5d324",
"btc_key": "529e135d-843b-43dc-8aca-677a860f4b4b", "btc_key": "529e135d-843b-43dc-8aca-677a860f4b4b",

125
core/db/db_astock.py Normal file
View File

@ -0,0 +1,125 @@
import pandas as pd
from sqlalchemy import create_engine, exc, text
import re
from core.utils import get_current_date_time
import core.logger as logging
from core.utils import transform_data_type
logger = logging.logger
class DBAStockData:
def __init__(
self,
db_url: str,
):
self.db_url = db_url
self.db_engine = create_engine(
self.db_url,
pool_size=25, # 连接池大小
max_overflow=10, # 允许的最大溢出连接
pool_timeout=30, # 连接超时时间(秒)
pool_recycle=60, # 连接回收时间(秒),避免长时间闲置
)
def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True):
"""
查询数据
:param sql: 查询SQL
:param db_url: 数据库连接URL
"""
try:
with self.db_engine.connect() as conn:
result = conn.execute(text(sql), condition_dict)
if return_multi:
result = result.fetchall()
if result:
result_list = [
transform_data_type(dict(row._mapping)) for row in result
]
return result_list
else:
return None
else:
result = result.fetchone()
if result:
result_dict = transform_data_type(dict(result._mapping))
return result_dict
else:
return None
except Exception as e:
logger.error(f"查询数据出错: {e}")
return None
def query_market_data_by_symbol_bar(
self,
symbol: str,
bar: str,
fields: list = None,
start: str = None,
end: str = None,
table_name: str = "index_daily_price_from_2021",
):
"""
根据交易对和K线周期查询数据
:param symbol: 交易对
:param bar: K线周期
:param fields: 字段列表
:param start: 开始时间
:param end: 结束时间
"""
if fields is None:
fields = ["*"]
fields_str = ", ".join(fields)
if table_name is None:
table_name = "index_daily_price_from_2021"
join_table = "all_index"
if table_name.startswith("index"):
join_table = "all_index"
else:
join_table = "all_stock"
if start is None and end is None:
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol
ORDER BY a.trade_date ASC
"""
condition_dict = {"symbol": symbol}
else:
if start is not None and end is not None:
start = start.replace("-", "")
end = end.replace("-", "")
if start > end:
start, end = end, start
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol AND a.trade_date BETWEEN :start AND :end
ORDER BY a.trade_date ASC
"""
condition_dict = {
"symbol": symbol,
"start": start,
"end": end,
}
elif start is not None:
start = start.replace("-", "")
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol AND a.trade_date >= :start
ORDER BY a.trade_date ASC
"""
condition_dict = {"symbol": symbol, "start": start}
elif end is not None:
end = end.replace("-", "")
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol AND a.trade_date <= :end
ORDER BY a.trade_date ASC
"""
condition_dict = {"symbol": symbol, "end": end}
return self.query_data(sql, condition_dict, return_multi=True)

View File

@ -480,6 +480,7 @@ class DBBinanceData:
fields: list = None, fields: list = None,
start: str = None, start: str = None,
end: str = None, end: str = None,
table_name: str = "crypto_binance_data",
): ):
""" """
根据交易对和K线周期查询数据 根据交易对和K线周期查询数据
@ -494,7 +495,7 @@ class DBBinanceData:
fields_str = ", ".join(fields) fields_str = ", ".join(fields)
if start is None and end is None: if start is None and end is None:
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_binance_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar WHERE symbol = :symbol AND bar = :bar
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """
@ -514,7 +515,7 @@ class DBBinanceData:
if start > end: if start > end:
start, end = end, start start, end = end, start
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_binance_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar AND timestamp BETWEEN :start AND :end WHERE symbol = :symbol AND bar = :bar AND timestamp BETWEEN :start AND :end
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """
@ -526,14 +527,14 @@ class DBBinanceData:
} }
elif start is not None: elif start is not None:
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_binance_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar AND timestamp >= :start WHERE symbol = :symbol AND bar = :bar AND timestamp >= :start
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """
condition_dict = {"symbol": symbol, "bar": bar, "start": start} condition_dict = {"symbol": symbol, "bar": bar, "start": start}
elif end is not None: elif end is not None:
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_binance_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar AND timestamp <= :end WHERE symbol = :symbol AND bar = :bar AND timestamp <= :end
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """

View File

@ -7,10 +7,7 @@ logger = logging.logger
class DBMarketData: class DBMarketData:
def __init__( def __init__(self, db_url: str):
self,
db_url: str
):
self.db_url = db_url self.db_url = db_url
self.table_name = "crypto_market_data" self.table_name = "crypto_market_data"
self.columns = [ self.columns = [
@ -143,7 +140,9 @@ class DBMarketData:
condition_dict = {"symbol": symbol, "bar": bar} condition_dict = {"symbol": symbol, "bar": bar}
return self.db_manager.query_data(sql, condition_dict, return_multi=False) return self.db_manager.query_data(sql, condition_dict, return_multi=False)
def query_data_before_timestamp(self, symbol: str, bar: str, timestamp: int, limit: int = 100): def query_data_before_timestamp(
self, symbol: str, bar: str, timestamp: int, limit: int = 100
):
""" """
根据时间戳查询之前的数据 根据时间戳查询之前的数据
:param symbol: 交易对 :param symbol: 交易对
@ -157,7 +156,12 @@ class DBMarketData:
ORDER BY timestamp DESC ORDER BY timestamp DESC
LIMIT :limit LIMIT :limit
""" """
condition_dict = {"symbol": symbol, "bar": bar, "timestamp": timestamp, "limit": limit} condition_dict = {
"symbol": symbol,
"bar": bar,
"timestamp": timestamp,
"limit": limit,
}
return self.db_manager.query_data(sql, condition_dict, return_multi=True) return self.db_manager.query_data(sql, condition_dict, return_multi=True)
def query_data_by_technical_indicators( def query_data_by_technical_indicators(
@ -170,7 +174,7 @@ class DBMarketData:
kdj_signal: str = None, kdj_signal: str = None,
rsi_signal: str = None, rsi_signal: str = None,
boll_signal: str = None, boll_signal: str = None,
ma_cross: str = None ma_cross: str = None,
): ):
""" """
根据技术指标查询数据 根据技术指标查询数据
@ -230,7 +234,7 @@ class DBMarketData:
bar: str, bar: str,
signal: str = None, signal: str = None,
start: str = None, start: str = None,
end: str = None end: str = None,
): ):
""" """
查询MACD信号数据 查询MACD信号数据
@ -275,7 +279,7 @@ class DBMarketData:
signal: str = None, signal: str = None,
pattern: str = None, pattern: str = None,
start: str = None, start: str = None,
end: str = None end: str = None,
): ):
""" """
查询KDJ信号数据 查询KDJ信号数据
@ -325,7 +329,7 @@ class DBMarketData:
long_short: str = None, long_short: str = None,
divergence: str = None, divergence: str = None,
start: str = None, start: str = None,
end: str = None end: str = None,
): ):
""" """
查询均线信号数据 查询均线信号数据
@ -378,7 +382,7 @@ class DBMarketData:
signal: str = None, signal: str = None,
pattern: str = None, pattern: str = None,
start: str = None, start: str = None,
end: str = None end: str = None,
): ):
""" """
查询布林带信号数据 查询布林带信号数据
@ -421,11 +425,7 @@ class DBMarketData:
return self.db_manager.query_data(sql, condition_dict, return_multi=True) return self.db_manager.query_data(sql, condition_dict, return_multi=True)
def get_technical_statistics( def get_technical_statistics(
self, self, symbol: str, bar: str, start: str = None, end: str = None
symbol: str,
bar: str,
start: str = None,
end: str = None
): ):
""" """
获取技术指标统计信息 获取技术指标统计信息
@ -473,7 +473,15 @@ class DBMarketData:
return self.db_manager.query_data(sql, condition_dict, return_multi=False) return self.db_manager.query_data(sql, condition_dict, return_multi=False)
def query_market_data_by_symbol_bar(self, symbol: str, bar: str, fields: list = None, start: str = None, end: str = None): def query_market_data_by_symbol_bar(
self,
symbol: str,
bar: str,
fields: list = None,
start: str = None,
end: str = None,
table_name: str = "crypto_market_data",
):
""" """
根据交易对和K线周期查询数据 根据交易对和K线周期查询数据
:param symbol: 交易对 :param symbol: 交易对
@ -487,7 +495,7 @@ class DBMarketData:
fields_str = ", ".join(fields) fields_str = ", ".join(fields)
if start is None and end is None: if start is None and end is None:
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_market_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar WHERE symbol = :symbol AND bar = :bar
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """
@ -507,21 +515,26 @@ class DBMarketData:
if start > end: if start > end:
start, end = end, start start, end = end, start
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_market_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar AND timestamp BETWEEN :start AND :end WHERE symbol = :symbol AND bar = :bar AND timestamp BETWEEN :start AND :end
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """
condition_dict = {"symbol": symbol, "bar": bar, "start": start, "end": end} condition_dict = {
"symbol": symbol,
"bar": bar,
"start": start,
"end": end,
}
elif start is not None: elif start is not None:
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_market_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar AND timestamp >= :start WHERE symbol = :symbol AND bar = :bar AND timestamp >= :start
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """
condition_dict = {"symbol": symbol, "bar": bar, "start": start} condition_dict = {"symbol": symbol, "bar": bar, "start": start}
elif end is not None: elif end is not None:
sql = f""" sql = f"""
SELECT {fields_str} FROM crypto_market_data SELECT {fields_str} FROM {table_name}
WHERE symbol = :symbol AND bar = :bar AND timestamp <= :end WHERE symbol = :symbol AND bar = :bar AND timestamp <= :end
ORDER BY timestamp ASC ORDER BY timestamp ASC
""" """

View File

@ -12,7 +12,7 @@ from openpyxl.drawing.image import Image
import openpyxl import openpyxl
from openpyxl.styles import Font from openpyxl.styles import Font
from PIL import Image as PILImage from PIL import Image as PILImage
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE from config import OKX_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WINDOW_SIZE
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.db.db_huge_volume_data import DBHugeVolumeData from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
@ -25,13 +25,13 @@ logger = logging.logger
class PriceVolumeStats: class PriceVolumeStats:
def __init__(self): def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url) self.db_market_data = DBMarketData(self.db_url)

View File

@ -17,11 +17,16 @@ from PIL import Image as PILImage
from config import ( from config import (
OKX_MONITOR_CONFIG, OKX_MONITOR_CONFIG,
US_STOCK_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG,
MYSQL_CONFIG, COIN_MYSQL_CONFIG,
A_MYSQL_CONFIG,
WINDOW_SIZE, WINDOW_SIZE,
BINANCE_MONITOR_CONFIG, BINANCE_MONITOR_CONFIG,
A_STOCK_MONITOR_CONFIG,
A_INDEX_MONITOR_CONFIG,
) )
from core.biz.metrics_calculation import MetricsCalculation
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.db.db_astock import DBAStockData
from core.db.db_binance_data import DBBinanceData from core.db.db_binance_data import DBBinanceData
from core.db.db_huge_volume_data import DBHugeVolumeData from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
@ -44,21 +49,40 @@ class MaBreakStatistics:
def __init__( def __init__(
self, self,
is_us_stock: bool = False, is_us_stock: bool = False,
is_binance: bool = False, is_astock: bool = False,
is_aindex: bool = False,
is_binance: bool = True,
buy_by_long_period: dict = {"by_week": False, "by_month": False},
long_period_condition: dict = {"ma5>ma10": True, "ma10>ma20": False, "macd_diff>0": True, "macd>0": True},
commission_per_share: float = 0.0008, commission_per_share: float = 0.0008,
): ):
mysql_user = MYSQL_CONFIG.get("user", "xch") if is_astock or is_aindex:
mysql_password = MYSQL_CONFIG.get("password", "") mysql_user = A_MYSQL_CONFIG.get("user", "root")
mysql_password = A_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = A_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = A_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = A_MYSQL_CONFIG.get("database", "astock")
else:
mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_huge_volume_data = DBHugeVolumeData(self.db_url) self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.is_us_stock = is_us_stock self.is_us_stock = is_us_stock
self.is_astock = is_astock
self.is_aindex = is_aindex
if self.is_us_stock:
self.date_time_field = "date_time_us"
else:
self.date_time_field = "date_time"
self.is_binance = is_binance self.is_binance = is_binance
if is_us_stock: if is_us_stock:
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
@ -71,6 +95,28 @@ class MaBreakStatistics:
"initial_date", "2014-11-30 00:00:00" "initial_date", "2014-11-30 00:00:00"
) )
self.db_market_data = DBMarketData(self.db_url) self.db_market_data = DBMarketData(self.db_url)
elif is_astock:
self.symbols = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["000001.SH"]
)
self.bars = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBAStockData(self.db_url)
elif is_aindex:
self.symbols = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["000001.SH"]
)
self.bars = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBAStockData(self.db_url)
else: else:
if is_binance: if is_binance:
self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get( self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
@ -98,12 +144,36 @@ class MaBreakStatistics:
self.commission_per_share = commission_per_share self.commission_per_share = commission_per_share
self.trade_strategy_config = self.get_trade_strategy_config() self.trade_strategy_config = self.get_trade_strategy_config()
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None) self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
self.buy_by_long_period = buy_by_long_period
self.long_period_condition = long_period_condition
def get_trade_strategy_config(self): def get_trade_strategy_config(self):
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f: with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
trade_strategy_config = json.load(f) trade_strategy_config = json.load(f)
return trade_strategy_config return trade_strategy_config
def get_by_long_period_desc(self):
by_week = self.buy_by_long_period.get("by_week", False)
by_month = self.buy_by_long_period.get("by_month", False)
by_long_period = ""
if by_week:
by_long_period += "1W"
if by_month:
by_long_period += "1M"
if by_long_period == "":
return "no_long_period_judge"
by_condition = ""
if self.long_period_condition.get("ma5>ma10", False):
by_condition += "ma5gtma10"
if self.long_period_condition.get("ma10>ma20", False):
by_condition += "_ma10gtma20"
if self.long_period_condition.get("macd_diff>0", False):
by_condition += "_macd_diffgt0"
if self.long_period_condition.get("macd>0", False):
by_condition += "_macdgt0"
return by_long_period + "_" + by_condition
def batch_statistics(self, strategy_name: str = "全均线策略"): def batch_statistics(self, strategy_name: str = "全均线策略"):
if self.is_us_stock: if self.is_us_stock:
self.stats_output_dir = ( self.stats_output_dir = (
@ -119,6 +189,38 @@ class MaBreakStatistics:
self.stats_chart_dir = ( self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/" f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/"
) )
elif self.is_astock:
long_period_desc = self.get_by_long_period_desc()
if len(long_period_desc) > 0:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/astock/{long_period_desc}/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/astock/{long_period_desc}/chart/{strategy_name}/"
)
else:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/astock/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/astock/chart/{strategy_name}/"
)
elif self.is_aindex:
long_period_desc = self.get_by_long_period_desc()
if len(long_period_desc) > 0:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/aindex/{long_period_desc}/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/aindex/{long_period_desc}/chart/{strategy_name}/"
)
else:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/aindex/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/aindex/chart/{strategy_name}/"
)
else: else:
self.stats_output_dir = ( self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/" f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/"
@ -192,9 +294,11 @@ class MaBreakStatistics:
total_commission = round(total_commission, 4) total_commission = round(total_commission, 4)
total_buy_commission = round(total_buy_commission, 4) total_buy_commission = round(total_buy_commission, 4)
total_sell_commission = round(total_sell_commission, 4) total_sell_commission = round(total_sell_commission, 4)
symbol_name = str(symbol_bar_data["symbol_name"].iloc[0])
account_value_chg_list.append({ account_value_chg_list.append({
"strategy_name": strategy_name, "strategy_name": strategy_name,
"symbol": symbol, "symbol": symbol,
"symbol_name": symbol_name,
"bar": bar, "bar": bar,
"total_buy_commission": total_buy_commission, "total_buy_commission": total_buy_commission,
"total_sell_commission": total_sell_commission, "total_sell_commission": total_sell_commission,
@ -209,6 +313,7 @@ class MaBreakStatistics:
[ [
"strategy_name", "strategy_name",
"symbol", "symbol",
"symbol_name",
"bar", "bar",
"total_buy_commission", "total_buy_commission",
"total_sell_commission", "total_sell_commission",
@ -221,7 +326,7 @@ class MaBreakStatistics:
] ]
account_value_statistics_df = ( account_value_statistics_df = (
ma_break_market_data.groupby(["symbol", "bar"])["end_account_value"] ma_break_market_data.groupby(["symbol", "symbol_name", "bar"])["end_account_value"]
.agg( .agg(
account_value_max="max", account_value_max="max",
account_value_min="min", account_value_min="min",
@ -237,6 +342,7 @@ class MaBreakStatistics:
[ [
"strategy_name", "strategy_name",
"symbol", "symbol",
"symbol_name",
"bar", "bar",
"account_value_max", "account_value_max",
"account_value_min", "account_value_min",
@ -249,7 +355,7 @@ class MaBreakStatistics:
# 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count # 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
interval_minutes_df = ( interval_minutes_df = (
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"] ma_break_market_data.groupby(["symbol", "symbol_name", "bar"])["interval_minutes"]
.agg( .agg(
interval_minutes_max="max", interval_minutes_max="max",
interval_minutes_min="min", interval_minutes_min="min",
@ -265,6 +371,7 @@ class MaBreakStatistics:
[ [
"strategy_name", "strategy_name",
"symbol", "symbol",
"symbol_name",
"bar", "bar",
"interval_minutes_max", "interval_minutes_max",
"interval_minutes_min", "interval_minutes_min",
@ -335,6 +442,23 @@ class MaBreakStatistics:
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
else: else:
strategy_info["买入策略"] = buy_and_text strategy_info["买入策略"] = buy_and_text
# 假如根据长周期判断买入,则需要设置长周期策略
by_week = self.buy_by_long_period.get("by_week", False)
by_month = self.buy_by_long_period.get("by_month", False)
if by_week:
strategy_info["买入策略"] += "根据周线指标,\n"
if by_month:
strategy_info["买入策略"] += "根据月线指标,\n"
if self.long_period_condition.get("ma5>ma10", False):
strategy_info["买入策略"] += "ma5>ma10, \n"
if self.long_period_condition.get("ma10>ma20", False):
strategy_info["买入策略"] += "ma10>ma20, \n"
if self.long_period_condition.get("macd_diff>0", False):
strategy_info["买入策略"] += "macd_diff>0, \n"
if self.long_period_condition.get("macd>0", False):
strategy_info["买入策略"] += "macd>0, \n"
strategy_info["买入策略"] = strategy_info["买入策略"].strip()
sell_dict = strategy_config.get("sell", {}) sell_dict = strategy_config.get("sell", {})
sell_and_list = sell_dict.get("and", []) sell_and_list = sell_dict.get("and", [])
sell_or_list = sell_dict.get("or", []) sell_or_list = sell_dict.get("or", [])
@ -349,6 +473,7 @@ class MaBreakStatistics:
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
else: else:
strategy_info["卖出策略"] = sell_and_text strategy_info["卖出策略"] = sell_and_text
strategy_info["卖出策略"] = strategy_info["卖出策略"].strip()
# 将strategy_info转换为pd.DataFrame # 将strategy_info转换为pd.DataFrame
strategy_info_df = pd.DataFrame([strategy_info]) strategy_info_df = pd.DataFrame([strategy_info])
return strategy_info_df return strategy_info_df
@ -384,22 +509,24 @@ class MaBreakStatistics:
market_data.reset_index(drop=True, inplace=True) market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = [] ma_break_market_data_pair_list = []
ma_break_market_data_pair = {} ma_break_market_data_pair = {}
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
close_mean = market_data["close"].mean() close_mean = market_data["close"].mean()
self.update_initial_capital(close_mean) self.update_initial_capital(close_mean)
logger.info( logger.info(
f"成功获取{symbol}数据:{len(market_data)}{bar}K线,开始日期={market_data[date_time_field].min()},结束日期={market_data[date_time_field].max()}" f"成功获取{symbol}数据:{len(market_data)}{bar}K线,开始日期={market_data[self.date_time_field].min()},结束日期={market_data[self.date_time_field].max()}"
) )
account_value = self.initial_capital account_value = self.initial_capital
for index, row in market_data.iterrows(): for index, row in market_data.iterrows():
if self.is_astock:
symbol_name = row["symbol_name"]
elif self.is_aindex:
symbol_name = row["symbol_name"]
else:
symbol_name = row["symbol"]
ma_cross = row["ma_cross"] ma_cross = row["ma_cross"]
timestamp = row["timestamp"] timestamp = row["timestamp"]
date_time = row[date_time_field] date_time = row[self.date_time_field]
close = row["close"] close = row["close"]
ma5 = row["ma5"] ma5 = row["ma5"]
ma10 = row["ma10"] ma10 = row["ma10"]
@ -411,7 +538,6 @@ class MaBreakStatistics:
if ma_break_market_data_pair.get("begin_timestamp", None) is None: if ma_break_market_data_pair.get("begin_timestamp", None) is None:
buy_condition = self.fit_strategy( buy_condition = self.fit_strategy(
strategy_name=strategy_name, strategy_name=strategy_name,
market_data=market_data,
row=row, row=row,
behavior="buy", behavior="buy",
) )
@ -431,6 +557,7 @@ class MaBreakStatistics:
ma_break_market_data_pair = {} ma_break_market_data_pair = {}
ma_break_market_data_pair["symbol"] = symbol ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["symbol_name"] = symbol_name
ma_break_market_data_pair["bar"] = bar ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp ma_break_market_data_pair["begin_timestamp"] = timestamp
ma_break_market_data_pair["begin_date_time"] = date_time ma_break_market_data_pair["begin_date_time"] = date_time
@ -449,12 +576,12 @@ class MaBreakStatistics:
else: else:
sell_condition = self.fit_strategy( sell_condition = self.fit_strategy(
strategy_name=strategy_name, strategy_name=strategy_name,
market_data=market_data,
row=row, row=row,
behavior="sell", behavior="sell",
) )
if sell_condition: if sell_condition or index == len(market_data) - 1:
# 达到卖出条件或者最后一条数据,则卖出
shares = ma_break_market_data_pair["shares"] shares = ma_break_market_data_pair["shares"]
entry_price = ma_break_market_data_pair["begin_close"] entry_price = ma_break_market_data_pair["begin_close"]
exit_price = close exit_price = close
@ -525,8 +652,10 @@ class MaBreakStatistics:
* 100 * 100
) )
pct_chg = round(pct_chg, 4) pct_chg = round(pct_chg, 4)
symbol_name = ma_break_market_data["symbol_name"].iloc[0]
market_data_pct_chg = { market_data_pct_chg = {
"symbol": symbol, "symbol": symbol,
"symbol_name": symbol_name,
"bar": bar, "bar": bar,
"pct_chg": pct_chg, "pct_chg": pct_chg,
"initial_capital": self.initial_capital, "initial_capital": self.initial_capital,
@ -604,6 +733,50 @@ class MaBreakStatistics:
data = pd.DataFrame() data = pd.DataFrame()
start_date = datetime.strptime(self.initial_date, "%Y-%m-%d") start_date = datetime.strptime(self.initial_date, "%Y-%m-%d")
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1) end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
table_name = ""
if self.is_astock:
if bar == "1D":
table_name = "stock_daily_price_from_2021"
elif bar == "1W":
table_name = "stock_weekly_price_from_2020"
elif bar == "1M":
table_name = "stock_monthly_price_from_2015"
elif self.is_aindex:
if bar == "1D":
table_name = "index_daily_price_from_2021"
elif bar == "1W":
table_name = "index_weekly_price_from_2020"
elif bar == "1M":
table_name = "index_monthly_price_from_2015"
elif self.is_us_stock:
table_name = "crypto_market_data"
elif self.is_binance:
table_name = "crypto_binance_data"
else:
table_name = "crypto_binance_data"
if self.is_astock or self.is_aindex:
fields = [
"a.ts_code as symbol",
"b.name as symbol_name",
f"'{bar}' as bar",
"0 as timestamp",
"trade_date as date_time",
"open",
"high",
"low",
"close",
"vol as volume",
"MA5 as ma5",
"MA10 as ma10",
"MA20 as ma20",
"MA30 as ma30",
"均线交叉 as ma_cross",
"DIF as dif",
"DEA as dea",
"MACD as macd",
]
else:
fields = [ fields = [
"symbol", "symbol",
"bar", "bar",
@ -631,25 +804,120 @@ class MaBreakStatistics:
current_end_date_str = current_end_date.strftime("%Y-%m-%d") current_end_date_str = current_end_date.strftime("%Y-%m-%d")
logger.info(f"获取{symbol}数据:{start_date_str}{current_end_date_str}") logger.info(f"获取{symbol}数据:{start_date_str}{current_end_date_str}")
current_data = self.db_market_data.query_market_data_by_symbol_bar( current_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, fields, start=start_date_str, end=current_end_date_str symbol, bar, fields, start=start_date_str, end=current_end_date_str, table_name=table_name
) )
if current_data is not None and len(current_data) > 0: if current_data is not None and len(current_data) > 0:
current_data = pd.DataFrame(current_data) current_data = pd.DataFrame(current_data)
data = pd.concat([data, current_data]) data = pd.concat([data, current_data])
start_date = current_end_date start_date = current_end_date
data.drop_duplicates(inplace=True) data.drop_duplicates(inplace=True)
if self.is_us_stock: data.sort_values(by=self.date_time_field, inplace=True)
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
data.sort_values(by=date_time_field, inplace=True)
data.reset_index(drop=True, inplace=True) data.reset_index(drop=True, inplace=True)
if self.is_astock or self.is_aindex:
data = self.update_data(data)
return data return data
def get_long_period_data(self, symbol: str, bar: str, end_date: str):
"""
获取长周期数据
:param data: 数据
:return: 长周期数据
"""
if not (self.is_astock or self.is_aindex):
return None
table_name = ""
if self.is_astock:
if bar == "1M":
table_name = "stock_monthly_price_from_2015"
elif bar == "1W":
table_name = "stock_weekly_price_from_2020"
else:
pass
elif self.is_aindex:
if bar == "1M":
table_name = "index_monthly_price_from_2015"
elif bar == "1W":
table_name = "index_weekly_price_from_2020"
else:
pass
if len(end_date) != 10:
end_date = self.change_date_format(end_date)
if bar == "1M":
# 获取上两个月的日期
last_date = datetime.strptime(end_date, "%Y-%m-%d") - timedelta(days=60)
last_date = last_date.strftime("%Y-%m-%d")
elif bar == "1W":
# 获取上两周的日期
last_date = datetime.strptime(end_date, "%Y-%m-%d") - timedelta(days=14)
last_date = last_date.strftime("%Y-%m-%d")
else:
last_date = None
if len(table_name) == 0 or last_date is None:
return None
fields = [
"a.ts_code as symbol",
"b.name as symbol_name",
f"'{bar}' as bar",
"0 as timestamp",
"trade_date as date_time",
"open",
"high",
"low",
"close",
"vol as volume",
"MA5 as ma5",
"MA10 as ma10",
"MA20 as ma20",
"MA30 as ma30",
"均线交叉 as ma_cross",
"DIF as dif",
"DEA as dea",
"MACD as macd",
]
data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, fields, start=last_date, end=end_date, table_name=table_name
)
if data is not None and len(data) > 0:
data = pd.DataFrame(data)
data.sort_values(by="date_time", inplace=True)
latest_row = data.iloc[-1]
if (latest_row["ma5"] is None or
latest_row["ma10"] is None or
latest_row["ma20"] is None or
latest_row["dif"] is None or
latest_row["macd"] is None):
return None
return latest_row
else:
return None
def update_data(self, data: pd.DataFrame):
"""
更新数据
1. 将date_time列中的20210104这种格式替换为2021-01-04的格式
2. 将date_time转换为timestamp并更新timestamp列
3. 通过MetricsCalculation的ma5102030方法更新ma_cross列
:param data: 数据
:return: 更新后的数据
"""
data["date_time"] = data["date_time"].apply(lambda x: self.change_date_format(x))
data["timestamp"] = data["date_time"].apply(lambda x: transform_date_time_to_timestamp(x))
metrics_calculation = MetricsCalculation()
data = metrics_calculation.ma5102030(data)
return data
def change_date_format(self, date_text: str):
# 将20210104这种格式替换为2021-01-04的格式
if len(date_text) == 8:
return date_text[0:4] + "-" + date_text[4:6] + "-" + date_text[6:8]
else:
return date_text
def fit_strategy( def fit_strategy(
self, self,
strategy_name: str = "全均线策略", strategy_name: str = "全均线策略",
market_data: pd.DataFrame = None,
row: pd.Series = None, row: pd.Series = None,
behavior: str = "buy", behavior: str = "buy",
): ):
@ -661,6 +929,45 @@ class MaBreakStatistics:
if condition_dict is None: if condition_dict is None:
logger.error(f"策略{strategy_name}{behavior}条件不存在") logger.error(f"策略{strategy_name}{behavior}条件不存在")
return False return False
and_list = condition_dict.get("and", [])
condition = True
condition = self.get_judge_result(row, and_list, "and", condition)
or_list = condition_dict.get("or", [])
condition = self.get_judge_result(row, or_list, "or", condition)
if behavior == "buy" and condition:
# 如果满足条件,则判断是否根据长周期指标买入
bar = row["bar"]
if (self.is_astock or self.is_aindex) and bar == "1D":
date_time = row["date_time"]
long_period_condition_list = []
if self.long_period_condition.get("ma5>ma10", False):
long_period_condition_list.append("ma5>ma10")
if self.long_period_condition.get("ma10>ma20", False):
long_period_condition_list.append("ma10>ma20")
if self.long_period_condition.get("macd_diff>0", False):
long_period_condition_list.append("macd_diff>0")
if self.long_period_condition.get("macd>0", False):
long_period_condition_list.append("macd>0")
if len(long_period_condition_list) > 0:
if self.buy_by_long_period.get("by_week", False):
long_period_data = self.get_long_period_data(row["symbol"], "1W", date_time)
if long_period_data is not None:
condition = self.get_judge_result(long_period_data, long_period_condition_list, "and", condition)
if not condition:
logger.info(f"根据周线指标,{row['symbol']}不满足买入条件")
if self.buy_by_long_period.get("by_month", False):
long_period_data = self.get_long_period_data(row["symbol"], "1M", date_time)
if long_period_data is not None:
condition = self.get_judge_result(long_period_data, long_period_condition_list, "and", condition)
if not condition:
logger.info(f"根据月线指标,{row['symbol']}不满足买入条件")
return condition
def get_judge_result(self, row: pd.Series, condition_list: list, and_or: str = "and", raw_condition: bool = True):
ma_cross = row["ma_cross"] ma_cross = row["ma_cross"]
if pd.isna(ma_cross) or ma_cross is None: if pd.isna(ma_cross) or ma_cross is None:
ma_cross = "" ma_cross = ""
@ -670,107 +977,107 @@ class MaBreakStatistics:
ma20 = float(row["ma20"]) ma20 = float(row["ma20"])
ma30 = float(row["ma30"]) ma30 = float(row["ma30"])
close = float(row["close"]) close = float(row["close"])
if "volume_pct_chg" in list(row.index) and row["volume_pct_chg"] is not None:
volume_pct_chg = float(row["volume_pct_chg"]) volume_pct_chg = float(row["volume_pct_chg"])
else:
volume_pct_chg = None
macd_diff = float(row["dif"]) macd_diff = float(row["dif"])
macd_dea = float(row["dea"]) macd_dea = float(row["dea"])
macd = float(row["macd"]) macd = float(row["macd"])
if and_or == "and":
and_list = condition_dict.get("and", []) for and_condition in condition_list:
condition = True
for and_condition in and_list:
if and_condition == "5上穿10": if and_condition == "5上穿10":
condition = condition and ("5上穿10" in ma_cross) raw_condition = raw_condition and ("5上穿10" in ma_cross)
elif and_condition == "10上穿20": elif and_condition == "10上穿20":
condition = condition and ("10上穿20" in ma_cross) raw_condition = raw_condition and ("10上穿20" in ma_cross)
elif and_condition == "20上穿30": elif and_condition == "20上穿30":
condition = condition and ("20上穿30" in ma_cross) raw_condition = raw_condition and ("20上穿30" in ma_cross)
elif and_condition == "ma5>ma10": elif and_condition == "ma5>ma10":
condition = condition and (ma5 > ma10) raw_condition = raw_condition and (ma5 > ma10)
elif and_condition == "ma10>ma20": elif and_condition == "ma10>ma20":
condition = condition and (ma10 > ma20) raw_condition = raw_condition and (ma10 > ma20)
elif and_condition == "ma20>ma30": elif and_condition == "ma20>ma30":
condition = condition and (ma20 > ma30) raw_condition = raw_condition and (ma20 > ma30)
elif and_condition == "close>ma20": elif and_condition == "close>ma20":
condition = condition and (close > ma20) raw_condition = raw_condition and (close > ma20)
elif and_condition == "volume_pct_chg>0.2": elif and_condition == "volume_pct_chg>0.2" and volume_pct_chg is not None:
condition = condition and (volume_pct_chg > 0.2) raw_condition = raw_condition and (volume_pct_chg > 0.2)
elif and_condition == "macd_diff>0": elif and_condition == "macd_diff>0":
condition = condition and (macd_diff > 0) raw_condition = raw_condition and (macd_diff > 0)
elif and_condition == "macd_dea>0": elif and_condition == "macd_dea>0":
condition = condition and (macd_dea > 0) raw_condition = raw_condition and (macd_dea > 0)
elif and_condition == "macd>0": elif and_condition == "macd>0":
condition = condition and (macd > 0) raw_condition = raw_condition and (macd > 0)
elif and_condition == "10下穿5": elif and_condition == "10下穿5":
condition = condition and ("10下穿5" in ma_cross) raw_condition = raw_condition and ("10下穿5" in ma_cross)
elif and_condition == "20下穿10": elif and_condition == "20下穿10":
condition = condition and ("20下穿10" in ma_cross) raw_condition = raw_condition and ("20下穿10" in ma_cross)
elif and_condition == "30下穿20": elif and_condition == "30下穿20":
condition = condition and ("30下穿20" in ma_cross) raw_condition = raw_condition and ("30下穿20" in ma_cross)
elif and_condition == "ma5<ma10": elif and_condition == "ma5<ma10":
condition = condition and (ma5 < ma10) raw_condition = raw_condition and (ma5 < ma10)
elif and_condition == "ma10<ma20": elif and_condition == "ma10<ma20":
condition = condition and (ma10 < ma20) raw_condition = raw_condition and (ma10 < ma20)
elif and_condition == "ma20<ma30": elif and_condition == "ma20<ma30":
condition = condition and (ma20 < ma30) raw_condition = raw_condition and (ma20 < ma30)
elif and_condition == "close<ma20": elif and_condition == "close<ma20":
condition = condition and (close < ma20) raw_condition = raw_condition and (close < ma20)
elif and_condition == "macd_diff<0": elif and_condition == "macd_diff<0":
condition = condition and (macd_diff < 0) raw_condition = raw_condition and (macd_diff < 0)
elif and_condition == "macd_dea<0": elif and_condition == "macd_dea<0":
condition = condition and (macd_dea < 0) raw_condition = raw_condition and (macd_dea < 0)
elif and_condition == "macd<0": elif and_condition == "macd<0":
condition = condition and (macd < 0) raw_condition = raw_condition and (macd < 0)
else: else:
pass pass
if not condition: elif and_or == "or":
or_list = condition_dict.get("or", []) for or_condition in condition_list:
for or_condition in or_list:
if or_condition == "5上穿10": if or_condition == "5上穿10":
condition = condition or ("5上穿10" in ma_cross) raw_condition = raw_condition or ("5上穿10" in ma_cross)
elif or_condition == "10上穿20": elif or_condition == "10上穿20":
condition = condition or ("10上穿20" in ma_cross) raw_condition = raw_condition or ("10上穿20" in ma_cross)
elif or_condition == "20上穿30": elif or_condition == "20上穿30":
condition = condition or ("20上穿30" in ma_cross) raw_condition = raw_condition or ("20上穿30" in ma_cross)
elif or_condition == "ma5>ma10": elif or_condition == "ma5>ma10":
condition = condition or (ma5 > ma10) raw_condition = raw_condition or (ma5 > ma10)
elif or_condition == "ma10>ma20": elif or_condition == "ma10>ma20":
condition = condition or (ma10 > ma20) raw_condition = raw_condition or (ma10 > ma20)
elif or_condition == "ma20>ma30": elif or_condition == "ma20>ma30":
condition = condition or (ma20 > ma30) raw_condition = raw_condition or (ma20 > ma30)
elif or_condition == "close>ma20": elif or_condition == "close>ma20":
condition = condition or (close > ma20) raw_condition = raw_condition or (close > ma20)
elif or_condition == "volume_pct_chg>0.2": elif or_condition == "volume_pct_chg>0.2" and volume_pct_chg is not None:
condition = condition or (volume_pct_chg > 0.2) raw_condition = raw_condition or (volume_pct_chg > 0.2)
elif or_condition == "macd_diff>0": elif or_condition == "macd_diff>0":
condition = condition or (macd_diff > 0) raw_condition = raw_condition or (macd_diff > 0)
elif or_condition == "macd_dea>0": elif or_condition == "macd_dea>0":
condition = condition or (macd_dea > 0) raw_condition = raw_condition or (macd_dea > 0)
elif or_condition == "macd>0": elif or_condition == "macd>0":
condition = condition or (macd > 0) raw_condition = raw_condition or (macd > 0)
elif or_condition == "10下穿5": elif or_condition == "10下穿5":
condition = condition or ("10下穿5" in ma_cross) raw_condition = raw_condition or ("10下穿5" in ma_cross)
elif or_condition == "20下穿10": elif or_condition == "20下穿10":
condition = condition or ("20下穿10" in ma_cross) raw_condition = raw_condition or ("20下穿10" in ma_cross)
elif or_condition == "30下穿20": elif or_condition == "30下穿20":
condition = condition or ("30下穿20" in ma_cross) raw_condition = raw_condition or ("30下穿20" in ma_cross)
elif or_condition == "ma5<ma10": elif or_condition == "ma5<ma10":
condition = condition or (ma5 < ma10) raw_condition = raw_condition or (ma5 < ma10)
elif or_condition == "ma10<ma20": elif or_condition == "ma10<ma20":
condition = condition or (ma10 < ma20) raw_condition = raw_condition or (ma10 < ma20)
elif or_condition == "ma20<ma30": elif or_condition == "ma20<ma30":
condition = condition or (ma20 < ma30) raw_condition = raw_condition or (ma20 < ma30)
elif or_condition == "close<ma20": elif or_condition == "close<ma20":
condition = condition or (close < ma20) raw_condition = raw_condition or (close < ma20)
elif or_condition == "macd_diff<0": elif or_condition == "macd_diff<0":
condition = condition or (macd_diff < 0) raw_condition = raw_condition or (macd_diff < 0)
elif or_condition == "macd_dea<0": elif or_condition == "macd_dea<0":
condition = condition or (macd_dea < 0) raw_condition = raw_condition or (macd_dea < 0)
elif or_condition == "macd<0": elif or_condition == "macd<0":
condition = condition or (macd < 0) raw_condition = raw_condition or (macd < 0)
else: else:
pass pass
return condition return raw_condition
def draw_quant_pct_chg_bar_chart( def draw_quant_pct_chg_bar_chart(
self, data: pd.DataFrame, strategy_name: str = "全均线策略" self, data: pd.DataFrame, strategy_name: str = "全均线策略"
@ -812,8 +1119,8 @@ class MaBreakStatistics:
width = 0.35 # 柱状图宽度 width = 0.35 # 柱状图宽度
# 确保symbol列是字符串类型避免matplotlib警告 # 确保symbol列是字符串类型避免matplotlib警告
bar_data["symbol"] = bar_data["symbol"].astype(str) # bar_data["symbol"] = bar_data["symbol"].astype(str)
bar_data["symbol_name"] = bar_data["symbol_name"].astype(str)
# 绘制量化策略涨跌柱状图(蓝色渐变色) # 绘制量化策略涨跌柱状图(蓝色渐变色)
bars1 = plt.bar( bars1 = plt.bar(
x - width / 2, x - width / 2,
@ -840,7 +1147,7 @@ class MaBreakStatistics:
) )
plt.xlabel("Symbol", fontsize=12) plt.xlabel("Symbol", fontsize=12)
plt.ylabel("涨跌幅(%)", fontsize=12) plt.ylabel("涨跌幅(%)", fontsize=12)
plt.xticks(x, bar_data["symbol"], rotation=45, ha="right") plt.xticks(x, bar_data["symbol_name"], rotation=45, ha="right")
plt.legend() plt.legend()
plt.grid(True, alpha=0.3) plt.grid(True, alpha=0.3)
@ -907,13 +1214,13 @@ class MaBreakStatistics:
:param strategy_name: 策略名称 :param strategy_name: 策略名称
:return: None :return: None
""" """
symbols = data["symbol"].unique() symbols = data["symbol_name"].unique()
bars = data["bar"].unique() bars = data["bar"].unique()
chart_dict = {} chart_dict = {}
for symbol in symbols: for symbol in symbols:
for bar in bars: for bar in bars:
symbol_bar_data = data[ symbol_bar_data = data[
(data["symbol"] == symbol) & (data["bar"] == bar) (data["symbol_name"] == symbol) & (data["bar"] == bar)
] ]
if symbol_bar_data.empty: if symbol_bar_data.empty:
continue continue
@ -922,7 +1229,7 @@ class MaBreakStatistics:
first_row = symbol_bar_data.iloc[0].copy() first_row = symbol_bar_data.iloc[0].copy()
initial_capital = int( initial_capital = int(
market_data_pct_chg_df.loc[ market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol) (market_data_pct_chg_df["symbol_name"] == symbol)
& (market_data_pct_chg_df["bar"] == bar), & (market_data_pct_chg_df["bar"] == bar),
"initial_capital", "initial_capital",
].values[0] ].values[0]

View File

@ -11,7 +11,7 @@ from openpyxl.drawing.image import Image
import openpyxl import openpyxl
from openpyxl.styles import Font from openpyxl.styles import Font
from PIL import Image as PILImage from PIL import Image as PILImage
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE from config import OKX_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WINDOW_SIZE
import core.logger as logging import core.logger as logging
from core.db.db_merge_market_huge_volume import DBMergeMarketHugeVolume from core.db.db_merge_market_huge_volume import DBMergeMarketHugeVolume
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
@ -24,13 +24,13 @@ logger = logging.logger
class MeanReversionSandbox: class MeanReversionSandbox:
def __init__(self, solution: str): def __init__(self, solution: str):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_merge_market_huge_volume = DBMergeMarketHugeVolume(self.db_url) self.db_merge_market_huge_volume = DBMergeMarketHugeVolume(self.db_url)

View File

@ -12,7 +12,7 @@ from PIL import Image as PILImage
from datetime import datetime, timedelta from datetime import datetime, timedelta
import core.logger as logging import core.logger as logging
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE from config import OKX_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WINDOW_SIZE
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.db.db_binance_data import DBBinanceData from core.db.db_binance_data import DBBinanceData
from core.db.db_huge_volume_data import DBHugeVolumeData from core.db.db_huge_volume_data import DBHugeVolumeData
@ -92,13 +92,13 @@ class ORBStrategy:
self.data = None # 存储K线数据 self.data = None # 存储K线数据
self.trades = [] # 存储交易记录 self.trades = [] # 存储交易记录
self.equity_curve = None # 存储账户净值曲线 self.equity_curve = None # 存储账户净值曲线
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.is_us_stock = is_us_stock self.is_us_stock = is_us_stock
self.is_binance = is_binance self.is_binance = is_binance

View File

@ -0,0 +1,62 @@
- core/biz/huge_volume.py
- 作用: 放量(巨量)检测与后续涨跌统计的核心逻辑。
- 要点:
- 基于滑窗计算 volume_ma、volume_std、volume_threshold=均值+N倍标准差生成 huge_volume、volume_ratio、spike_intensity。
- 可选价格分位检查:对 close/high/low 计算 80/20、90/10 分位并标注高低点命中位。
- next_periods_rise_or_fall: 以未来 N 周期涨跌结果做分组统计,输出明细与汇总。
- core/biz/huge_volume_chart.py
- 作用: 将“巨量后走势统计”数据绘图(热力图/折线图)。
- 要点:
- 输入统计 DataFrame支持是否包含热力图/折线图。
- 被 `huge_volume_main.plot_huge_volume_data` 调用,输出到 `./output/huge_volume_statistics/`
- core/biz/market_data.py
- 作用: 行情获取统一封装OKX 为主Linux 环境下支持 Binance且支持美股模式
- 要点:
- get_realtime_kline_data / get_historical_kline_data统一返回 DataFrame`timestamp/date_time/date_time_us/symbol/bar/...`,自动数值化与排序。
- 历史数据分页向后抓取,含时间边界、去重与交易时段过滤(美股)。
- 提供 `get_realtime_candlesticks_from_binance/okx` 与基本 trade 聚合辅助buy_sz/sell_sz 预留)。
- core/biz/market_data_from_itick.py
- 作用: 从 iTick 等源拉取美股K线`market_data_from_itick_main.py` 使用)。
- 要点:
- 封装美股数据下载,适配 `MarketData` 历史数据流程,统一列结构。
- core/biz/market_monitor.py
- 作用: 实时监控报表生成面向企业微信推送的Markdown文案
- 要点:
- create_metrics_report: 基于一根最新K线及全量数据汇总价量、分位、MACD/KDJ/RSI/BOLL、均线多空/发散等信号,生成可读文本。
- get_last_huge_volume_record: 最近一次巨量回溯与十周期内巨量次数。
- get_long_short_over_buy_sell: 跨周期或对标BTC的多空/超买超卖对比说明。
- 依赖 `METRICS_CONFIG` 的权重与阈值映射。
- core/biz/metrics_calculation.py
- 作用: 技术指标与形态计算的总入口。
- 要点:
- 指标: pre_close/pct_chg、MACD(含金叉死叉)、KDJ(K/D/J+信号)、RSI、BOLL(上下轨与形态)、SAR(多/空/观望)。
- 均线: ma5/10/20/30、交叉组合信号、价格-均线相对百分比、`ma_long_short`(多/空/震荡)与 `ma_divergence`(发散/粘合等提供多策略判定weighted_voting/trend_strength/ma_alignment/statistical/hybrid。
- K线形态: k_length短/中/长/超长、k_shape吊锤线、倒T、十字星、超大实体、光头光脚等基于统计分布和Z-score自适应阈值。
- MACD 背离: 标准版与滑窗版两套检测。
- core/biz/quant_trader.py
- 作用: OKX 交易封装(账户、下单、行情、公共数据)。
- 要点:
- 余额查询USDT/现货币/合约、当前价格、K线拉取。
- 现货市价单买卖;合约侧设置杠杆、开空(卖出)与平空(买入)流程。
- 计算合约所需保证金与推荐保证金(含缓冲比例)。
- core/biz/strategy.py
- 作用: 策略抽象/占位(被 `trade_main.py` 中的 `QuantStrategy` 引用)。
- 要点:
- 用于承载策略接口或具体策略实现(与 `quant_trader` 协作下单)。
- core/biz/trade_data.py
- 作用: 交易明细获取与存储封装(供 `TradeDataMain` 使用)。
- 要点:
- 负责对接交易API、落库与查询支持按时间段增量补齐。
- core/biz/market_data_from_itick.py若存在
- 作用: iTick 美股数据源适配器。
- 要点:
- 输出结构与 OKX/Binance 对齐,便于统一后续指标/巨量检测流程。

View File

@ -0,0 +1,78 @@
- market_monitor_main.py
- 功能: 实时监控市场K线优先OKXLinux下可切换Binance计算技术指标与巨量信号生成监控报告并推送企业微信。
- 要点:
- 直接API拉取最近K线不访问DB以保证速度。
- 固定滑窗window_size=100实时判定huge_volume与价格分位异常。
- 通过最新时间戳与本地记录去重,避免重复推送。
- 支持过滤条件:仅巨量、仅超过均量、仅上涨。
- 依赖 Wechat 推送、DBMarketMonitor 记录、OKX_REALTIME_MONITOR_CONFIG 配置。
- trade_ma_strategy_main.py
- 功能: 批量运行“均线突破/相关”策略的统计回测主要是MACD命名策略
- 要点:
- 入口类 TradeMaStrategyMain -> MaBreakStatistics 批量统计。
- 聚合多个策略结果输出合并的资金曲线/收益数据。
- 可配置是否美股/是否Binance、佣金参数等。
- trade_sandbox_main.py
- 功能: 均值回归策略沙盒回测批量跑不同方案并输出Excel和图表。
- 要点:
- 批量维度symbols × bars × solutions。
- 统计指标:止盈/止损次数与占比、收益分布、均值等,按 `symbol, bar` 分组。
- 自动绘制2×2面板图不同bar嵌入Excel文件。
- 可仅跑5m也可多周期。
- market_data_from_itick_main.py
- 功能: 按时间段分片下载美股/ETF数据示例使用AlphaVantage类命名处理并保存CSV展示统计。
- 要点:
- 配置symbol/interval/分段天数,降低单次请求压力。
- 下载→处理→保存→打印统计的串行流程。
- 作为离线数据拉取脚本模版使用。
- auto_schedule.py
- 功能: 简易定时调度器,周期性运行 `huge_volume_main.py`
- 要点:
- 使用 schedule 每小时执行一次;记录执行时间与耗时。
- 兼容不同当前工作目录定位脚本路径。
- 适合本地常驻调度。
- auto_update_market_data.py
- 功能: 同上,定时运行 `huge_volume_main.py`(与 auto_schedule.py 功能基本一致)。
- 要点:
- 同样是每小时执行,日志与输出一致。
- 可按需要二选一保留,避免重复。
- update_data_main.py
- 功能: 批量更新数据库中行情数据的技术指标与美东时间字段。
- 要点:
- 从DB读取全量数据→按timestamp排序→更新 `date_time_us``SAR` 指标→回写DB。
- 支持美股/加密两套symbols与bars配置。
- 严格校验MySQL配置是否存在密码。
- trade_data_main.py
- 功能: 交易明细拉取与补齐API与DB结合返回时间段内整理过的交易数据。
- 要点:
- 依据DB现有最早/最新时间决定是否调用API补齐前段或后段。
- 默认时间范围从配置初始时间到当前最终结果从DB聚合、排序、去重。
- 依赖 TradeData 实现API交互与DB写入。
- statistics_main.py
- 功能: 批量价格/成交量统计。
- 要点:
- 调用 PriceVolumeStats 批处理,返回价格统计、成交量统计与联动统计结果。
- 用作一次性统计入口,便于离线分析。
- trade_main.py
- 功能: 交易流程演示脚本(三段式示例:开空→现货卖出→平空)。
- 要点:
- 封装 QuantTrader对接实盘/模拟由配置SANDBOX决定
- 展示下单参数:逐仓/全仓、张数、杠杆、缓冲比例;以及余额检查、下单、平仓流程。
- 以日志形式串联完整交易生命周期适合作为交易API联通性验证与流程Demo。
- huge_volume_main.py
- 功能: 巨量成交检测与统计分析的核心入口OKX与Binance均支持
- 要点:
- 从行情表读取K线计算滑窗放量、价格分位等支持初始化、按窗口增量更新。
- Binance 支持CSV历史导入导入后联动更新巨量表。
- 提供后续N周期涨跌统计、Excel导出与可视化以及企业微信推送过滤如volume_ratio>10且极值价位
- 多窗口50/80/100/120与多周期1m~1D批量处理能力MySQL落库去重。

View File

@ -0,0 +1,23 @@
- ma_break_statistics.py
- 功能: 统计“均线突破”后的收益表现,批量跑不同标的/周期,生成统计与图表/Excel。
- 要点:
- 数据源可切换OKX/Binance/美股;时间范围从配置读取。
- 关注 MA5/10/20/30 多组“上穿/下穿”组合,度量突破后的区间收益。
- 可配置手续费率,输出多策略合并结果,落地到指定 output 目录。
- 与数据库类 `DBMarketData/DBBinanceData/DBHugeVolumeData` 协同读取K线、过滤区间。
- mean_reversion_sandbox.py
- 功能: 均值回归策略沙盒(回测器),按多种“买入/止损/止盈方案”跑批评估,生成统计+图表+Excel。
- 要点:
- 条件以“价格分位+巨量”触发为主:如 close_10_low=1 或 close_80/90_high=1 且近2根K线任一巨量。
- 多个方案solution_1/2/3策略化定义止盈可用波段中位数、分位、高位形态等
- 读自合并视图 `DBMergeMarketHugeVolume`(含价量与巨量事件),统一回测窗口与分组汇总。
- 自动出图seaborn/matplotlib并将图贴入 Excel结果分 symbol×bar 汇总。
- orb_trade.py
- 功能: ORBOpening Range Breakout日内策略回测与可视化。
- 要点:
- 以开盘第一根5分钟K线的高低High1/Low1作为区间第二根K线产生多空信号入场价=第二根开盘价,止损价=第一根极值;盈亏基于 $Rentry-stop
- 支持参数:账户初始资金、最大杠杆、单笔风险比例、佣金、盈利目标倍数、仅做多/仅做空/双向、是否参考 SAR、是否参考 1H 形态等。
- 数据获取两路优先本地DBOKX/Binance也提供 yfinance 拉取美股数据的流程;自动调整初始资金规模以适配价格量级。
- 回测输出交易清单、资金曲线生成图表与Excel摘要到 output 目录。

View File

@ -11,7 +11,7 @@ import core.logger as logging
from config import ( from config import (
OKX_MONITOR_CONFIG, OKX_MONITOR_CONFIG,
US_STOCK_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG,
MYSQL_CONFIG, COIN_MYSQL_CONFIG,
WINDOW_SIZE, WINDOW_SIZE,
BINANCE_MONITOR_CONFIG, BINANCE_MONITOR_CONFIG,
) )
@ -30,13 +30,13 @@ class HugeVolumeMain:
is_us_stock: bool = False, is_us_stock: bool = False,
is_binance: bool = False, is_binance: bool = False,
): ):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.huge_volume = HugeVolume() self.huge_volume = HugeVolume()

View File

@ -21,7 +21,7 @@ from config import (
OKX_MONITOR_CONFIG, OKX_MONITOR_CONFIG,
BINANCE_MONITOR_CONFIG, BINANCE_MONITOR_CONFIG,
US_STOCK_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG,
MYSQL_CONFIG, COIN_MYSQL_CONFIG,
BAR_THRESHOLD, BAR_THRESHOLD,
) )
@ -67,13 +67,13 @@ class MarketDataMain:
self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2025-07-01 00:00:00" "initial_date", "2025-07-01 00:00:00"
) )
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
if is_binance: if is_binance:

View File

@ -4,7 +4,7 @@ from huge_volume_main import HugeVolumeMain
from core.biz.market_monitor import create_metrics_report from core.biz.market_monitor import create_metrics_report
from core.db.db_market_monitor import DBMarketMonitor from core.db.db_market_monitor import DBMarketMonitor
from core.wechat import Wechat from core.wechat import Wechat
from config import OKX_MONITOR_CONFIG, OKX_REALTIME_MONITOR_CONFIG, MYSQL_CONFIG, WECHAT_CONFIG from config import OKX_MONITOR_CONFIG, OKX_REALTIME_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WECHAT_CONFIG
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
import core.logger as logging import core.logger as logging
@ -31,13 +31,13 @@ class MarketMonitorMain:
self.output_folder = "./output/report/market_monitor/" self.output_folder = "./output/report/market_monitor/"
os.makedirs(self.output_folder, exist_ok=True) os.makedirs(self.output_folder, exist_ok=True)
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"

12
play.py
View File

@ -2,7 +2,7 @@ import logging
from core.biz.quant_trader import QuantTrader from core.biz.quant_trader import QuantTrader
from core.biz.strategy import QuantStrategy from core.biz.strategy import QuantStrategy
from config import MYSQL_CONFIG from config import COIN_MYSQL_CONFIG
from sqlalchemy import create_engine, exc, text from sqlalchemy import create_engine, exc, text
import pandas as pd import pandas as pd
@ -100,13 +100,13 @@ def main() -> None:
def test_query(): def test_query():
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
db_engine = create_engine( db_engine = create_engine(
db_url, db_url,

View File

@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation from core.biz.metrics_calculation import MetricsCalculation
import logging import logging
from config import OKX_MONITOR_CONFIG, MYSQL_CONFIG, WINDOW_SIZE from config import OKX_MONITOR_CONFIG, COIN_MYSQL_CONFIG, WINDOW_SIZE
# plt支持中文 # plt支持中文
plt.rcParams['font.family'] = ['SimHei'] plt.rcParams['font.family'] = ['SimHei']
@ -18,13 +18,13 @@ plt.rcParams['font.family'] = ['SimHei']
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_real_data(symbol, bar, start, end): def get_real_data(symbol, bar, start, end):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
db_market_data = DBMarketData(db_url) db_market_data = DBMarketData(db_url)

View File

@ -10,7 +10,7 @@ from config import (
PASSPHRASE, PASSPHRASE,
SANDBOX, SANDBOX,
OKX_MONITOR_CONFIG, OKX_MONITOR_CONFIG,
MYSQL_CONFIG, COIN_MYSQL_CONFIG,
) )
logger = logging.logger logger = logging.logger
@ -18,13 +18,13 @@ logger = logging.logger
class TradeDataMain: class TradeDataMain:
def __init__(self): def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.trade_data = TradeData( self.trade_data = TradeData(

View File

@ -18,7 +18,6 @@ from config import (
PASSPHRASE, PASSPHRASE,
SANDBOX, SANDBOX,
OKX_MONITOR_CONFIG, OKX_MONITOR_CONFIG,
MYSQL_CONFIG,
BAR_THRESHOLD, BAR_THRESHOLD,
) )
@ -29,13 +28,21 @@ class TradeMaStrategyMain:
def __init__( def __init__(
self, self,
is_us_stock: bool = False, is_us_stock: bool = False,
is_astock: bool = False,
is_aindex: bool = True,
is_binance: bool = False, is_binance: bool = False,
commission_per_share: float = 0, commission_per_share: float = 0,
buy_by_long_period: dict = {"by_week": False, "by_month": False},
long_period_condition: dict = {"ma5>ma10": False, "ma10>ma20": False, "macd_diff>0": False, "macd>0": False},
): ):
self.ma_break_statistics = MaBreakStatistics( self.ma_break_statistics = MaBreakStatistics(
is_us_stock=is_us_stock, is_us_stock=is_us_stock,
is_astock=is_astock,
is_aindex=is_aindex,
is_binance=is_binance, is_binance=is_binance,
commission_per_share=commission_per_share, commission_per_share=commission_per_share,
buy_by_long_period=buy_by_long_period,
long_period_condition=long_period_condition,
) )
def batch_ma_break_statistics(self): def batch_ma_break_statistics(self):
@ -60,12 +67,51 @@ class TradeMaStrategyMain:
logger.info("开始统计account_value_chg") logger.info("开始统计account_value_chg")
def test_single_symbol():
ma_break_statistics = MaBreakStatistics(
is_us_stock=False,
is_astock=True,
is_aindex=False,
is_binance=False,
commission_per_share=0,
)
symbol = "600111.SH"
bar = "1D"
ma_break_statistics.trade_simulate(symbol=symbol, bar=bar, strategy_name="均线macd结合策略2")
if __name__ == "__main__": if __name__ == "__main__":
commission_per_share_list = [0, 0.0008] commission_per_share_list = [0, 0.0008]
buy_by_long_period_list = [{"by_week": True, "by_month": True},
{"by_week": True, "by_month": False},
{"by_week": False, "by_month": True},
{"by_week": False, "by_month": False}]
long_period_condition_list = [{"ma5>ma10": True, "ma10>ma20": True, "macd_diff>0": True, "macd>0": True},
{"ma5>ma10": True, "ma10>ma20": False, "macd_diff>0": True, "macd>0": True},
{"ma5>ma10": False, "ma10>ma20": True, "macd_diff>0": True, "macd>0": True}]
for commission_per_share in commission_per_share_list: for commission_per_share in commission_per_share_list:
for buy_by_long_period in buy_by_long_period_list:
for long_period_condition in long_period_condition_list:
logger.info(f"开始计算, 主要参数commission_per_share: {commission_per_share}, buy_by_long_period: {buy_by_long_period}, long_period_condition: {long_period_condition}")
trade_ma_strategy_main = TradeMaStrategyMain( trade_ma_strategy_main = TradeMaStrategyMain(
is_us_stock=False, is_us_stock=False,
is_binance=True, is_astock=False,
is_aindex=True,
is_binance=False,
commission_per_share=commission_per_share, commission_per_share=commission_per_share,
buy_by_long_period=buy_by_long_period,
long_period_condition=long_period_condition,
)
trade_ma_strategy_main.batch_ma_break_statistics()
trade_ma_strategy_main = TradeMaStrategyMain(
is_us_stock=False,
is_astock=True,
is_aindex=False,
is_binance=False,
commission_per_share=commission_per_share,
buy_by_long_period=buy_by_long_period,
long_period_condition=long_period_condition,
) )
trade_ma_strategy_main.batch_ma_break_statistics() trade_ma_strategy_main.batch_ma_break_statistics()

View File

@ -2,20 +2,20 @@ import pandas as pd
from core.db.db_market_data import DBMarketData from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation from core.biz.metrics_calculation import MetricsCalculation
from config import MYSQL_CONFIG, US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG from config import COIN_MYSQL_CONFIG, US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG
import core.logger as logging import core.logger as logging
logger = logging.logger logger = logging.logger
class UpdateDataMain: class UpdateDataMain:
def __init__(self): def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "") mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password: if not mysql_password:
raise ValueError("MySQL password is not set") raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx") mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url) self.db_market_data = DBMarketData(self.db_url)