crypto_quant/core/trade/ma_break_statistics.py

1099 lines
49 KiB
Python
Raw Normal View History

2025-08-22 10:48:59 +00:00
import core.logger as logging
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
2025-09-16 06:31:15 +00:00
from datetime import datetime, timedelta, timezone
from core.utils import get_current_date_time
2025-08-22 10:48:59 +00:00
import re
import json
import math
2025-08-22 10:48:59 +00:00
from openpyxl import Workbook
from openpyxl.drawing.image import Image
import openpyxl
from openpyxl.styles import Font
from PIL import Image as PILImage
from config import (
OKX_MONITOR_CONFIG,
US_STOCK_MONITOR_CONFIG,
MYSQL_CONFIG,
WINDOW_SIZE,
BINANCE_MONITOR_CONFIG,
)
2025-08-22 10:48:59 +00:00
from core.db.db_market_data import DBMarketData
2025-09-15 06:12:47 +00:00
from core.db.db_binance_data import DBBinanceData
2025-08-22 10:48:59 +00:00
from core.db.db_huge_volume_data import DBHugeVolumeData
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
# seaborn支持中文
plt.rcParams["font.family"] = ["SimHei"]
logger = logging.logger
class MaBreakStatistics:
"""
统计MA突破之后的涨跌幅
MA向上突破的点位周期K线5 > 10 > 20 > 30
统计MA向上突破的点位周期K线突破之后
下一个MA向下突破的点位周期K线30 > 20 > 10 > 5
之间的涨跌幅
"""
def __init__(
self,
is_us_stock: bool = False,
is_binance: bool = False,
commission_per_share: float = 0.0008,
):
2025-08-22 10:48:59 +00:00
mysql_user = MYSQL_CONFIG.get("user", "xch")
mysql_password = MYSQL_CONFIG.get("password", "")
if not mysql_password:
raise ValueError("MySQL password is not set")
mysql_host = MYSQL_CONFIG.get("host", "localhost")
mysql_port = MYSQL_CONFIG.get("port", 3306)
mysql_database = MYSQL_CONFIG.get("database", "okx")
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
2025-08-22 10:48:59 +00:00
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
2025-09-15 06:12:47 +00:00
self.is_us_stock = is_us_stock
self.is_binance = is_binance
2025-09-01 10:01:21 +00:00
if is_us_stock:
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["QQQ"]
)
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m"]
)
2025-09-15 06:12:47 +00:00
self.initial_date = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2014-11-30 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
2025-09-01 10:01:21 +00:00
else:
2025-09-15 06:12:47 +00:00
if is_binance:
self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["BTC-USDT"]
)
self.bars = ["30m", "1H"]
self.initial_date = BINANCE_MONITOR_CONFIG.get(
"volume_monitor", {}
).get("initial_date", "2017-08-16 00:00:00")
2025-09-15 06:12:47 +00:00
self.db_market_data = DBBinanceData(self.db_url)
else:
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"symbols", ["XCH-USDT"]
2025-09-15 06:12:47 +00:00
)
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"bars", ["5m", "15m", "30m", "1H"]
2025-09-15 06:12:47 +00:00
)
self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
"initial_date", "2025-05-15 00:00:00"
)
self.db_market_data = DBMarketData(self.db_url)
if len(self.initial_date) > 10:
self.initial_date = self.initial_date[:10]
self.end_date = get_current_date_time(format="%Y-%m-%d")
self.commission_per_share = commission_per_share
2025-08-22 10:48:59 +00:00
self.trade_strategy_config = self.get_trade_strategy_config()
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
def get_trade_strategy_config(self):
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
trade_strategy_config = json.load(f)
return trade_strategy_config
def batch_statistics(self, strategy_name: str = "全均线策略"):
2025-09-15 06:12:47 +00:00
if self.is_us_stock:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/us_stock/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/us_stock/chart/{strategy_name}/"
)
2025-09-15 06:12:47 +00:00
elif self.is_binance:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/binance/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/"
)
2025-09-15 06:12:47 +00:00
else:
self.stats_output_dir = (
f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/"
)
self.stats_chart_dir = (
f"./output/trade_sandbox/ma_strategy/okx/chart/{strategy_name}/"
)
2025-08-22 10:48:59 +00:00
os.makedirs(self.stats_output_dir, exist_ok=True)
os.makedirs(self.stats_chart_dir, exist_ok=True)
2025-09-15 06:12:47 +00:00
2025-08-22 10:48:59 +00:00
ma_break_market_data_list = []
2025-08-23 17:44:33 +00:00
market_data_pct_chg_list = []
2025-08-22 10:48:59 +00:00
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
strategy_name = "全均线策略"
for symbol in self.symbols:
for bar in self.bars:
2025-08-23 17:44:33 +00:00
logger.info(
f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name},交易费率:{self.commission_per_share}"
2025-08-23 17:44:33 +00:00
)
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
symbol, bar, strategy_name
)
if (
ma_break_market_data is not None
and len(ma_break_market_data) > 0
and market_data_pct_chg is not None
):
2025-08-22 10:48:59 +00:00
ma_break_market_data_list.append(ma_break_market_data)
2025-08-23 17:44:33 +00:00
logger.info(
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
)
market_data_pct_chg_list.append(market_data_pct_chg)
2025-08-22 10:48:59 +00:00
if len(ma_break_market_data_list) > 0:
ma_break_market_data = pd.concat(ma_break_market_data_list)
2025-08-23 17:44:33 +00:00
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
2025-08-22 10:48:59 +00:00
# 依据symbol和bar分组统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
ma_break_market_data.sort_values(
by="begin_timestamp", ascending=True, inplace=True
2025-08-22 10:48:59 +00:00
)
ma_break_market_data.reset_index(drop=True, inplace=True)
account_value_chg_list = []
for symbol in market_data_pct_chg_df["symbol"].unique():
for bar in market_data_pct_chg_df["bar"].unique():
2025-08-23 17:44:33 +00:00
symbol_bar_data = ma_break_market_data[
(ma_break_market_data["symbol"] == symbol)
& (ma_break_market_data["bar"] == bar)
2025-08-25 08:58:38 +00:00
].copy() # 创建副本避免SettingWithCopyWarning
2025-08-23 17:44:33 +00:00
if len(symbol_bar_data) > 0:
symbol_bar_data.sort_values(
by="end_timestamp", ascending=True, inplace=True
)
symbol_bar_data.reset_index(drop=True, inplace=True)
initial_capital = int(market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"initial_capital",
].values[0])
final_account_value = float(symbol_bar_data["end_account_value"].iloc[-1])
account_value_chg = (final_account_value - initial_capital) / initial_capital * 100
account_value_chg = round(account_value_chg, 4)
2025-08-23 17:44:33 +00:00
market_pct_chg = market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"pct_chg",
].values[0]
total_buy_commission = float(symbol_bar_data["buy_commission"].sum())
total_sell_commission = float(symbol_bar_data["sell_commission"].sum())
total_commission = total_buy_commission + total_sell_commission
total_commission = round(total_commission, 4)
total_buy_commission = round(total_buy_commission, 4)
total_sell_commission = round(total_sell_commission, 4)
account_value_chg_list.append({
"strategy_name": strategy_name,
"symbol": symbol,
"bar": bar,
"total_buy_commission": total_buy_commission,
"total_sell_commission": total_sell_commission,
"total_commission": total_commission,
"initial_account_value": initial_capital,
"final_account_value": final_account_value,
"account_value_chg": account_value_chg,
"market_pct_chg": market_pct_chg,
})
account_value_chg_df = pd.DataFrame(account_value_chg_list)
account_value_chg_df = account_value_chg_df[
2025-08-23 17:44:33 +00:00
[
"strategy_name",
"symbol",
"bar",
"total_buy_commission",
"total_sell_commission",
"total_commission",
"initial_account_value",
"final_account_value",
"account_value_chg",
2025-08-23 17:44:33 +00:00
"market_pct_chg",
]
]
account_value_statistics_df = (
ma_break_market_data.groupby(["symbol", "bar"])["end_account_value"]
.agg(
account_value_max="max",
account_value_min="min",
account_value_mean="mean",
account_value_std="std",
account_value_median="median",
account_value_count="count",
)
.reset_index()
)
account_value_statistics_df["strategy_name"] = strategy_name
account_value_statistics_df = account_value_statistics_df[
[
"strategy_name",
"symbol",
"bar",
"account_value_max",
"account_value_min",
"account_value_mean",
"account_value_std",
"account_value_median",
"account_value_count",
]
]
2025-08-22 10:48:59 +00:00
# 依据symbol和bar分组统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
interval_minutes_df = (
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
.agg(
interval_minutes_max="max",
interval_minutes_min="min",
interval_minutes_mean="mean",
interval_minutes_std="std",
interval_minutes_median="median",
interval_minutes_count="count",
)
.reset_index()
)
interval_minutes_df["strategy_name"] = strategy_name
interval_minutes_df = interval_minutes_df[
[
"strategy_name",
"symbol",
"bar",
"interval_minutes_max",
"interval_minutes_min",
"interval_minutes_mean",
"interval_minutes_std",
"interval_minutes_median",
"interval_minutes_count",
]
]
2025-08-22 10:48:59 +00:00
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
earliest_market_date_time = re.sub(
r"[\:\-\s]", "", str(earliest_market_date_time)
)
latest_market_date_time = ma_break_market_data["end_date_time"].max()
if latest_market_date_time is None:
2025-09-16 09:09:00 +00:00
latest_market_date_time = get_current_date_time(format="%Y%m%d%H%M%S")
2025-08-22 10:48:59 +00:00
latest_market_date_time = re.sub(
r"[\:\-\s]", "", str(latest_market_date_time)
)
if self.commission_per_share > 0:
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}_with_commission.xlsx"
else:
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}_without_commission.xlsx"
2025-08-22 10:48:59 +00:00
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
logger.info(f"导出{output_file_path}")
strategy_info_df = self.get_strategy_info(strategy_name)
with pd.ExcelWriter(output_file_path) as writer:
strategy_info_df.to_excel(writer, sheet_name="策略信息", index=False)
ma_break_market_data.to_excel(
writer, sheet_name="买卖记录明细", index=False
)
account_value_chg_df.to_excel(writer, sheet_name="资产价值变化", index=False)
account_value_statistics_df.to_excel(
writer, sheet_name="买卖账户价值统计", index=False
)
2025-08-22 10:48:59 +00:00
interval_minutes_df.to_excel(
writer, sheet_name="买卖时间间隔统计", index=False
)
chart_dict = self.draw_quant_pct_chg_bar_chart(account_value_chg_df, strategy_name)
2025-08-25 08:58:38 +00:00
self.output_chart_to_excel(output_file_path, chart_dict)
chart_dict = self.draw_quant_line_chart(
ma_break_market_data, market_data_pct_chg_df, strategy_name
)
2025-08-22 10:48:59 +00:00
self.output_chart_to_excel(output_file_path, chart_dict)
return account_value_chg_df
2025-08-22 10:48:59 +00:00
else:
return None
2025-08-23 17:44:33 +00:00
2025-08-22 10:48:59 +00:00
def get_strategy_info(self, strategy_name: str = "全均线策略"):
strategy_config = self.main_strategy.get(strategy_name, None)
if strategy_config is None:
logger.error(f"策略{strategy_name}不存在")
return None
strategy_info = {"策略名称": strategy_name, "买入策略": "", "卖出策略": ""}
buy_dict = strategy_config.get("buy", {})
buy_and_list = buy_dict.get("and", [])
buy_or_list = buy_dict.get("or", [])
buy_and_text = ""
buy_or_text = ""
for and_condition in buy_and_list:
buy_and_text += f"{and_condition}, \n"
if len(buy_or_list) > 0:
for or_condition in buy_or_list:
buy_or_text += f"{or_condition}, \n"
if len(buy_or_text) > 0:
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
else:
strategy_info["买入策略"] = buy_and_text
sell_dict = strategy_config.get("sell", {})
sell_and_list = sell_dict.get("and", [])
sell_or_list = sell_dict.get("or", [])
sell_and_text = ""
sell_or_text = ""
for and_condition in sell_and_list:
sell_and_text += f"{and_condition}, \n"
if len(sell_or_list) > 0:
for or_condition in sell_or_list:
sell_or_text += f"{or_condition}, \n"
if len(sell_or_text) > 0:
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
else:
strategy_info["卖出策略"] = sell_and_text
# 将strategy_info转换为pd.DataFrame
strategy_info_df = pd.DataFrame([strategy_info])
return strategy_info_df
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
2025-09-15 06:12:47 +00:00
market_data = self.get_full_data(symbol, bar)
2025-08-22 10:48:59 +00:00
if market_data is None or len(market_data) == 0:
logger.warning(f"获取{symbol} {bar} 数据失败")
2025-08-23 17:44:33 +00:00
return None, None
2025-08-22 10:48:59 +00:00
else:
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
# 获得ma5, ma10, ma20, ma30不为空的行
market_data = market_data[
(market_data["ma5"].notna())
& (market_data["ma10"].notna())
& (market_data["ma20"].notna())
& (market_data["ma30"].notna())
]
logger.info(
f"ma5, ma10, ma20, ma30不为空的行数据条数: {len(market_data)}"
)
# 计算volume_ma5
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
2025-09-15 06:12:47 +00:00
2025-08-22 10:48:59 +00:00
market_data["volume_pct_chg"] = (
market_data["volume"] - market_data["volume_ma5"]
) / market_data["volume_ma5"]
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
# 按照timestamp排序
market_data = market_data.sort_values(by="timestamp", ascending=True)
# 获得ma_break_market_data的close列
market_data.reset_index(drop=True, inplace=True)
ma_break_market_data_pair_list = []
ma_break_market_data_pair = {}
2025-09-15 06:12:47 +00:00
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
close_mean = market_data["close"].mean()
self.update_initial_capital(close_mean)
logger.info(
f"成功获取{symbol}数据:{len(market_data)}{bar}K线,开始日期={market_data[date_time_field].min()},结束日期={market_data[date_time_field].max()}"
)
account_value = self.initial_capital
2025-08-22 10:48:59 +00:00
for index, row in market_data.iterrows():
ma_cross = row["ma_cross"]
timestamp = row["timestamp"]
2025-09-15 06:12:47 +00:00
date_time = row[date_time_field]
2025-08-22 10:48:59 +00:00
close = row["close"]
ma5 = row["ma5"]
ma10 = row["ma10"]
ma20 = row["ma20"]
ma30 = row["ma30"]
macd_diff = float(row["dif"])
macd_dea = float(row["dea"])
macd = float(row["macd"])
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
buy_condition = self.fit_strategy(
strategy_name=strategy_name,
market_data=market_data,
row=row,
behavior="buy",
)
if buy_condition:
entry_price = close
# 计算交易股数
shares, account_value = self.calculate_shares(
account_value, entry_price
)
if shares == 0:
# 股数为0→不交易
continue
# 计算佣金
buy_commission = shares * close * self.commission_per_share
2025-08-22 10:48:59 +00:00
ma_break_market_data_pair = {}
ma_break_market_data_pair["symbol"] = symbol
ma_break_market_data_pair["bar"] = bar
ma_break_market_data_pair["begin_timestamp"] = timestamp
2025-09-15 06:12:47 +00:00
ma_break_market_data_pair["begin_date_time"] = date_time
2025-08-22 10:48:59 +00:00
ma_break_market_data_pair["begin_close"] = close
ma_break_market_data_pair["begin_ma5"] = ma5
ma_break_market_data_pair["begin_ma10"] = ma10
ma_break_market_data_pair["begin_ma20"] = ma20
ma_break_market_data_pair["begin_ma30"] = ma30
ma_break_market_data_pair["begin_macd_diff"] = macd_diff
ma_break_market_data_pair["begin_macd_dea"] = macd_dea
ma_break_market_data_pair["begin_macd"] = macd
ma_break_market_data_pair["shares"] = shares
ma_break_market_data_pair["buy_commission"] = buy_commission
ma_break_market_data_pair["begin_account_value"] = account_value
2025-08-22 10:48:59 +00:00
continue
else:
sell_condition = self.fit_strategy(
strategy_name=strategy_name,
market_data=market_data,
row=row,
behavior="sell",
)
if sell_condition:
shares = ma_break_market_data_pair["shares"]
entry_price = ma_break_market_data_pair["begin_close"]
exit_price = close
sell_commission = (
shares * exit_price * self.commission_per_share
)
profit_loss = (exit_price - entry_price) * shares
begin_account_value = ma_break_market_data_pair[
"begin_account_value"
]
account_value = (
begin_account_value + profit_loss - sell_commission
)
2025-08-22 10:48:59 +00:00
ma_break_market_data_pair["end_timestamp"] = timestamp
2025-09-15 06:12:47 +00:00
ma_break_market_data_pair["end_date_time"] = date_time
2025-08-22 10:48:59 +00:00
ma_break_market_data_pair["end_close"] = close
ma_break_market_data_pair["end_ma5"] = ma5
ma_break_market_data_pair["end_ma10"] = ma10
ma_break_market_data_pair["end_ma20"] = ma20
ma_break_market_data_pair["end_ma30"] = ma30
ma_break_market_data_pair["end_macd_diff"] = macd_diff
ma_break_market_data_pair["end_macd_dea"] = macd_dea
ma_break_market_data_pair["end_macd"] = macd
ma_break_market_data_pair["pct_chg"] = (
exit_price - entry_price
) / entry_price
2025-08-22 10:48:59 +00:00
ma_break_market_data_pair["pct_chg"] = round(
ma_break_market_data_pair["pct_chg"] * 100, 4
)
ma_break_market_data_pair["profit_loss"] = profit_loss
ma_break_market_data_pair["sell_commission"] = sell_commission
ma_break_market_data_pair["end_account_value"] = account_value
2025-08-22 10:48:59 +00:00
ma_break_market_data_pair["interval_seconds"] = (
timestamp - ma_break_market_data_pair["begin_timestamp"]
) / 1000
# 将interval转换为分钟
ma_break_market_data_pair["interval_minutes"] = (
ma_break_market_data_pair["interval_seconds"] / 60
)
ma_break_market_data_pair["interval_hours"] = (
ma_break_market_data_pair["interval_seconds"] / 3600
)
ma_break_market_data_pair["interval_days"] = (
ma_break_market_data_pair["interval_seconds"] / 86400
)
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
ma_break_market_data_pair = {}
if len(ma_break_market_data_pair_list) > 0:
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
2025-08-25 08:58:38 +00:00
# sort by end_timestamp
ma_break_market_data.sort_values(
by="begin_timestamp", ascending=True, inplace=True
)
2025-08-25 08:58:38 +00:00
ma_break_market_data.reset_index(drop=True, inplace=True)
2025-08-23 17:44:33 +00:00
logger.info(
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
)
2025-08-25 08:58:38 +00:00
# 量化期间,市场的波动率:
# ma_break_market_data(最后一条数据的end_close - 第一条数据的begin_close) / 第一条数据的begin_close * 100
2025-08-23 17:44:33 +00:00
pct_chg = (
(
ma_break_market_data["end_close"].iloc[-1]
- ma_break_market_data["begin_close"].iloc[0]
)
2025-08-25 08:58:38 +00:00
/ ma_break_market_data["begin_close"].iloc[0]
2025-08-23 17:44:33 +00:00
* 100
)
pct_chg = round(pct_chg, 4)
market_data_pct_chg = {
"symbol": symbol,
"bar": bar,
"pct_chg": pct_chg,
"initial_capital": self.initial_capital,
}
2025-08-23 17:44:33 +00:00
return ma_break_market_data, market_data_pct_chg
2025-08-22 10:48:59 +00:00
else:
2025-08-23 17:44:33 +00:00
return None, None
def update_initial_capital(self, close_mean: float):
self.initial_capital = 25000
if close_mean > 10000:
self.initial_capital = self.initial_capital * 10000
elif close_mean > 5000:
self.initial_capital = self.initial_capital * 5000
elif close_mean > 1000:
self.initial_capital = self.initial_capital * 1000
elif close_mean > 500:
self.initial_capital = self.initial_capital * 500
elif close_mean > 100:
self.initial_capital = self.initial_capital * 100
else:
pass
logger.info(f"收盘价均值:{close_mean}")
logger.info(f"初始资金调整为:{self.initial_capital}")
def calculate_shares(self, account_value, entry_price):
"""
根据ORB公式计算交易股数
:param account_value: 当前账户价值美元
:param entry_price: 交易买入价格
:param commission_per_share: 交易佣金, 默认为0.0008
:return: 整数股数Shares
"""
logger.info(
f"开始计算交易股数:账户价值={account_value},买入价格={entry_price}"
)
try:
# 验证输入参数
if account_value <= 0 or entry_price <= 0 or self.commission_per_share < 0:
logger.error("账户价值、买入价格或佣金不能为负或零")
return 0, account_value # 返回0股账户价值不变
# 计算考虑手续费后的每单位BTC总成本
total_cost_per_share = entry_price * (1 + self.commission_per_share)
# 计算可购买的BTC数量向下取整
shares = math.floor(account_value / total_cost_per_share)
# 计算总成本(含手续费)
total_cost = shares * total_cost_per_share
# 计算剩余现金
remaining_cash = account_value - total_cost
# 计算总资产价值 = (购买的BTC数量 × 买入价格) + 剩余现金
remaining_value = (shares * entry_price) + remaining_cash
# 记录计算结果
logger.info(
f"计算结果:可购买股数={shares},总成本={total_cost:.2f}美元,"
f"剩余现金={remaining_cash:.2f}美元,总资产价值={remaining_value:.2f}美元"
)
return shares, remaining_value
except Exception as e:
logger.error(f"计算股数或账户价值时出错:{str(e)}")
return 0, account_value # 出错时返回0股账户价值不变
2025-09-15 06:12:47 +00:00
def get_full_data(self, symbol: str, bar: str = "5m"):
"""
分段获取数据并将数据合并为完整数据
分段依据如果end_date与start_date相差超过一年则每次取一年数据
"""
data = pd.DataFrame()
start_date = datetime.strptime(self.initial_date, "%Y-%m-%d")
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
fields = [
"symbol",
"bar",
"timestamp",
"date_time",
"date_time_us",
"open",
"high",
"low",
"close",
"volume",
"sar_signal",
"ma5",
"ma10",
"ma20",
"ma30",
"ma_cross",
"dif",
"dea",
"macd",
]
while start_date < end_date:
current_end_date = min(start_date + timedelta(days=180), end_date)
start_date_str = start_date.strftime("%Y-%m-%d")
current_end_date_str = current_end_date.strftime("%Y-%m-%d")
logger.info(f"获取{symbol}数据:{start_date_str}{current_end_date_str}")
2025-09-15 06:12:47 +00:00
current_data = self.db_market_data.query_market_data_by_symbol_bar(
symbol, bar, fields, start=start_date_str, end=current_end_date_str
)
if current_data is not None and len(current_data) > 0:
current_data = pd.DataFrame(current_data)
data = pd.concat([data, current_data])
start_date = current_end_date
data.drop_duplicates(inplace=True)
if self.is_us_stock:
date_time_field = "date_time_us"
else:
date_time_field = "date_time"
data.sort_values(by=date_time_field, inplace=True)
data.reset_index(drop=True, inplace=True)
return data
2025-08-22 10:48:59 +00:00
def fit_strategy(
self,
strategy_name: str = "全均线策略",
market_data: pd.DataFrame = None,
row: pd.Series = None,
behavior: str = "buy",
):
strategy_config = self.main_strategy.get(strategy_name, None)
if strategy_config is None:
logger.error(f"策略{strategy_name}不存在")
return False
condition_dict = strategy_config.get(behavior, None)
if condition_dict is None:
logger.error(f"策略{strategy_name}{behavior}条件不存在")
return False
ma_cross = row["ma_cross"]
if pd.isna(ma_cross) or ma_cross is None:
ma_cross = ""
ma_cross = str(ma_cross)
ma5 = float(row["ma5"])
ma10 = float(row["ma10"])
ma20 = float(row["ma20"])
ma30 = float(row["ma30"])
close = float(row["close"])
volume_pct_chg = float(row["volume_pct_chg"])
macd_diff = float(row["dif"])
macd_dea = float(row["dea"])
macd = float(row["macd"])
and_list = condition_dict.get("and", [])
condition = True
for and_condition in and_list:
if and_condition == "5上穿10":
condition = condition and ("5上穿10" in ma_cross)
elif and_condition == "10上穿20":
condition = condition and ("10上穿20" in ma_cross)
elif and_condition == "20上穿30":
condition = condition and ("20上穿30" in ma_cross)
elif and_condition == "ma5>ma10":
condition = condition and (ma5 > ma10)
elif and_condition == "ma10>ma20":
condition = condition and (ma10 > ma20)
elif and_condition == "ma20>ma30":
condition = condition and (ma20 > ma30)
elif and_condition == "close>ma20":
condition = condition and (close > ma20)
elif and_condition == "volume_pct_chg>0.2":
condition = condition and (volume_pct_chg > 0.2)
elif and_condition == "macd_diff>0":
condition = condition and (macd_diff > 0)
elif and_condition == "macd_dea>0":
condition = condition and (macd_dea > 0)
elif and_condition == "macd>0":
condition = condition and (macd > 0)
elif and_condition == "10下穿5":
condition = condition and ("10下穿5" in ma_cross)
elif and_condition == "20下穿10":
condition = condition and ("20下穿10" in ma_cross)
elif and_condition == "30下穿20":
condition = condition and ("30下穿20" in ma_cross)
elif and_condition == "ma5<ma10":
condition = condition and (ma5 < ma10)
elif and_condition == "ma10<ma20":
condition = condition and (ma10 < ma20)
elif and_condition == "ma20<ma30":
condition = condition and (ma20 < ma30)
elif and_condition == "close<ma20":
condition = condition and (close < ma20)
elif and_condition == "macd_diff<0":
condition = condition and (macd_diff < 0)
elif and_condition == "macd_dea<0":
condition = condition and (macd_dea < 0)
elif and_condition == "macd<0":
condition = condition and (macd < 0)
else:
pass
if not condition:
or_list = condition_dict.get("or", [])
for or_condition in or_list:
if or_condition == "5上穿10":
condition = condition or ("5上穿10" in ma_cross)
elif or_condition == "10上穿20":
condition = condition or ("10上穿20" in ma_cross)
elif or_condition == "20上穿30":
condition = condition or ("20上穿30" in ma_cross)
elif or_condition == "ma5>ma10":
condition = condition or (ma5 > ma10)
elif or_condition == "ma10>ma20":
condition = condition or (ma10 > ma20)
elif or_condition == "ma20>ma30":
condition = condition or (ma20 > ma30)
elif or_condition == "close>ma20":
condition = condition or (close > ma20)
elif or_condition == "volume_pct_chg>0.2":
condition = condition or (volume_pct_chg > 0.2)
elif or_condition == "macd_diff>0":
condition = condition or (macd_diff > 0)
elif or_condition == "macd_dea>0":
condition = condition or (macd_dea > 0)
elif or_condition == "macd>0":
condition = condition or (macd > 0)
elif or_condition == "10下穿5":
condition = condition or ("10下穿5" in ma_cross)
elif or_condition == "20下穿10":
condition = condition or ("20下穿10" in ma_cross)
elif or_condition == "30下穿20":
condition = condition or ("30下穿20" in ma_cross)
elif or_condition == "ma5<ma10":
condition = condition or (ma5 < ma10)
elif or_condition == "ma10<ma20":
condition = condition or (ma10 < ma20)
elif or_condition == "ma20<ma30":
condition = condition or (ma20 < ma30)
elif or_condition == "close<ma20":
condition = condition or (close < ma20)
elif or_condition == "macd_diff<0":
condition = condition or (macd_diff < 0)
elif or_condition == "macd_dea<0":
condition = condition or (macd_dea < 0)
elif or_condition == "macd<0":
condition = condition or (macd < 0)
else:
pass
return condition
2025-08-25 08:58:38 +00:00
def draw_quant_pct_chg_bar_chart(
2025-08-22 10:48:59 +00:00
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
):
"""
绘制pct_chg mean的柱状图表美观保存到self.stats_chart_dir
:param data: 波段pct_chg_mean的数据
:return: None
"""
if data is None or data.empty:
return None
# seaborn风格设置
sns.set_theme(style="whitegrid")
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
plt.rcParams["font.size"] = 11 # 设置字体大小
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
chart_dict = {}
column_name_dict = {
"account_value_chg": "量化策略涨跌",
}
2025-08-22 10:48:59 +00:00
for column_name, column_name_text in column_name_dict.items():
for bar in data["bar"].unique():
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
if bar_data.empty:
continue
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
bar_data.rename(
columns={"market_pct_chg": "市场自然涨跌"}, inplace=True
)
2025-08-22 10:48:59 +00:00
# 可选:按均值排序
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
bar_data.reset_index(drop=True, inplace=True)
2025-08-25 08:58:38 +00:00
# 如果column_name_text是"量化策略涨跌",则柱状图,同时绘制量化策略涨跌与市场自然涨跌的柱状图,并绘制在同一个图表中
plt.figure(figsize=(12, 7))
2025-08-23 17:44:33 +00:00
# 设置x轴位置为并列柱状图做准备
x = np.arange(len(bar_data))
width = 0.35 # 柱状图宽度
# 确保symbol列是字符串类型避免matplotlib警告
bar_data["symbol"] = bar_data["symbol"].astype(str)
# 绘制量化策略涨跌柱状图(蓝色渐变色)
bars1 = plt.bar(
x - width / 2,
bar_data[column_name_text],
width,
label=column_name_text,
color=plt.cm.Blues(np.linspace(0.6, 0.9, len(bar_data))),
)
# 绘制市场自然涨跌柱状图(绿色渐变色)
bars2 = plt.bar(
x + width / 2,
bar_data["市场自然涨跌"],
width,
label="市场自然涨跌",
color=plt.cm.Greens(np.linspace(0.6, 0.9, len(bar_data))),
)
# 设置图表标题和标签
plt.title(
f"{bar}趋势{column_name_text}与市场自然涨跌对比(%)",
fontsize=14,
fontweight="bold",
)
plt.xlabel("Symbol", fontsize=12)
plt.ylabel("涨跌幅(%)", fontsize=12)
plt.xticks(x, bar_data["symbol"], rotation=45, ha="right")
plt.legend()
plt.grid(True, alpha=0.3)
# 在量化策略涨跌柱状图上方添加数值标签
for i, (bar1, value1) in enumerate(
zip(bars1, bar_data[column_name_text])
):
plt.text(
bar1.get_x() + bar1.get_width() / 2,
value1 + (0.01 if value1 >= 0 else -0.01),
f"{value1:.3f}%",
ha="center",
va="bottom" if value1 >= 0 else "top",
fontsize=9,
fontweight="bold",
color="darkblue",
)
# 在市场自然涨跌柱状图上方添加数值标签
for i, (bar2, value2) in enumerate(
zip(bars2, bar_data["市场自然涨跌"])
):
plt.text(
bar2.get_x() + bar2.get_width() / 2,
value2 + (0.01 if value2 >= 0 else -0.01),
f"{value2:.3f}%",
ha="center",
va="bottom" if value2 >= 0 else "top",
fontsize=9,
fontweight="bold",
color="darkgreen",
2025-08-23 17:44:33 +00:00
)
2025-08-22 10:48:59 +00:00
plt.tight_layout()
if self.commission_per_share > 0:
save_path = os.path.join(
self.stats_chart_dir,
f"{bar}_bar_chart_{column_name}_{strategy_name}_with_commission.png",
)
else:
save_path = os.path.join(
self.stats_chart_dir,
f"{bar}_bar_chart_{column_name}_{strategy_name}_without_commission.png",
)
2025-08-22 10:48:59 +00:00
plt.savefig(save_path, dpi=150)
plt.close()
2025-08-25 08:58:38 +00:00
sheet_name = f"{bar}_趋势{column_name_text}柱状图_{strategy_name}"
chart_dict[sheet_name] = save_path
return chart_dict
def draw_quant_line_chart(
self,
data: pd.DataFrame,
market_data_pct_chg_df: pd.DataFrame,
strategy_name: str = "全均线策略",
):
2025-08-25 08:58:38 +00:00
"""
根据量化策略买卖明细记录绘制量化策略涨跌与市场自然涨跌的折线图
:param data: 量化策略买卖明细记录
:param market_data_pct_chg_df: 市场自然涨跌记录
2025-08-25 08:58:38 +00:00
:param strategy_name: 策略名称
:return: None
"""
symbols = data["symbol"].unique()
bars = data["bar"].unique()
chart_dict = {}
for symbol in symbols:
for bar in bars:
symbol_bar_data = data[
(data["symbol"] == symbol) & (data["bar"] == bar)
]
2025-08-25 08:58:38 +00:00
if symbol_bar_data.empty:
continue
2025-08-25 08:58:38 +00:00
# 获取第一行数据作为基准
first_row = symbol_bar_data.iloc[0].copy()
initial_capital = int(
market_data_pct_chg_df.loc[
(market_data_pct_chg_df["symbol"] == symbol)
& (market_data_pct_chg_df["bar"] == bar),
"initial_capital",
].values[0]
)
2025-08-25 08:58:38 +00:00
# 创建初始化行,设置基准值
init_row = first_row.copy()
init_row.loc["profit_loss"] = 0
init_row.loc["end_account_value"] = (
initial_capital # 量化策略初始值为初始资金
)
2025-08-25 08:58:38 +00:00
init_row.loc["end_timestamp"] = first_row["begin_timestamp"]
init_row.loc["end_date_time"] = first_row["begin_date_time"]
init_row.loc["end_close"] = first_row["begin_close"]
init_row.loc["end_ma5"] = first_row["begin_ma5"]
init_row.loc["end_ma10"] = first_row["begin_ma10"]
init_row.loc["end_ma20"] = first_row["begin_ma20"]
init_row.loc["end_ma30"] = first_row["begin_ma30"]
init_row.loc["end_macd_diff"] = first_row["begin_macd_diff"]
init_row.loc["end_macd_dea"] = first_row["begin_macd_dea"]
init_row.loc["end_macd"] = first_row["begin_macd"]
init_row.loc["pct_chg"] = 0
init_row.loc["interval_seconds"] = 0
init_row.loc["interval_minutes"] = 0
init_row.loc["interval_hours"] = 0
init_row.loc["interval_days"] = 0
2025-08-25 08:58:38 +00:00
# 将初始化行添加到数据开头
symbol_bar_data = pd.concat([pd.DataFrame([init_row]), symbol_bar_data])
symbol_bar_data.sort_values(
by="end_timestamp", ascending=True, inplace=True
)
2025-08-25 08:58:38 +00:00
symbol_bar_data.reset_index(drop=True, inplace=True)
2025-08-25 08:58:38 +00:00
# 确保时间列是datetime类型避免matplotlib警告
symbol_bar_data["end_date_time"] = pd.to_datetime(
symbol_bar_data["end_date_time"]
)
# 计算持仓价值归一化数据(相对于初始价格)
symbol_bar_data["end_account_value_to_1"] = (
symbol_bar_data["end_account_value"] / initial_capital
)
symbol_bar_data["end_account_value_to_1"] = symbol_bar_data[
"end_account_value_to_1"
].round(4)
2025-08-25 08:58:38 +00:00
# 计算市场价位归一化数据(相对于初始价格)
symbol_bar_data["end_close_to_1"] = (
symbol_bar_data["end_close"] / init_row["end_close"]
)
symbol_bar_data["end_close_to_1"] = symbol_bar_data[
"end_close_to_1"
].round(4)
2025-08-25 08:58:38 +00:00
# 绘制折线图
plt.figure(figsize=(12, 7))
2025-08-25 08:58:38 +00:00
# 绘制量化策略涨跌线(蓝色)
plt.plot(
symbol_bar_data["end_date_time"],
symbol_bar_data["end_account_value_to_1"],
label="量化策略涨跌",
color="blue",
linewidth=2,
marker="o",
markersize=4,
)
2025-08-25 08:58:38 +00:00
# 绘制市场自然涨跌线(绿色)
plt.plot(
symbol_bar_data["end_date_time"],
symbol_bar_data["end_close_to_1"],
label="市场自然涨跌",
color="green",
linewidth=2,
marker="s",
markersize=4,
)
plt.title(
f"{symbol} {bar} 量化与市场折线图_{strategy_name}",
fontsize=14,
fontweight="bold",
)
2025-08-25 08:58:38 +00:00
plt.xlabel("时间", fontsize=12)
plt.ylabel("涨跌变化", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
2025-08-25 08:58:38 +00:00
# 设置x轴标签避免matplotlib警告
# 选择合适的时间间隔显示标签,避免过于密集
if len(symbol_bar_data) > 30:
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
step = max(1, len(symbol_bar_data) // 30)
2025-08-25 08:58:38 +00:00
# 创建标签索引列表,确保包含首尾数据
label_indices = [0] # 第一条
2025-08-25 08:58:38 +00:00
# 添加中间间隔的标签
for i in range(step, len(symbol_bar_data) - 1, step):
label_indices.append(i)
2025-08-25 08:58:38 +00:00
# 添加最后一条(如果还没有包含的话)
if len(symbol_bar_data) - 1 not in label_indices:
label_indices.append(len(symbol_bar_data) - 1)
2025-08-25 08:58:38 +00:00
# 设置x轴标签
plt.xticks(
symbol_bar_data["end_date_time"].iloc[label_indices],
symbol_bar_data["end_date_time"]
.iloc[label_indices]
.dt.strftime("%Y%m%d %H:%M"),
rotation=45,
ha="right",
)
2025-08-25 08:58:38 +00:00
else:
# 如果数据点较少,全部显示
plt.xticks(
symbol_bar_data["end_date_time"],
symbol_bar_data["end_date_time"].dt.strftime("%Y%m%d %H:%M"),
rotation=45,
ha="right",
)
2025-08-25 08:58:38 +00:00
plt.tight_layout()
if self.commission_per_share > 0:
save_path = os.path.join(
self.stats_chart_dir,
f"{symbol}_{bar}_line_chart_{strategy_name}_with_commission.png",
)
else:
save_path = os.path.join(
self.stats_chart_dir,
f"{symbol}_{bar}_line_chart_{strategy_name}_without_commission.png",
)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
2025-08-25 08:58:38 +00:00
plt.close()
2025-08-25 08:58:38 +00:00
sheet_name = f"{symbol}_{bar}_折线图_{strategy_name}"
2025-08-22 10:48:59 +00:00
chart_dict[sheet_name] = save_path
return chart_dict
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
"""
输出Excel文件包含所有图表
charts_dict: 图表数据字典格式为
{
"sheet_name": {
"chart_name": "chart_path"
}
}
"""
logger.info(f"将图表输出到{excel_file_path}")
# 打开已经存在的Excel文件
wb = openpyxl.load_workbook(excel_file_path)
for sheet_name, chart_path in charts_dict.items():
try:
ws = wb.create_sheet(title=sheet_name)
row_offset = 1
# Insert chart image
img = Image(chart_path)
ws.add_image(img, f"A{row_offset}")
except Exception as e:
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
continue
# Save Excel file
wb.save(excel_file_path)
print(f"Chart saved as {excel_file_path}")