crypto_quant/core/statistics/similar_pattern_stocks.py

382 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import mplfinance as mpf
from datetime import datetime, timedelta, timezone
from core.utils import get_current_date_time
import re
import json
import math
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import (
A_MYSQL_CONFIG,
)
from core.db.db_astock import DBAStockData
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
# 统一Seaborn外观
sns.set_theme(style="whitegrid")
logger = logging.logger
class SimilarPatternStocks:
def __init__(self):
mysql_user = A_MYSQL_CONFIG.get("user", "root")
mysql_password = A_MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = A_MYSQL_CONFIG.get("host", "localhost")
mysql_port = A_MYSQL_CONFIG.get("port", 3306)
mysql_database = A_MYSQL_CONFIG.get("database", "astock")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_astock = DBAStockData(self.db_url)
self.output_dir = r"./output/similar_pattern_stocks/"
os.makedirs(self.output_dir, exist_ok=True)
def get_stock_list(self):
stock_data = self.db_astock.query_stock_data()
# 仅获取market为主板创业板与科创板的股票
stock_data = stock_data[stock_data["market"].isin(["主板", "创业板", "科创板"])]
return stock_data
def get_stock_market_data(
self,
symbol: str,
bar: str,
start_date: str,
end_date: str,
compare_record_amount: int = 100,
):
fields = [
"a.ts_code as symbol",
"b.name as symbol_name",
f"'{bar}' as bar",
"trade_date as date_time",
"open",
"high",
"low",
"close",
"vol as volume",
"MA5 as ma5",
"MA10 as ma10",
"MA20 as ma20",
"MA30 as ma30",
"均线交叉 as ma_cross",
"DIF as dif",
"DEA as dea",
"MACD as macd",
]
if bar == "1W":
table_name = "stock_weekly_price_from_2020"
elif bar == "1M":
table_name = "stock_monthly_price_from_2015"
else:
table_name = "stock_daily_price_from_2021"
data = self.db_astock.query_market_data_by_symbol_bar(
symbol=symbol,
fields=fields,
start=start_date,
end=end_date,
table_name=table_name,
)
data = pd.DataFrame(data)
data.sort_values(by=["date_time"], inplace=True)
# 获取最后一百条数据
data = data.tail(compare_record_amount)
data.reset_index(drop=True, inplace=True)
return data
def get_stock_market_data_similar_pattern(
self,
target_symbol: str,
bar: str,
start_date: str = None,
end_date: str = None,
compare_record_amount: int = 100,
):
"""
1. 获取目标股票的market数据
2. 根据end_date, 获取目标股票的最后一百条数据
3. 遍历所有股票,获取所有股票的最后一百条数据
4. 计算目标股票与所有股票的最后一百条数据的相似度
5. 返回相似度最高的十只股票
6. 绘制目标股票与相似度最高的十只股票的对比图, 通过K线图展示
"""
logger.info(f"获取目标股票{target_symbol}{bar}数据, 截止日期{end_date}")
target_stock_data = self.get_stock_market_data(
target_symbol, bar, start_date, end_date, compare_record_amount
)
if len(target_stock_data) == 0:
logger.error(f"目标股票{target_symbol}{bar}数据为空")
return []
compare_record_amount = len(target_stock_data)
stock_list_data = self.get_stock_list()
if len(stock_list_data) == 0:
logger.error("所有股票数据为空")
return []
stock_list_data = stock_list_data[stock_list_data["ts_code"] != target_symbol]
now_date = datetime.now().strftime("%Y%m%d")
similarity_data = []
for index, row in stock_list_data.iterrows():
logger.info(f"获取第{index+1}只股票{row['ts_code']} {row['name']}{bar}数据, 截止日期{now_date}")
stock_data = self.get_stock_market_data(
symbol=row["ts_code"],
bar=bar,
start_date=None,
end_date=now_date,
compare_record_amount=compare_record_amount,
)
if len(stock_data) == 0:
logger.error(f"股票{row['ts_code']}{bar}数据为空")
continue
if len(stock_data) != compare_record_amount:
logger.error(f"股票{row['ts_code']}的数据不足{compare_record_amount}")
continue
similarity = self.calculate_similarity(target_stock_data, stock_data)
similarity_data.append(
{
"symbol": row["ts_code"],
"symbol_name": row["name"],
"similarity_distance": similarity,
"stock_data": stock_data,
}
)
# 按照similarity_distance升序排序
similarity_data.sort(key=lambda x: x["similarity_distance"])
pure_data = []
for item in similarity_data:
pure_data.append(
{
"symbol": item["symbol"],
"symbol_name": item["symbol_name"],
"similarity_distance": item["similarity_distance"],
}
)
pure_data = pd.DataFrame(pure_data)
# 去除similarity_distance为空或nan的数据
pure_data = pure_data[pure_data["similarity_distance"].notna()]
pure_data.sort_values(by=["similarity_distance"], inplace=True)
pure_data.reset_index(drop=True, inplace=True)
target_stock_symbol = str(target_stock_data["symbol"].iloc[0]).split(".")[0]
target_stock_name = str(target_stock_data["symbol_name"].iloc[0])
if end_date is not None and len(end_date) > 0:
target_stock_folder = os.path.join(
self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{end_date}"
)
else:
target_stock_folder = os.path.join(
self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{now_date}"
)
os.makedirs(target_stock_folder, exist_ok=True)
excel_file_path = os.path.join(
target_stock_folder,
f"{target_stock_symbol}_{target_stock_name}_{bar}_similar_stocks.xlsx",
)
with pd.ExcelWriter(excel_file_path) as writer:
pure_data.to_excel(writer, sheet_name="股票形态相似度", index=False)
similar_stocks_chart_folder = os.path.join(
target_stock_folder, f"similar_stocks_chart"
)
os.makedirs(similar_stocks_chart_folder, exist_ok=True)
chart_list = []
chart_info = self.draw_similar_stocks_chart(
target_stock_symbol,
target_stock_name,
bar,
"0",
target_stock_data,
similar_stocks_chart_folder,
)
chart_list.append(chart_info)
for index, row in pure_data.iterrows():
symbol = row["symbol"]
symbol_name = row["symbol_name"]
similarity_distance = row["similarity_distance"]
for item in similarity_data:
if item["symbol"] == symbol:
stock_data = item["stock_data"]
break
chart_info = self.draw_similar_stocks_chart(
symbol,
symbol_name,
bar,
similarity_distance,
stock_data,
similar_stocks_chart_folder,
)
chart_list.append(chart_info)
if index >= 9:
break
self.output_chart_to_excel(excel_file_path, chart_list)
return chart_list
def output_chart_to_excel(self, excel_file_path: str, charts_list: list):
"""
输出Excel文件包含所有图表
charts_list: 图表数据列表,格式为:
{
"chart_path": "chart_path",
"symbol": "symbol",
"stock_name": "stock_name",
"bar": "bar",
"similarity_distance": "similarity_distance",
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
ws = wb.create_sheet(title="股票K线图")
row_offset = 1
for chart_info in charts_list:
chart_path = chart_info["chart_path"]
chart_name = chart_info["symbol"] + " " + chart_info["stock_name"] + " " + chart_info["bar"] + " 距离:" + str(chart_info["similarity_distance"])
# Load image to get dimensions
with PILImage.open(chart_path) as img:
width_px, height_px = img.size
# Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels)
pixels_per_point = 1.333
points_per_row = 15 # Default row height in points
pixels_per_row = (
points_per_row * pixels_per_point
) # ≈ 20 pixels per row
chart_rows = max(
10, int(height_px / pixels_per_row)
) # Minimum 10 rows for small charts
# Add chart title
# 支持中文标题
ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8")
ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12)
row_offset += 2 # Add 2 rows for title and spacing
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
# Update row offset (chart height + padding)
row_offset += (
chart_rows + 5
) # Add 5 rows for padding between charts
# Save Excel file
wb.save(excel_file_path)
logger.info(f"Chart saved as {excel_file_path}")
def draw_similar_stocks_chart(
self,
symbol: str,
stock_name: str,
bar: str,
similarity_distance: float,
stock_data: pd.DataFrame,
similar_stocks_chart_folder: str,
):
"""
绘制股票K线图, 并保存到similar_stocks_chart_folder
"""
pure_stock_name = stock_name.replace("*", "")
chart_path = os.path.join(
similar_stocks_chart_folder, f"{symbol}_{pure_stock_name}_{bar}_chart.png"
)
# 准备数据确保日期列名为Date并转换为datetime
df = stock_data.copy()
# 检查并重命名日期列
if "trade_date" in df.columns:
df = df.rename(columns={"trade_date": "Date"})
elif "date_time" in df.columns:
df = df.rename(columns={"date_time": "Date"})
# 确保Date列为datetime类型
df["Date"] = pd.to_datetime(df["Date"])
df = df.set_index("Date")
# 确保OHLC列名正确
df = df.rename(
columns={"open": "Open", "high": "High", "low": "Low", "close": "Close"}
)
# 使用mplfinance绘制K线图接近Seaborn风格红涨绿跌
mc = mpf.make_marketcolors(
up="#e41a1c", # 红色-上涨A股常用
down="#4daf4a", # 绿色-下跌A股常用
edge="inherit",
wick="inherit",
volume="inherit",
)
mpf_style = mpf.make_mpf_style(
base_mpf_style="yahoo",
marketcolors=mc,
facecolor="#FAFAFA",
edgecolor="#EAEAEA",
gridcolor="#E6E6E6",
gridstyle="--",
rc={"font.family": "SimHei", "axes.unicode_minus": False},
)
mpf.plot(
df,
type="candle",
style=mpf_style,
title=f"{symbol} {stock_name} {bar} 距离:{similarity_distance}",
ylabel="价格",
figsize=(12, 6),
savefig={"fname": chart_path, "dpi": 150, "bbox_inches": "tight"},
)
chart_info = {
"chart_path": chart_path,
"symbol": symbol,
"stock_name": stock_name,
"bar": bar,
"similarity_distance": similarity_distance,
}
return chart_info
def calculate_similarity(
self,
target_stock_data: pd.DataFrame,
stock_data: pd.DataFrame,
close_weight=0.8,
volume_weight=0.2,
):
"""
通过股价归一化以及欧氏距离,计算目标股票与股票的相似度
"""
target_stock_close = target_stock_data["close"]
stock_close = stock_data["close"]
target_stock_close = (target_stock_close - target_stock_close.min()) / (
target_stock_close.max() - target_stock_close.min()
)
stock_close = (stock_close - stock_close.min()) / (
stock_close.max() - stock_close.min()
)
close_distance = np.linalg.norm(target_stock_close - stock_close)
target_stock_volume = target_stock_data["volume"]
stock_volume = stock_data["volume"]
target_stock_volume = (target_stock_volume - target_stock_volume.min()) / (
target_stock_volume.max() - target_stock_volume.min()
)
stock_volume = (stock_volume - stock_volume.min()) / (
stock_volume.max() - stock_volume.min()
)
volume_distance = np.linalg.norm(target_stock_volume - stock_volume)
similarity_distance = (
close_weight * close_distance + volume_weight * volume_distance
)
return float(similarity_distance)