support get similar stocks
This commit is contained in:
parent
18168010ce
commit
8ae5dab544
|
|
@ -54,7 +54,6 @@ class DBAStockData:
|
|||
def query_market_data_by_symbol_bar(
|
||||
self,
|
||||
symbol: str,
|
||||
bar: str,
|
||||
fields: list = None,
|
||||
start: str = None,
|
||||
end: str = None,
|
||||
|
|
@ -123,3 +122,21 @@ class DBAStockData:
|
|||
"""
|
||||
condition_dict = {"symbol": symbol, "end": end}
|
||||
return self.query_data(sql, condition_dict, return_multi=True)
|
||||
|
||||
def query_index_data(self):
|
||||
sql = f"""
|
||||
SELECT * FROM all_index a
|
||||
order by ts_code
|
||||
"""
|
||||
condition_dict = {}
|
||||
data = self.query_data(sql, condition_dict, return_multi=True)
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def query_stock_data(self):
|
||||
sql = f"""
|
||||
SELECT * FROM all_stock a
|
||||
order by ts_code
|
||||
"""
|
||||
condition_dict = {}
|
||||
data = self.query_data(sql, condition_dict, return_multi=True)
|
||||
return pd.DataFrame(data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,381 @@
|
|||
import core.logger as logging
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import mplfinance as mpf
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from core.utils import get_current_date_time
|
||||
import re
|
||||
import json
|
||||
import math
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.drawing.image import Image
|
||||
import openpyxl
|
||||
from openpyxl.styles import Font
|
||||
from PIL import Image as PILImage
|
||||
from config import (
|
||||
A_MYSQL_CONFIG,
|
||||
)
|
||||
from core.db.db_astock import DBAStockData
|
||||
|
||||
# seaborn支持中文
|
||||
plt.rcParams["font.family"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
# 统一Seaborn外观
|
||||
sns.set_theme(style="whitegrid")
|
||||
|
||||
logger = logging.logger
|
||||
|
||||
|
||||
class SimilarPatternStocks:
|
||||
def __init__(self):
|
||||
mysql_user = A_MYSQL_CONFIG.get("user", "root")
|
||||
mysql_password = A_MYSQL_CONFIG.get("password", "")
|
||||
if not mysql_password:
|
||||
raise ValueError("MySQL password is not set")
|
||||
mysql_host = A_MYSQL_CONFIG.get("host", "localhost")
|
||||
mysql_port = A_MYSQL_CONFIG.get("port", 3306)
|
||||
mysql_database = A_MYSQL_CONFIG.get("database", "astock")
|
||||
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||||
self.db_astock = DBAStockData(self.db_url)
|
||||
self.output_dir = r"./output/similar_pattern_stocks/"
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
def get_stock_list(self):
|
||||
stock_data = self.db_astock.query_stock_data()
|
||||
# 仅获取market为主板,创业板与科创板的股票
|
||||
stock_data = stock_data[stock_data["market"].isin(["主板", "创业板", "科创板"])]
|
||||
return stock_data
|
||||
|
||||
def get_stock_market_data(
|
||||
self,
|
||||
symbol: str,
|
||||
bar: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
compare_record_amount: int = 100,
|
||||
):
|
||||
fields = [
|
||||
"a.ts_code as symbol",
|
||||
"b.name as symbol_name",
|
||||
f"'{bar}' as bar",
|
||||
"trade_date as date_time",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol as volume",
|
||||
"MA5 as ma5",
|
||||
"MA10 as ma10",
|
||||
"MA20 as ma20",
|
||||
"MA30 as ma30",
|
||||
"均线交叉 as ma_cross",
|
||||
"DIF as dif",
|
||||
"DEA as dea",
|
||||
"MACD as macd",
|
||||
]
|
||||
if bar == "1W":
|
||||
table_name = "stock_weekly_price_from_2020"
|
||||
elif bar == "1M":
|
||||
table_name = "stock_monthly_price_from_2015"
|
||||
else:
|
||||
table_name = "stock_daily_price_from_2021"
|
||||
data = self.db_astock.query_market_data_by_symbol_bar(
|
||||
symbol=symbol,
|
||||
fields=fields,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
table_name=table_name,
|
||||
)
|
||||
data = pd.DataFrame(data)
|
||||
data.sort_values(by=["date_time"], inplace=True)
|
||||
# 获取最后一百条数据
|
||||
data = data.tail(compare_record_amount)
|
||||
data.reset_index(drop=True, inplace=True)
|
||||
return data
|
||||
|
||||
def get_stock_market_data_similar_pattern(
|
||||
self,
|
||||
target_symbol: str,
|
||||
bar: str,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
compare_record_amount: int = 100,
|
||||
):
|
||||
"""
|
||||
1. 获取目标股票的market数据
|
||||
2. 根据end_date, 获取目标股票的最后一百条数据
|
||||
3. 遍历所有股票,获取所有股票的最后一百条数据
|
||||
4. 计算目标股票与所有股票的最后一百条数据的相似度
|
||||
5. 返回相似度最高的十只股票
|
||||
6. 绘制目标股票与相似度最高的十只股票的对比图, 通过K线图展示
|
||||
"""
|
||||
logger.info(f"获取目标股票{target_symbol}的{bar}数据, 截止日期{end_date}")
|
||||
target_stock_data = self.get_stock_market_data(
|
||||
target_symbol, bar, start_date, end_date, compare_record_amount
|
||||
)
|
||||
if len(target_stock_data) == 0:
|
||||
logger.error(f"目标股票{target_symbol}的{bar}数据为空")
|
||||
return []
|
||||
compare_record_amount = len(target_stock_data)
|
||||
stock_list_data = self.get_stock_list()
|
||||
if len(stock_list_data) == 0:
|
||||
logger.error("所有股票数据为空")
|
||||
return []
|
||||
stock_list_data = stock_list_data[stock_list_data["ts_code"] != target_symbol]
|
||||
now_date = datetime.now().strftime("%Y%m%d")
|
||||
similarity_data = []
|
||||
for index, row in stock_list_data.iterrows():
|
||||
logger.info(f"获取第{index+1}只股票{row['ts_code']} {row['name']} 的{bar}数据, 截止日期{now_date}")
|
||||
stock_data = self.get_stock_market_data(
|
||||
symbol=row["ts_code"],
|
||||
bar=bar,
|
||||
start_date=None,
|
||||
end_date=now_date,
|
||||
compare_record_amount=compare_record_amount,
|
||||
)
|
||||
if len(stock_data) == 0:
|
||||
logger.error(f"股票{row['ts_code']}的{bar}数据为空")
|
||||
continue
|
||||
if len(stock_data) != compare_record_amount:
|
||||
logger.error(f"股票{row['ts_code']}的数据不足{compare_record_amount}条")
|
||||
continue
|
||||
similarity = self.calculate_similarity(target_stock_data, stock_data)
|
||||
similarity_data.append(
|
||||
{
|
||||
"symbol": row["ts_code"],
|
||||
"symbol_name": row["name"],
|
||||
"similarity_distance": similarity,
|
||||
"stock_data": stock_data,
|
||||
}
|
||||
)
|
||||
# 按照similarity_distance升序排序
|
||||
similarity_data.sort(key=lambda x: x["similarity_distance"])
|
||||
|
||||
pure_data = []
|
||||
for item in similarity_data:
|
||||
pure_data.append(
|
||||
{
|
||||
"symbol": item["symbol"],
|
||||
"symbol_name": item["symbol_name"],
|
||||
"similarity_distance": item["similarity_distance"],
|
||||
}
|
||||
)
|
||||
pure_data = pd.DataFrame(pure_data)
|
||||
# 去除similarity_distance为空或nan的数据
|
||||
pure_data = pure_data[pure_data["similarity_distance"].notna()]
|
||||
pure_data.sort_values(by=["similarity_distance"], inplace=True)
|
||||
pure_data.reset_index(drop=True, inplace=True)
|
||||
target_stock_symbol = str(target_stock_data["symbol"].iloc[0]).split(".")[0]
|
||||
target_stock_name = str(target_stock_data["symbol_name"].iloc[0])
|
||||
if end_date is not None and len(end_date) > 0:
|
||||
target_stock_folder = os.path.join(
|
||||
self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{end_date}"
|
||||
)
|
||||
else:
|
||||
target_stock_folder = os.path.join(
|
||||
self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{now_date}"
|
||||
)
|
||||
os.makedirs(target_stock_folder, exist_ok=True)
|
||||
excel_file_path = os.path.join(
|
||||
target_stock_folder,
|
||||
f"{target_stock_symbol}_{target_stock_name}_{bar}_similar_stocks.xlsx",
|
||||
)
|
||||
with pd.ExcelWriter(excel_file_path) as writer:
|
||||
pure_data.to_excel(writer, sheet_name="股票形态相似度", index=False)
|
||||
|
||||
similar_stocks_chart_folder = os.path.join(
|
||||
target_stock_folder, f"similar_stocks_chart"
|
||||
)
|
||||
os.makedirs(similar_stocks_chart_folder, exist_ok=True)
|
||||
chart_list = []
|
||||
chart_info = self.draw_similar_stocks_chart(
|
||||
target_stock_symbol,
|
||||
target_stock_name,
|
||||
bar,
|
||||
"0",
|
||||
target_stock_data,
|
||||
similar_stocks_chart_folder,
|
||||
)
|
||||
chart_list.append(chart_info)
|
||||
for index, row in pure_data.iterrows():
|
||||
symbol = row["symbol"]
|
||||
symbol_name = row["symbol_name"]
|
||||
similarity_distance = row["similarity_distance"]
|
||||
for item in similarity_data:
|
||||
if item["symbol"] == symbol:
|
||||
stock_data = item["stock_data"]
|
||||
break
|
||||
chart_info = self.draw_similar_stocks_chart(
|
||||
symbol,
|
||||
symbol_name,
|
||||
bar,
|
||||
similarity_distance,
|
||||
stock_data,
|
||||
similar_stocks_chart_folder,
|
||||
)
|
||||
chart_list.append(chart_info)
|
||||
if index >= 9:
|
||||
break
|
||||
self.output_chart_to_excel(excel_file_path, chart_list)
|
||||
return chart_list
|
||||
|
||||
def output_chart_to_excel(self, excel_file_path: str, charts_list: list):
|
||||
"""
|
||||
输出Excel文件,包含所有图表
|
||||
charts_list: 图表数据列表,格式为:
|
||||
{
|
||||
"chart_path": "chart_path",
|
||||
"symbol": "symbol",
|
||||
"stock_name": "stock_name",
|
||||
"bar": "bar",
|
||||
"similarity_distance": "similarity_distance",
|
||||
}
|
||||
"""
|
||||
logger.info(f"将图表输出到{excel_file_path}")
|
||||
|
||||
# 打开已经存在的Excel文件
|
||||
wb = openpyxl.load_workbook(excel_file_path)
|
||||
ws = wb.create_sheet(title="股票K线图")
|
||||
row_offset = 1
|
||||
for chart_info in charts_list:
|
||||
chart_path = chart_info["chart_path"]
|
||||
chart_name = chart_info["symbol"] + " " + chart_info["stock_name"] + " " + chart_info["bar"] + " 距离:" + str(chart_info["similarity_distance"])
|
||||
# Load image to get dimensions
|
||||
with PILImage.open(chart_path) as img:
|
||||
width_px, height_px = img.size
|
||||
|
||||
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
|
||||
pixels_per_point = 1.333
|
||||
points_per_row = 15 # Default row height in points
|
||||
pixels_per_row = (
|
||||
points_per_row * pixels_per_point
|
||||
) # ≈ 20 pixels per row
|
||||
chart_rows = max(
|
||||
10, int(height_px / pixels_per_row)
|
||||
) # Minimum 10 rows for small charts
|
||||
|
||||
# Add chart title
|
||||
# 支持中文标题
|
||||
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
|
||||
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
|
||||
row_offset += 2 # Add 2 rows for title and spacing
|
||||
|
||||
# Insert chart image
|
||||
img = Image(chart_path)
|
||||
ws.add_image(img, f"A{row_offset}")
|
||||
|
||||
# Update row offset (chart height + padding)
|
||||
row_offset += (
|
||||
chart_rows + 5
|
||||
) # Add 5 rows for padding between charts
|
||||
# Save Excel file
|
||||
wb.save(excel_file_path)
|
||||
logger.info(f"Chart saved as {excel_file_path}")
|
||||
|
||||
def draw_similar_stocks_chart(
|
||||
self,
|
||||
symbol: str,
|
||||
stock_name: str,
|
||||
bar: str,
|
||||
similarity_distance: float,
|
||||
stock_data: pd.DataFrame,
|
||||
similar_stocks_chart_folder: str,
|
||||
):
|
||||
"""
|
||||
绘制股票K线图, 并保存到similar_stocks_chart_folder
|
||||
"""
|
||||
pure_stock_name = stock_name.replace("*", "")
|
||||
chart_path = os.path.join(
|
||||
similar_stocks_chart_folder, f"{symbol}_{pure_stock_name}_{bar}_chart.png"
|
||||
)
|
||||
|
||||
# 准备数据:确保日期列名为Date,并转换为datetime
|
||||
df = stock_data.copy()
|
||||
|
||||
# 检查并重命名日期列
|
||||
if "trade_date" in df.columns:
|
||||
df = df.rename(columns={"trade_date": "Date"})
|
||||
elif "date_time" in df.columns:
|
||||
df = df.rename(columns={"date_time": "Date"})
|
||||
|
||||
# 确保Date列为datetime类型
|
||||
df["Date"] = pd.to_datetime(df["Date"])
|
||||
df = df.set_index("Date")
|
||||
|
||||
# 确保OHLC列名正确
|
||||
df = df.rename(
|
||||
columns={"open": "Open", "high": "High", "low": "Low", "close": "Close"}
|
||||
)
|
||||
|
||||
# 使用mplfinance绘制K线图(接近Seaborn风格,红涨绿跌)
|
||||
mc = mpf.make_marketcolors(
|
||||
up="#e41a1c", # 红色-上涨(A股常用)
|
||||
down="#4daf4a", # 绿色-下跌(A股常用)
|
||||
edge="inherit",
|
||||
wick="inherit",
|
||||
volume="inherit",
|
||||
)
|
||||
mpf_style = mpf.make_mpf_style(
|
||||
base_mpf_style="yahoo",
|
||||
marketcolors=mc,
|
||||
facecolor="#FAFAFA",
|
||||
edgecolor="#EAEAEA",
|
||||
gridcolor="#E6E6E6",
|
||||
gridstyle="--",
|
||||
rc={"font.family": "SimHei", "axes.unicode_minus": False},
|
||||
)
|
||||
mpf.plot(
|
||||
df,
|
||||
type="candle",
|
||||
style=mpf_style,
|
||||
title=f"{symbol} {stock_name} {bar} 距离:{similarity_distance}",
|
||||
ylabel="价格",
|
||||
figsize=(12, 6),
|
||||
savefig={"fname": chart_path, "dpi": 150, "bbox_inches": "tight"},
|
||||
)
|
||||
chart_info = {
|
||||
"chart_path": chart_path,
|
||||
"symbol": symbol,
|
||||
"stock_name": stock_name,
|
||||
"bar": bar,
|
||||
"similarity_distance": similarity_distance,
|
||||
}
|
||||
return chart_info
|
||||
|
||||
def calculate_similarity(
|
||||
self,
|
||||
target_stock_data: pd.DataFrame,
|
||||
stock_data: pd.DataFrame,
|
||||
close_weight=0.8,
|
||||
volume_weight=0.2,
|
||||
):
|
||||
"""
|
||||
通过股价归一化以及欧氏距离,计算目标股票与股票的相似度
|
||||
"""
|
||||
target_stock_close = target_stock_data["close"]
|
||||
stock_close = stock_data["close"]
|
||||
target_stock_close = (target_stock_close - target_stock_close.min()) / (
|
||||
target_stock_close.max() - target_stock_close.min()
|
||||
)
|
||||
stock_close = (stock_close - stock_close.min()) / (
|
||||
stock_close.max() - stock_close.min()
|
||||
)
|
||||
close_distance = np.linalg.norm(target_stock_close - stock_close)
|
||||
|
||||
target_stock_volume = target_stock_data["volume"]
|
||||
stock_volume = stock_data["volume"]
|
||||
target_stock_volume = (target_stock_volume - target_stock_volume.min()) / (
|
||||
target_stock_volume.max() - target_stock_volume.min()
|
||||
)
|
||||
stock_volume = (stock_volume - stock_volume.min()) / (
|
||||
stock_volume.max() - stock_volume.min()
|
||||
)
|
||||
volume_distance = np.linalg.norm(target_stock_volume - stock_volume)
|
||||
|
||||
similarity_distance = (
|
||||
close_weight * close_distance + volume_weight * volume_distance
|
||||
)
|
||||
return float(similarity_distance)
|
||||
|
|
@ -864,7 +864,6 @@ class MaBreakStatistics:
|
|||
logger.info(f"获取{symbol}数据:{start_date_str}至{current_end_date_str}")
|
||||
current_data = self.db_market_data.query_market_data_by_symbol_bar(
|
||||
symbol,
|
||||
bar,
|
||||
fields,
|
||||
start=start_date_str,
|
||||
end=current_end_date_str,
|
||||
|
|
@ -944,7 +943,7 @@ class MaBreakStatistics:
|
|||
"MACD as macd",
|
||||
]
|
||||
data = self.db_market_data.query_market_data_by_symbol_bar(
|
||||
symbol, bar, fields, start=last_date, end=end_date, table_name=table_name
|
||||
symbol, fields, start=last_date, end=end_date, table_name=table_name
|
||||
)
|
||||
if data is not None and len(data) > 0:
|
||||
data = pd.DataFrame(data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
import core.logger as logging
|
||||
from core.statistics.similar_pattern_stocks import SimilarPatternStocks
|
||||
import os
|
||||
|
||||
logger = logging.logger
|
||||
|
||||
def main():
|
||||
similar_pattern_stocks = SimilarPatternStocks()
|
||||
|
||||
target_stock_list = [
|
||||
{
|
||||
"symbol": "600111.SH",
|
||||
"bar": "1W",
|
||||
"start_date": None,
|
||||
"end_date": "20250711",
|
||||
"compare_record_amount": 100,
|
||||
},
|
||||
{
|
||||
"symbol": "600111.SH",
|
||||
"bar": "1M",
|
||||
"start_date": None,
|
||||
"end_date": "20250630",
|
||||
"compare_record_amount": 100,
|
||||
},
|
||||
{
|
||||
"symbol": "601398.SH",
|
||||
"bar": "1M",
|
||||
"start_date": None,
|
||||
"end_date": "20230430",
|
||||
"compare_record_amount": 100,
|
||||
}
|
||||
]
|
||||
|
||||
for target_stock in target_stock_list:
|
||||
similar_pattern_stocks.get_stock_market_data_similar_pattern(
|
||||
target_symbol=target_stock["symbol"],
|
||||
bar=target_stock["bar"],
|
||||
start_date=target_stock["start_date"],
|
||||
end_date=target_stock["end_date"],
|
||||
compare_record_amount=target_stock["compare_record_amount"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue