import pandas as pd from sqlalchemy import create_engine, exc, text import re from core.utils import get_current_date_time import core.logger as logging from core.utils import transform_data_type logger = logging.logger class DBData: def __init__( self, db_url: str, table_name: str = "crypto_market_data", columns: list = None ): self.table_name = table_name self.temp_table_name = ( f"temp_{table_name}_{get_current_date_time()}" ) self.columns = columns if self.columns is None: raise ValueError("columns不能为空") if len(self.columns) != len(set(self.columns)): raise ValueError("columns不能有重复") self.db_url = db_url self.db_engine = create_engine( self.db_url, pool_size=25, # 连接池大小 max_overflow=10, # 允许的最大溢出连接 pool_timeout=30, # 连接超时时间(秒) pool_recycle=60, # 连接回收时间(秒),避免长时间闲置 ) def create_insert_sql_text_by_temp_table(self, temp_table_name: str): """ 创建插入SQL语句(使用临时表) 示例: INSERT INTO crypto_market_data (symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time) SELECT symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time FROM {temp_table_name} ON DUPLICATE KEY UPDATE open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close), volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote), date_time=VALUES(date_time), create_time=VALUES(create_time) """ sql = f""" INSERT INTO {self.table_name} ({",".join(self.columns)}) SELECT {",".join(self.columns)} FROM {temp_table_name} ON DUPLICATE KEY UPDATE {", ".join([f"{col}=VALUES({col})" for col in self.columns])} """ return sql def create_insert_sql_text(self): """ 创建插入SQL语句(不使用临时表) 示例: INSERT INTO crypto_market_data (symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time) VALUES (:symbol, :bar, :timestamp, :date_time, :open, :high, :low, :close, :volume, :volCcy, :volCCyQuote, :create_time) ON DUPLICATE KEY UPDATE open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close), volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote), date_time=VALUES(date_time), create_time=VALUES(create_time) """ sql = f""" INSERT INTO {self.table_name} ({",".join(self.columns)}) VALUES ({",".join([f":{col}" for col in self.columns])}) ON DUPLICATE KEY UPDATE {", ".join([f"{col}=VALUES({col})" for col in self.columns])} """ return sql def insert_data_to_mysql(self, df: pd.DataFrame): """ 将K线行情数据保存到MySQL的crypto_market_data表 速度:⭐⭐⭐⭐⭐ 最快 内存:⭐⭐⭐⭐ 中等 适用场景:中小数据量(<10万条) :param df: K线数据DataFrame :param symbol: 交易对 :param bar: K线周期 :param db_url: 数据库连接URL """ if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] # 建立数据库连接 try: # 方案1:使用临时表 + 批量更新(推荐,速度最快) with self.db_engine.connect() as conn: # 将数据写入临时表 df.to_sql( name=self.temp_table_name, con=conn, if_exists="replace", index=False, method="multi", ) # 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理 sql = text( self.create_insert_sql_text_by_temp_table(self.temp_table_name) ) conn.execute(sql) # 删除临时表 conn.execute(text(f"DROP TABLE IF EXISTS {self.temp_table_name}")) logger.info("数据已成功写入数据库。") except Exception as e: logger.error(f"数据库连接或写入失败: {e}") def insert_data_to_mysql_fast(self, df: pd.DataFrame): """ 快速插入K线行情数据(方案2:使用executemany批量插入) 速度:⭐⭐⭐⭐ 很快 内存:⭐⭐⭐⭐⭐ 低 适用场景:中等数据量 """ if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] try: with self.db_engine.connect() as conn: # 使用executemany批量插入 sql = text(self.create_insert_sql_text()) # 将DataFrame转换为字典列表 data_dicts = [row.to_dict() for _, row in df.iterrows()] conn.execute(sql, data_dicts) logger.info("数据已成功写入数据库。") except Exception as e: logger.error(f"数据库连接或写入失败: {e}") def insert_data_to_mysql_chunk(self, df: pd.DataFrame, chunk_size: int = 1000): """ 分块插入K线行情数据(方案3:适合大数据量) 速度:⭐⭐⭐ 中等 内存:⭐⭐⭐⭐⭐ 最低 适用场景:大数据量(>10万条) """ if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] try: total_rows = len(df) for i in range(0, total_rows, chunk_size): chunk_df = df.iloc[i : i + chunk_size] with self.db_engine.connect() as conn: # 创建临时表 temp_table_name = f"{self.temp_table_name}_{i}" # 将数据写入临时表 chunk_df.to_sql( name=temp_table_name, con=conn, if_exists="replace", index=False, method="multi", ) # 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理 sql = text( self.create_insert_sql_text_by_temp_table(temp_table_name) ) conn.execute(sql) # 删除临时表 conn.execute(text(f"DROP TABLE IF EXISTS {temp_table_name}")) logger.info( f"已处理 {min(i+chunk_size, total_rows)}/{total_rows} 条记录" ) logger.info("数据已成功写入数据库。") except Exception as e: logger.error(f"数据库连接或写入失败: {e}") def insert_data_to_mysql_simple(self, df: pd.DataFrame): """ 简单插入K线行情数据(方案4:直接使用to_sql,忽略重复) 速度:⭐⭐⭐⭐⭐ 最快 内存:⭐⭐⭐⭐ 中等 注意:会抛出重复键错误,需要额外处理 """ if df is None or df.empty: logger.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] try: with self.db_engine.connect() as conn: df.to_sql( name=self.table_name, con=conn, if_exists="append", index=False, method="multi", ) logger.info("数据已成功写入数据库。") except Exception as e: logger.error(f"数据库连接或写入失败: {e}") def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True): """ 查询数据 :param sql: 查询SQL :param db_url: 数据库连接URL """ try: with self.db_engine.connect() as conn: result = conn.execute(text(sql), condition_dict) if return_multi: result = result.fetchall() if result: result_list = [ transform_data_type(dict(row._mapping)) for row in result ] return result_list else: return None else: result = result.fetchone() if result: result_dict = transform_data_type(dict(result._mapping)) return result_dict else: return None except Exception as e: logger.error(f"查询数据出错: {e}") return None