crypto_quant/core/db/db_manager.py

249 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sqlalchemy import create_engine, exc, text
import re, datetime
import core.logger as logging
from core.utils import transform_data_type
logger = logging.logger
class DBData:
def __init__(
self, db_url: str, table_name: str = "crypto_market_data", columns: list = None
):
self.table_name = table_name
self.temp_table_name = (
f"temp_{table_name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
self.columns = columns
if self.columns is None:
raise ValueError("columns不能为空")
if len(self.columns) != len(set(self.columns)):
raise ValueError("columns不能有重复")
self.db_url = db_url
self.db_engine = create_engine(
self.db_url,
pool_size=25, # 连接池大小
max_overflow=10, # 允许的最大溢出连接
pool_timeout=30, # 连接超时时间(秒)
pool_recycle=60, # 连接回收时间(秒),避免长时间闲置
)
def create_insert_sql_text_by_temp_table(self, temp_table_name: str):
"""
创建插入SQL语句使用临时表
示例:
INSERT INTO crypto_market_data
(symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time)
SELECT symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time
FROM {temp_table_name}
ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote),
date_time=VALUES(date_time), create_time=VALUES(create_time)
"""
sql = f"""
INSERT INTO {self.table_name}
({",".join(self.columns)})
SELECT {",".join(self.columns)}
FROM {temp_table_name}
ON DUPLICATE KEY UPDATE
{", ".join([f"{col}=VALUES({col})" for col in self.columns])}
"""
return sql
def create_insert_sql_text(self):
"""
创建插入SQL语句不使用临时表
示例:
INSERT INTO crypto_market_data
(symbol, bar, timestamp, date_time, open, high, low, close, volume, volCcy, volCCyQuote, create_time)
VALUES (:symbol, :bar, :timestamp, :date_time, :open, :high, :low, :close, :volume, :volCcy, :volCCyQuote, :create_time)
ON DUPLICATE KEY UPDATE
open=VALUES(open), high=VALUES(high), low=VALUES(low), close=VALUES(close),
volume=VALUES(volume), volCcy=VALUES(volCcy), volCCyQuote=VALUES(volCCyQuote),
date_time=VALUES(date_time), create_time=VALUES(create_time)
"""
sql = f"""
INSERT INTO {self.table_name}
({",".join(self.columns)})
VALUES ({",".join([f":{col}" for col in self.columns])})
ON DUPLICATE KEY UPDATE
{", ".join([f"{col}=VALUES({col})" for col in self.columns])}
"""
return sql
def insert_data_to_mysql(self, df: pd.DataFrame):
"""
将K线行情数据保存到MySQL的crypto_market_data表
速度:⭐⭐⭐⭐⭐ 最快
内存:⭐⭐⭐⭐ 中等
适用场景:中小数据量(<10万条
:param df: K线数据DataFrame
:param symbol: 交易对
:param bar: K线周期
:param db_url: 数据库连接URL
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
return
df = df[self.columns]
# 建立数据库连接
try:
# 方案1使用临时表 + 批量更新(推荐,速度最快)
with self.db_engine.connect() as conn:
# 将数据写入临时表
df.to_sql(
name=self.temp_table_name,
con=conn,
if_exists="replace",
index=False,
method="multi",
)
# 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理
sql = text(
self.create_insert_sql_text_by_temp_table(self.temp_table_name)
)
conn.execute(sql)
# 删除临时表
conn.execute(text(f"DROP TABLE IF EXISTS {self.temp_table_name}"))
logger.info("数据已成功写入数据库。")
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
def insert_data_to_mysql_fast(self, df: pd.DataFrame):
"""
快速插入K线行情数据方案2使用executemany批量插入
速度:⭐⭐⭐⭐ 很快
内存:⭐⭐⭐⭐⭐ 低
适用场景:中等数据量
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
return
df = df[self.columns]
try:
with self.db_engine.connect() as conn:
# 使用executemany批量插入
sql = text(self.create_insert_sql_text())
# 将DataFrame转换为字典列表
data_dicts = [row.to_dict() for _, row in df.iterrows()]
conn.execute(sql, data_dicts)
logger.info("数据已成功写入数据库。")
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
def insert_data_to_mysql_chunk(self, df: pd.DataFrame, chunk_size: int = 1000):
"""
分块插入K线行情数据方案3适合大数据量
速度:⭐⭐⭐ 中等
内存:⭐⭐⭐⭐⭐ 最低
适用场景:大数据量(>10万条
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
return
df = df[self.columns]
try:
total_rows = len(df)
for i in range(0, total_rows, chunk_size):
chunk_df = df.iloc[i : i + chunk_size]
with self.db_engine.connect() as conn:
# 创建临时表
temp_table_name = f"{self.temp_table_name}_{i}"
# 将数据写入临时表
chunk_df.to_sql(
name=temp_table_name,
con=conn,
if_exists="replace",
index=False,
method="multi",
)
# 使用INSERT ... ON DUPLICATE KEY UPDATE批量处理
sql = text(
self.create_insert_sql_text_by_temp_table(temp_table_name)
)
conn.execute(sql)
# 删除临时表
conn.execute(text(f"DROP TABLE IF EXISTS {temp_table_name}"))
logger.info(
f"已处理 {min(i+chunk_size, total_rows)}/{total_rows} 条记录"
)
logger.info("数据已成功写入数据库。")
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
def insert_data_to_mysql_simple(self, df: pd.DataFrame):
"""
简单插入K线行情数据方案4直接使用to_sql忽略重复
速度:⭐⭐⭐⭐⭐ 最快
内存:⭐⭐⭐⭐ 中等
注意:会抛出重复键错误,需要额外处理
"""
if df is None or df.empty:
logger.warning("DataFrame为空无需写入数据库。")
return
df = df[self.columns]
try:
with self.db_engine.connect() as conn:
df.to_sql(
name=self.table_name,
con=conn,
if_exists="append",
index=False,
method="multi",
)
logger.info("数据已成功写入数据库。")
except Exception as e:
logger.error(f"数据库连接或写入失败: {e}")
def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True):
"""
查询数据
:param sql: 查询SQL
:param db_url: 数据库连接URL
"""
try:
engine = create_engine(
self.db_url,
pool_size=5, # 连接池大小
max_overflow=10, # 允许的最大溢出连接
pool_timeout=30, # 连接超时时间(秒)
pool_recycle=1800, # 连接回收时间(秒),避免长时间闲置
)
with self.db_engine.connect() as conn:
result = conn.execute(text(sql), condition_dict)
if return_multi:
result = result.fetchall()
if result:
result_list = [
transform_data_type(dict(row._mapping)) for row in result
]
return result_list
else:
return None
else:
result = result.fetchone()
if result:
result_dict = transform_data_type(dict(result._mapping))
return result_dict
else:
return None
except Exception as e:
logger.error(f"查询数据出错: {e}")
return None