import pandas as pd from sqlalchemy import create_engine, exc, text import re, datetime import logging from core.utils import ( transform_data_type, datetime_to_timestamp, check_date_time_format, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") class DBData: def __init__( self, db_url: str, table_name: str = "crypto_market_data", columns: list = None ): self.db_url = db_url self.table_name = table_name self.temp_table_name = f"temp_{table_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" self.columns = columns if self.columns is None: raise ValueError("columns不能为空") if len(self.columns) != len(set(self.columns)): raise ValueError("columns不能有重复") def create_insert_sql_text_by_temp_table(self, temp_table_name: str): """ 创建插入SQL语句(使用临时表) 示例: INSERT INTO crypto_market_data (symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time) SELECT symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time FROM {temp_table_name} ON DUPLICATE KEY UPDATE open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close), volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote), date_time=VALUES(date_time), create_time=VALUES(create_time) """ sql = f""" INSERT INTO {self.table_name} ({",".join(self.columns)}) SELECT {",".join(self.columns)} FROM {temp_table_name} ON DUPLICATE KEY UPDATE {", ".join([f"{col}=VALUES({col})" for col in self.columns])} """ return sql def create_insert_sql_text(self): """ 创建插入SQL语句(不使用临时表) 示例: INSERT INTO crypto_market_data (symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time) VALUES (:symbol, :bar, :timestamp, :date_time, :open, :high, :low, :close, :volume, :volCcy, :volCCyQuote, :create_time) ON DUPLICATE KEY UPDATE open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close), volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote), date_time=VALUES(date_time), create_time=VALUES(create_time) """ sql = f""" INSERT INTO {self.table_name} ({",".join(self.columns)}) VALUES ({",".join([f":{col}" for col in self.columns])}) ON DUPLICATE KEY UPDATE {", ".join([f"{col}=VALUES({col})" for col in self.columns])} """ return sql def insert_data_to_mysql(self, df: pd.DataFrame): """ 将K线行情数据保存到MySQL的crypto_market_data表 速度:⭐⭐⭐⭐⭐ 最快 内存:⭐⭐⭐⭐ 中等 适用场景:中小数据量(<10万条) :param df: K线数据DataFrame :param symbol: 交易对 :param bar: K线周期 :param db_url: 数据库连接URL """ if df is None or df.empty: logging.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] # 建立数据库连接 try: engine = create_engine(self.db_url) # 方案1:使用临时表 + 批量更新(推荐,速度最快) with engine.begin() as conn: # 将数据写入临时表 df.to_sql( name=self.temp_table_name, con=engine, if_exists="replace", index=False, method="multi", ) # 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理 sql = text(self.create_insert_sql_text_by_temp_table(self.temp_table_name)) conn.execute(sql) # 删除临时表 conn.execute(text(f"DROP TABLE IF EXISTS {self.temp_table_name}")) logging.info("数据已成功写入数据库。") except Exception as e: logging.error(f"数据库连接或写入失败: {e}") def insert_data_to_mysql_fast(self, df: pd.DataFrame): """ 快速插入K线行情数据(方案2:使用executemany批量插入) 速度:⭐⭐⭐⭐ 很快 内存:⭐⭐⭐⭐⭐ 低 适用场景:中等数据量 """ if df is None or df.empty: logging.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] try: engine = create_engine(self.db_url) with engine.begin() as conn: # 使用executemany批量插入 sql = text(self.create_insert_sql_text()) # 将DataFrame转换为字典列表 data_dicts = [row.to_dict() for _, row in df.iterrows()] conn.execute(sql, data_dicts) logging.info("数据已成功写入数据库。") except Exception as e: logging.error(f"数据库连接或写入失败: {e}") def insert_data_to_mysql_chunk(self, df: pd.DataFrame, chunk_size: int = 1000): """ 分块插入K线行情数据(方案3:适合大数据量) 速度:⭐⭐⭐ 中等 内存:⭐⭐⭐⭐⭐ 最低 适用场景:大数据量(>10万条) """ if df is None or df.empty: logging.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] try: engine = create_engine(self.db_url) with engine.begin() as conn: # 分块处理 total_rows = len(df) for i in range(0, total_rows, chunk_size): chunk_df = df.iloc[i : i + chunk_size] with engine.begin() as conn: # 创建临时表 temp_table_name = f"{self.temp_table_name}_{i}" # 将数据写入临时表 chunk_df.to_sql( name=temp_table_name, con=engine, if_exists="replace", index=False, method="multi", ) # 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理 sql = text(self.create_insert_sql_text_by_temp_table(temp_table_name)) conn.execute(sql) # 删除临时表 conn.execute(text(f"DROP TABLE IF EXISTS {temp_table_name}")) logging.info(f"已处理 {min(i+chunk_size, total_rows)}/{total_rows} 条记录") logging.info("数据已成功写入数据库。") except Exception as e: logging.error(f"数据库连接或写入失败: {e}") def insert_data_to_mysql_simple(self, df: pd.DataFrame): """ 简单插入K线行情数据(方案4:直接使用to_sql,忽略重复) 速度:⭐⭐⭐⭐⭐ 最快 内存:⭐⭐⭐⭐ 中等 注意:会抛出重复键错误,需要额外处理 """ if df is None or df.empty: logging.warning("DataFrame为空,无需写入数据库。") return df = df[self.columns] try: engine = create_engine(self.db_url) with engine.begin() as conn: df.to_sql( name=self.table_name, con=engine, if_exists="append", index=False, method="multi", ) logging.info("数据已成功写入数据库。") except Exception as e: logging.error(f"数据库连接或写入失败: {e}") def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True): """ 查询数据 :param sql: 查询SQL :param db_url: 数据库连接URL """ try: engine = create_engine(self.db_url) with engine.connect() as conn: result = conn.execute(text(sql), condition_dict) if return_multi: result = result.fetchall() if result: result_list = [ transform_data_type(dict(row._mapping)) for row in result ] return result_list else: return None else: result = result.fetchone() if result: result_dict = transform_data_type(dict(result._mapping)) return result_dict else: return None except Exception as e: logging.error(f"查询数据出错: {e}") return None