crypto_quant/core/trade/ma_break_statistics.py

1537 lines
68 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta, timezone
from core.utils import get_current_date_time
import re
import json
import math
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import (
OKX_MONITOR_CONFIG,
US_STOCK_MONITOR_CONFIG,
COIN_MYSQL_CONFIG,
A_MYSQL_CONFIG,
WINDOW_SIZE,
BINANCE_MONITOR_CONFIG,
A_STOCK_MONITOR_CONFIG,
A_INDEX_MONITOR_CONFIG,
)
from core.biz.metrics_calculation import MetricsCalculation
from core.db.db_market_data import DBMarketData
from core.db.db_astock import DBAStockData
from core.db.db_binance_data import DBBinanceData
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MaBreakStatistics:
"""
统计MA突破之后的涨跌幅
MA向上突破的点位周期K线5 > 10 > 20 > 30
统计MA向上突破的点位周期K线突破之后
下一个MA向下突破的点位周期K线30 > 20 > 10 > 5
之间的涨跌幅
"""
def __init__(
self,
is_us_stock: bool = False,
is_astock: bool = False,
is_aindex: bool = False,
is_binance: bool = True,
buy_by_long_period: dict = {
"by_week": False,
"by_month": False,
"buy_by_10_percentile": False,
},
long_period_condition: dict = {
"ma5>ma10": True,
"ma10>ma20": False,
"macd_diff>0": True,
"macd>0": True,
},
cut_loss_by_valleys_median: bool = False,
commission_per_share: float = 0.0008,
):
if is_astock or is_aindex:
mysql_user = A_MYSQL_CONFIG.get("user", "root")
mysql_password = A_MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = A_MYSQL_CONFIG.get("host", "localhost")
mysql_port = A_MYSQL_CONFIG.get("port", 3306)
mysql_database = A_MYSQL_CONFIG.get("database", "astock")
else:
mysql_user = COIN_MYSQL_CONFIG.get("user", "xch")
mysql_password = COIN_MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = COIN_MYSQL_CONFIG.get("host", "localhost")
mysql_port = COIN_MYSQL_CONFIG.get("port", 3306)
mysql_database = COIN_MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
self.is_us_stock = is_us_stock
self.is_astock = is_astock
self.is_aindex = is_aindex
if self.is_us_stock:
self.date_time_field = "date_time_us"
else:
self.date_time_field = "date_time"
self.is_binance = is_binance
if is_us_stock:
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"]
)
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
elif is_astock:
self.symbols = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["000001.SH"]
)
self.bars = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = A_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBAStockData(self.db_url)
elif is_aindex:
self.symbols = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["000001.SH"]
)
self.bars = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
self.initial_date = A_INDEX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBAStockData(self.db_url)
else:
if is_binance:
self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["BTC-USDT"]
)
self.bars = ["30m", "1H"]
self.initial_date = BINANCE_MONITOR_CONFIG.get(
"volume_monitor", {}
).get("initial_date", "2017-08-16 00:00:00")
self.db_market_data = DBBinanceData(self.db_url)
else:
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
)
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
)
self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2025-05-15 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
if len(self.initial_date) > 10:
self.initial_date = self.initial_date[:10]
self.end_date = get_current_date_time(format="%Y-%m-%d")
self.commission_per_share = commission_per_share
self.trade_strategy_config = self.get_trade_strategy_config()
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
self.buy_by_long_period = buy_by_long_period
self.long_period_condition = long_period_condition
self.cut_loss_by_valleys_median = cut_loss_by_valleys_median
self.metrics_calculation = MetricsCalculation()
def get_trade_strategy_config(self):
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
trade_strategy_config = json.load(f)
return trade_strategy_config
def get_by_long_period_desc(self):
by_week = self.buy_by_long_period.get("by_week", False)
by_month = self.buy_by_long_period.get("by_month", False)
by_long_period = ""
if by_week:
by_long_period += "1W"
if by_month:
by_long_period += "1M"
if self.buy_by_long_period.get("buy_by_10_percentile", False):
by_long_period += "_10percentile"
if by_long_period == "":
return "no_long_period_judge"
by_condition = ""
if self.long_period_condition.get("ma5>ma10", False):
by_condition += "ma5gtma10"
if self.long_period_condition.get("ma10>ma20", False):
by_condition += "_ma10gtma20"
if self.long_period_condition.get("macd_diff>0", False):
by_condition += "_macd_diffgt0"
if self.long_period_condition.get("macd>0", False):
by_condition += "_macdgt0"
return by_long_period + "_" + by_condition
def batch_statistics(self, strategy_name: str = "全均线策略"):
if self.is_us_stock:
main_folder = "./output/trade_sandbox/ma_strategy/us_stock/"
if self.cut_loss_by_valleys_median:
main_folder += "cut_loss_by_valleys_median/"
else:
main_folder += "no_cut_loss_by_valleys_median/"
self.stats_output_dir = f"{main_folder}excel/{strategy_name}/"
self.stats_chart_dir = f"{main_folder}chart/{strategy_name}/"
elif self.is_binance:
main_folder = "./output/trade_sandbox/ma_strategy/binance/"
if self.cut_loss_by_valleys_median:
main_folder += "cut_loss_by_valleys_median/"
else:
main_folder += "no_cut_loss_by_valleys_median/"
self.stats_output_dir = f"{main_folder}excel/{strategy_name}/"
self.stats_chart_dir = f"{main_folder}chart/{strategy_name}/"
elif self.is_astock:
long_period_desc = self.get_by_long_period_desc()
main_folder = "./output/trade_sandbox/ma_strategy/astock/"
if self.cut_loss_by_valleys_median:
main_folder += "cut_loss_by_valleys_median/"
else:
main_folder += "no_cut_loss_by_valleys_median/"
if len(long_period_desc) > 0:
self.stats_output_dir = (
f"{main_folder}{long_period_desc}/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"{main_folder}{long_period_desc}/chart/{strategy_name}/"
)
else:
self.stats_output_dir = f"{main_folder}excel/{strategy_name}/"
self.stats_chart_dir = f"{main_folder}chart/{strategy_name}/"
elif self.is_aindex:
main_folder = "./output/trade_sandbox/ma_strategy/aindex/"
if self.cut_loss_by_valleys_median:
main_folder += "cut_loss_by_valleys_median/"
else:
main_folder += "no_cut_loss_by_valleys_median/"
long_period_desc = self.get_by_long_period_desc()
if len(long_period_desc) > 0:
self.stats_output_dir = (
f"{main_folder}{long_period_desc}/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"{main_folder}{long_period_desc}/chart/{strategy_name}/"
)
else:
self.stats_output_dir = f"{main_folder}excel/{strategy_name}/"
self.stats_chart_dir = f"{main_folder}chart/{strategy_name}/"
else:
main_folder = "./output/trade_sandbox/ma_strategy/okx/"
if self.cut_loss_by_valleys_median:
main_folder += "cut_loss_by_valleys_median/"
else:
main_folder += "no_cut_loss_by_valleys_median/"
self.stats_output_dir = f"{main_folder}excel/{strategy_name}/"
self.stats_chart_dir = f"{main_folder}chart/{strategy_name}/"
os.makedirs(self.stats_output_dir, exist_ok=True)
os.makedirs(self.stats_chart_dir, exist_ok=True)
ma_break_market_data_list = []
market_data_pct_chg_list = []
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
strategy_name = "全均线策略"
for symbol in self.symbols:
for bar in self.bars:
logger.info(
f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name},交易费率:{self.commission_per_share}"
)
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
symbol, bar, strategy_name
)
if (
ma_break_market_data is not None
and len(ma_break_market_data) > 0
and market_data_pct_chg is not None
):
ma_break_market_data_list.append(ma_break_market_data)
logger.info(
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
)
market_data_pct_chg_list.append(market_data_pct_chg)
if len(ma_break_market_data_list) > 0:
ma_break_market_data = pd.concat(ma_break_market_data_list)
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
# 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
ma_break_market_data.sort_values(
by="begin_timestamp", ascending=True, inplace=True
)
ma_break_market_data.reset_index(drop=True, inplace=True)
account_value_chg_list = []
for symbol in market_data_pct_chg_df["symbol"].unique():
for bar in market_data_pct_chg_df["bar"].unique():
symbol_bar_data = ma_break_market_data[
(ma_break_market_data["symbol"] == symbol)
& (ma_break_market_data["bar"] == bar)
].copy() # 创建副本避免SettingWithCopyWarning
if len(symbol_bar_data) > 0:
symbol_bar_data.sort_values(
by="end_timestamp", ascending=True, inplace=True
)
symbol_bar_data.reset_index(drop=True, inplace=True)
initial_capital = int(
market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"initial_capital",
].values[0]
)
final_account_value = float(
symbol_bar_data["end_account_value"].iloc[-1]
)
account_value_chg = (
(final_account_value - initial_capital)
/ initial_capital
* 100
)
account_value_chg = round(account_value_chg, 4)
market_pct_chg = market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"pct_chg",
].values[0]
total_buy_commission = float(
symbol_bar_data["buy_commission"].sum()
)
total_sell_commission = float(
symbol_bar_data["sell_commission"].sum()
)
total_commission = total_buy_commission + total_sell_commission
total_commission = round(total_commission, 4)
total_buy_commission = round(total_buy_commission, 4)
total_sell_commission = round(total_sell_commission, 4)
symbol_name = str(symbol_bar_data["symbol_name"].iloc[0])
account_value_chg_list.append(
{
"strategy_name": strategy_name,
"symbol": symbol,
"symbol_name": symbol_name,
"bar": bar,
"total_buy_commission": total_buy_commission,
"total_sell_commission": total_sell_commission,
"total_commission": total_commission,
"initial_account_value": initial_capital,
"final_account_value": final_account_value,
"account_value_chg": account_value_chg,
"market_pct_chg": market_pct_chg,
}
)
account_value_chg_df = pd.DataFrame(account_value_chg_list)
account_value_chg_df = account_value_chg_df[
[
"strategy_name",
"symbol",
"symbol_name",
"bar",
"total_buy_commission",
"total_sell_commission",
"total_commission",
"initial_account_value",
"final_account_value",
"account_value_chg",
"market_pct_chg",
]
]
account_value_statistics_df = (
ma_break_market_data.groupby(["symbol", "symbol_name", "bar"])[
"end_account_value"
]
.agg(
account_value_max="max",
account_value_min="min",
account_value_mean="mean",
account_value_std="std",
account_value_median="median",
account_value_count="count",
)
.reset_index()
)
account_value_statistics_df["strategy_name"] = strategy_name
account_value_statistics_df = account_value_statistics_df[
[
"strategy_name",
"symbol",
"symbol_name",
"bar",
"account_value_max",
"account_value_min",
"account_value_mean",
"account_value_std",
"account_value_median",
"account_value_count",
]
]
# 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
interval_minutes_df = (
ma_break_market_data.groupby(["symbol", "symbol_name", "bar"])[
"interval_minutes"
]
.agg(
interval_minutes_max="max",
interval_minutes_min="min",
interval_minutes_mean="mean",
interval_minutes_std="std",
interval_minutes_median="median",
interval_minutes_count="count",
)
.reset_index()
)
interval_minutes_df["strategy_name"] = strategy_name
interval_minutes_df = interval_minutes_df[
[
"strategy_name",
"symbol",
"symbol_name",
"bar",
"interval_minutes_max",
"interval_minutes_min",
"interval_minutes_mean",
"interval_minutes_std",
"interval_minutes_median",
"interval_minutes_count",
]
]
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
earliest_market_date_time = re.sub(
r"[\:\-\s]", "", str(earliest_market_date_time)
)
latest_market_date_time = ma_break_market_data["end_date_time"].max()
if latest_market_date_time is None:
latest_market_date_time = get_current_date_time(format="%Y%m%d%H%M%S")
latest_market_date_time = re.sub(
r"[\:\-\s]", "", str(latest_market_date_time)
)
if self.commission_per_share > 0:
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}_with_commission.xlsx"
else:
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}_without_commission.xlsx"
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
logger.info(f"导出{output_file_path}")
strategy_info_df = self.get_strategy_info(strategy_name)
with pd.ExcelWriter(output_file_path) as writer:
strategy_info_df.to_excel(writer, sheet_name="策略信息", index=False)
ma_break_market_data.to_excel(
writer, sheet_name="买卖记录明细", index=False
)
account_value_chg_df.to_excel(
writer, sheet_name="资产价值变化", index=False
)
account_value_statistics_df.to_excel(
writer, sheet_name="买卖账户价值统计", index=False
)
interval_minutes_df.to_excel(
writer, sheet_name="买卖时间间隔统计", index=False
)
chart_dict = self.draw_quant_pct_chg_bar_chart(
account_value_chg_df, strategy_name
)
self.output_chart_to_excel(output_file_path, chart_dict)
chart_dict = self.draw_quant_line_chart(
ma_break_market_data, market_data_pct_chg_df, strategy_name
)
self.output_chart_to_excel(output_file_path, chart_dict)
return account_value_chg_df
else:
return None
def get_strategy_info(self, strategy_name: str = "全均线策略"):
strategy_config = self.main_strategy.get(strategy_name, None)
if strategy_config is None:
logger.error(f"策略{strategy_name}不存在")
return None
strategy_info = {"策略名称": strategy_name, "买入策略": "", "卖出策略": ""}
buy_dict = strategy_config.get("buy", {})
buy_and_list = buy_dict.get("and", [])
buy_or_list = buy_dict.get("or", [])
buy_and_text = ""
buy_or_text = ""
for and_condition in buy_and_list:
buy_and_text += f"{and_condition}, \n"
if len(buy_or_list) > 0:
for or_condition in buy_or_list:
buy_or_text += f"{or_condition}, \n"
if len(buy_or_text) > 0:
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
else:
strategy_info["买入策略"] = buy_and_text
# 假如根据长周期判断买入,则需要设置长周期策略
by_week = self.buy_by_long_period.get("by_week", False)
by_month = self.buy_by_long_period.get("by_month", False)
if by_week:
strategy_info["买入策略"] += "根据周线指标,\n"
if by_month:
strategy_info["买入策略"] += "根据月线指标,\n"
if self.long_period_condition.get("ma5>ma10", False):
strategy_info["买入策略"] += "ma5>ma10, \n"
if self.long_period_condition.get("ma10>ma20", False):
strategy_info["买入策略"] += "ma10>ma20, \n"
if self.long_period_condition.get("macd_diff>0", False):
strategy_info["买入策略"] += "macd_diff>0, \n"
if self.long_period_condition.get("macd>0", False):
strategy_info["买入策略"] += "macd>0, \n"
strategy_info["买入策略"] = strategy_info["买入策略"].strip()
sell_dict = strategy_config.get("sell", {})
sell_and_list = sell_dict.get("and", [])
sell_or_list = sell_dict.get("or", [])
sell_and_text = ""
sell_or_text = ""
for and_condition in sell_and_list:
sell_and_text += f"{and_condition}, \n"
if len(sell_or_list) > 0:
for or_condition in sell_or_list:
sell_or_text += f"{or_condition}, \n"
if len(sell_or_text) > 0:
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
else:
strategy_info["卖出策略"] = sell_and_text
strategy_info["卖出策略"] = strategy_info["卖出策略"].strip()
# 将strategy_info转换为pd.DataFrame
strategy_info_df = pd.DataFrame([strategy_info])
return strategy_info_df
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
market_data = self.get_full_data(symbol, bar)
if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败")
return None, None
else:
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
# 获得ma5, ma10, ma20, ma30不为空的行
market_data = market_data[
(market_data["ma5"].notna())
& (market_data["ma10"].notna())
& (market_data["ma20"].notna())
& (market_data["ma30"].notna())
]
logger.info(
f"ma5, ma10, ma20, ma30不为空的行数据条数: {len(market_data)}"
)
# 计算volume_ma5
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
market_data["volume_pct_chg"] = (
market_data["volume"] - market_data["volume_ma5"]
) / market_data["volume_ma5"]
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
# 按照timestamp排序
market_data = market_data.sort_values(by="timestamp", ascending=True)
# 获得ma_break_market_data的close列
market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = []
ma_break_market_data_pair = {}
close_mean = market_data["close"].mean()
self.update_initial_capital(close_mean)
logger.info(
f"成功获取{symbol}数据:{len(market_data)}{bar}K线,开始日期={market_data[self.date_time_field].min()},结束日期={market_data[self.date_time_field].max()}"
)
account_value = self.initial_capital
for index, row in market_data.iterrows():
if self.is_astock:
symbol_name = row["symbol_name"]
elif self.is_aindex:
symbol_name = row["symbol_name"]
else:
symbol_name = row["symbol"]
ma_cross = row["ma_cross"]
timestamp = row["timestamp"]
date_time = row[self.date_time_field]
close = row["close"]
ma5 = row["ma5"]
ma10 = row["ma10"]
ma20 = row["ma20"]
ma30 = row["ma30"]
macd_diff = float(row["dif"])
macd_dea = float(row["dea"])
macd = float(row["macd"])
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
buy_condition = self.fit_strategy(
strategy_name=strategy_name,
row=row,
behavior="buy",
buy_price=None,
window_100_valleys_median=None,
)
if buy_condition:
entry_price = close
# 计算交易股数
shares, account_value = self.calculate_shares(
account_value, entry_price
)
if shares == 0:
# 股数为0→不交易
continue
# 计算佣金
buy_commission = shares * close * self.commission_per_share
ma_break_market_data_pair = {}
ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["symbol_name"] = symbol_name
ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp
ma_break_market_data_pair["begin_date_time"] = date_time
ma_break_market_data_pair["begin_close"] = close
ma_break_market_data_pair["begin_ma5"] = ma5
ma_break_market_data_pair["begin_ma10"] = ma10
ma_break_market_data_pair["begin_ma20"] = ma20
ma_break_market_data_pair["begin_ma30"] = ma30
ma_break_market_data_pair["begin_macd_diff"] = macd_diff
ma_break_market_data_pair["begin_macd_dea"] = macd_dea
ma_break_market_data_pair["begin_macd"] = macd
ma_break_market_data_pair["shares"] = shares
ma_break_market_data_pair["buy_commission"] = buy_commission
ma_break_market_data_pair["begin_account_value"] = account_value
continue
else:
valleys_median = None
if self.cut_loss_by_valleys_median and index >= 100:
window_100_records = market_data.iloc[index - 100 : index]
peaks_valleys = self.metrics_calculation.get_peaks_valleys_mean(
window_100_records
)
valleys_median = peaks_valleys.get("valleys_median", None)
if valleys_median is not None and valleys_median > 0:
valleys_median = valleys_median / 100
sell_condition = self.fit_strategy(
strategy_name=strategy_name,
row=row,
behavior="sell",
buy_price=ma_break_market_data_pair["begin_close"],
window_100_valleys_median=valleys_median,
)
if sell_condition or index == len(market_data) - 1:
# 达到卖出条件或者最后一条数据,则卖出
shares = ma_break_market_data_pair["shares"]
entry_price = ma_break_market_data_pair["begin_close"]
exit_price = close
sell_commission = (
shares * exit_price * self.commission_per_share
)
profit_loss = (exit_price - entry_price) * shares
begin_account_value = ma_break_market_data_pair[
"begin_account_value"
]
account_value = (
begin_account_value + profit_loss - sell_commission
)
ma_break_market_data_pair["end_timestamp"] = timestamp
ma_break_market_data_pair["end_date_time"] = date_time
ma_break_market_data_pair["end_close"] = close
ma_break_market_data_pair["end_ma5"] = ma5
ma_break_market_data_pair["end_ma10"] = ma10
ma_break_market_data_pair["end_ma20"] = ma20
ma_break_market_data_pair["end_ma30"] = ma30
ma_break_market_data_pair["end_macd_diff"] = macd_diff
ma_break_market_data_pair["end_macd_dea"] = macd_dea
ma_break_market_data_pair["end_macd"] = macd
ma_break_market_data_pair["pct_chg"] = (
exit_price - entry_price
) / entry_price
ma_break_market_data_pair["pct_chg"] = round(
ma_break_market_data_pair["pct_chg"] * 100, 4
)
if valleys_median is not None:
ma_break_market_data_pair["valleys_median"] = (
valleys_median * 100
)
else:
ma_break_market_data_pair["valleys_median"] = None
ma_break_market_data_pair["profit_loss"] = profit_loss
ma_break_market_data_pair["sell_commission"] = sell_commission
ma_break_market_data_pair["end_account_value"] = account_value
ma_break_market_data_pair["interval_seconds"] = (
timestamp - ma_break_market_data_pair["begin_timestamp"]
) / 1000
# 将interval转换为分钟
ma_break_market_data_pair["interval_minutes"] = (
ma_break_market_data_pair["interval_seconds"] / 60
)
ma_break_market_data_pair["interval_hours"] = (
ma_break_market_data_pair["interval_seconds"] / 3600
)
ma_break_market_data_pair["interval_days"] = (
ma_break_market_data_pair["interval_seconds"] / 86400
)
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
ma_break_market_data_pair = {}
if len(ma_break_market_data_pair_list) > 0:
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
# sort by end_timestamp
ma_break_market_data.sort_values(
by="begin_timestamp", ascending=True, inplace=True
)
ma_break_market_data.reset_index(drop=True, inplace=True)
logger.info(
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
)
# 量化期间,市场的波动率:
# ma_break_market_data(最后一条数据的end_close - 第一条数据的begin_close) / 第一条数据的begin_close * 100
pct_chg = (
(
ma_break_market_data["end_close"].iloc[-1]
- ma_break_market_data["begin_close"].iloc[0]
)
/ ma_break_market_data["begin_close"].iloc[0]
* 100
)
pct_chg = round(pct_chg, 4)
symbol_name = ma_break_market_data["symbol_name"].iloc[0]
market_data_pct_chg = {
"symbol": symbol,
"symbol_name": symbol_name,
"bar": bar,
"pct_chg": pct_chg,
"initial_capital": self.initial_capital,
}
return ma_break_market_data, market_data_pct_chg
else:
return None, None
def update_initial_capital(self, close_mean: float):
self.initial_capital = 25000
if close_mean > 10000:
self.initial_capital = self.initial_capital * 10000
elif close_mean > 5000:
self.initial_capital = self.initial_capital * 5000
elif close_mean > 1000:
self.initial_capital = self.initial_capital * 1000
elif close_mean > 500:
self.initial_capital = self.initial_capital * 500
elif close_mean > 100:
self.initial_capital = self.initial_capital * 100
else:
pass
logger.info(f"收盘价均值:{close_mean}")
logger.info(f"初始资金调整为:{self.initial_capital}")
def calculate_shares(self, account_value, entry_price):
"""
根据ORB公式计算交易股数
:param account_value: 当前账户价值(美元)
:param entry_price: 交易买入价格
:param commission_per_share: 交易佣金, 默认为0.0008
:return: 整数股数Shares
"""
logger.info(
f"开始计算交易股数:账户价值={account_value},买入价格={entry_price}"
)
try:
# 验证输入参数
if account_value <= 0 or entry_price <= 0 or self.commission_per_share < 0:
logger.error("账户价值、买入价格或佣金不能为负或零")
return 0, account_value # 返回0股账户价值不变
# 计算考虑手续费后的每单位BTC总成本
total_cost_per_share = entry_price * (1 + self.commission_per_share)
# 计算可购买的BTC数量向下取整
shares = math.floor(account_value / total_cost_per_share)
# 计算总成本(含手续费)
total_cost = shares * total_cost_per_share
# 计算剩余现金
remaining_cash = account_value - total_cost
# 计算总资产价值 = (购买的BTC数量 × 买入价格) + 剩余现金
remaining_value = (shares * entry_price) + remaining_cash
# 记录计算结果
logger.info(
f"计算结果:可购买股数={shares},总成本={total_cost:.2f}美元,"
f"剩余现金={remaining_cash:.2f}美元,总资产价值={remaining_value:.2f}美元"
)
return shares, remaining_value
except Exception as e:
logger.error(f"计算股数或账户价值时出错:{str(e)}")
return 0, account_value # 出错时返回0股账户价值不变
def get_full_data(self, symbol: str, bar: str = "5m"):
"""
分段获取数据,并将数据合并为完整数据
分段依据如果end_date与start_date相差超过一年则每次取一年数据
"""
data = pd.DataFrame()
start_date = datetime.strptime(self.initial_date, "%Y-%m-%d")
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
table_name = ""
if self.is_astock:
if bar == "1D":
table_name = "stock_daily_price_from_2021"
elif bar == "1W":
table_name = "stock_weekly_price_from_2020"
elif bar == "1M":
table_name = "stock_monthly_price_from_2015"
elif self.is_aindex:
if bar == "1D":
table_name = "index_daily_price_from_2021"
elif bar == "1W":
table_name = "index_weekly_price_from_2020"
elif bar == "1M":
table_name = "index_monthly_price_from_2015"
elif self.is_us_stock:
table_name = "crypto_market_data"
elif self.is_binance:
table_name = "crypto_binance_data"
else:
table_name = "crypto_binance_data"
if self.is_astock or self.is_aindex:
fields = [
"a.ts_code as symbol",
"b.name as symbol_name",
f"'{bar}' as bar",
"0 as timestamp",
"trade_date as date_time",
"open",
"high",
"low",
"close",
"vol as volume",
"MA5 as ma5",
"MA10 as ma10",
"MA20 as ma20",
"MA30 as ma30",
"均线交叉 as ma_cross",
"DIF as dif",
"DEA as dea",
"MACD as macd",
]
else:
fields = [
"symbol",
"bar",
"timestamp",
"date_time",
"date_time_us",
"open",
"high",
"low",
"close",
"volume",
"sar_signal",
"ma5",
"ma10",
"ma20",
"ma30",
"ma_cross",
"dif",
"dea",
"macd",
]
while start_date < end_date:
current_end_date = min(start_date + timedelta(days=180), end_date)
start_date_str = start_date.strftime("%Y-%m-%d")
current_end_date_str = current_end_date.strftime("%Y-%m-%d")
logger.info(f"获取{symbol}数据:{start_date_str}{current_end_date_str}")
current_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol,
fields,
start=start_date_str,
end=current_end_date_str,
table_name=table_name,
)
if current_data is not None and len(current_data) > 0:
current_data = pd.DataFrame(current_data)
data = pd.concat([data, current_data])
start_date = current_end_date
data.drop_duplicates(inplace=True)
data.sort_values(by=self.date_time_field, inplace=True)
data.reset_index(drop=True, inplace=True)
if self.is_astock or self.is_aindex:
data = self.update_data(data)
return data
def get_long_period_data(self, symbol: str, bar: str, end_date: str):
"""
获取长周期数据
:param data: 数据
:return: 长周期数据
"""
if not (self.is_astock or self.is_aindex):
return None
table_name = ""
if self.is_astock:
if bar == "1M":
table_name = "stock_monthly_price_from_2015"
elif bar == "1W":
table_name = "stock_weekly_price_from_2020"
else:
pass
elif self.is_aindex:
if bar == "1M":
table_name = "index_monthly_price_from_2015"
elif bar == "1W":
table_name = "index_weekly_price_from_2020"
else:
pass
if len(end_date) != 10:
end_date = self.change_date_format(end_date)
if bar == "1M":
# 获取上五年的日期
last_date = datetime.strptime(end_date, "%Y-%m-%d") - timedelta(
days=360 * 5
)
last_date = last_date.strftime("%Y-%m-%d")
elif bar == "1W":
# 获取上两年的日期
last_date = datetime.strptime(end_date, "%Y-%m-%d") - timedelta(
days=360 * 2
)
last_date = last_date.strftime("%Y-%m-%d")
else:
last_date = None
if len(table_name) == 0 or last_date is None:
return None
fields = [
"a.ts_code as symbol",
"b.name as symbol_name",
f"'{bar}' as bar",
"0 as timestamp",
"trade_date as date_time",
"open",
"high",
"low",
"close",
"vol as volume",
"MA5 as ma5",
"MA10 as ma10",
"MA20 as ma20",
"MA30 as ma30",
"均线交叉 as ma_cross",
"DIF as dif",
"DEA as dea",
"MACD as macd",
]
data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, fields, start=last_date, end=end_date, table_name=table_name
)
if data is not None and len(data) > 0:
data = pd.DataFrame(data)
data.sort_values(by="date_time", inplace=True)
data = self.metrics_calculation.calculate_percentile_indicators(
data=data,
window_size=50,
price_column="close",
percentiles=[(0.1, "10")],
)
latest_row = data.iloc[-1]
if (
latest_row["ma5"] is None
or latest_row["ma10"] is None
or latest_row["ma20"] is None
or latest_row["dif"] is None
or latest_row["macd"] is None
):
return None
return latest_row
else:
return None
def update_data(self, data: pd.DataFrame):
"""
更新数据
1. 将date_time列中的20210104这种格式替换为2021-01-04的格式
2. 将date_time转换为timestamp并更新timestamp列
3. 通过MetricsCalculation的ma5102030方法更新ma_cross列
:param data: 数据
:return: 更新后的数据
"""
data["date_time"] = data["date_time"].apply(
lambda x: self.change_date_format(x)
)
data["timestamp"] = data["date_time"].apply(
lambda x: transform_date_time_to_timestamp(x)
)
data = self.metrics_calculation.ma5102030(data)
return data
def change_date_format(self, date_text: str):
# 将20210104这种格式替换为2021-01-04的格式
if len(date_text) == 8:
return date_text[0:4] + "-" + date_text[4:6] + "-" + date_text[6:8]
else:
return date_text
def fit_strategy(
self,
strategy_name: str = "全均线策略",
row: pd.Series = None,
behavior: str = "buy",
buy_price: float = None,
window_100_valleys_median: float = None,
):
# 如果行为是卖出,则判断是否根据止损价格卖出
# 止损价格 = 买入价格 * (1 - window_100_valleys_median)
# window_100_valleys_median为100日下跌波谷幅度中位数
# 当前价格 < 止损价格,则卖出
if (
behavior == "sell"
and buy_price is not None
and window_100_valleys_median is not None
):
current_price = float(row["close"])
if current_price < buy_price:
loss_ratio = (buy_price - current_price) / buy_price
if loss_ratio > window_100_valleys_median:
return True
strategy_config = self.main_strategy.get(strategy_name, None)
if strategy_config is None:
logger.error(f"策略{strategy_name}不存在")
return False
condition_dict = strategy_config.get(behavior, None)
if condition_dict is None:
logger.error(f"策略{strategy_name}{behavior}条件不存在")
return False
and_list = condition_dict.get("and", [])
condition = True
condition = self.get_judge_result(row, and_list, "and", condition)
or_list = condition_dict.get("or", [])
condition = self.get_judge_result(row, or_list, "or", condition)
if behavior == "buy" and condition:
# 如果满足条件,则判断是否根据长周期指标买入
bar = row["bar"]
if (self.is_astock or self.is_aindex) and bar == "1D":
date_time = row["date_time"]
long_period_condition_list = []
if self.long_period_condition.get("ma5>ma10", False):
long_period_condition_list.append("ma5>ma10")
if self.long_period_condition.get("ma10>ma20", False):
long_period_condition_list.append("ma10>ma20")
if self.long_period_condition.get("macd_diff>0", False):
long_period_condition_list.append("macd_diff>0")
if self.long_period_condition.get("macd>0", False):
long_period_condition_list.append("macd>0")
if len(long_period_condition_list) > 0:
if self.buy_by_long_period.get("by_week", False):
long_period_data = self.get_long_period_data(
row["symbol"], "1W", date_time
)
if long_period_data is not None:
condition = self.get_judge_result(
long_period_data,
long_period_condition_list,
"and",
condition,
)
if not condition:
# 如果周线处于空头条件但收盘价位于50窗口的低点10分位数则买入
if self.buy_by_long_period.get("buy_by_10_percentile", False):
if long_period_data["close_10_low"] == 1:
condition = True
if not condition:
logger.info(
f"根据周线指标,{row['symbol']}不满足买入条件"
)
if self.buy_by_long_period.get("by_month", False):
long_period_data = self.get_long_period_data(
row["symbol"], "1M", date_time
)
if long_period_data is not None:
condition = self.get_judge_result(
long_period_data,
long_period_condition_list,
"and",
condition,
)
if not condition:
# 如果月线处于空头条件但收盘价位于50窗口的低点10分位数则买入
if self.buy_by_long_period.get("buy_by_10_percentile", False):
if long_period_data["close_10_low"] == 1:
condition = True
if not condition:
logger.info(
f"根据月线指标,{row['symbol']}不满足买入条件"
)
return condition
def get_judge_result(
self,
row: pd.Series,
condition_list: list,
and_or: str = "and",
raw_condition: bool = True,
):
ma_cross = row["ma_cross"]
if pd.isna(ma_cross) or ma_cross is None:
ma_cross = ""
ma_cross = str(ma_cross)
ma5 = float(row["ma5"])
ma10 = float(row["ma10"])
ma20 = float(row["ma20"])
ma30 = float(row["ma30"])
close = float(row["close"])
if "volume_pct_chg" in list(row.index) and row["volume_pct_chg"] is not None:
volume_pct_chg = float(row["volume_pct_chg"])
else:
volume_pct_chg = None
macd_diff = float(row["dif"])
macd_dea = float(row["dea"])
macd = float(row["macd"])
if and_or == "and":
for and_condition in condition_list:
if and_condition == "5上穿10":
raw_condition = raw_condition and ("5上穿10" in ma_cross)
elif and_condition == "10上穿20":
raw_condition = raw_condition and ("10上穿20" in ma_cross)
elif and_condition == "20上穿30":
raw_condition = raw_condition and ("20上穿30" in ma_cross)
elif and_condition == "ma5>ma10":
raw_condition = raw_condition and (ma5 > ma10)
elif and_condition == "ma10>ma20":
raw_condition = raw_condition and (ma10 > ma20)
elif and_condition == "ma20>ma30":
raw_condition = raw_condition and (ma20 > ma30)
elif and_condition == "close>ma20":
raw_condition = raw_condition and (close > ma20)
elif (
and_condition == "volume_pct_chg>0.2" and volume_pct_chg is not None
):
raw_condition = raw_condition and (volume_pct_chg > 0.2)
elif and_condition == "macd_diff>0":
raw_condition = raw_condition and (macd_diff > 0)
elif and_condition == "macd_dea>0":
raw_condition = raw_condition and (macd_dea > 0)
elif and_condition == "macd>0":
raw_condition = raw_condition and (macd > 0)
elif and_condition == "10下穿5":
raw_condition = raw_condition and ("10下穿5" in ma_cross)
elif and_condition == "20下穿10":
raw_condition = raw_condition and ("20下穿10" in ma_cross)
elif and_condition == "30下穿20":
raw_condition = raw_condition and ("30下穿20" in ma_cross)
elif and_condition == "ma5<ma10":
raw_condition = raw_condition and (ma5 < ma10)
elif and_condition == "ma10<ma20":
raw_condition = raw_condition and (ma10 < ma20)
elif and_condition == "ma20<ma30":
raw_condition = raw_condition and (ma20 < ma30)
elif and_condition == "close<ma20":
raw_condition = raw_condition and (close < ma20)
elif and_condition == "macd_diff<0":
raw_condition = raw_condition and (macd_diff < 0)
elif and_condition == "macd_dea<0":
raw_condition = raw_condition and (macd_dea < 0)
elif and_condition == "macd<0":
raw_condition = raw_condition and (macd < 0)
else:
pass
elif and_or == "or":
for or_condition in condition_list:
if or_condition == "5上穿10":
raw_condition = raw_condition or ("5上穿10" in ma_cross)
elif or_condition == "10上穿20":
raw_condition = raw_condition or ("10上穿20" in ma_cross)
elif or_condition == "20上穿30":
raw_condition = raw_condition or ("20上穿30" in ma_cross)
elif or_condition == "ma5>ma10":
raw_condition = raw_condition or (ma5 > ma10)
elif or_condition == "ma10>ma20":
raw_condition = raw_condition or (ma10 > ma20)
elif or_condition == "ma20>ma30":
raw_condition = raw_condition or (ma20 > ma30)
elif or_condition == "close>ma20":
raw_condition = raw_condition or (close > ma20)
elif (
or_condition == "volume_pct_chg>0.2" and volume_pct_chg is not None
):
raw_condition = raw_condition or (volume_pct_chg > 0.2)
elif or_condition == "macd_diff>0":
raw_condition = raw_condition or (macd_diff > 0)
elif or_condition == "macd_dea>0":
raw_condition = raw_condition or (macd_dea > 0)
elif or_condition == "macd>0":
raw_condition = raw_condition or (macd > 0)
elif or_condition == "10下穿5":
raw_condition = raw_condition or ("10下穿5" in ma_cross)
elif or_condition == "20下穿10":
raw_condition = raw_condition or ("20下穿10" in ma_cross)
elif or_condition == "30下穿20":
raw_condition = raw_condition or ("30下穿20" in ma_cross)
elif or_condition == "ma5<ma10":
raw_condition = raw_condition or (ma5 < ma10)
elif or_condition == "ma10<ma20":
raw_condition = raw_condition or (ma10 < ma20)
elif or_condition == "ma20<ma30":
raw_condition = raw_condition or (ma20 < ma30)
elif or_condition == "close<ma20":
raw_condition = raw_condition or (close < ma20)
elif or_condition == "macd_diff<0":
raw_condition = raw_condition or (macd_diff < 0)
elif or_condition == "macd_dea<0":
raw_condition = raw_condition or (macd_dea < 0)
elif or_condition == "macd<0":
raw_condition = raw_condition or (macd < 0)
else:
pass
return raw_condition
def draw_quant_pct_chg_bar_chart(
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
):
"""
绘制pct_chg mean的柱状图表美观保存到self.stats_chart_dir
:param data: 波段pct_chg_mean的数据
:return: None
"""
if data is None or data.empty:
return None
# seaborn风格设置
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
chart_dict = {}
column_name_dict = {
"account_value_chg": "量化策略涨跌",
}
for column_name, column_name_text in column_name_dict.items():
for bar in data["bar"].unique():
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
if bar_data.empty:
continue
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
bar_data.rename(
columns={"market_pct_chg": "市场自然涨跌"}, inplace=True
)
# 可选:按均值排序
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
# 如果column_name_text是"量化策略涨跌",则柱状图,同时绘制量化策略涨跌与市场自然涨跌的柱状图,并绘制在同一个图表中
plt.figure(figsize=(12, 7))
# 设置x轴位置为并列柱状图做准备
x = np.arange(len(bar_data))
width = 0.35 # 柱状图宽度
# 确保symbol列是字符串类型避免matplotlib警告
# bar_data["symbol"] = bar_data["symbol"].astype(str)
bar_data["symbol_name"] = bar_data["symbol_name"].astype(str)
# 绘制量化策略涨跌柱状图(蓝色渐变色)
bars1 = plt.bar(
x - width / 2,
bar_data[column_name_text],
width,
label=column_name_text,
color=plt.cm.Blues(np.linspace(0.6, 0.9, len(bar_data))),
)
# 绘制市场自然涨跌柱状图(绿色渐变色)
bars2 = plt.bar(
x + width / 2,
bar_data["市场自然涨跌"],
width,
label="市场自然涨跌",
color=plt.cm.Greens(np.linspace(0.6, 0.9, len(bar_data))),
)
# 设置图表标题和标签
plt.title(
f"{bar}趋势{column_name_text}与市场自然涨跌对比(%)",
fontsize=14,
fontweight="bold",
)
plt.xlabel("Symbol", fontsize=12)
plt.ylabel("涨跌幅(%)", fontsize=12)
plt.xticks(x, bar_data["symbol_name"], rotation=45, ha="right")
plt.legend()
plt.grid(True, alpha=0.3)
# 在量化策略涨跌柱状图上方添加数值标签
for i, (bar1, value1) in enumerate(
zip(bars1, bar_data[column_name_text])
):
plt.text(
bar1.get_x() + bar1.get_width() / 2,
value1 + (0.01 if value1 >= 0 else -0.01),
f"{value1:.3f}%",
ha="center",
va="bottom" if value1 >= 0 else "top",
fontsize=9,
fontweight="bold",
color="darkblue",
)
# 在市场自然涨跌柱状图上方添加数值标签
for i, (bar2, value2) in enumerate(
zip(bars2, bar_data["市场自然涨跌"])
):
plt.text(
bar2.get_x() + bar2.get_width() / 2,
value2 + (0.01 if value2 >= 0 else -0.01),
f"{value2:.3f}%",
ha="center",
va="bottom" if value2 >= 0 else "top",
fontsize=9,
fontweight="bold",
color="darkgreen",
)
plt.tight_layout()
if self.commission_per_share > 0:
save_path = os.path.join(
self.stats_chart_dir,
f"{bar}_bar_chart_{column_name}_{strategy_name}_with_commission.png",
)
else:
save_path = os.path.join(
self.stats_chart_dir,
f"{bar}_bar_chart_{column_name}_{strategy_name}_without_commission.png",
)
plt.savefig(save_path, dpi=150)
plt.close()
sheet_name = f"{bar}_趋势{column_name_text}柱状图_{strategy_name}"
chart_dict[sheet_name] = save_path
return chart_dict
def draw_quant_line_chart(
self,
data: pd.DataFrame,
market_data_pct_chg_df: pd.DataFrame,
strategy_name: str = "全均线策略",
):
"""
根据量化策略买卖明细记录,绘制量化策略涨跌与市场自然涨跌的折线图
:param data: 量化策略买卖明细记录
:param market_data_pct_chg_df: 市场自然涨跌记录
:param strategy_name: 策略名称
:return: None
"""
symbols = data["symbol_name"].unique()
bars = data["bar"].unique()
chart_dict = {}
for symbol in symbols:
for bar in bars:
symbol_bar_data = data[
(data["symbol_name"] == symbol) & (data["bar"] == bar)
]
if symbol_bar_data.empty:
continue
# 获取第一行数据作为基准
first_row = symbol_bar_data.iloc[0].copy()
initial_capital = int(
market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol_name"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"initial_capital",
].values[0]
)
# 创建初始化行,设置基准值
init_row = first_row.copy()
init_row.loc["profit_loss"] = 0
init_row.loc["end_account_value"] = (
initial_capital # 量化策略初始值为初始资金
)
init_row.loc["end_timestamp"] = first_row["begin_timestamp"]
init_row.loc["end_date_time"] = first_row["begin_date_time"]
init_row.loc["end_close"] = first_row["begin_close"]
init_row.loc["end_ma5"] = first_row["begin_ma5"]
init_row.loc["end_ma10"] = first_row["begin_ma10"]
init_row.loc["end_ma20"] = first_row["begin_ma20"]
init_row.loc["end_ma30"] = first_row["begin_ma30"]
init_row.loc["end_macd_diff"] = first_row["begin_macd_diff"]
init_row.loc["end_macd_dea"] = first_row["begin_macd_dea"]
init_row.loc["end_macd"] = first_row["begin_macd"]
init_row.loc["pct_chg"] = 0
init_row.loc["interval_seconds"] = 0
init_row.loc["interval_minutes"] = 0
init_row.loc["interval_hours"] = 0
init_row.loc["interval_days"] = 0
# 将初始化行添加到数据开头
symbol_bar_data = pd.concat([pd.DataFrame([init_row]), symbol_bar_data])
symbol_bar_data.sort_values(
by="end_timestamp", ascending=True, inplace=True
)
symbol_bar_data.reset_index(drop=True, inplace=True)
# 确保时间列是datetime类型避免matplotlib警告
symbol_bar_data["end_date_time"] = pd.to_datetime(
symbol_bar_data["end_date_time"]
)
# 计算持仓价值归一化数据(相对于初始价格)
symbol_bar_data["end_account_value_to_1"] = (
symbol_bar_data["end_account_value"] / initial_capital
)
symbol_bar_data["end_account_value_to_1"] = symbol_bar_data[
"end_account_value_to_1"
].round(4)
# 计算市场价位归一化数据(相对于初始价格)
symbol_bar_data["end_close_to_1"] = (
symbol_bar_data["end_close"] / init_row["end_close"]
)
symbol_bar_data["end_close_to_1"] = symbol_bar_data[
"end_close_to_1"
].round(4)
# 绘制折线图
plt.figure(figsize=(12, 7))
# 绘制量化策略涨跌线(蓝色)
plt.plot(
symbol_bar_data["end_date_time"],
symbol_bar_data["end_account_value_to_1"],
label="量化策略涨跌",
color="blue",
linewidth=2,
marker="o",
markersize=4,
)
# 绘制市场自然涨跌线(绿色)
plt.plot(
symbol_bar_data["end_date_time"],
symbol_bar_data["end_close_to_1"],
label="市场自然涨跌",
color="green",
linewidth=2,
marker="s",
markersize=4,
)
plt.title(
f"{symbol} {bar} 量化与市场折线图_{strategy_name}",
fontsize=14,
fontweight="bold",
)
plt.xlabel("时间", fontsize=12)
plt.ylabel("涨跌变化", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
# 设置x轴标签避免matplotlib警告
# 选择合适的时间间隔显示标签,避免过于密集
if len(symbol_bar_data) > 30:
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
step = max(1, len(symbol_bar_data) // 30)
# 创建标签索引列表,确保包含首尾数据
label_indices = [0] # 第一条
# 添加中间间隔的标签
for i in range(step, len(symbol_bar_data) - 1, step):
label_indices.append(i)
# 添加最后一条(如果还没有包含的话)
if len(symbol_bar_data) - 1 not in label_indices:
label_indices.append(len(symbol_bar_data) - 1)
# 设置x轴标签
plt.xticks(
symbol_bar_data["end_date_time"].iloc[label_indices],
symbol_bar_data["end_date_time"]
.iloc[label_indices]
.dt.strftime("%Y%m%d %H:%M"),
rotation=45,
ha="right",
)
else:
# 如果数据点较少,全部显示
plt.xticks(
symbol_bar_data["end_date_time"],
symbol_bar_data["end_date_time"].dt.strftime("%Y%m%d %H:%M"),
rotation=45,
ha="right",
)
plt.tight_layout()
if self.commission_per_share > 0:
save_path = os.path.join(
self.stats_chart_dir,
f"{symbol}_{bar}_line_chart_{strategy_name}_with_commission.png",
)
else:
save_path = os.path.join(
self.stats_chart_dir,
f"{symbol}_{bar}_line_chart_{strategy_name}_without_commission.png",
)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
sheet_name = f"{symbol}_{bar}_折线图_{strategy_name}"
chart_dict[sheet_name] = save_path
return chart_dict
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典,格式为:
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_path in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")