crypto_quant/core/db/db_manager.py

242 lines
9.2 KiB
Python
Raw Normal View History

import pandas as pd
from sqlalchemy import create_engine, exc, text
import re, datetime
import core.logger as logging
from core.utils import transform_data_type
logger = logging.logger
2025-07-28 04:29:31 +00:00
class DBData:
def __init__(
self, db_url: str, table_name: str = "crypto_market_data", columns: list = None
2025-07-28 04:29:31 +00:00
):
self.table_name = table_name
self.temp_table_name = (
f"temp_{table_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
2025-07-28 04:29:31 +00:00
self.columns = columns
if self.columns is None:
raise ValueError("columns不能为空")
if len(self.columns) != len(set(self.columns)):
raise ValueError("columns不能有重复")
self.db_url = db_url
self.db_engine = create_engine(
self.db_url,
pool_size=25, # 连接池大小
max_overflow=10, # 允许的最大溢出连接
pool_timeout=30, # 连接超时时间(秒)
pool_recycle=60, # 连接回收时间(秒),避免长时间闲置
)
2025-07-28 04:29:31 +00:00
def create_insert_sql_text_by_temp_table(self, temp_table_name: str):
"""
创建插入SQL语句使用临时表
示例
INSERT INTO crypto_market_data
(symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time)
SELECT symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time
FROM {temp_table_name}
ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote),
2025-07-28 04:29:31 +00:00
date_time=VALUES(date_time), create_time=VALUES(create_time)
"""
sql = f"""
INSERT INTO {self.table_name}
({",".join(self.columns)})
SELECT {",".join(self.columns)}
FROM {temp_table_name}
ON DUPLICATE KEY UPDATE
{", ".join([f"{col}=VALUES({col})" for col in self.columns])}
"""
return sql
2025-07-28 04:29:31 +00:00
def create_insert_sql_text(self):
"""
创建插入SQL语句不使用临时表
示例
INSERT INTO crypto_market_data
(symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time)
VALUES (:symbol, :bar, :timestamp, :date_time, :open, :high, :low, :close, :volume, :volCcy, :volCCyQuote, :create_time)
ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote),
2025-07-28 04:29:31 +00:00
date_time=VALUES(date_time), create_time=VALUES(create_time)
"""
sql = f"""
INSERT INTO {self.table_name}
({",".join(self.columns)})
VALUES ({",".join([f":{col}" for col in self.columns])})
ON DUPLICATE KEY UPDATE
{", ".join([f"{col}=VALUES({col})" for col in self.columns])}
"""
return sql
def insert_data_to_mysql(self, df: pd.DataFrame):
"""
将K线行情数据保存到MySQL的crypto_market_data表
速度 最快
内存 中等
适用场景中小数据量<10万条
:param df: K线数据DataFrame
:param symbol: 交易对
:param bar: K线周期
:param db_url: 数据库连接URL
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
2025-07-28 04:29:31 +00:00
return
df = df[self.columns]
# 建立数据库连接
try:
2025-07-28 04:29:31 +00:00
# 方案1使用临时表 + 批量更新(推荐,速度最快)
with self.db_engine.connect() as conn:
2025-07-28 04:29:31 +00:00
# 将数据写入临时表
df.to_sql(
name=self.temp_table_name,
con=conn,
2025-07-28 04:29:31 +00:00
if_exists="replace",
index=False,
method="multi",
)
# 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理
sql = text(
self.create_insert_sql_text_by_temp_table(self.temp_table_name)
)
2025-07-28 04:29:31 +00:00
conn.execute(sql)
# 删除临时表
conn.execute(text(f"DROP TABLE IF EXISTS {self.temp_table_name}"))
logger.info("数据已成功写入数据库。")
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
2025-07-28 04:29:31 +00:00
def insert_data_to_mysql_fast(self, df: pd.DataFrame):
"""
快速插入K线行情数据方案2使用executemany批量插入
速度 很快
内存
适用场景中等数据量
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
2025-07-28 04:29:31 +00:00
return
2025-07-28 04:29:31 +00:00
df = df[self.columns]
try:
with self.db_engine.connect() as conn:
2025-07-28 04:29:31 +00:00
# 使用executemany批量插入
sql = text(self.create_insert_sql_text())
2025-07-25 08:12:52 +00:00
2025-07-28 04:29:31 +00:00
# 将DataFrame转换为字典列表
data_dicts = [row.to_dict() for _, row in df.iterrows()]
conn.execute(sql, data_dicts)
2025-07-25 08:12:52 +00:00
logger.info("数据已成功写入数据库。")
2025-07-28 04:29:31 +00:00
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
2025-07-28 04:29:31 +00:00
def insert_data_to_mysql_chunk(self, df: pd.DataFrame, chunk_size: int = 1000):
"""
分块插入K线行情数据方案3适合大数据量
速度 中等
内存 最低
适用场景大数据量>10万条
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
2025-07-28 04:29:31 +00:00
return
2025-07-28 04:29:31 +00:00
df = df[self.columns]
2025-07-28 04:29:31 +00:00
try:
total_rows = len(df)
for i in range(0, total_rows, chunk_size):
chunk_df = df.iloc[i : i + chunk_size]
with self.db_engine.connect() as conn:
# 创建临时表
temp_table_name = f"{self.temp_table_name}_{i}"
# 将数据写入临时表
chunk_df.to_sql(
name=temp_table_name,
con=conn,
if_exists="replace",
index=False,
method="multi",
)
# 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理
sql = text(
self.create_insert_sql_text_by_temp_table(temp_table_name)
)
conn.execute(sql)
# 删除临时表
conn.execute(text(f"DROP TABLE IF EXISTS {temp_table_name}"))
logger.info(
f"已处理 {min(i+chunk_size, total_rows)}/{total_rows} 条记录"
)
logger.info("数据已成功写入数据库。")
2025-07-28 04:29:31 +00:00
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
2025-07-28 04:29:31 +00:00
def insert_data_to_mysql_simple(self, df: pd.DataFrame):
"""
简单插入K线行情数据方案4直接使用to_sql忽略重复
速度 最快
内存 中等
注意会抛出重复键错误需要额外处理
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
2025-07-28 04:29:31 +00:00
return
2025-07-28 04:29:31 +00:00
df = df[self.columns]
try:
with self.db_engine.connect() as conn:
2025-07-28 04:29:31 +00:00
df.to_sql(
name=self.table_name,
con=conn,
2025-07-28 04:29:31 +00:00
if_exists="append",
index=False,
method="multi",
)
logger.info("数据已成功写入数据库。")
2025-07-28 04:29:31 +00:00
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
2025-07-28 04:29:31 +00:00
def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True):
"""
查询数据
:param sql: 查询SQL
:param db_url: 数据库连接URL
"""
try:
with self.db_engine.connect() as conn:
2025-07-28 04:29:31 +00:00
result = conn.execute(text(sql), condition_dict)
if return_multi:
result = result.fetchall()
if result:
result_list = [
transform_data_type(dict(row._mapping)) for row in result
]
return result_list
else:
return None
else:
2025-07-28 04:29:31 +00:00
result = result.fetchone()
if result:
result_dict = transform_data_type(dict(result._mapping))
return result_dict
else:
return None
2025-07-28 04:29:31 +00:00
except Exception as e:
logger.error(f"查询数据出错: {e}")
2025-07-28 04:29:31 +00:00
return None