import core.logger as logging from datetime import datetime from time import sleep import pandas as pd from core.biz.market_data import MarketData from core.db.db_market_data import DBMarketData from core.db.db_binance_data import DBBinanceData from core.biz.metrics_calculation import MetricsCalculation from core.utils import ( datetime_to_timestamp, timestamp_to_datetime, transform_date_time_to_timestamp, ) from trade_data_main import TradeDataMain from config import ( API_KEY, SECRET_KEY, PASSPHRASE, SANDBOX, OKX_MONITOR_CONFIG, BINANCE_MONITOR_CONFIG, US_STOCK_MONITOR_CONFIG, MYSQL_CONFIG, BAR_THRESHOLD, ) logger = logging.logger class MarketDataMain: def __init__(self, is_us_stock: bool = False, is_binance: bool = False): self.market_data = MarketData( api_key=API_KEY, secret_key=SECRET_KEY, passphrase=PASSPHRASE, sandbox=SANDBOX, is_us_stock=is_us_stock, ) if is_us_stock: self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( "symbols", ["QQQ"] ) self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( "bars", ["5m"] ) self.initial_date = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get( "initial_date", "2015-08-30 00:00:00" ) elif is_binance: self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get( "symbols", ["BTC-USDT"] ) self.bars = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get( "bars", ["5m", "30m", "1H"] ) self.initial_date = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get( "initial_date", "2017-08-17 00:00:00" ) else: self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "symbols", ["XCH-USDT"] ) self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "bars", ["5m", "15m", "1H", "1D"] ) self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "initial_date", "2025-07-01 00:00:00" ) mysql_user = MYSQL_CONFIG.get("user", "xch") mysql_password = MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") mysql_host = MYSQL_CONFIG.get("host", "localhost") mysql_port = MYSQL_CONFIG.get("port", 3306) mysql_database = MYSQL_CONFIG.get("database", "okx") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" if is_binance: self.db_market_data = DBBinanceData(self.db_url) else: self.db_market_data = DBMarketData(self.db_url) self.is_binance = is_binance self.trade_data_main = TradeDataMain() self.is_us_stock = is_us_stock def initial_data(self): """ 初始化数据 """ for symbol in self.symbols: for bar in self.bars: logger.info(f"开始初始化行情数据: {symbol} {bar}") latest_data = self.db_market_data.query_latest_data(symbol, bar) if latest_data: start = latest_data.get("timestamp") start_date_time = timestamp_to_datetime(start) start = start + 1 else: start = datetime_to_timestamp(self.initial_date) start_date_time = self.initial_date logger.info( f"开始初始化{symbol}, {bar} 行情数据,从 {start_date_time} 开始" ) self.fetch_save_data(symbol, bar, start) def fetch_save_data(self, symbol: str, bar: str, start: str): """ 获取保存数据 """ end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") end_time_ts = transform_date_time_to_timestamp(end_time) if end_time_ts is None: logger.error(f"结束时间格式错误: {end_time}") return None start_time_ts = transform_date_time_to_timestamp(start) if start_time_ts is None: logger.error(f"开始时间格式错误: {start}") return None # 如果bar为5m, 15m, 30m: # end_time_ts与start_time_ts相差超过1天,则按照1天为单位 # 如果bar为1H, 4H, # end_time_ts与start_time_ts相差超过5天,则按照5天为单位 # 如果bar为1D, 则end_time_ts与start_time_ts相差超过10天,则按照10天为单位 # 获取数据,直到end_time_ts threshold = None if bar in ["5m", "15m", "30m", "1H"]: if self.is_us_stock: if bar == "5m": threshold = 86400000 * 4 elif bar == "15m": threshold = 86400000 * 6 elif bar == "30m": threshold = 86400000 * 12 elif bar == "1H": threshold = 86400000 * 24 else: threshold = 86400000 elif bar in ["1H", "4H"]: threshold = 432000000 elif bar == "1D": threshold = 864000000 get_data = False min_start_time_ts = start_time_ts max_start_time_ts = None while start_time_ts < end_time_ts: current_start_time_ts = int(end_time_ts - threshold) if current_start_time_ts < start_time_ts: current_start_time_ts = start_time_ts start_date_time = timestamp_to_datetime(current_start_time_ts) end_date_time = timestamp_to_datetime(end_time_ts) logger.info( f"获取行情数据: {symbol} {bar} 从 {start_date_time} 到 {end_date_time}" ) if self.is_us_stock: limit = 1000 else: limit = 100 data = self.market_data.get_historical_kline_data( symbol=symbol, start=current_start_time_ts, bar=bar, end_time=end_time_ts, limit=limit, ) if data is not None and len(data) > 0: data = self.post_save_data(data) current_min_start_time_ts = int(data["timestamp"].min()) if current_min_start_time_ts < min_start_time_ts: min_start_time_ts = current_min_start_time_ts current_max_start_time_ts = int(data["timestamp"].max()) if max_start_time_ts is None: max_start_time_ts = current_max_start_time_ts else: if current_max_start_time_ts > max_start_time_ts: max_start_time_ts = current_max_start_time_ts get_data = True else: logger.warning( f"获取行情数据为空: {symbol} {bar} 从 {start_date_time} 到 {end_date_time}" ) break if current_start_time_ts == start_time_ts: break if current_min_start_time_ts < current_start_time_ts: end_time_ts = current_min_start_time_ts else: end_time_ts = current_start_time_ts if get_data: # 补充技术指标数据 # 获得min_start_time_ts之前30条数据 logger.info(f"开始补充技术指标数据: {symbol} {bar}") data = self.post_calculate_metrics( symbol, bar, min_start_time_ts, max_start_time_ts ) return data def adjust_binance_csv_data(self, symbol: str, bar: str, data: pd.DataFrame): """ 调整binance csv数据 """ data["symbol"] = symbol data["bar"] = bar data["timestamp"] = None data["date_time"] = None data["date_time_us"] = None data["volCcy"] = None data["volCCyQuote"] = None data["create_time"] = None for index, row in data.iterrows(): candle_begin_time = row["candle_begin_time"] timestamp = datetime_to_timestamp(candle_begin_time, is_utc=True) data.loc[index, "timestamp"] = timestamp data.loc[index, "volCcy"] = row["quote_volume"] data.loc[index, "volCCyQuote"] = row["quote_volume"] data["timestamp"] = data["timestamp"].astype(int) dt_series = pd.to_datetime(data['timestamp'].astype(int), unit='ms', utc=True, errors='coerce').dt.tz_convert('Asia/Shanghai') data['date_time'] = dt_series.dt.strftime('%Y-%m-%d %H:%M:%S') dt_us_series = pd.to_datetime(data['timestamp'].astype(int), unit='ms', utc=True, errors='coerce').dt.tz_convert('America/New_York') data['date_time_us'] = dt_us_series.dt.strftime('%Y-%m-%d %H:%M:%S') data['create_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') data["date_time"] = data["date_time"].astype(str) data["date_time_us"] = data["date_time_us"].astype(str) data["open"] = data["open"].astype(float) data["high"] = data["high"].astype(float) data["low"] = data["low"].astype(float) data["close"] = data["close"].astype(float) data["volume"] = data["volume"].astype(float) data["volCcy"] = data["volCcy"].astype(float) data["volCCyQuote"] = data["volCCyQuote"].astype(float) data["create_time"] = data["create_time"].astype(str) data = data[ [ "symbol", "bar", "timestamp", "date_time", "date_time_us", "open", "high", "low", "close", "volume", "volCcy", "volCCyQuote", "create_time", ] ] data = data.sort_values(by="timestamp", ascending=True) data = data.reset_index(drop=True) return data def post_save_data(self, data: pd.DataFrame): if data is not None and len(data) > 0: data["buy_sz"] = -1 data["sell_sz"] = -1 data = data[ [ "symbol", "bar", "timestamp", "date_time", "date_time_us", "open", "high", "low", "close", "volume", "volCcy", "volCCyQuote", "buy_sz", "sell_sz", "create_time", ] ] data = self.add_new_columns(data) self.db_market_data.insert_data_to_mysql(data) return data def post_calculate_metrics( self, symbol: str, bar: str, min_start_time_ts: int, max_start_time_ts: int ): logger.info(f"开始补充技术指标数据: {symbol} {bar}") before_data = self.db_market_data.query_data_before_timestamp( symbol, bar, min_start_time_ts, 31 ) if before_data is not None and len(before_data) > 0: earliest_timestamp = before_data[-1]["timestamp"] else: earliest_timestamp = min_start_time_ts handle_data = self.db_market_data.query_market_data_by_symbol_bar( symbol=symbol, bar=bar, start=earliest_timestamp, end=max_start_time_ts ) if handle_data is not None: if before_data is not None and len(handle_data) <= len(before_data): logger.error( f"handle_data数据条数小于before_data数据条数: {symbol} {bar}" ) return None if isinstance(handle_data, list): handle_data = pd.DataFrame(handle_data) elif isinstance(handle_data, dict): handle_data = pd.DataFrame([handle_data]) elif isinstance(handle_data, pd.DataFrame): pass else: logger.error(f"handle_data类型错误: {type(handle_data)}") return None handle_data = self.calculate_metrics(handle_data) handle_data = handle_data[handle_data["timestamp"] >= min_start_time_ts] handle_data.reset_index(drop=True, inplace=True) logger.info(f"开始保存技术指标数据: {symbol} {bar}") self.db_market_data.insert_data_to_mysql(handle_data) return handle_data def add_new_columns(self, data: pd.DataFrame): """ 添加新列 """ data = data.copy() columns = data.columns.tolist() if "buy_sz" not in columns: data.loc[:, "buy_sz"] = -1 if "sell_sz" not in columns: data.loc[:, "sell_sz"] = -1 new_cols = [ "pre_close", "close_change", "pct_chg", "ma1", "ma2", "dif", "dea", "macd", "macd_signal", "macd_divergence", "kdj_k", "kdj_d", "kdj_j", "kdj_signal", "kdj_pattern", "sar", "sar_signal", "ma5", "ma10", "ma20", "ma30", "ma_cross", "ma5_close_diff", "ma10_close_diff", "ma20_close_diff", "ma30_close_diff", "ma_close_avg", "ma_long_short", "ma_divergence", "rsi_14", "rsi_signal", "boll_upper", "boll_middle", "boll_lower", "boll_signal", "boll_pattern", "k_length", "k_shape", "k_up_down", ] for col in new_cols: data.loc[:, col] = pd.NA return data def calculate_metrics(self, data: pd.DataFrame): """ 计算技术指标 1. 计算前一日收盘价、涨跌幅、涨跌幅百分比 2. 计算MACD指标 3. 计算KDJ指标 4. 计算BOLL指标 5. 计算K线长度 6. 计算K线形状 7. 计算K线方向 pre_close DECIMAL(20,10) NULL, close_change DECIMAL(20,10) NULL, pct_chg DECIMAL(20,10) NULL, ma1 DOUBLE DEFAULT NULL COMMENT '移动平均线1', ma2 DOUBLE DEFAULT NULL COMMENT '移动平均线2', dif DOUBLE DEFAULT NULL COMMENT 'MACD指标DIF线', dea DOUBLE DEFAULT NULL COMMENT 'MACD指标DEA线', macd DOUBLE DEFAULT NULL COMMENT 'MACD指标值', macd_signal VARCHAR(15) DEFAULT NULL COMMENT 'MACD金叉死叉信号', macd_divergence varchar(25) DEFAULT NULL COMMENT 'MACD背离,顶背离或底背离', kdj_k DOUBLE DEFAULT NULL COMMENT 'KDJ指标K值', kdj_d DOUBLE DEFAULT NULL COMMENT 'KDJ指标D值', kdj_j DOUBLE DEFAULT NULL COMMENT 'KDJ指标J值', kdj_signal VARCHAR(15) DEFAULT NULL COMMENT 'KDJ金叉死叉信号', kdj_pattern varchar(25) DEFAULT NULL COMMENT 'KDJ超买,超卖,徘徊', sar DOUBLE DEFAULT NULL COMMENT 'SAR指标', sar_signal VARCHAR(15) DEFAULT NULL COMMENT 'SAR多头,SAR空头,SAR观望', ma5 DOUBLE DEFAULT NULL COMMENT '5移动平均线', ma10 DOUBLE DEFAULT NULL COMMENT '10移动平均线', ma20 DOUBLE DEFAULT NULL COMMENT '20移动平均线', ma30 DOUBLE DEFAULT NULL COMMENT '30移动平均线', ma_cross VARCHAR(15) DEFAULT NULL COMMENT '均线交叉信号', ma5_close_diff double DEFAULT NULL COMMENT '5移动平均线与收盘价差值', ma10_close_diff double DEFAULT NULL COMMENT '10移动平均线与收盘价差值', ma20_close_diff double DEFAULT NULL COMMENT '20移动平均线与收盘价差值', ma30_close_diff double DEFAULT NULL COMMENT '30移动平均线与收盘价差值', ma_close_avg double DEFAULT NULL COMMENT '收盘价移动平均值', ma_long_short varchar(25) DEFAULT NULL COMMENT '均线多空', ma_divergence varchar(25) DEFAULT NULL COMMENT '均线发散,均线粘合,均线适中,均线发散,均线超发散' rsi_14 DOUBLE DEFAULT NULL COMMENT '14RSI指标', rsi_signal VARCHAR(15) DEFAULT NULL COMMENT 'RSI强弱信号', boll_upper DOUBLE DEFAULT NULL COMMENT '布林带上轨', boll_middle DOUBLE DEFAULT NULL COMMENT '布林带中轨', boll_lower DOUBLE DEFAULT NULL COMMENT '布林带下轨', boll_signal VARCHAR(15) DEFAULT NULL COMMENT '布林带强弱信号', boll_pattern varchar(25) DEFAULT NULL COMMENT 'BOLL超买,超卖,徘徊', k_length varchar(25) DEFAULT NULL COMMENT 'K线长度', k_shape varchar(25) DEFAULT NULL COMMENT 'K线形状', k_up_down varchar(25) DEFAULT NULL COMMENT 'K线方向', """ data = data.sort_values(by="timestamp") data = data.reset_index(drop=True) metrics_calculation = MetricsCalculation() data = metrics_calculation.pre_close(data) data = metrics_calculation.macd(data) data = metrics_calculation.kdj(data) data = metrics_calculation.sar(data) data = metrics_calculation.set_kdj_pattern(data) data = metrics_calculation.update_macd_divergence_column_simple(data) data = metrics_calculation.ma5102030(data) data = metrics_calculation.calculate_ma_price_percent(data) data = metrics_calculation.set_ma_long_short_divergence(data) data = metrics_calculation.rsi(data) data = metrics_calculation.boll(data) data = metrics_calculation.set_boll_pattern(data) data = metrics_calculation.set_k_length(data) data = metrics_calculation.set_k_shape(data) return data def batch_update_data(self): """ 更新数据 1. 获取最新数据 2. 获取最新数据的时间戳 3. 根据最新数据的时间戳,获取最新数据 4. 将最新数据保存到数据库 """ for symbol in self.symbols: for bar in self.bars: self.update_data(symbol, bar) def update_data(self, symbol: str, bar: str): """ 更新数据 """ logger.info(f"开始更新行情数据: {symbol} {bar}") latest_data = self.db_market_data.query_latest_data(symbol, bar) if not latest_data: logger.info(f"{symbol}, {bar} 无数据,开始从{self.initial_date}初始化数据") data = self.fetch_save_data(symbol, bar, self.initial_date) else: latest_timestamp = latest_data.get("timestamp") if latest_timestamp: latest_timestamp = int(latest_timestamp) latest_date_time = timestamp_to_datetime(latest_timestamp) logger.info( f"{symbol}, {bar} 上次获取的最新数据时间: {latest_date_time}" ) else: logger.warning(f"获取{symbol}, {bar} 最新数据失败") return data = self.fetch_save_data(symbol, bar, latest_timestamp + 1) return data def batch_calculate_metrics(self): """ 批量计算技术指标 """ logger.info("开始批量计算技术指标") start_date_time = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get( "initial_date", "2025-05-15 00:00:00" ) start_timestamp = transform_date_time_to_timestamp(start_date_time) current_date_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_timestamp = transform_date_time_to_timestamp(current_date_time) for symbol in self.symbols: for bar in self.bars: logger.info(f"开始计算技术指标: {symbol} {bar}") data = self.db_market_data.query_market_data_by_symbol_bar( symbol=symbol, bar=bar, start=start_timestamp - 1, end=current_timestamp, ) if data is not None and len(data) > 0: data = pd.DataFrame(data) data = self.calculate_metrics(data) logger.info(f"开始保存技术指标数据: {symbol} {bar}") self.db_market_data.insert_data_to_mysql(data) def batch_ma_break_statistics(self): """ 批量计算MA突破统计 """ logger.info("开始批量计算MA突破统计") self.ma_break_statistics.batch_statistics(all_change=False) self.ma_break_statistics.batch_statistics(all_change=True) if __name__ == "__main__": market_data_main = MarketDataMain() # market_data_main.batch_update_data() # market_data_main.initial_data() market_data_main.batch_calculate_metrics() # market_data_main.batch_ma_break_statistics()