import core.logger as logging import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns import mplfinance as mpf from datetime import datetime, timedelta, timezone from core.utils import get_current_date_time import re import json import math from openpyxl import Workbook from openpyxl.drawing.image import Image import openpyxl from openpyxl.styles import Font from PIL import Image as PILImage from config import ( A_MYSQL_CONFIG, ) from core.db.db_astock import DBAStockData # seaborn支持中文 plt.rcParams["font.family"] = ["SimHei"] plt.rcParams["axes.unicode_minus"] = False # 统一Seaborn外观 sns.set_theme(style="whitegrid") logger = logging.logger class SimilarPatternStocks: def __init__(self): mysql_user = A_MYSQL_CONFIG.get("user", "root") mysql_password = A_MYSQL_CONFIG.get("password", "") if not mysql_password: raise ValueError("MySQL password is not set") mysql_host = A_MYSQL_CONFIG.get("host", "localhost") mysql_port = A_MYSQL_CONFIG.get("port", 3306) mysql_database = A_MYSQL_CONFIG.get("database", "astock") self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" self.db_astock = DBAStockData(self.db_url) self.output_dir = r"./output/similar_pattern_stocks/" os.makedirs(self.output_dir, exist_ok=True) def get_stock_list(self): stock_data = self.db_astock.query_stock_data() # 仅获取market为主板,创业板与科创板的股票 stock_data = stock_data[stock_data["market"].isin(["主板", "创业板", "科创板"])] return stock_data def get_stock_market_data( self, symbol: str, bar: str, start_date: str, end_date: str, compare_record_amount: int = 100, ): fields = [ "a.ts_code as symbol", "b.name as symbol_name", f"'{bar}' as bar", "trade_date as date_time", "open", "high", "low", "close", "vol as volume", "MA5 as ma5", "MA10 as ma10", "MA20 as ma20", "MA30 as ma30", "均线交叉 as ma_cross", "DIF as dif", "DEA as dea", "MACD as macd", ] if bar == "1W": table_name = "stock_weekly_price_from_2020" elif bar == "1M": table_name = "stock_monthly_price_from_2015" else: table_name = "stock_daily_price_from_2021" data = self.db_astock.query_market_data_by_symbol_bar( symbol=symbol, fields=fields, start=start_date, end=end_date, table_name=table_name, ) data = pd.DataFrame(data) data.sort_values(by=["date_time"], inplace=True) # 获取最后一百条数据 data = data.tail(compare_record_amount) data.reset_index(drop=True, inplace=True) return data def get_stock_market_data_similar_pattern( self, target_symbol: str, bar: str, start_date: str = None, end_date: str = None, compare_record_amount: int = 100, ): """ 1. 获取目标股票的market数据 2. 根据end_date, 获取目标股票的最后一百条数据 3. 遍历所有股票,获取所有股票的最后一百条数据 4. 计算目标股票与所有股票的最后一百条数据的相似度 5. 返回相似度最高的十只股票 6. 绘制目标股票与相似度最高的十只股票的对比图, 通过K线图展示 """ logger.info(f"获取目标股票{target_symbol}的{bar}数据, 截止日期{end_date}") target_stock_data = self.get_stock_market_data( target_symbol, bar, start_date, end_date, compare_record_amount ) if len(target_stock_data) == 0: logger.error(f"目标股票{target_symbol}的{bar}数据为空") return [] compare_record_amount = len(target_stock_data) stock_list_data = self.get_stock_list() if len(stock_list_data) == 0: logger.error("所有股票数据为空") return [] stock_list_data = stock_list_data[stock_list_data["ts_code"] != target_symbol] now_date = datetime.now().strftime("%Y%m%d") similarity_data = [] for index, row in stock_list_data.iterrows(): logger.info(f"获取第{index+1}只股票{row['ts_code']} {row['name']} 的{bar}数据, 截止日期{now_date}") stock_data = self.get_stock_market_data( symbol=row["ts_code"], bar=bar, start_date=None, end_date=now_date, compare_record_amount=compare_record_amount, ) if len(stock_data) == 0: logger.error(f"股票{row['ts_code']}的{bar}数据为空") continue if len(stock_data) != compare_record_amount: logger.error(f"股票{row['ts_code']}的数据不足{compare_record_amount}条") continue similarity = self.calculate_similarity(target_stock_data, stock_data) similarity_data.append( { "symbol": row["ts_code"], "symbol_name": row["name"], "similarity_distance": similarity, "stock_data": stock_data, } ) # 按照similarity_distance升序排序 similarity_data.sort(key=lambda x: x["similarity_distance"]) pure_data = [] for item in similarity_data: pure_data.append( { "symbol": item["symbol"], "symbol_name": item["symbol_name"], "similarity_distance": item["similarity_distance"], } ) pure_data = pd.DataFrame(pure_data) # 去除similarity_distance为空或nan的数据 pure_data = pure_data[pure_data["similarity_distance"].notna()] pure_data.sort_values(by=["similarity_distance"], inplace=True) pure_data.reset_index(drop=True, inplace=True) target_stock_symbol = str(target_stock_data["symbol"].iloc[0]).split(".")[0] target_stock_name = str(target_stock_data["symbol_name"].iloc[0]) if end_date is not None and len(end_date) > 0: target_stock_folder = os.path.join( self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{end_date}" ) else: target_stock_folder = os.path.join( self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{now_date}" ) os.makedirs(target_stock_folder, exist_ok=True) excel_file_path = os.path.join( target_stock_folder, f"{target_stock_symbol}_{target_stock_name}_{bar}_similar_stocks.xlsx", ) with pd.ExcelWriter(excel_file_path) as writer: pure_data.to_excel(writer, sheet_name="股票形态相似度", index=False) similar_stocks_chart_folder = os.path.join( target_stock_folder, f"similar_stocks_chart" ) os.makedirs(similar_stocks_chart_folder, exist_ok=True) chart_list = [] chart_info = self.draw_similar_stocks_chart( target_stock_symbol, target_stock_name, bar, "0", target_stock_data, similar_stocks_chart_folder, ) chart_list.append(chart_info) for index, row in pure_data.iterrows(): symbol = row["symbol"] symbol_name = row["symbol_name"] similarity_distance = row["similarity_distance"] for item in similarity_data: if item["symbol"] == symbol: stock_data = item["stock_data"] break chart_info = self.draw_similar_stocks_chart( symbol, symbol_name, bar, similarity_distance, stock_data, similar_stocks_chart_folder, ) chart_list.append(chart_info) if index >= 9: break self.output_chart_to_excel(excel_file_path, chart_list) return chart_list def output_chart_to_excel(self, excel_file_path: str, charts_list: list): """ 输出Excel文件,包含所有图表 charts_list: 图表数据列表,格式为: { "chart_path": "chart_path", "symbol": "symbol", "stock_name": "stock_name", "bar": "bar", "similarity_distance": "similarity_distance", } """ logger.info(f"将图表输出到{excel_file_path}") # 打开已经存在的Excel文件 wb = openpyxl.load_workbook(excel_file_path) ws = wb.create_sheet(title="股票K线图") row_offset = 1 for chart_info in charts_list: chart_path = chart_info["chart_path"] chart_name = chart_info["symbol"] + " " + chart_info["stock_name"] + " " + chart_info["bar"] + " 距离:" + str(chart_info["similarity_distance"]) # Load image to get dimensions with PILImage.open(chart_path) as img: width_px, height_px = img.size # Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels) pixels_per_point = 1.333 points_per_row = 15 # Default row height in points pixels_per_row = ( points_per_row * pixels_per_point ) # ≈ 20 pixels per row chart_rows = max( 10, int(height_px / pixels_per_row) ) # Minimum 10 rows for small charts # Add chart title # 支持中文标题 ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8") ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12) row_offset += 2 # Add 2 rows for title and spacing # Insert chart image img = Image(chart_path) ws.add_image(img, f"A{row_offset}") # Update row offset (chart height + padding) row_offset += ( chart_rows + 5 ) # Add 5 rows for padding between charts # Save Excel file wb.save(excel_file_path) logger.info(f"Chart saved as {excel_file_path}") def draw_similar_stocks_chart( self, symbol: str, stock_name: str, bar: str, similarity_distance: float, stock_data: pd.DataFrame, similar_stocks_chart_folder: str, ): """ 绘制股票K线图, 并保存到similar_stocks_chart_folder """ pure_stock_name = stock_name.replace("*", "") chart_path = os.path.join( similar_stocks_chart_folder, f"{symbol}_{pure_stock_name}_{bar}_chart.png" ) # 准备数据:确保日期列名为Date,并转换为datetime df = stock_data.copy() # 检查并重命名日期列 if "trade_date" in df.columns: df = df.rename(columns={"trade_date": "Date"}) elif "date_time" in df.columns: df = df.rename(columns={"date_time": "Date"}) # 确保Date列为datetime类型 df["Date"] = pd.to_datetime(df["Date"]) df = df.set_index("Date") # 确保OHLC列名正确 df = df.rename( columns={"open": "Open", "high": "High", "low": "Low", "close": "Close"} ) # 使用mplfinance绘制K线图(接近Seaborn风格,红涨绿跌) mc = mpf.make_marketcolors( up="#e41a1c", # 红色-上涨(A股常用) down="#4daf4a", # 绿色-下跌(A股常用) edge="inherit", wick="inherit", volume="inherit", ) mpf_style = mpf.make_mpf_style( base_mpf_style="yahoo", marketcolors=mc, facecolor="#FAFAFA", edgecolor="#EAEAEA", gridcolor="#E6E6E6", gridstyle="--", rc={"font.family": "SimHei", "axes.unicode_minus": False}, ) mpf.plot( df, type="candle", style=mpf_style, title=f"{symbol} {stock_name} {bar} 距离:{similarity_distance}", ylabel="价格", figsize=(12, 6), savefig={"fname": chart_path, "dpi": 150, "bbox_inches": "tight"}, ) chart_info = { "chart_path": chart_path, "symbol": symbol, "stock_name": stock_name, "bar": bar, "similarity_distance": similarity_distance, } return chart_info def calculate_similarity( self, target_stock_data: pd.DataFrame, stock_data: pd.DataFrame, close_weight=0.8, volume_weight=0.2, ): """ 通过股价归一化以及欧氏距离,计算目标股票与股票的相似度 """ target_stock_close = target_stock_data["close"] stock_close = stock_data["close"] target_stock_close = (target_stock_close - target_stock_close.min()) / ( target_stock_close.max() - target_stock_close.min() ) stock_close = (stock_close - stock_close.min()) / ( stock_close.max() - stock_close.min() ) close_distance = np.linalg.norm(target_stock_close - stock_close) target_stock_volume = target_stock_data["volume"] stock_volume = stock_data["volume"] target_stock_volume = (target_stock_volume - target_stock_volume.min()) / ( target_stock_volume.max() - target_stock_volume.min() ) stock_volume = (stock_volume - stock_volume.min()) / ( stock_volume.max() - stock_volume.min() ) volume_distance = np.linalg.norm(target_stock_volume - stock_volume) similarity_distance = ( close_weight * close_distance + volume_weight * volume_distance ) return float(similarity_distance)