crypto_quant/core/db/db_astock.py

143 lines
5.0 KiB
Python
Raw Normal View History

2025-09-25 04:28:43 +00:00
import pandas as pd
from sqlalchemy import create_engine, exc, text
import re
from core.utils import get_current_date_time
import core.logger as logging
from core.utils import transform_data_type
logger = logging.logger
class DBAStockData:
def __init__(
self,
db_url: str,
):
self.db_url = db_url
self.db_engine = create_engine(
self.db_url,
pool_size=25, # 连接池大小
max_overflow=10, # 允许的最大溢出连接
pool_timeout=30, # 连接超时时间(秒)
pool_recycle=60, # 连接回收时间(秒),避免长时间闲置
)
def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True):
"""
查询数据
:param sql: 查询SQL
:param db_url: 数据库连接URL
"""
try:
with self.db_engine.connect() as conn:
result = conn.execute(text(sql), condition_dict)
if return_multi:
result = result.fetchall()
if result:
result_list = [
transform_data_type(dict(row._mapping)) for row in result
]
return result_list
else:
return None
else:
result = result.fetchone()
if result:
result_dict = transform_data_type(dict(result._mapping))
return result_dict
else:
return None
except Exception as e:
logger.error(f"查询数据出错: {e}")
return None
def query_market_data_by_symbol_bar(
self,
symbol: str,
fields: list = None,
start: str = None,
end: str = None,
table_name: str = "index_daily_price_from_2021",
):
"""
根据交易对和K线周期查询数据
:param symbol: 交易对
:param bar: K线周期
:param fields: 字段列表
:param start: 开始时间
:param end: 结束时间
"""
if fields is None:
fields = ["*"]
fields_str = ", ".join(fields)
if table_name is None:
table_name = "index_daily_price_from_2021"
join_table = "all_index"
if table_name.startswith("index"):
join_table = "all_index"
else:
join_table = "all_stock"
if start is None and end is None:
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol
ORDER BY a.trade_date ASC
"""
condition_dict = {"symbol": symbol}
else:
if start is not None and end is not None:
start = start.replace("-", "")
end = end.replace("-", "")
if start > end:
start, end = end, start
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol AND a.trade_date BETWEEN :start AND :end
ORDER BY a.trade_date ASC
"""
condition_dict = {
"symbol": symbol,
"start": start,
"end": end,
}
elif start is not None:
start = start.replace("-", "")
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol AND a.trade_date >= :start
ORDER BY a.trade_date ASC
"""
condition_dict = {"symbol": symbol, "start": start}
elif end is not None:
end = end.replace("-", "")
sql = f"""
SELECT {fields_str} FROM {table_name} a
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
WHERE a.ts_code = :symbol AND a.trade_date <= :end
ORDER BY a.trade_date ASC
"""
condition_dict = {"symbol": symbol, "end": end}
return self.query_data(sql, condition_dict, return_multi=True)
2025-10-10 08:18:12 +00:00
def query_index_data(self):
sql = f"""
SELECT * FROM all_index a
order by ts_code
"""
condition_dict = {}
data = self.query_data(sql, condition_dict, return_multi=True)
return pd.DataFrame(data)
def query_stock_data(self):
sql = f"""
SELECT * FROM all_stock a
order by ts_code
"""
condition_dict = {}
data = self.query_data(sql, condition_dict, return_multi=True)
return pd.DataFrame(data)