126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
import pandas as pd
|
|
from sqlalchemy import create_engine, exc, text
|
|
import re
|
|
from core.utils import get_current_date_time
|
|
import core.logger as logging
|
|
from core.utils import transform_data_type
|
|
|
|
logger = logging.logger
|
|
|
|
|
|
class DBAStockData:
|
|
def __init__(
|
|
self,
|
|
db_url: str,
|
|
):
|
|
self.db_url = db_url
|
|
self.db_engine = create_engine(
|
|
self.db_url,
|
|
pool_size=25, # 连接池大小
|
|
max_overflow=10, # 允许的最大溢出连接
|
|
pool_timeout=30, # 连接超时时间(秒)
|
|
pool_recycle=60, # 连接回收时间(秒),避免长时间闲置
|
|
)
|
|
|
|
def query_data(self, sql: str, condition_dict: dict, return_multi: bool = True):
|
|
"""
|
|
查询数据
|
|
:param sql: 查询SQL
|
|
:param db_url: 数据库连接URL
|
|
"""
|
|
try:
|
|
with self.db_engine.connect() as conn:
|
|
result = conn.execute(text(sql), condition_dict)
|
|
if return_multi:
|
|
result = result.fetchall()
|
|
if result:
|
|
result_list = [
|
|
transform_data_type(dict(row._mapping)) for row in result
|
|
]
|
|
return result_list
|
|
else:
|
|
return None
|
|
else:
|
|
result = result.fetchone()
|
|
if result:
|
|
result_dict = transform_data_type(dict(result._mapping))
|
|
return result_dict
|
|
else:
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"查询数据出错: {e}")
|
|
return None
|
|
|
|
def query_market_data_by_symbol_bar(
|
|
self,
|
|
symbol: str,
|
|
bar: str,
|
|
fields: list = None,
|
|
start: str = None,
|
|
end: str = None,
|
|
table_name: str = "index_daily_price_from_2021",
|
|
):
|
|
"""
|
|
根据交易对和K线周期查询数据
|
|
:param symbol: 交易对
|
|
:param bar: K线周期
|
|
:param fields: 字段列表
|
|
:param start: 开始时间
|
|
:param end: 结束时间
|
|
"""
|
|
if fields is None:
|
|
fields = ["*"]
|
|
fields_str = ", ".join(fields)
|
|
if table_name is None:
|
|
table_name = "index_daily_price_from_2021"
|
|
join_table = "all_index"
|
|
if table_name.startswith("index"):
|
|
join_table = "all_index"
|
|
else:
|
|
join_table = "all_stock"
|
|
|
|
if start is None and end is None:
|
|
sql = f"""
|
|
SELECT {fields_str} FROM {table_name} a
|
|
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
|
|
WHERE a.ts_code = :symbol
|
|
ORDER BY a.trade_date ASC
|
|
"""
|
|
condition_dict = {"symbol": symbol}
|
|
else:
|
|
if start is not None and end is not None:
|
|
start = start.replace("-", "")
|
|
end = end.replace("-", "")
|
|
if start > end:
|
|
start, end = end, start
|
|
sql = f"""
|
|
SELECT {fields_str} FROM {table_name} a
|
|
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
|
|
WHERE a.ts_code = :symbol AND a.trade_date BETWEEN :start AND :end
|
|
ORDER BY a.trade_date ASC
|
|
"""
|
|
condition_dict = {
|
|
"symbol": symbol,
|
|
"start": start,
|
|
"end": end,
|
|
}
|
|
elif start is not None:
|
|
start = start.replace("-", "")
|
|
sql = f"""
|
|
SELECT {fields_str} FROM {table_name} a
|
|
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
|
|
WHERE a.ts_code = :symbol AND a.trade_date >= :start
|
|
ORDER BY a.trade_date ASC
|
|
"""
|
|
condition_dict = {"symbol": symbol, "start": start}
|
|
elif end is not None:
|
|
end = end.replace("-", "")
|
|
sql = f"""
|
|
SELECT {fields_str} FROM {table_name} a
|
|
INNER JOIN {join_table} b ON a.ts_code = b.ts_code
|
|
WHERE a.ts_code = :symbol AND a.trade_date <= :end
|
|
ORDER BY a.trade_date ASC
|
|
"""
|
|
condition_dict = {"symbol": symbol, "end": end}
|
|
return self.query_data(sql, condition_dict, return_multi=True)
|