support SAR

This commit is contained in:
blade 2025-09-02 12:44:34 +08:00
parent e990db26a6
commit 6ee64abaf5
9 changed files with 296 additions and 104 deletions

View File

@ -73,7 +73,7 @@ OKX_MONITOR_CONFIG = {
US_STOCK_MONITOR_CONFIG = {
"volume_monitor":{
"symbols": ["QQQ", "TQQQ", "MSFT", "AAPL", "GOOG", "NVDA", "META", "AMZN", "TSLA", "AVGO"],
"bars": ["5", "15m", "30m", "1H"],
"bars": ["5m", "15m", "30m", "1H"],
"initial_date": "2015-08-31 00:00:00"
}
}

View File

@ -137,6 +137,64 @@ class MetricsCalculation:
] = "死叉"
return df
def sar(self, df: pd.DataFrame, acceleration=0.02, maximum=0.2):
"""
计算SAR抛物线转向指标
Args:
df: 包含high, low, close列的DataFrame
acceleration: 加速因子默认0.02控制SAR值随价格变化的加速程度
maximum: 最大加速因子默认0.2设定加速因子的上限防止过度增加
参数说明
- acceleration=0.02: 标准设置适合大多数市场
- maximum=0.2: 标准设置防止SAR过度敏感
- 对于高波动性市场如加密货币可适当增加acceleration到0.03-0.04
- 对于低波动性市场可降低acceleration到0.015-0.02
"""
logger.info(f"计算SAR指标acceleration={acceleration}, maximum={maximum}")
# 初始化sar和sar_signal列
df["sar"] = np.nan
df["sar_signal"] = ""
df["sar"] = tb.SAR(
df["high"], df["low"], acceleration=acceleration, maximum=maximum
)
# sar_position = df["sar"] > df["close"]
# df.loc[
# sar_position[
# (sar_position == True) & (sar_position.shift() == False)
# ].index,
# "sar_signal",
# ] = "SAR多头"
# df.loc[
# sar_position[
# (sar_position == False) & (sar_position.shift() == True)
# ].index,
# "sar_signal",
# ] = "SAR空头"
# df.loc[sar_position[sar_position == False].index, "sar_signal"] = "SAR观望"
# 生成交易信号
# SAR多头: SAR < close
# SAR空头: SAR > close
# SAR观望: SAR == close 或 SAR为NaN
df["sar_signal"] = np.where(
df["sar"].isna(),
"SAR观望",
np.where(
df["sar"] < df["close"],
"SAR多头",
np.where(df["sar"] > df["close"], "SAR空头", "SAR观望"),
),
)
# 确保sar列为float类型
df["sar"] = df["sar"].astype(float)
# 确保sar_signal列为str类型
df["sar_signal"] = df["sar_signal"].astype(str)
return df
def set_kdj_pattern(self, df: pd.DataFrame):
"""
设置每一根K线数据对应的KDJ形态超买超卖情况
@ -209,7 +267,7 @@ class MetricsCalculation:
# 震荡条件已经在初始化时设置,无需额外处理
data["ma_long_short"] = "震荡"
data = self._trend_strength_method(data)
# 计算各均线偏离度的标准差和均值
data["ma_divergence"] = "未知"
ma_diffs = data[
@ -421,7 +479,7 @@ class MetricsCalculation:
df["ma30"] = df["close"].rolling(window=30).mean().dropna()
df["ma_cross"] = ""
# 定义均线交叉检测函数
def detect_cross(short_ma, long_ma, short_name, long_name):
"""检测均线交叉"""
@ -429,22 +487,22 @@ class MetricsCalculation:
cross_up = (position == True) & (position.shift() == False)
cross_down = (position == False) & (position.shift() == True)
return cross_up, cross_down
# 检测所有均线交叉
crosses = {}
# MA5与其他均线的交叉
ma5_ma10_up, ma5_ma10_down = detect_cross("ma5", "ma10", "5", "10")
ma5_ma20_up, ma5_ma20_down = detect_cross("ma5", "ma20", "5", "20")
ma5_ma30_up, ma5_ma30_down = detect_cross("ma5", "ma30", "5", "30")
# MA10与其他均线的交叉
ma10_ma20_up, ma10_ma20_down = detect_cross("ma10", "ma20", "10", "20")
ma10_ma30_up, ma10_ma30_down = detect_cross("ma10", "ma30", "10", "30")
# MA20与MA30的交叉
ma20_ma30_up, ma20_ma30_down = detect_cross("ma20", "ma30", "20", "30")
# 存储上穿信号
crosses["5上穿10"] = ma5_ma10_up
crosses["5上穿20"] = ma5_ma20_up
@ -452,7 +510,7 @@ class MetricsCalculation:
crosses["10上穿20"] = ma10_ma20_up
crosses["10上穿30"] = ma10_ma30_up
crosses["20上穿30"] = ma20_ma30_up
# 存储下穿信号
crosses["10下穿5"] = ma5_ma10_down
crosses["20下穿10"] = ma10_ma20_down
@ -460,22 +518,22 @@ class MetricsCalculation:
crosses["30下穿20"] = ma20_ma30_down
crosses["30下穿10"] = ma10_ma30_down
crosses["30下穿5"] = ma5_ma30_down
# 分析每个时间点的交叉组合
for idx in df.index:
current_crosses = []
# 检查当前时间点的所有交叉信号
for cross_name, cross_signal in crosses.items():
if cross_signal.loc[idx]:
current_crosses.append(cross_name)
# 根据交叉类型组合信号
if len(current_crosses) > 0:
# 分离上穿和下穿信号
up_crosses = [c for c in current_crosses if "上穿" in c]
down_crosses = [c for c in current_crosses if "下穿" in c]
# 组合信号
if len(up_crosses) > 1:
# 多个上穿信号
@ -486,7 +544,7 @@ class MetricsCalculation:
else:
# 单个交叉信号
df.loc[idx, "ma_cross"] = current_crosses[0]
return df
def rsi(self, df: pd.DataFrame):
@ -726,13 +784,21 @@ class MetricsCalculation:
) # 下影线长度
# 计算实体占比
df["open_close_fill"] = df["open_close_diff"] / df["high_low_diff"].replace(0, np.nan)
df["open_close_fill"] = df["open_close_diff"] / df["high_low_diff"].replace(
0, np.nan
)
df["open_close_fill"] = df["open_close_fill"].fillna(1.0) # 处理除零情况
# 计算影线占比
df["upper_shadow_ratio"] = df["high_close_diff"] / df["high_low_diff"].replace(0, np.nan)
df["lower_shadow_ratio"] = df["low_close_diff"] / df["high_low_diff"].replace(0, np.nan)
df["upper_shadow_ratio"] = df["upper_shadow_ratio"].fillna(0) # 无波动时影线占比为 0
df["upper_shadow_ratio"] = df["high_close_diff"] / df["high_low_diff"].replace(
0, np.nan
)
df["lower_shadow_ratio"] = df["low_close_diff"] / df["high_low_diff"].replace(
0, np.nan
)
df["upper_shadow_ratio"] = df["upper_shadow_ratio"].fillna(
0
) # 无波动时影线占比为 0
df["lower_shadow_ratio"] = df["lower_shadow_ratio"].fillna(0)
# 初始化k_shape列
@ -760,15 +826,23 @@ class MetricsCalculation:
)
# 计算滚动窗口内 price_range_ratio 和 price_range_zscore 的分位数
df["price_range_ratio_p75"] = df["price_range_ratio"].rolling(window=window_size, min_periods=1).quantile(0.75)
df["price_range_zscore_p75"] = df["price_range_zscore"].rolling(window=window_size, min_periods=1).quantile(0.75)
df["price_range_ratio_p75"] = (
df["price_range_ratio"]
.rolling(window=window_size, min_periods=1)
.quantile(0.75)
)
df["price_range_zscore_p75"] = (
df["price_range_zscore"]
.rolling(window=window_size, min_periods=1)
.quantile(0.75)
)
# 识别“一字”形态波动极小Z 分数 < -1.0 或 price_range_ratio < 0.05%)且无影线
one_line_condition = (
((df["price_range_zscore"] < -1.0) | (df["price_range_ratio"] < 0.05)) &
(df["upper_shadow_ratio"] <= 0.01) & # 上影线极小或无
(df["lower_shadow_ratio"] <= 0.01) & # 下影线极小或无
(df["open_close_diff"] / df["close"] < 0.0005) # 开收盘价差小于0.05%
((df["price_range_zscore"] < -1.0) | (df["price_range_ratio"] < 0.05))
& (df["upper_shadow_ratio"] <= 0.01) # 上影线极小或无
& (df["lower_shadow_ratio"] <= 0.01) # 下影线极小或无
& (df["open_close_diff"] / df["close"] < 0.0005) # 开收盘价差小于0.05%
)
df.loc[one_line_condition, "k_shape"] = "一字"
@ -857,10 +931,18 @@ class MetricsCalculation:
& (df["open_close_fill"] <= 0.55)
& (df["k_shape"] != "一字")
)
df.loc[small_body_condition_2
& (df["upper_shadow_ratio"] >= 0.25) & (df["k_shape"] == "未知"), "k_shape"] = "长上影线纺锤体"
df.loc[small_body_condition_2
& (df["lower_shadow_ratio"] >= 0.25) & (df["k_shape"] == "未知"), "k_shape"] = "长下影线纺锤体"
df.loc[
small_body_condition_2
& (df["upper_shadow_ratio"] >= 0.25)
& (df["k_shape"] == "未知"),
"k_shape",
] = "长上影线纺锤体"
df.loc[
small_body_condition_2
& (df["lower_shadow_ratio"] >= 0.25)
& (df["k_shape"] == "未知"),
"k_shape",
] = "长下影线纺锤体"
df.loc[small_body_condition_2 & (df["k_shape"] == "未知"), "k_shape"] = "小实体"
# 大实体实体占比55%-90%
@ -873,16 +955,20 @@ class MetricsCalculation:
# 识别“超大实体”形态:实体占比 75%-90%,价格波动显著,且非“一字”或“大实体”
super_large_body_condition = (
(df["open_close_fill"] > 0.75) &
(df["open_close_fill"] <= 1) &
(df["price_range_ratio"] >= df["price_range_ratio_p75"]) & # 价格波动范围超过75th分位数
(df["k_shape"] != "一字")
(df["open_close_fill"] > 0.75)
& (df["open_close_fill"] <= 1)
& (
df["price_range_ratio"] >= df["price_range_ratio_p75"]
) # 价格波动范围超过75th分位数
& (df["k_shape"] != "一字")
)
df.loc[super_large_body_condition, "k_shape"] = "超大实体"
# 光头光脚:实体占比>90%(非一字情况)
bald_body_condition = (df["open_close_fill"] > 0.9) & (df["k_shape"] != "一字")
df.loc[bald_body_condition & (df["k_shape"] == "超大实体"), "k_shape"] = "超大实体+光头光脚"
df.loc[bald_body_condition & (df["k_shape"] == "超大实体"), "k_shape"] = (
"超大实体+光头光脚"
)
df.loc[bald_body_condition & (df["k_shape"] == "未知"), "k_shape"] = "光头光脚"
# 清理临时列
@ -911,7 +997,7 @@ class MetricsCalculation:
def set_ma_long_short_advanced(self, data: pd.DataFrame, method="weighted_voting"):
"""
高级均线多空判定方法提供多种科学的判定策略
Args:
data: 包含均线数据的DataFrame
method: 判定方法
@ -922,7 +1008,7 @@ class MetricsCalculation:
- "hybrid": 混合方法
"""
logger.info(f"使用{method}方法设置均线多空")
if method == "weighted_voting":
return self._weighted_voting_method(data)
elif method == "trend_strength":
@ -936,20 +1022,20 @@ class MetricsCalculation:
else:
logger.warning(f"未知的方法: {method},使用默认加权投票方法")
return self._weighted_voting_method(data)
def _weighted_voting_method(self, data: pd.DataFrame):
"""加权投票机制:短期均线权重更高"""
# 权重设置:短期均线权重更高
weights = {
"ma5_close_diff": 0.4, # 40%权重
"ma5_close_diff": 0.4, # 40%权重
"ma10_close_diff": 0.3, # 30%权重
"ma20_close_diff": 0.2, # 20%权重
"ma30_close_diff": 0.1 # 10%权重
"ma30_close_diff": 0.1, # 10%权重
}
# 计算加权得分
weighted_score = sum(data[col] * weight for col, weight in weights.items())
# 动态阈值:基于历史分布
window_size = min(50, len(data) // 4)
if window_size > 10:
@ -960,69 +1046,78 @@ class MetricsCalculation:
else:
long_threshold = 0.3
short_threshold = -0.3
# 判定逻辑
data.loc[weighted_score > long_threshold, "ma_long_short"] = ""
data.loc[weighted_score < short_threshold, "ma_long_short"] = ""
return data
def _trend_strength_method(self, data: pd.DataFrame):
"""趋势强度评估:考虑偏离幅度和趋势持续性"""
# 计算趋势强度(考虑偏离幅度)
trend_strength = data["ma_close_avg"]
# 计算趋势持续性(连续同向的周期数)
trend_persistence = self._calculate_trend_persistence(data)
# 综合评分
strength_threshold = 0.5
persistence_threshold = 3 # 至少连续3个周期
long_condition = (trend_strength > strength_threshold) & (trend_persistence >= persistence_threshold)
short_condition = (trend_strength < -strength_threshold) & (trend_persistence >= persistence_threshold)
long_condition = (trend_strength > strength_threshold) & (
trend_persistence >= persistence_threshold
)
short_condition = (trend_strength < -strength_threshold) & (
trend_persistence >= persistence_threshold
)
data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = ""
return data
def _ma_alignment_method(self, data: pd.DataFrame):
"""均线排列分析:检查均线的排列顺序和间距"""
# 检查均线排列顺序
ma_alignment_score = 0
# 多头排列MA5 > MA10 > MA20 > MA30
bullish_alignment = (
(data["ma5_close_diff"] > data["ma10_close_diff"]) &
(data["ma10_close_diff"] > data["ma20_close_diff"]) &
(data["ma20_close_diff"] > data["ma30_close_diff"])
(data["ma5_close_diff"] > data["ma10_close_diff"])
& (data["ma10_close_diff"] > data["ma20_close_diff"])
& (data["ma20_close_diff"] > data["ma30_close_diff"])
)
# 空头排列MA5 < MA10 < MA20 < MA30
bearish_alignment = (
(data["ma5_close_diff"] < data["ma10_close_diff"]) &
(data["ma10_close_diff"] < data["ma20_close_diff"]) &
(data["ma20_close_diff"] < data["ma30_close_diff"])
(data["ma5_close_diff"] < data["ma10_close_diff"])
& (data["ma10_close_diff"] < data["ma20_close_diff"])
& (data["ma20_close_diff"] < data["ma30_close_diff"])
)
# 计算均线间距的合理性
ma_spacing = self._calculate_ma_spacing(data)
# 综合判定
long_condition = bullish_alignment & (ma_spacing > 0.2)
short_condition = bearish_alignment & (ma_spacing > 0.2)
data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = ""
return data
def _statistical_method(self, data: pd.DataFrame):
"""统计分布方法基于历史分位数和Z-score"""
# 计算各均线偏离度的Z-score
ma_cols = ["ma5_close_diff", "ma10_close_diff", "ma20_close_diff", "ma30_close_diff"]
ma_cols = [
"ma5_close_diff",
"ma10_close_diff",
"ma20_close_diff",
"ma30_close_diff",
]
# 使用滚动窗口计算Z-score
window_size = min(30, len(data) // 4)
if window_size > 10:
@ -1031,44 +1126,46 @@ class MetricsCalculation:
rolling_mean = data[col].rolling(window=window_size).mean()
rolling_std = data[col].rolling(window=window_size).std()
z_scores[col] = (data[col] - rolling_mean) / rolling_std
# 计算综合Z-score
avg_z_score = z_scores.mean(axis=1)
# 基于Z-score判定
long_condition = avg_z_score > 0.5
short_condition = avg_z_score < -0.5
data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = ""
return data
def _hybrid_method(self, data: pd.DataFrame):
"""混合方法:结合多种判定策略"""
# 1. 加权投票得分
weights = {"ma5_close_diff": 0.4, "ma10_close_diff": 0.3,
"ma20_close_diff": 0.2, "ma30_close_diff": 0.1}
weights = {
"ma5_close_diff": 0.4,
"ma10_close_diff": 0.3,
"ma20_close_diff": 0.2,
"ma30_close_diff": 0.1,
}
weighted_score = sum(data[col] * weight for col, weight in weights.items())
# 2. 均线排列得分
alignment_score = (
(data["ma5_close_diff"] >= data["ma10_close_diff"]) * 0.25 +
(data["ma10_close_diff"] >= data["ma20_close_diff"]) * 0.25 +
(data["ma20_close_diff"] >= data["ma30_close_diff"]) * 0.25 +
(data["ma_close_avg"] > 0) * 0.25
(data["ma5_close_diff"] >= data["ma10_close_diff"]) * 0.25
+ (data["ma10_close_diff"] >= data["ma20_close_diff"]) * 0.25
+ (data["ma20_close_diff"] >= data["ma30_close_diff"]) * 0.25
+ (data["ma_close_avg"] > 0) * 0.25
)
# 3. 趋势强度得分
strength_score = data["ma_close_avg"].abs()
# 4. 综合评分
composite_score = (
weighted_score * 0.4 +
alignment_score * 0.3 +
strength_score * 0.3
weighted_score * 0.4 + alignment_score * 0.3 + strength_score * 0.3
)
# 动态阈值
window_size = min(50, len(data) // 4)
if window_size > 10:
@ -1079,38 +1176,44 @@ class MetricsCalculation:
else:
long_threshold = 0.4
short_threshold = -0.4
# 判定
long_condition = composite_score > long_threshold
short_condition = composite_score < short_threshold
data.loc[long_condition, "ma_long_short"] = ""
data.loc[short_condition, "ma_long_short"] = ""
return data
def _calculate_trend_persistence(self, data: pd.DataFrame):
"""计算趋势持续性"""
trend_persistence = pd.Series(0, index=data.index)
for i in range(1, len(data)):
if data["ma_close_avg"].iloc[i] > 0 and data["ma_close_avg"].iloc[i-1] > 0:
trend_persistence.iloc[i] = trend_persistence.iloc[i-1] + 1
elif data["ma_close_avg"].iloc[i] < 0 and data["ma_close_avg"].iloc[i-1] < 0:
trend_persistence.iloc[i] = trend_persistence.iloc[i-1] + 1
if (
data["ma_close_avg"].iloc[i] > 0
and data["ma_close_avg"].iloc[i - 1] > 0
):
trend_persistence.iloc[i] = trend_persistence.iloc[i - 1] + 1
elif (
data["ma_close_avg"].iloc[i] < 0
and data["ma_close_avg"].iloc[i - 1] < 0
):
trend_persistence.iloc[i] = trend_persistence.iloc[i - 1] + 1
else:
trend_persistence.iloc[i] = 0
return trend_persistence
def _calculate_ma_spacing(self, data: pd.DataFrame):
"""计算均线间距的合理性"""
# 计算相邻均线之间的间距
spacing_5_10 = abs(data["ma5_close_diff"] - data["ma10_close_diff"])
spacing_10_20 = abs(data["ma10_close_diff"] - data["ma20_close_diff"])
spacing_20_30 = abs(data["ma20_close_diff"] - data["ma30_close_diff"])
# 平均间距
avg_spacing = (spacing_5_10 + spacing_10_20 + spacing_20_30) / 3
return avg_spacing

View File

@ -44,6 +44,8 @@ class DBMarketData:
"kdj_j",
"kdj_signal",
"kdj_pattern",
"sar",
"sar_signal",
"ma5",
"ma10",
"ma20",
@ -471,7 +473,7 @@ class DBMarketData:
return self.db_manager.query_data(sql, condition_dict, return_multi=False)
def query_market_data_by_symbol_bar(self, symbol: str, bar: str, start: str, end: str):
def query_market_data_by_symbol_bar(self, symbol: str, bar: str, start: str = None, end: str = None):
"""
根据交易对和K线周期查询数据
:param symbol: 交易对

View File

@ -134,13 +134,20 @@ class ORBStrategy:
return int(max_shares) # 股数取整
def generate_orb_signals(self):
def generate_orb_signals(self, direction: str = None, by_sar: bool = False):
"""
生成ORB策略信号每日仅1次交易机会
- 第一根5分钟K线确定开盘区间High1, Low1
- 第二根5分钟K线根据第一根K线方向生成多空信号
:param direction: 方向None=自动Long=多头Short=空头
:param by_sar: 是否根据SAR指标生成信号True=False=
"""
logger.info("开始生成ORB策略信号")
direction_desc = "既做多又做空"
if direction == "Long":
direction_desc = "做多"
elif direction == "Short":
direction_desc = "做空"
logger.info(f"开始生成ORB策略信号{direction_desc}根据SAR指标{by_sar}")
if self.data is None:
raise ValueError("请先调用fetch_intraday_data获取数据")
@ -164,16 +171,16 @@ class ORBStrategy:
entry_time = second_candle.date_time # entry时间
# 生成信号第一根K线方向决定多空排除十字星open1 == close1
if open1 < close1:
if open1 < close1 and (direction == "Long" or direction is None):
# 第一根K线收涨→多头信号
signal = "Long"
stop_price = low1 # 多头止损=第一根K线最低价
elif open1 > close1:
elif open1 > close1 and (direction == "Short" or direction is None):
# 第一根K线收跌→空头信号
signal = "Short"
stop_price = high1 # 空头止损=第一根K线最高价
else:
# 十字星→无信号
# 与direction不一致或十字星→无信号
signal = None
stop_price = None

View File

@ -512,7 +512,7 @@ def test_send_huge_volume_data_to_wechat():
if __name__ == "__main__":
# test_send_huge_volume_data_to_wechat()
# batch_initial_detect_volume_spike(threshold=2.0)
batch_update_volume_spike(threshold=2.0, is_us_stock=True)
batch_update_volume_spike(threshold=2.0, is_us_stock=False)
# huge_volume_main = HugeVolumeMain(threshold=2.0)
# huge_volume_main.batch_next_periods_rise_or_fall(output_excel=True)
# data_file_path = "./output/huge_volume_statistics/next_periods_rise_or_fall_stat_20250731200304.xlsx"

View File

@ -252,6 +252,8 @@ class MarketDataMain:
data["kdj_j"] = None
data["kdj_signal"] = None
data["kdj_pattern"] = None
data["sar"] = None
data["sar_signal"] = None
data["ma5"] = None
data["ma10"] = None
data["ma20"] = None
@ -301,6 +303,8 @@ class MarketDataMain:
kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值',
kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号',
kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买超卖徘徊',
sar DOUBLE DEFAULT NULL COMMENT 'SAR指标',
sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头SAR空头SAR观望',
ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线',
ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线',
ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线',
@ -331,6 +335,7 @@ class MarketDataMain:
data = metrics_calculation.pre_close(data)
data = metrics_calculation.macd(data)
data = metrics_calculation.kdj(data)
data = metrics_calculation.sar(data)
data = metrics_calculation.set_kdj_pattern(data)
data = metrics_calculation.update_macd_divergence_column_simple(data)
data = metrics_calculation.ma5102030(data)

View File

@ -33,6 +33,8 @@ CREATE TABLE IF NOT EXISTS crypto_market_data (
kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值',
kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号',
kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买超卖徘徊',
sar DOUBLE DEFAULT NULL COMMENT 'SAR指标值',
sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头空头信号',
ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线',
ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线',
ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线',
@ -64,3 +66,7 @@ ALTER TABLE crypto_market_data MODIFY COLUMN ma_cross VARCHAR(150) DEFAULT NULL
--date_time_us字段
ALTER TABLE crypto_market_data ADD COLUMN date_time_us VARCHAR(50) NULL COMMENT '美国时间格式的日期时间' AFTER date_time;
--SAR相关字段
ALTER TABLE crypto_market_data ADD COLUMN sar DOUBLE DEFAULT NULL COMMENT 'SAR指标值' AFTER kdj_pattern;
ALTER TABLE crypto_market_data ADD COLUMN sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头空头信号' AFTER sar;

69
update_data_main.py Normal file
View File

@ -0,0 +1,69 @@
import pandas as pd
from core.db.db_market_data import DBMarketData
from core.biz.metrics_calculation import MetricsCalculation
from config import MYSQL_CONFIG, US_STOCK_MONITOR_CONFIG, OKX_MONITOR_CONFIG
import core.logger as logging
logger = logging.logger
class UpdateDataMain:
def __init__(self):
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_market_data = DBMarketData(self.db_url)
self.metrics_calculation = MetricsCalculation()
def batch_update_data(self, is_us_stock: bool = False):
"""
批量更新数据
"""
if is_us_stock:
symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", [])
bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get("bars", ["5m", "15m", "1H", "1D"])
else:
symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get("symbols", [])
bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get("bars", ["5m", "15m", "1H", "1D"])
for symbol in symbols:
for bar in bars:
self.update_data(symbol, bar)
def update_data(self, symbol: str, bar: str):
"""
更新数据
"""
logger.info(f"开始更新数据: {symbol} {bar}")
data = self.db_market_data.query_market_data_by_symbol_bar(symbol, bar)
logger.info(f"查询数据完成: {symbol} {bar},共有{len(data)}条数据")
data = pd.DataFrame(data)
data.sort_values(by="timestamp", inplace=True)
data = self.update_date_time_us(data)
logger.info("更新SAR指标")
data = self.metrics_calculation.sar(data)
logger.info("更新SAR指标完成")
logger.info(f"开始保存数据: {symbol} {bar}")
self.db_market_data.insert_data_to_mysql(data)
logger.info(f"保存数据完成: {symbol} {bar}")
def update_date_time_us(self, data: pd.DataFrame):
"""
更新日期时间
"""
logger.info(f"开始更新美东日期时间: {data.shape[0]}条数据")
dt_us_series = pd.to_datetime(data['timestamp'].astype(int), unit='ms', utc=True, errors='coerce').dt.tz_convert('America/New_York')
data['date_time_us'] = dt_us_series.dt.strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"更新美东日期时间完成: {data.shape[0]}条数据")
return data
if __name__ == "__main__":
update_data_main = UpdateDataMain()
update_data_main.batch_update_data(is_us_stock=True)
update_data_main.batch_update_data(is_us_stock=False)