382 lines
14 KiB
Python
382 lines
14 KiB
Python
import core.logger as logging
|
||
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
import mplfinance as mpf
|
||
from datetime import datetime, timedelta, timezone
|
||
from core.utils import get_current_date_time
|
||
import re
|
||
import json
|
||
import math
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
import openpyxl
|
||
from openpyxl.styles import Font
|
||
from PIL import Image as PILImage
|
||
from config import (
|
||
A_MYSQL_CONFIG,
|
||
)
|
||
from core.db.db_astock import DBAStockData
|
||
|
||
# seaborn支持中文
|
||
plt.rcParams["font.family"] = ["SimHei"]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
# 统一Seaborn外观
|
||
sns.set_theme(style="whitegrid")
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
class SimilarPatternStocks:
|
||
def __init__(self):
|
||
mysql_user = A_MYSQL_CONFIG.get("user", "root")
|
||
mysql_password = A_MYSQL_CONFIG.get("password", "")
|
||
if not mysql_password:
|
||
raise ValueError("MySQL password is not set")
|
||
mysql_host = A_MYSQL_CONFIG.get("host", "localhost")
|
||
mysql_port = A_MYSQL_CONFIG.get("port", 3306)
|
||
mysql_database = A_MYSQL_CONFIG.get("database", "astock")
|
||
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||
self.db_astock = DBAStockData(self.db_url)
|
||
self.output_dir = r"./output/similar_pattern_stocks/"
|
||
os.makedirs(self.output_dir, exist_ok=True)
|
||
|
||
def get_stock_list(self):
|
||
stock_data = self.db_astock.query_stock_data()
|
||
# 仅获取market为主板,创业板与科创板的股票
|
||
stock_data = stock_data[stock_data["market"].isin(["主板", "创业板", "科创板"])]
|
||
return stock_data
|
||
|
||
def get_stock_market_data(
|
||
self,
|
||
symbol: str,
|
||
bar: str,
|
||
start_date: str,
|
||
end_date: str,
|
||
compare_record_amount: int = 100,
|
||
):
|
||
fields = [
|
||
"a.ts_code as symbol",
|
||
"b.name as symbol_name",
|
||
f"'{bar}' as bar",
|
||
"trade_date as date_time",
|
||
"open",
|
||
"high",
|
||
"low",
|
||
"close",
|
||
"vol as volume",
|
||
"MA5 as ma5",
|
||
"MA10 as ma10",
|
||
"MA20 as ma20",
|
||
"MA30 as ma30",
|
||
"均线交叉 as ma_cross",
|
||
"DIF as dif",
|
||
"DEA as dea",
|
||
"MACD as macd",
|
||
]
|
||
if bar == "1W":
|
||
table_name = "stock_weekly_price_from_2020"
|
||
elif bar == "1M":
|
||
table_name = "stock_monthly_price_from_2015"
|
||
else:
|
||
table_name = "stock_daily_price_from_2021"
|
||
data = self.db_astock.query_market_data_by_symbol_bar(
|
||
symbol=symbol,
|
||
fields=fields,
|
||
start=start_date,
|
||
end=end_date,
|
||
table_name=table_name,
|
||
)
|
||
data = pd.DataFrame(data)
|
||
data.sort_values(by=["date_time"], inplace=True)
|
||
# 获取最后一百条数据
|
||
data = data.tail(compare_record_amount)
|
||
data.reset_index(drop=True, inplace=True)
|
||
return data
|
||
|
||
def get_stock_market_data_similar_pattern(
|
||
self,
|
||
target_symbol: str,
|
||
bar: str,
|
||
start_date: str = None,
|
||
end_date: str = None,
|
||
compare_record_amount: int = 100,
|
||
):
|
||
"""
|
||
1. 获取目标股票的market数据
|
||
2. 根据end_date, 获取目标股票的最后一百条数据
|
||
3. 遍历所有股票,获取所有股票的最后一百条数据
|
||
4. 计算目标股票与所有股票的最后一百条数据的相似度
|
||
5. 返回相似度最高的十只股票
|
||
6. 绘制目标股票与相似度最高的十只股票的对比图, 通过K线图展示
|
||
"""
|
||
logger.info(f"获取目标股票{target_symbol}的{bar}数据, 截止日期{end_date}")
|
||
target_stock_data = self.get_stock_market_data(
|
||
target_symbol, bar, start_date, end_date, compare_record_amount
|
||
)
|
||
if len(target_stock_data) == 0:
|
||
logger.error(f"目标股票{target_symbol}的{bar}数据为空")
|
||
return []
|
||
compare_record_amount = len(target_stock_data)
|
||
stock_list_data = self.get_stock_list()
|
||
if len(stock_list_data) == 0:
|
||
logger.error("所有股票数据为空")
|
||
return []
|
||
stock_list_data = stock_list_data[stock_list_data["ts_code"] != target_symbol]
|
||
now_date = datetime.now().strftime("%Y%m%d")
|
||
similarity_data = []
|
||
for index, row in stock_list_data.iterrows():
|
||
logger.info(f"获取第{index+1}只股票{row['ts_code']} {row['name']} 的{bar}数据, 截止日期{now_date}")
|
||
stock_data = self.get_stock_market_data(
|
||
symbol=row["ts_code"],
|
||
bar=bar,
|
||
start_date=None,
|
||
end_date=now_date,
|
||
compare_record_amount=compare_record_amount,
|
||
)
|
||
if len(stock_data) == 0:
|
||
logger.error(f"股票{row['ts_code']}的{bar}数据为空")
|
||
continue
|
||
if len(stock_data) != compare_record_amount:
|
||
logger.error(f"股票{row['ts_code']}的数据不足{compare_record_amount}条")
|
||
continue
|
||
similarity = self.calculate_similarity(target_stock_data, stock_data)
|
||
similarity_data.append(
|
||
{
|
||
"symbol": row["ts_code"],
|
||
"symbol_name": row["name"],
|
||
"similarity_distance": similarity,
|
||
"stock_data": stock_data,
|
||
}
|
||
)
|
||
# 按照similarity_distance升序排序
|
||
similarity_data.sort(key=lambda x: x["similarity_distance"])
|
||
|
||
pure_data = []
|
||
for item in similarity_data:
|
||
pure_data.append(
|
||
{
|
||
"symbol": item["symbol"],
|
||
"symbol_name": item["symbol_name"],
|
||
"similarity_distance": item["similarity_distance"],
|
||
}
|
||
)
|
||
pure_data = pd.DataFrame(pure_data)
|
||
# 去除similarity_distance为空或nan的数据
|
||
pure_data = pure_data[pure_data["similarity_distance"].notna()]
|
||
pure_data.sort_values(by=["similarity_distance"], inplace=True)
|
||
pure_data.reset_index(drop=True, inplace=True)
|
||
target_stock_symbol = str(target_stock_data["symbol"].iloc[0]).split(".")[0]
|
||
target_stock_name = str(target_stock_data["symbol_name"].iloc[0])
|
||
if end_date is not None and len(end_date) > 0:
|
||
target_stock_folder = os.path.join(
|
||
self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{end_date}"
|
||
)
|
||
else:
|
||
target_stock_folder = os.path.join(
|
||
self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{now_date}"
|
||
)
|
||
os.makedirs(target_stock_folder, exist_ok=True)
|
||
excel_file_path = os.path.join(
|
||
target_stock_folder,
|
||
f"{target_stock_symbol}_{target_stock_name}_{bar}_similar_stocks.xlsx",
|
||
)
|
||
with pd.ExcelWriter(excel_file_path) as writer:
|
||
pure_data.to_excel(writer, sheet_name="股票形态相似度", index=False)
|
||
|
||
similar_stocks_chart_folder = os.path.join(
|
||
target_stock_folder, f"similar_stocks_chart"
|
||
)
|
||
os.makedirs(similar_stocks_chart_folder, exist_ok=True)
|
||
chart_list = []
|
||
chart_info = self.draw_similar_stocks_chart(
|
||
target_stock_symbol,
|
||
target_stock_name,
|
||
bar,
|
||
"0",
|
||
target_stock_data,
|
||
similar_stocks_chart_folder,
|
||
)
|
||
chart_list.append(chart_info)
|
||
for index, row in pure_data.iterrows():
|
||
symbol = row["symbol"]
|
||
symbol_name = row["symbol_name"]
|
||
similarity_distance = row["similarity_distance"]
|
||
for item in similarity_data:
|
||
if item["symbol"] == symbol:
|
||
stock_data = item["stock_data"]
|
||
break
|
||
chart_info = self.draw_similar_stocks_chart(
|
||
symbol,
|
||
symbol_name,
|
||
bar,
|
||
similarity_distance,
|
||
stock_data,
|
||
similar_stocks_chart_folder,
|
||
)
|
||
chart_list.append(chart_info)
|
||
if index >= 9:
|
||
break
|
||
self.output_chart_to_excel(excel_file_path, chart_list)
|
||
return chart_list
|
||
|
||
def output_chart_to_excel(self, excel_file_path: str, charts_list: list):
|
||
"""
|
||
输出Excel文件,包含所有图表
|
||
charts_list: 图表数据列表,格式为:
|
||
{
|
||
"chart_path": "chart_path",
|
||
"symbol": "symbol",
|
||
"stock_name": "stock_name",
|
||
"bar": "bar",
|
||
"similarity_distance": "similarity_distance",
|
||
}
|
||
"""
|
||
logger.info(f"将图表输出到{excel_file_path}")
|
||
|
||
# 打开已经存在的Excel文件
|
||
wb = openpyxl.load_workbook(excel_file_path)
|
||
ws = wb.create_sheet(title="股票K线图")
|
||
row_offset = 1
|
||
for chart_info in charts_list:
|
||
chart_path = chart_info["chart_path"]
|
||
chart_name = chart_info["symbol"] + " " + chart_info["stock_name"] + " " + chart_info["bar"] + " 距离:" + str(chart_info["similarity_distance"])
|
||
# Load image to get dimensions
|
||
with PILImage.open(chart_path) as img:
|
||
width_px, height_px = img.size
|
||
|
||
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
|
||
pixels_per_point = 1.333
|
||
points_per_row = 15 # Default row height in points
|
||
pixels_per_row = (
|
||
points_per_row * pixels_per_point
|
||
) # ≈ 20 pixels per row
|
||
chart_rows = max(
|
||
10, int(height_px / pixels_per_row)
|
||
) # Minimum 10 rows for small charts
|
||
|
||
# Add chart title
|
||
# 支持中文标题
|
||
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
|
||
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
|
||
row_offset += 2 # Add 2 rows for title and spacing
|
||
|
||
# Insert chart image
|
||
img = Image(chart_path)
|
||
ws.add_image(img, f"A{row_offset}")
|
||
|
||
# Update row offset (chart height + padding)
|
||
row_offset += (
|
||
chart_rows + 5
|
||
) # Add 5 rows for padding between charts
|
||
# Save Excel file
|
||
wb.save(excel_file_path)
|
||
logger.info(f"Chart saved as {excel_file_path}")
|
||
|
||
def draw_similar_stocks_chart(
|
||
self,
|
||
symbol: str,
|
||
stock_name: str,
|
||
bar: str,
|
||
similarity_distance: float,
|
||
stock_data: pd.DataFrame,
|
||
similar_stocks_chart_folder: str,
|
||
):
|
||
"""
|
||
绘制股票K线图, 并保存到similar_stocks_chart_folder
|
||
"""
|
||
pure_stock_name = stock_name.replace("*", "")
|
||
chart_path = os.path.join(
|
||
similar_stocks_chart_folder, f"{symbol}_{pure_stock_name}_{bar}_chart.png"
|
||
)
|
||
|
||
# 准备数据:确保日期列名为Date,并转换为datetime
|
||
df = stock_data.copy()
|
||
|
||
# 检查并重命名日期列
|
||
if "trade_date" in df.columns:
|
||
df = df.rename(columns={"trade_date": "Date"})
|
||
elif "date_time" in df.columns:
|
||
df = df.rename(columns={"date_time": "Date"})
|
||
|
||
# 确保Date列为datetime类型
|
||
df["Date"] = pd.to_datetime(df["Date"])
|
||
df = df.set_index("Date")
|
||
|
||
# 确保OHLC列名正确
|
||
df = df.rename(
|
||
columns={"open": "Open", "high": "High", "low": "Low", "close": "Close"}
|
||
)
|
||
|
||
# 使用mplfinance绘制K线图(接近Seaborn风格,红涨绿跌)
|
||
mc = mpf.make_marketcolors(
|
||
up="#e41a1c", # 红色-上涨(A股常用)
|
||
down="#4daf4a", # 绿色-下跌(A股常用)
|
||
edge="inherit",
|
||
wick="inherit",
|
||
volume="inherit",
|
||
)
|
||
mpf_style = mpf.make_mpf_style(
|
||
base_mpf_style="yahoo",
|
||
marketcolors=mc,
|
||
facecolor="#FAFAFA",
|
||
edgecolor="#EAEAEA",
|
||
gridcolor="#E6E6E6",
|
||
gridstyle="--",
|
||
rc={"font.family": "SimHei", "axes.unicode_minus": False},
|
||
)
|
||
mpf.plot(
|
||
df,
|
||
type="candle",
|
||
style=mpf_style,
|
||
title=f"{symbol} {stock_name} {bar} 距离:{similarity_distance}",
|
||
ylabel="价格",
|
||
figsize=(12, 6),
|
||
savefig={"fname": chart_path, "dpi": 150, "bbox_inches": "tight"},
|
||
)
|
||
chart_info = {
|
||
"chart_path": chart_path,
|
||
"symbol": symbol,
|
||
"stock_name": stock_name,
|
||
"bar": bar,
|
||
"similarity_distance": similarity_distance,
|
||
}
|
||
return chart_info
|
||
|
||
def calculate_similarity(
|
||
self,
|
||
target_stock_data: pd.DataFrame,
|
||
stock_data: pd.DataFrame,
|
||
close_weight=0.8,
|
||
volume_weight=0.2,
|
||
):
|
||
"""
|
||
通过股价归一化以及欧氏距离,计算目标股票与股票的相似度
|
||
"""
|
||
target_stock_close = target_stock_data["close"]
|
||
stock_close = stock_data["close"]
|
||
target_stock_close = (target_stock_close - target_stock_close.min()) / (
|
||
target_stock_close.max() - target_stock_close.min()
|
||
)
|
||
stock_close = (stock_close - stock_close.min()) / (
|
||
stock_close.max() - stock_close.min()
|
||
)
|
||
close_distance = np.linalg.norm(target_stock_close - stock_close)
|
||
|
||
target_stock_volume = target_stock_data["volume"]
|
||
stock_volume = stock_data["volume"]
|
||
target_stock_volume = (target_stock_volume - target_stock_volume.min()) / (
|
||
target_stock_volume.max() - target_stock_volume.min()
|
||
)
|
||
stock_volume = (stock_volume - stock_volume.min()) / (
|
||
stock_volume.max() - stock_volume.min()
|
||
)
|
||
volume_distance = np.linalg.norm(target_stock_volume - stock_volume)
|
||
|
||
similarity_distance = (
|
||
close_weight * close_distance + volume_weight * volume_distance
|
||
)
|
||
return float(similarity_distance)
|