diff --git a/core/db/db_astock.py b/core/db/db_astock.py index 26eed5a..797b5fc 100644 --- a/core/db/db_astock.py +++ b/core/db/db_astock.py @@ -54,7 +54,6 @@ class DBAStockData: def query_market_data_by_symbol_bar( self, symbol: str, - bar: str, fields: list = None, start: str = None, end: str = None, @@ -123,3 +122,21 @@ class DBAStockData: """ condition_dict = {"symbol": symbol, "end": end} return self.query_data(sql, condition_dict, return_multi=True) + + def query_index_data(self): + sql = f""" + SELECT * FROM all_index a + order by ts_code + """ + condition_dict = {} + data = self.query_data(sql, condition_dict, return_multi=True) + return pd.DataFrame(data) + + def query_stock_data(self): + sql = f""" + SELECT * FROM all_stock a + order by ts_code + """ + condition_dict = {} + data = self.query_data(sql, condition_dict, return_multi=True) + return pd.DataFrame(data) diff --git a/core/statistics/similar_pattern_stocks.py b/core/statistics/similar_pattern_stocks.py new file mode 100644 index 0000000..098025f --- /dev/null +++ b/core/statistics/similar_pattern_stocks.py @@ -0,0 +1,381 @@ +import core.logger as logging +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import mplfinance as mpf +from datetime import datetime, timedelta, timezone +from core.utils import get_current_date_time +import re +import json +import math +from openpyxl import Workbook +from openpyxl.drawing.image import Image +import openpyxl +from openpyxl.styles import Font +from PIL import Image as PILImage +from config import ( + A_MYSQL_CONFIG, +) +from core.db.db_astock import DBAStockData + +# seaborn支持中文 +plt.rcParams["font.family"] = ["SimHei"] +plt.rcParams["axes.unicode_minus"] = False +# 统一Seaborn外观 +sns.set_theme(style="whitegrid") + +logger = logging.logger + + +class SimilarPatternStocks: + def __init__(self): + mysql_user = A_MYSQL_CONFIG.get("user", "root") + mysql_password = A_MYSQL_CONFIG.get("password", "") + if not mysql_password: + raise ValueError("MySQL password is not set") + mysql_host = A_MYSQL_CONFIG.get("host", "localhost") + mysql_port = A_MYSQL_CONFIG.get("port", 3306) + mysql_database = A_MYSQL_CONFIG.get("database", "astock") + self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" + self.db_astock = DBAStockData(self.db_url) + self.output_dir = r"./output/similar_pattern_stocks/" + os.makedirs(self.output_dir, exist_ok=True) + + def get_stock_list(self): + stock_data = self.db_astock.query_stock_data() + # 仅获取market为主板,创业板与科创板的股票 + stock_data = stock_data[stock_data["market"].isin(["主板", "创业板", "科创板"])] + return stock_data + + def get_stock_market_data( + self, + symbol: str, + bar: str, + start_date: str, + end_date: str, + compare_record_amount: int = 100, + ): + fields = [ + "a.ts_code as symbol", + "b.name as symbol_name", + f"'{bar}' as bar", + "trade_date as date_time", + "open", + "high", + "low", + "close", + "vol as volume", + "MA5 as ma5", + "MA10 as ma10", + "MA20 as ma20", + "MA30 as ma30", + "均线交叉 as ma_cross", + "DIF as dif", + "DEA as dea", + "MACD as macd", + ] + if bar == "1W": + table_name = "stock_weekly_price_from_2020" + elif bar == "1M": + table_name = "stock_monthly_price_from_2015" + else: + table_name = "stock_daily_price_from_2021" + data = self.db_astock.query_market_data_by_symbol_bar( + symbol=symbol, + fields=fields, + start=start_date, + end=end_date, + table_name=table_name, + ) + data = pd.DataFrame(data) + data.sort_values(by=["date_time"], inplace=True) + # 获取最后一百条数据 + data = data.tail(compare_record_amount) + data.reset_index(drop=True, inplace=True) + return data + + def get_stock_market_data_similar_pattern( + self, + target_symbol: str, + bar: str, + start_date: str = None, + end_date: str = None, + compare_record_amount: int = 100, + ): + """ + 1. 获取目标股票的market数据 + 2. 根据end_date, 获取目标股票的最后一百条数据 + 3. 遍历所有股票,获取所有股票的最后一百条数据 + 4. 计算目标股票与所有股票的最后一百条数据的相似度 + 5. 返回相似度最高的十只股票 + 6. 绘制目标股票与相似度最高的十只股票的对比图, 通过K线图展示 + """ + logger.info(f"获取目标股票{target_symbol}的{bar}数据, 截止日期{end_date}") + target_stock_data = self.get_stock_market_data( + target_symbol, bar, start_date, end_date, compare_record_amount + ) + if len(target_stock_data) == 0: + logger.error(f"目标股票{target_symbol}的{bar}数据为空") + return [] + compare_record_amount = len(target_stock_data) + stock_list_data = self.get_stock_list() + if len(stock_list_data) == 0: + logger.error("所有股票数据为空") + return [] + stock_list_data = stock_list_data[stock_list_data["ts_code"] != target_symbol] + now_date = datetime.now().strftime("%Y%m%d") + similarity_data = [] + for index, row in stock_list_data.iterrows(): + logger.info(f"获取第{index+1}只股票{row['ts_code']} {row['name']} 的{bar}数据, 截止日期{now_date}") + stock_data = self.get_stock_market_data( + symbol=row["ts_code"], + bar=bar, + start_date=None, + end_date=now_date, + compare_record_amount=compare_record_amount, + ) + if len(stock_data) == 0: + logger.error(f"股票{row['ts_code']}的{bar}数据为空") + continue + if len(stock_data) != compare_record_amount: + logger.error(f"股票{row['ts_code']}的数据不足{compare_record_amount}条") + continue + similarity = self.calculate_similarity(target_stock_data, stock_data) + similarity_data.append( + { + "symbol": row["ts_code"], + "symbol_name": row["name"], + "similarity_distance": similarity, + "stock_data": stock_data, + } + ) + # 按照similarity_distance升序排序 + similarity_data.sort(key=lambda x: x["similarity_distance"]) + + pure_data = [] + for item in similarity_data: + pure_data.append( + { + "symbol": item["symbol"], + "symbol_name": item["symbol_name"], + "similarity_distance": item["similarity_distance"], + } + ) + pure_data = pd.DataFrame(pure_data) + # 去除similarity_distance为空或nan的数据 + pure_data = pure_data[pure_data["similarity_distance"].notna()] + pure_data.sort_values(by=["similarity_distance"], inplace=True) + pure_data.reset_index(drop=True, inplace=True) + target_stock_symbol = str(target_stock_data["symbol"].iloc[0]).split(".")[0] + target_stock_name = str(target_stock_data["symbol_name"].iloc[0]) + if end_date is not None and len(end_date) > 0: + target_stock_folder = os.path.join( + self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{end_date}" + ) + else: + target_stock_folder = os.path.join( + self.output_dir, f"{target_stock_symbol}_{target_stock_name}_{bar}_{now_date}" + ) + os.makedirs(target_stock_folder, exist_ok=True) + excel_file_path = os.path.join( + target_stock_folder, + f"{target_stock_symbol}_{target_stock_name}_{bar}_similar_stocks.xlsx", + ) + with pd.ExcelWriter(excel_file_path) as writer: + pure_data.to_excel(writer, sheet_name="股票形态相似度", index=False) + + similar_stocks_chart_folder = os.path.join( + target_stock_folder, f"similar_stocks_chart" + ) + os.makedirs(similar_stocks_chart_folder, exist_ok=True) + chart_list = [] + chart_info = self.draw_similar_stocks_chart( + target_stock_symbol, + target_stock_name, + bar, + "0", + target_stock_data, + similar_stocks_chart_folder, + ) + chart_list.append(chart_info) + for index, row in pure_data.iterrows(): + symbol = row["symbol"] + symbol_name = row["symbol_name"] + similarity_distance = row["similarity_distance"] + for item in similarity_data: + if item["symbol"] == symbol: + stock_data = item["stock_data"] + break + chart_info = self.draw_similar_stocks_chart( + symbol, + symbol_name, + bar, + similarity_distance, + stock_data, + similar_stocks_chart_folder, + ) + chart_list.append(chart_info) + if index >= 9: + break + self.output_chart_to_excel(excel_file_path, chart_list) + return chart_list + + def output_chart_to_excel(self, excel_file_path: str, charts_list: list): + """ + 输出Excel文件,包含所有图表 + charts_list: 图表数据列表,格式为: + { + "chart_path": "chart_path", + "symbol": "symbol", + "stock_name": "stock_name", + "bar": "bar", + "similarity_distance": "similarity_distance", + } + """ + logger.info(f"将图表输出到{excel_file_path}") + + # 打开已经存在的Excel文件 + wb = openpyxl.load_workbook(excel_file_path) + ws = wb.create_sheet(title="股票K线图") + row_offset = 1 + for chart_info in charts_list: + chart_path = chart_info["chart_path"] + chart_name = chart_info["symbol"] + " " + chart_info["stock_name"] + " " + chart_info["bar"] + " 距离:" + str(chart_info["similarity_distance"]) + # Load image to get dimensions + with PILImage.open(chart_path) as img: + width_px, height_px = img.size + + # Convert pixel height to Excel row height (approximate: 1 point = 1.333 pixels, 1 row ≈ 15 points for 20 pixels) + pixels_per_point = 1.333 + points_per_row = 15 # Default row height in points + pixels_per_row = ( + points_per_row * pixels_per_point + ) # ≈ 20 pixels per row + chart_rows = max( + 10, int(height_px / pixels_per_row) + ) # Minimum 10 rows for small charts + + # Add chart title + # 支持中文标题 + ws[f"A{row_offset}"] = chart_name.encode("utf-8").decode("utf-8") + ws[f"A{row_offset}"].font = openpyxl.styles.Font(bold=True, size=12) + row_offset += 2 # Add 2 rows for title and spacing + + # Insert chart image + img = Image(chart_path) + ws.add_image(img, f"A{row_offset}") + + # Update row offset (chart height + padding) + row_offset += ( + chart_rows + 5 + ) # Add 5 rows for padding between charts + # Save Excel file + wb.save(excel_file_path) + logger.info(f"Chart saved as {excel_file_path}") + + def draw_similar_stocks_chart( + self, + symbol: str, + stock_name: str, + bar: str, + similarity_distance: float, + stock_data: pd.DataFrame, + similar_stocks_chart_folder: str, + ): + """ + 绘制股票K线图, 并保存到similar_stocks_chart_folder + """ + pure_stock_name = stock_name.replace("*", "") + chart_path = os.path.join( + similar_stocks_chart_folder, f"{symbol}_{pure_stock_name}_{bar}_chart.png" + ) + + # 准备数据:确保日期列名为Date,并转换为datetime + df = stock_data.copy() + + # 检查并重命名日期列 + if "trade_date" in df.columns: + df = df.rename(columns={"trade_date": "Date"}) + elif "date_time" in df.columns: + df = df.rename(columns={"date_time": "Date"}) + + # 确保Date列为datetime类型 + df["Date"] = pd.to_datetime(df["Date"]) + df = df.set_index("Date") + + # 确保OHLC列名正确 + df = df.rename( + columns={"open": "Open", "high": "High", "low": "Low", "close": "Close"} + ) + + # 使用mplfinance绘制K线图(接近Seaborn风格,红涨绿跌) + mc = mpf.make_marketcolors( + up="#e41a1c", # 红色-上涨(A股常用) + down="#4daf4a", # 绿色-下跌(A股常用) + edge="inherit", + wick="inherit", + volume="inherit", + ) + mpf_style = mpf.make_mpf_style( + base_mpf_style="yahoo", + marketcolors=mc, + facecolor="#FAFAFA", + edgecolor="#EAEAEA", + gridcolor="#E6E6E6", + gridstyle="--", + rc={"font.family": "SimHei", "axes.unicode_minus": False}, + ) + mpf.plot( + df, + type="candle", + style=mpf_style, + title=f"{symbol} {stock_name} {bar} 距离:{similarity_distance}", + ylabel="价格", + figsize=(12, 6), + savefig={"fname": chart_path, "dpi": 150, "bbox_inches": "tight"}, + ) + chart_info = { + "chart_path": chart_path, + "symbol": symbol, + "stock_name": stock_name, + "bar": bar, + "similarity_distance": similarity_distance, + } + return chart_info + + def calculate_similarity( + self, + target_stock_data: pd.DataFrame, + stock_data: pd.DataFrame, + close_weight=0.8, + volume_weight=0.2, + ): + """ + 通过股价归一化以及欧氏距离,计算目标股票与股票的相似度 + """ + target_stock_close = target_stock_data["close"] + stock_close = stock_data["close"] + target_stock_close = (target_stock_close - target_stock_close.min()) / ( + target_stock_close.max() - target_stock_close.min() + ) + stock_close = (stock_close - stock_close.min()) / ( + stock_close.max() - stock_close.min() + ) + close_distance = np.linalg.norm(target_stock_close - stock_close) + + target_stock_volume = target_stock_data["volume"] + stock_volume = stock_data["volume"] + target_stock_volume = (target_stock_volume - target_stock_volume.min()) / ( + target_stock_volume.max() - target_stock_volume.min() + ) + stock_volume = (stock_volume - stock_volume.min()) / ( + stock_volume.max() - stock_volume.min() + ) + volume_distance = np.linalg.norm(target_stock_volume - stock_volume) + + similarity_distance = ( + close_weight * close_distance + volume_weight * volume_distance + ) + return float(similarity_distance) diff --git a/core/trade/ma_break_statistics.py b/core/trade/ma_break_statistics.py index a374334..ca14451 100644 --- a/core/trade/ma_break_statistics.py +++ b/core/trade/ma_break_statistics.py @@ -864,7 +864,6 @@ class MaBreakStatistics: logger.info(f"获取{symbol}数据:{start_date_str}至{current_end_date_str}") current_data = self.db_market_data.query_market_data_by_symbol_bar( symbol, - bar, fields, start=start_date_str, end=current_end_date_str, @@ -944,7 +943,7 @@ class MaBreakStatistics: "MACD as macd", ] data = self.db_market_data.query_market_data_by_symbol_bar( - symbol, bar, fields, start=last_date, end=end_date, table_name=table_name + symbol, fields, start=last_date, end=end_date, table_name=table_name ) if data is not None and len(data) > 0: data = pd.DataFrame(data) diff --git a/similar_pattern_stocks_main.py b/similar_pattern_stocks_main.py new file mode 100644 index 0000000..1dae3b9 --- /dev/null +++ b/similar_pattern_stocks_main.py @@ -0,0 +1,45 @@ +import core.logger as logging +from core.statistics.similar_pattern_stocks import SimilarPatternStocks +import os + +logger = logging.logger + +def main(): + similar_pattern_stocks = SimilarPatternStocks() + + target_stock_list = [ + { + "symbol": "600111.SH", + "bar": "1W", + "start_date": None, + "end_date": "20250711", + "compare_record_amount": 100, + }, + { + "symbol": "600111.SH", + "bar": "1M", + "start_date": None, + "end_date": "20250630", + "compare_record_amount": 100, + }, + { + "symbol": "601398.SH", + "bar": "1M", + "start_date": None, + "end_date": "20230430", + "compare_record_amount": 100, + } + ] + + for target_stock in target_stock_list: + similar_pattern_stocks.get_stock_market_data_similar_pattern( + target_symbol=target_stock["symbol"], + bar=target_stock["bar"], + start_date=target_stock["start_date"], + end_date=target_stock["end_date"], + compare_record_amount=target_stock["compare_record_amount"], + ) + + +if __name__ == "__main__": + main() \ No newline at end of file