1099 lines
49 KiB
Python
1099 lines
49 KiB
Python
import core.logger as logging
|
||
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
from datetime import datetime, timedelta, timezone
|
||
from core.utils import get_current_date_time
|
||
import re
|
||
import json
|
||
import math
|
||
from openpyxl import Workbook
|
||
from openpyxl.drawing.image import Image
|
||
import openpyxl
|
||
from openpyxl.styles import Font
|
||
from PIL import Image as PILImage
|
||
from config import (
|
||
OKX_MONITOR_CONFIG,
|
||
US_STOCK_MONITOR_CONFIG,
|
||
MYSQL_CONFIG,
|
||
WINDOW_SIZE,
|
||
BINANCE_MONITOR_CONFIG,
|
||
)
|
||
from core.db.db_market_data import DBMarketData
|
||
from core.db.db_binance_data import DBBinanceData
|
||
from core.db.db_huge_volume_data import DBHugeVolumeData
|
||
from core.utils import timestamp_to_datetime, transform_date_time_to_timestamp
|
||
|
||
# seaborn支持中文
|
||
plt.rcParams["font.family"] = ["SimHei"]
|
||
|
||
logger = logging.logger
|
||
|
||
|
||
class MaBreakStatistics:
|
||
"""
|
||
统计MA突破之后的涨跌幅
|
||
MA向上突破的点位周期K线:5 > 10 > 20 > 30
|
||
统计MA向上突破的点位周期K线,突破之后,到:
|
||
下一个MA向下突破的点位周期K线:30 > 20 > 10 > 5
|
||
之间的涨跌幅
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
is_us_stock: bool = False,
|
||
is_binance: bool = False,
|
||
commission_per_share: float = 0.0008,
|
||
):
|
||
mysql_user = MYSQL_CONFIG.get("user", "xch")
|
||
mysql_password = MYSQL_CONFIG.get("password", "")
|
||
if not mysql_password:
|
||
raise ValueError("MySQL password is not set")
|
||
mysql_host = MYSQL_CONFIG.get("host", "localhost")
|
||
mysql_port = MYSQL_CONFIG.get("port", 3306)
|
||
mysql_database = MYSQL_CONFIG.get("database", "okx")
|
||
|
||
self.db_url = f"mysql+pymysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
||
|
||
self.db_huge_volume_data = DBHugeVolumeData(self.db_url)
|
||
self.is_us_stock = is_us_stock
|
||
self.is_binance = is_binance
|
||
if is_us_stock:
|
||
self.symbols = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["QQQ"]
|
||
)
|
||
self.bars = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"bars", ["5m"]
|
||
)
|
||
self.initial_date = US_STOCK_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"initial_date", "2014-11-30 00:00:00"
|
||
)
|
||
self.db_market_data = DBMarketData(self.db_url)
|
||
else:
|
||
if is_binance:
|
||
self.symbols = BINANCE_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["BTC-USDT"]
|
||
)
|
||
self.bars = ["30m", "1H"]
|
||
self.initial_date = BINANCE_MONITOR_CONFIG.get(
|
||
"volume_monitor", {}
|
||
).get("initial_date", "2017-08-16 00:00:00")
|
||
self.db_market_data = DBBinanceData(self.db_url)
|
||
else:
|
||
self.symbols = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"symbols", ["XCH-USDT"]
|
||
)
|
||
self.bars = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"bars", ["5m", "15m", "30m", "1H"]
|
||
)
|
||
self.initial_date = OKX_MONITOR_CONFIG.get("volume_monitor", {}).get(
|
||
"initial_date", "2025-05-15 00:00:00"
|
||
)
|
||
self.db_market_data = DBMarketData(self.db_url)
|
||
if len(self.initial_date) > 10:
|
||
self.initial_date = self.initial_date[:10]
|
||
self.end_date = get_current_date_time(format="%Y-%m-%d")
|
||
self.commission_per_share = commission_per_share
|
||
self.trade_strategy_config = self.get_trade_strategy_config()
|
||
self.main_strategy = self.trade_strategy_config.get("均线系统策略", None)
|
||
|
||
def get_trade_strategy_config(self):
|
||
with open("./json/trade_strategy.json", "r", encoding="utf-8") as f:
|
||
trade_strategy_config = json.load(f)
|
||
return trade_strategy_config
|
||
|
||
def batch_statistics(self, strategy_name: str = "全均线策略"):
|
||
if self.is_us_stock:
|
||
self.stats_output_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/us_stock/excel/{strategy_name}/"
|
||
)
|
||
self.stats_chart_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/us_stock/chart/{strategy_name}/"
|
||
)
|
||
elif self.is_binance:
|
||
self.stats_output_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/binance/excel/{strategy_name}/"
|
||
)
|
||
self.stats_chart_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/binance/chart/{strategy_name}/"
|
||
)
|
||
else:
|
||
self.stats_output_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/okx/excel/{strategy_name}/"
|
||
)
|
||
self.stats_chart_dir = (
|
||
f"./output/trade_sandbox/ma_strategy/okx/chart/{strategy_name}/"
|
||
)
|
||
os.makedirs(self.stats_output_dir, exist_ok=True)
|
||
os.makedirs(self.stats_chart_dir, exist_ok=True)
|
||
|
||
ma_break_market_data_list = []
|
||
market_data_pct_chg_list = []
|
||
if strategy_name not in self.main_strategy.keys() or strategy_name is None:
|
||
strategy_name = "全均线策略"
|
||
for symbol in self.symbols:
|
||
for bar in self.bars:
|
||
logger.info(
|
||
f"开始计算{symbol} {bar}的MA突破区间涨跌幅统计, 策略: {strategy_name},交易费率:{self.commission_per_share}"
|
||
)
|
||
ma_break_market_data, market_data_pct_chg = self.trade_simulate(
|
||
symbol, bar, strategy_name
|
||
)
|
||
if (
|
||
ma_break_market_data is not None
|
||
and len(ma_break_market_data) > 0
|
||
and market_data_pct_chg is not None
|
||
):
|
||
ma_break_market_data_list.append(ma_break_market_data)
|
||
logger.info(
|
||
f"{symbol} {bar} 的市场价格变化, {market_data_pct_chg.get('pct_chg', 0)}%"
|
||
)
|
||
market_data_pct_chg_list.append(market_data_pct_chg)
|
||
|
||
if len(ma_break_market_data_list) > 0:
|
||
ma_break_market_data = pd.concat(ma_break_market_data_list)
|
||
market_data_pct_chg_df = pd.DataFrame(market_data_pct_chg_list)
|
||
# 依据symbol和bar分组,统计每个symbol和bar的pct_chg的max, min, mean, std, median, count
|
||
ma_break_market_data.sort_values(
|
||
by="begin_timestamp", ascending=True, inplace=True
|
||
)
|
||
ma_break_market_data.reset_index(drop=True, inplace=True)
|
||
|
||
account_value_chg_list = []
|
||
for symbol in market_data_pct_chg_df["symbol"].unique():
|
||
for bar in market_data_pct_chg_df["bar"].unique():
|
||
symbol_bar_data = ma_break_market_data[
|
||
(ma_break_market_data["symbol"] == symbol)
|
||
& (ma_break_market_data["bar"] == bar)
|
||
].copy() # 创建副本避免SettingWithCopyWarning
|
||
if len(symbol_bar_data) > 0:
|
||
symbol_bar_data.sort_values(
|
||
by="end_timestamp", ascending=True, inplace=True
|
||
)
|
||
symbol_bar_data.reset_index(drop=True, inplace=True)
|
||
initial_capital = int(market_data_pct_chg_df.loc[
|
||
(market_data_pct_chg_df["symbol"] == symbol)
|
||
& (market_data_pct_chg_df["bar"] == bar),
|
||
"initial_capital",
|
||
].values[0])
|
||
final_account_value = float(symbol_bar_data["end_account_value"].iloc[-1])
|
||
account_value_chg = (final_account_value - initial_capital) / initial_capital * 100
|
||
account_value_chg = round(account_value_chg, 4)
|
||
market_pct_chg = market_data_pct_chg_df.loc[
|
||
(market_data_pct_chg_df["symbol"] == symbol)
|
||
& (market_data_pct_chg_df["bar"] == bar),
|
||
"pct_chg",
|
||
].values[0]
|
||
total_buy_commission = float(symbol_bar_data["buy_commission"].sum())
|
||
total_sell_commission = float(symbol_bar_data["sell_commission"].sum())
|
||
total_commission = total_buy_commission + total_sell_commission
|
||
total_commission = round(total_commission, 4)
|
||
total_buy_commission = round(total_buy_commission, 4)
|
||
total_sell_commission = round(total_sell_commission, 4)
|
||
account_value_chg_list.append({
|
||
"strategy_name": strategy_name,
|
||
"symbol": symbol,
|
||
"bar": bar,
|
||
"total_buy_commission": total_buy_commission,
|
||
"total_sell_commission": total_sell_commission,
|
||
"total_commission": total_commission,
|
||
"initial_account_value": initial_capital,
|
||
"final_account_value": final_account_value,
|
||
"account_value_chg": account_value_chg,
|
||
"market_pct_chg": market_pct_chg,
|
||
})
|
||
account_value_chg_df = pd.DataFrame(account_value_chg_list)
|
||
account_value_chg_df = account_value_chg_df[
|
||
[
|
||
"strategy_name",
|
||
"symbol",
|
||
"bar",
|
||
"total_buy_commission",
|
||
"total_sell_commission",
|
||
"total_commission",
|
||
"initial_account_value",
|
||
"final_account_value",
|
||
"account_value_chg",
|
||
"market_pct_chg",
|
||
]
|
||
]
|
||
|
||
account_value_statistics_df = (
|
||
ma_break_market_data.groupby(["symbol", "bar"])["end_account_value"]
|
||
.agg(
|
||
account_value_max="max",
|
||
account_value_min="min",
|
||
account_value_mean="mean",
|
||
account_value_std="std",
|
||
account_value_median="median",
|
||
account_value_count="count",
|
||
)
|
||
.reset_index()
|
||
)
|
||
account_value_statistics_df["strategy_name"] = strategy_name
|
||
account_value_statistics_df = account_value_statistics_df[
|
||
[
|
||
"strategy_name",
|
||
"symbol",
|
||
"bar",
|
||
"account_value_max",
|
||
"account_value_min",
|
||
"account_value_mean",
|
||
"account_value_std",
|
||
"account_value_median",
|
||
"account_value_count",
|
||
]
|
||
]
|
||
|
||
# 依据symbol和bar分组,统计每个symbol和bar的interval_minutes的max, min, mean, std, median, count
|
||
interval_minutes_df = (
|
||
ma_break_market_data.groupby(["symbol", "bar"])["interval_minutes"]
|
||
.agg(
|
||
interval_minutes_max="max",
|
||
interval_minutes_min="min",
|
||
interval_minutes_mean="mean",
|
||
interval_minutes_std="std",
|
||
interval_minutes_median="median",
|
||
interval_minutes_count="count",
|
||
)
|
||
.reset_index()
|
||
)
|
||
interval_minutes_df["strategy_name"] = strategy_name
|
||
interval_minutes_df = interval_minutes_df[
|
||
[
|
||
"strategy_name",
|
||
"symbol",
|
||
"bar",
|
||
"interval_minutes_max",
|
||
"interval_minutes_min",
|
||
"interval_minutes_mean",
|
||
"interval_minutes_std",
|
||
"interval_minutes_median",
|
||
"interval_minutes_count",
|
||
]
|
||
]
|
||
|
||
earliest_market_date_time = ma_break_market_data["begin_date_time"].min()
|
||
earliest_market_date_time = re.sub(
|
||
r"[\:\-\s]", "", str(earliest_market_date_time)
|
||
)
|
||
latest_market_date_time = ma_break_market_data["end_date_time"].max()
|
||
if latest_market_date_time is None:
|
||
latest_market_date_time = get_current_date_time(format="%Y%m%d%H%M%S")
|
||
latest_market_date_time = re.sub(
|
||
r"[\:\-\s]", "", str(latest_market_date_time)
|
||
)
|
||
if self.commission_per_share > 0:
|
||
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}_with_commission.xlsx"
|
||
else:
|
||
output_file_name = f"ma_break_stats_from_{earliest_market_date_time}_to_{latest_market_date_time}_{strategy_name}_without_commission.xlsx"
|
||
output_file_path = os.path.join(self.stats_output_dir, output_file_name)
|
||
logger.info(f"导出{output_file_path}")
|
||
strategy_info_df = self.get_strategy_info(strategy_name)
|
||
with pd.ExcelWriter(output_file_path) as writer:
|
||
strategy_info_df.to_excel(writer, sheet_name="策略信息", index=False)
|
||
ma_break_market_data.to_excel(
|
||
writer, sheet_name="买卖记录明细", index=False
|
||
)
|
||
account_value_chg_df.to_excel(writer, sheet_name="资产价值变化", index=False)
|
||
account_value_statistics_df.to_excel(
|
||
writer, sheet_name="买卖账户价值统计", index=False
|
||
)
|
||
interval_minutes_df.to_excel(
|
||
writer, sheet_name="买卖时间间隔统计", index=False
|
||
)
|
||
|
||
chart_dict = self.draw_quant_pct_chg_bar_chart(account_value_chg_df, strategy_name)
|
||
self.output_chart_to_excel(output_file_path, chart_dict)
|
||
chart_dict = self.draw_quant_line_chart(
|
||
ma_break_market_data, market_data_pct_chg_df, strategy_name
|
||
)
|
||
self.output_chart_to_excel(output_file_path, chart_dict)
|
||
return account_value_chg_df
|
||
else:
|
||
return None
|
||
|
||
def get_strategy_info(self, strategy_name: str = "全均线策略"):
|
||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||
if strategy_config is None:
|
||
logger.error(f"策略{strategy_name}不存在")
|
||
return None
|
||
strategy_info = {"策略名称": strategy_name, "买入策略": "", "卖出策略": ""}
|
||
buy_dict = strategy_config.get("buy", {})
|
||
buy_and_list = buy_dict.get("and", [])
|
||
buy_or_list = buy_dict.get("or", [])
|
||
buy_and_text = ""
|
||
buy_or_text = ""
|
||
for and_condition in buy_and_list:
|
||
buy_and_text += f"{and_condition}, \n"
|
||
if len(buy_or_list) > 0:
|
||
for or_condition in buy_or_list:
|
||
buy_or_text += f"{or_condition}, \n"
|
||
if len(buy_or_text) > 0:
|
||
strategy_info["买入策略"] = buy_and_text + " 或者 \n" + buy_or_text
|
||
else:
|
||
strategy_info["买入策略"] = buy_and_text
|
||
sell_dict = strategy_config.get("sell", {})
|
||
sell_and_list = sell_dict.get("and", [])
|
||
sell_or_list = sell_dict.get("or", [])
|
||
sell_and_text = ""
|
||
sell_or_text = ""
|
||
for and_condition in sell_and_list:
|
||
sell_and_text += f"{and_condition}, \n"
|
||
if len(sell_or_list) > 0:
|
||
for or_condition in sell_or_list:
|
||
sell_or_text += f"{or_condition}, \n"
|
||
if len(sell_or_text) > 0:
|
||
strategy_info["卖出策略"] = sell_and_text + " 或者 \n" + sell_or_text
|
||
else:
|
||
strategy_info["卖出策略"] = sell_and_text
|
||
# 将strategy_info转换为pd.DataFrame
|
||
strategy_info_df = pd.DataFrame([strategy_info])
|
||
return strategy_info_df
|
||
|
||
def trade_simulate(self, symbol: str, bar: str, strategy_name: str = "全均线策略"):
|
||
market_data = self.get_full_data(symbol, bar)
|
||
if market_data is None or len(market_data) == 0:
|
||
logger.warning(f"获取{symbol} {bar} 数据失败")
|
||
return None, None
|
||
else:
|
||
logger.info(f"获取{symbol} {bar} 数据成功,数据条数: {len(market_data)}")
|
||
# 获得ma5, ma10, ma20, ma30不为空的行
|
||
market_data = market_data[
|
||
(market_data["ma5"].notna())
|
||
& (market_data["ma10"].notna())
|
||
& (market_data["ma20"].notna())
|
||
& (market_data["ma30"].notna())
|
||
]
|
||
logger.info(
|
||
f"ma5, ma10, ma20, ma30不为空的行,数据条数: {len(market_data)}"
|
||
)
|
||
# 计算volume_ma5
|
||
market_data["volume_ma5"] = market_data["volume"].rolling(window=5).mean()
|
||
|
||
market_data["volume_pct_chg"] = (
|
||
market_data["volume"] - market_data["volume_ma5"]
|
||
) / market_data["volume_ma5"]
|
||
market_data["volume_pct_chg"] = market_data["volume_pct_chg"].fillna(0)
|
||
|
||
# 按照timestamp排序
|
||
market_data = market_data.sort_values(by="timestamp", ascending=True)
|
||
# 获得ma_break_market_data的close列
|
||
market_data.reset_index(drop=True, inplace=True)
|
||
ma_break_market_data_pair_list = []
|
||
ma_break_market_data_pair = {}
|
||
if self.is_us_stock:
|
||
date_time_field = "date_time_us"
|
||
else:
|
||
date_time_field = "date_time"
|
||
close_mean = market_data["close"].mean()
|
||
self.update_initial_capital(close_mean)
|
||
logger.info(
|
||
f"成功获取{symbol}数据:{len(market_data)}根{bar}K线,开始日期={market_data[date_time_field].min()},结束日期={market_data[date_time_field].max()}"
|
||
)
|
||
|
||
account_value = self.initial_capital
|
||
|
||
for index, row in market_data.iterrows():
|
||
ma_cross = row["ma_cross"]
|
||
timestamp = row["timestamp"]
|
||
date_time = row[date_time_field]
|
||
close = row["close"]
|
||
ma5 = row["ma5"]
|
||
ma10 = row["ma10"]
|
||
ma20 = row["ma20"]
|
||
ma30 = row["ma30"]
|
||
macd_diff = float(row["dif"])
|
||
macd_dea = float(row["dea"])
|
||
macd = float(row["macd"])
|
||
if ma_break_market_data_pair.get("begin_timestamp", None) is None:
|
||
buy_condition = self.fit_strategy(
|
||
strategy_name=strategy_name,
|
||
market_data=market_data,
|
||
row=row,
|
||
behavior="buy",
|
||
)
|
||
|
||
if buy_condition:
|
||
entry_price = close
|
||
# 计算交易股数
|
||
shares, account_value = self.calculate_shares(
|
||
account_value, entry_price
|
||
)
|
||
if shares == 0:
|
||
# 股数为0→不交易
|
||
continue
|
||
|
||
# 计算佣金
|
||
buy_commission = shares * close * self.commission_per_share
|
||
|
||
ma_break_market_data_pair = {}
|
||
ma_break_market_data_pair["symbol"] = symbol
|
||
ma_break_market_data_pair["bar"] = bar
|
||
ma_break_market_data_pair["begin_timestamp"] = timestamp
|
||
ma_break_market_data_pair["begin_date_time"] = date_time
|
||
ma_break_market_data_pair["begin_close"] = close
|
||
ma_break_market_data_pair["begin_ma5"] = ma5
|
||
ma_break_market_data_pair["begin_ma10"] = ma10
|
||
ma_break_market_data_pair["begin_ma20"] = ma20
|
||
ma_break_market_data_pair["begin_ma30"] = ma30
|
||
ma_break_market_data_pair["begin_macd_diff"] = macd_diff
|
||
ma_break_market_data_pair["begin_macd_dea"] = macd_dea
|
||
ma_break_market_data_pair["begin_macd"] = macd
|
||
ma_break_market_data_pair["shares"] = shares
|
||
ma_break_market_data_pair["buy_commission"] = buy_commission
|
||
ma_break_market_data_pair["begin_account_value"] = account_value
|
||
continue
|
||
else:
|
||
sell_condition = self.fit_strategy(
|
||
strategy_name=strategy_name,
|
||
market_data=market_data,
|
||
row=row,
|
||
behavior="sell",
|
||
)
|
||
|
||
if sell_condition:
|
||
shares = ma_break_market_data_pair["shares"]
|
||
entry_price = ma_break_market_data_pair["begin_close"]
|
||
exit_price = close
|
||
sell_commission = (
|
||
shares * exit_price * self.commission_per_share
|
||
)
|
||
profit_loss = (exit_price - entry_price) * shares
|
||
begin_account_value = ma_break_market_data_pair[
|
||
"begin_account_value"
|
||
]
|
||
account_value = (
|
||
begin_account_value + profit_loss - sell_commission
|
||
)
|
||
|
||
ma_break_market_data_pair["end_timestamp"] = timestamp
|
||
ma_break_market_data_pair["end_date_time"] = date_time
|
||
ma_break_market_data_pair["end_close"] = close
|
||
ma_break_market_data_pair["end_ma5"] = ma5
|
||
ma_break_market_data_pair["end_ma10"] = ma10
|
||
ma_break_market_data_pair["end_ma20"] = ma20
|
||
ma_break_market_data_pair["end_ma30"] = ma30
|
||
ma_break_market_data_pair["end_macd_diff"] = macd_diff
|
||
ma_break_market_data_pair["end_macd_dea"] = macd_dea
|
||
ma_break_market_data_pair["end_macd"] = macd
|
||
ma_break_market_data_pair["pct_chg"] = (
|
||
exit_price - entry_price
|
||
) / entry_price
|
||
ma_break_market_data_pair["pct_chg"] = round(
|
||
ma_break_market_data_pair["pct_chg"] * 100, 4
|
||
)
|
||
ma_break_market_data_pair["profit_loss"] = profit_loss
|
||
ma_break_market_data_pair["sell_commission"] = sell_commission
|
||
ma_break_market_data_pair["end_account_value"] = account_value
|
||
ma_break_market_data_pair["interval_seconds"] = (
|
||
timestamp - ma_break_market_data_pair["begin_timestamp"]
|
||
) / 1000
|
||
# 将interval转换为分钟
|
||
ma_break_market_data_pair["interval_minutes"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 60
|
||
)
|
||
ma_break_market_data_pair["interval_hours"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 3600
|
||
)
|
||
ma_break_market_data_pair["interval_days"] = (
|
||
ma_break_market_data_pair["interval_seconds"] / 86400
|
||
)
|
||
ma_break_market_data_pair_list.append(ma_break_market_data_pair)
|
||
ma_break_market_data_pair = {}
|
||
|
||
if len(ma_break_market_data_pair_list) > 0:
|
||
ma_break_market_data = pd.DataFrame(ma_break_market_data_pair_list)
|
||
# sort by end_timestamp
|
||
ma_break_market_data.sort_values(
|
||
by="begin_timestamp", ascending=True, inplace=True
|
||
)
|
||
ma_break_market_data.reset_index(drop=True, inplace=True)
|
||
logger.info(
|
||
f"获取{symbol} {bar} 的买卖记录明细成功, 买卖次数: {len(ma_break_market_data)}"
|
||
)
|
||
# 量化期间,市场的波动率:
|
||
# ma_break_market_data(最后一条数据的end_close - 第一条数据的begin_close) / 第一条数据的begin_close * 100
|
||
pct_chg = (
|
||
(
|
||
ma_break_market_data["end_close"].iloc[-1]
|
||
- ma_break_market_data["begin_close"].iloc[0]
|
||
)
|
||
/ ma_break_market_data["begin_close"].iloc[0]
|
||
* 100
|
||
)
|
||
pct_chg = round(pct_chg, 4)
|
||
market_data_pct_chg = {
|
||
"symbol": symbol,
|
||
"bar": bar,
|
||
"pct_chg": pct_chg,
|
||
"initial_capital": self.initial_capital,
|
||
}
|
||
return ma_break_market_data, market_data_pct_chg
|
||
else:
|
||
return None, None
|
||
|
||
def update_initial_capital(self, close_mean: float):
|
||
self.initial_capital = 25000
|
||
if close_mean > 10000:
|
||
self.initial_capital = self.initial_capital * 10000
|
||
elif close_mean > 5000:
|
||
self.initial_capital = self.initial_capital * 5000
|
||
elif close_mean > 1000:
|
||
self.initial_capital = self.initial_capital * 1000
|
||
elif close_mean > 500:
|
||
self.initial_capital = self.initial_capital * 500
|
||
elif close_mean > 100:
|
||
self.initial_capital = self.initial_capital * 100
|
||
else:
|
||
pass
|
||
logger.info(f"收盘价均值:{close_mean}")
|
||
logger.info(f"初始资金调整为:{self.initial_capital}")
|
||
|
||
def calculate_shares(self, account_value, entry_price):
|
||
"""
|
||
根据ORB公式计算交易股数
|
||
:param account_value: 当前账户价值(美元)
|
||
:param entry_price: 交易买入价格
|
||
:param commission_per_share: 交易佣金, 默认为0.0008
|
||
:return: 整数股数(Shares)
|
||
"""
|
||
logger.info(
|
||
f"开始计算交易股数:账户价值={account_value},买入价格={entry_price}"
|
||
)
|
||
try:
|
||
# 验证输入参数
|
||
if account_value <= 0 or entry_price <= 0 or self.commission_per_share < 0:
|
||
logger.error("账户价值、买入价格或佣金不能为负或零")
|
||
return 0, account_value # 返回0股,账户价值不变
|
||
|
||
# 计算考虑手续费后的每单位BTC总成本
|
||
total_cost_per_share = entry_price * (1 + self.commission_per_share)
|
||
|
||
# 计算可购买的BTC数量(向下取整)
|
||
shares = math.floor(account_value / total_cost_per_share)
|
||
|
||
# 计算总成本(含手续费)
|
||
total_cost = shares * total_cost_per_share
|
||
|
||
# 计算剩余现金
|
||
remaining_cash = account_value - total_cost
|
||
|
||
# 计算总资产价值 = (购买的BTC数量 × 买入价格) + 剩余现金
|
||
remaining_value = (shares * entry_price) + remaining_cash
|
||
|
||
# 记录计算结果
|
||
logger.info(
|
||
f"计算结果:可购买股数={shares},总成本={total_cost:.2f}美元,"
|
||
f"剩余现金={remaining_cash:.2f}美元,总资产价值={remaining_value:.2f}美元"
|
||
)
|
||
|
||
return shares, remaining_value
|
||
|
||
except Exception as e:
|
||
logger.error(f"计算股数或账户价值时出错:{str(e)}")
|
||
return 0, account_value # 出错时返回0股,账户价值不变
|
||
|
||
def get_full_data(self, symbol: str, bar: str = "5m"):
|
||
"""
|
||
分段获取数据,并将数据合并为完整数据
|
||
分段依据:如果end_date与start_date相差超过一年,则每次取一年数据
|
||
"""
|
||
data = pd.DataFrame()
|
||
start_date = datetime.strptime(self.initial_date, "%Y-%m-%d")
|
||
end_date = datetime.strptime(self.end_date, "%Y-%m-%d") + timedelta(days=1)
|
||
fields = [
|
||
"symbol",
|
||
"bar",
|
||
"timestamp",
|
||
"date_time",
|
||
"date_time_us",
|
||
"open",
|
||
"high",
|
||
"low",
|
||
"close",
|
||
"volume",
|
||
"sar_signal",
|
||
"ma5",
|
||
"ma10",
|
||
"ma20",
|
||
"ma30",
|
||
"ma_cross",
|
||
"dif",
|
||
"dea",
|
||
"macd",
|
||
]
|
||
while start_date < end_date:
|
||
current_end_date = min(start_date + timedelta(days=180), end_date)
|
||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||
current_end_date_str = current_end_date.strftime("%Y-%m-%d")
|
||
logger.info(f"获取{symbol}数据:{start_date_str}至{current_end_date_str}")
|
||
current_data = self.db_market_data.query_market_data_by_symbol_bar(
|
||
symbol, bar, fields, start=start_date_str, end=current_end_date_str
|
||
)
|
||
if current_data is not None and len(current_data) > 0:
|
||
current_data = pd.DataFrame(current_data)
|
||
data = pd.concat([data, current_data])
|
||
start_date = current_end_date
|
||
data.drop_duplicates(inplace=True)
|
||
if self.is_us_stock:
|
||
date_time_field = "date_time_us"
|
||
else:
|
||
date_time_field = "date_time"
|
||
data.sort_values(by=date_time_field, inplace=True)
|
||
data.reset_index(drop=True, inplace=True)
|
||
return data
|
||
|
||
def fit_strategy(
|
||
self,
|
||
strategy_name: str = "全均线策略",
|
||
market_data: pd.DataFrame = None,
|
||
row: pd.Series = None,
|
||
behavior: str = "buy",
|
||
):
|
||
strategy_config = self.main_strategy.get(strategy_name, None)
|
||
if strategy_config is None:
|
||
logger.error(f"策略{strategy_name}不存在")
|
||
return False
|
||
condition_dict = strategy_config.get(behavior, None)
|
||
if condition_dict is None:
|
||
logger.error(f"策略{strategy_name}的{behavior}条件不存在")
|
||
return False
|
||
ma_cross = row["ma_cross"]
|
||
if pd.isna(ma_cross) or ma_cross is None:
|
||
ma_cross = ""
|
||
ma_cross = str(ma_cross)
|
||
ma5 = float(row["ma5"])
|
||
ma10 = float(row["ma10"])
|
||
ma20 = float(row["ma20"])
|
||
ma30 = float(row["ma30"])
|
||
close = float(row["close"])
|
||
volume_pct_chg = float(row["volume_pct_chg"])
|
||
macd_diff = float(row["dif"])
|
||
macd_dea = float(row["dea"])
|
||
macd = float(row["macd"])
|
||
|
||
and_list = condition_dict.get("and", [])
|
||
|
||
condition = True
|
||
for and_condition in and_list:
|
||
if and_condition == "5上穿10":
|
||
condition = condition and ("5上穿10" in ma_cross)
|
||
elif and_condition == "10上穿20":
|
||
condition = condition and ("10上穿20" in ma_cross)
|
||
elif and_condition == "20上穿30":
|
||
condition = condition and ("20上穿30" in ma_cross)
|
||
elif and_condition == "ma5>ma10":
|
||
condition = condition and (ma5 > ma10)
|
||
elif and_condition == "ma10>ma20":
|
||
condition = condition and (ma10 > ma20)
|
||
elif and_condition == "ma20>ma30":
|
||
condition = condition and (ma20 > ma30)
|
||
elif and_condition == "close>ma20":
|
||
condition = condition and (close > ma20)
|
||
elif and_condition == "volume_pct_chg>0.2":
|
||
condition = condition and (volume_pct_chg > 0.2)
|
||
elif and_condition == "macd_diff>0":
|
||
condition = condition and (macd_diff > 0)
|
||
elif and_condition == "macd_dea>0":
|
||
condition = condition and (macd_dea > 0)
|
||
elif and_condition == "macd>0":
|
||
condition = condition and (macd > 0)
|
||
elif and_condition == "10下穿5":
|
||
condition = condition and ("10下穿5" in ma_cross)
|
||
elif and_condition == "20下穿10":
|
||
condition = condition and ("20下穿10" in ma_cross)
|
||
elif and_condition == "30下穿20":
|
||
condition = condition and ("30下穿20" in ma_cross)
|
||
elif and_condition == "ma5<ma10":
|
||
condition = condition and (ma5 < ma10)
|
||
elif and_condition == "ma10<ma20":
|
||
condition = condition and (ma10 < ma20)
|
||
elif and_condition == "ma20<ma30":
|
||
condition = condition and (ma20 < ma30)
|
||
elif and_condition == "close<ma20":
|
||
condition = condition and (close < ma20)
|
||
elif and_condition == "macd_diff<0":
|
||
condition = condition and (macd_diff < 0)
|
||
elif and_condition == "macd_dea<0":
|
||
condition = condition and (macd_dea < 0)
|
||
elif and_condition == "macd<0":
|
||
condition = condition and (macd < 0)
|
||
else:
|
||
pass
|
||
if not condition:
|
||
or_list = condition_dict.get("or", [])
|
||
for or_condition in or_list:
|
||
if or_condition == "5上穿10":
|
||
condition = condition or ("5上穿10" in ma_cross)
|
||
elif or_condition == "10上穿20":
|
||
condition = condition or ("10上穿20" in ma_cross)
|
||
elif or_condition == "20上穿30":
|
||
condition = condition or ("20上穿30" in ma_cross)
|
||
elif or_condition == "ma5>ma10":
|
||
condition = condition or (ma5 > ma10)
|
||
elif or_condition == "ma10>ma20":
|
||
condition = condition or (ma10 > ma20)
|
||
elif or_condition == "ma20>ma30":
|
||
condition = condition or (ma20 > ma30)
|
||
elif or_condition == "close>ma20":
|
||
condition = condition or (close > ma20)
|
||
elif or_condition == "volume_pct_chg>0.2":
|
||
condition = condition or (volume_pct_chg > 0.2)
|
||
elif or_condition == "macd_diff>0":
|
||
condition = condition or (macd_diff > 0)
|
||
elif or_condition == "macd_dea>0":
|
||
condition = condition or (macd_dea > 0)
|
||
elif or_condition == "macd>0":
|
||
condition = condition or (macd > 0)
|
||
elif or_condition == "10下穿5":
|
||
condition = condition or ("10下穿5" in ma_cross)
|
||
elif or_condition == "20下穿10":
|
||
condition = condition or ("20下穿10" in ma_cross)
|
||
elif or_condition == "30下穿20":
|
||
condition = condition or ("30下穿20" in ma_cross)
|
||
elif or_condition == "ma5<ma10":
|
||
condition = condition or (ma5 < ma10)
|
||
elif or_condition == "ma10<ma20":
|
||
condition = condition or (ma10 < ma20)
|
||
elif or_condition == "ma20<ma30":
|
||
condition = condition or (ma20 < ma30)
|
||
elif or_condition == "close<ma20":
|
||
condition = condition or (close < ma20)
|
||
elif or_condition == "macd_diff<0":
|
||
condition = condition or (macd_diff < 0)
|
||
elif or_condition == "macd_dea<0":
|
||
condition = condition or (macd_dea < 0)
|
||
elif or_condition == "macd<0":
|
||
condition = condition or (macd < 0)
|
||
else:
|
||
pass
|
||
return condition
|
||
|
||
def draw_quant_pct_chg_bar_chart(
|
||
self, data: pd.DataFrame, strategy_name: str = "全均线策略"
|
||
):
|
||
"""
|
||
绘制pct_chg mean的柱状图表(美观,保存到self.stats_chart_dir)
|
||
:param data: 波段pct_chg_mean的数据
|
||
:return: None
|
||
"""
|
||
if data is None or data.empty:
|
||
return None
|
||
# seaborn风格设置
|
||
sns.set_theme(style="whitegrid")
|
||
plt.rcParams["font.sans-serif"] = ["SimHei"] # 也可直接用字体名
|
||
plt.rcParams["font.size"] = 11 # 设置字体大小
|
||
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
|
||
chart_dict = {}
|
||
column_name_dict = {
|
||
"account_value_chg": "量化策略涨跌",
|
||
}
|
||
for column_name, column_name_text in column_name_dict.items():
|
||
for bar in data["bar"].unique():
|
||
bar_data = data[data["bar"] == bar].copy() # 一次筛选即可
|
||
if bar_data.empty:
|
||
continue
|
||
bar_data.rename(columns={column_name: column_name_text}, inplace=True)
|
||
bar_data.rename(
|
||
columns={"market_pct_chg": "市场自然涨跌"}, inplace=True
|
||
)
|
||
# 可选:按均值排序
|
||
bar_data.sort_values(by=column_name_text, ascending=False, inplace=True)
|
||
bar_data.reset_index(drop=True, inplace=True)
|
||
|
||
# 如果column_name_text是"量化策略涨跌",则柱状图,同时绘制量化策略涨跌与市场自然涨跌的柱状图,并绘制在同一个图表中
|
||
plt.figure(figsize=(12, 7))
|
||
|
||
# 设置x轴位置,为并列柱状图做准备
|
||
x = np.arange(len(bar_data))
|
||
width = 0.35 # 柱状图宽度
|
||
|
||
# 确保symbol列是字符串类型,避免matplotlib警告
|
||
bar_data["symbol"] = bar_data["symbol"].astype(str)
|
||
|
||
# 绘制量化策略涨跌柱状图(蓝色渐变色)
|
||
bars1 = plt.bar(
|
||
x - width / 2,
|
||
bar_data[column_name_text],
|
||
width,
|
||
label=column_name_text,
|
||
color=plt.cm.Blues(np.linspace(0.6, 0.9, len(bar_data))),
|
||
)
|
||
|
||
# 绘制市场自然涨跌柱状图(绿色渐变色)
|
||
bars2 = plt.bar(
|
||
x + width / 2,
|
||
bar_data["市场自然涨跌"],
|
||
width,
|
||
label="市场自然涨跌",
|
||
color=plt.cm.Greens(np.linspace(0.6, 0.9, len(bar_data))),
|
||
)
|
||
|
||
# 设置图表标题和标签
|
||
plt.title(
|
||
f"{bar}趋势{column_name_text}与市场自然涨跌对比(%)",
|
||
fontsize=14,
|
||
fontweight="bold",
|
||
)
|
||
plt.xlabel("Symbol", fontsize=12)
|
||
plt.ylabel("涨跌幅(%)", fontsize=12)
|
||
plt.xticks(x, bar_data["symbol"], rotation=45, ha="right")
|
||
plt.legend()
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# 在量化策略涨跌柱状图上方添加数值标签
|
||
for i, (bar1, value1) in enumerate(
|
||
zip(bars1, bar_data[column_name_text])
|
||
):
|
||
plt.text(
|
||
bar1.get_x() + bar1.get_width() / 2,
|
||
value1 + (0.01 if value1 >= 0 else -0.01),
|
||
f"{value1:.3f}%",
|
||
ha="center",
|
||
va="bottom" if value1 >= 0 else "top",
|
||
fontsize=9,
|
||
fontweight="bold",
|
||
color="darkblue",
|
||
)
|
||
|
||
# 在市场自然涨跌柱状图上方添加数值标签
|
||
for i, (bar2, value2) in enumerate(
|
||
zip(bars2, bar_data["市场自然涨跌"])
|
||
):
|
||
plt.text(
|
||
bar2.get_x() + bar2.get_width() / 2,
|
||
value2 + (0.01 if value2 >= 0 else -0.01),
|
||
f"{value2:.3f}%",
|
||
ha="center",
|
||
va="bottom" if value2 >= 0 else "top",
|
||
fontsize=9,
|
||
fontweight="bold",
|
||
color="darkgreen",
|
||
)
|
||
|
||
plt.tight_layout()
|
||
|
||
if self.commission_per_share > 0:
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{bar}_bar_chart_{column_name}_{strategy_name}_with_commission.png",
|
||
)
|
||
else:
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{bar}_bar_chart_{column_name}_{strategy_name}_without_commission.png",
|
||
)
|
||
|
||
plt.savefig(save_path, dpi=150)
|
||
plt.close()
|
||
|
||
sheet_name = f"{bar}_趋势{column_name_text}柱状图_{strategy_name}"
|
||
chart_dict[sheet_name] = save_path
|
||
return chart_dict
|
||
|
||
def draw_quant_line_chart(
|
||
self,
|
||
data: pd.DataFrame,
|
||
market_data_pct_chg_df: pd.DataFrame,
|
||
strategy_name: str = "全均线策略",
|
||
):
|
||
"""
|
||
根据量化策略买卖明细记录,绘制量化策略涨跌与市场自然涨跌的折线图
|
||
:param data: 量化策略买卖明细记录
|
||
:param market_data_pct_chg_df: 市场自然涨跌记录
|
||
:param strategy_name: 策略名称
|
||
:return: None
|
||
"""
|
||
symbols = data["symbol"].unique()
|
||
bars = data["bar"].unique()
|
||
chart_dict = {}
|
||
for symbol in symbols:
|
||
for bar in bars:
|
||
symbol_bar_data = data[
|
||
(data["symbol"] == symbol) & (data["bar"] == bar)
|
||
]
|
||
if symbol_bar_data.empty:
|
||
continue
|
||
|
||
# 获取第一行数据作为基准
|
||
first_row = symbol_bar_data.iloc[0].copy()
|
||
initial_capital = int(
|
||
market_data_pct_chg_df.loc[
|
||
(market_data_pct_chg_df["symbol"] == symbol)
|
||
& (market_data_pct_chg_df["bar"] == bar),
|
||
"initial_capital",
|
||
].values[0]
|
||
)
|
||
# 创建初始化行,设置基准值
|
||
init_row = first_row.copy()
|
||
init_row.loc["profit_loss"] = 0
|
||
init_row.loc["end_account_value"] = (
|
||
initial_capital # 量化策略初始值为初始资金
|
||
)
|
||
init_row.loc["end_timestamp"] = first_row["begin_timestamp"]
|
||
init_row.loc["end_date_time"] = first_row["begin_date_time"]
|
||
init_row.loc["end_close"] = first_row["begin_close"]
|
||
init_row.loc["end_ma5"] = first_row["begin_ma5"]
|
||
init_row.loc["end_ma10"] = first_row["begin_ma10"]
|
||
init_row.loc["end_ma20"] = first_row["begin_ma20"]
|
||
init_row.loc["end_ma30"] = first_row["begin_ma30"]
|
||
init_row.loc["end_macd_diff"] = first_row["begin_macd_diff"]
|
||
init_row.loc["end_macd_dea"] = first_row["begin_macd_dea"]
|
||
init_row.loc["end_macd"] = first_row["begin_macd"]
|
||
init_row.loc["pct_chg"] = 0
|
||
init_row.loc["interval_seconds"] = 0
|
||
init_row.loc["interval_minutes"] = 0
|
||
init_row.loc["interval_hours"] = 0
|
||
init_row.loc["interval_days"] = 0
|
||
|
||
# 将初始化行添加到数据开头
|
||
symbol_bar_data = pd.concat([pd.DataFrame([init_row]), symbol_bar_data])
|
||
symbol_bar_data.sort_values(
|
||
by="end_timestamp", ascending=True, inplace=True
|
||
)
|
||
symbol_bar_data.reset_index(drop=True, inplace=True)
|
||
|
||
# 确保时间列是datetime类型,避免matplotlib警告
|
||
symbol_bar_data["end_date_time"] = pd.to_datetime(
|
||
symbol_bar_data["end_date_time"]
|
||
)
|
||
|
||
# 计算持仓价值归一化数据(相对于初始价格)
|
||
symbol_bar_data["end_account_value_to_1"] = (
|
||
symbol_bar_data["end_account_value"] / initial_capital
|
||
)
|
||
symbol_bar_data["end_account_value_to_1"] = symbol_bar_data[
|
||
"end_account_value_to_1"
|
||
].round(4)
|
||
|
||
# 计算市场价位归一化数据(相对于初始价格)
|
||
symbol_bar_data["end_close_to_1"] = (
|
||
symbol_bar_data["end_close"] / init_row["end_close"]
|
||
)
|
||
symbol_bar_data["end_close_to_1"] = symbol_bar_data[
|
||
"end_close_to_1"
|
||
].round(4)
|
||
|
||
# 绘制折线图
|
||
plt.figure(figsize=(12, 7))
|
||
|
||
# 绘制量化策略涨跌线(蓝色)
|
||
plt.plot(
|
||
symbol_bar_data["end_date_time"],
|
||
symbol_bar_data["end_account_value_to_1"],
|
||
label="量化策略涨跌",
|
||
color="blue",
|
||
linewidth=2,
|
||
marker="o",
|
||
markersize=4,
|
||
)
|
||
|
||
# 绘制市场自然涨跌线(绿色)
|
||
plt.plot(
|
||
symbol_bar_data["end_date_time"],
|
||
symbol_bar_data["end_close_to_1"],
|
||
label="市场自然涨跌",
|
||
color="green",
|
||
linewidth=2,
|
||
marker="s",
|
||
markersize=4,
|
||
)
|
||
|
||
plt.title(
|
||
f"{symbol} {bar} 量化与市场折线图_{strategy_name}",
|
||
fontsize=14,
|
||
fontweight="bold",
|
||
)
|
||
plt.xlabel("时间", fontsize=12)
|
||
plt.ylabel("涨跌变化", fontsize=12)
|
||
plt.legend(fontsize=11)
|
||
plt.grid(True, alpha=0.3)
|
||
|
||
# 设置x轴标签,避免matplotlib警告
|
||
# 选择合适的时间间隔显示标签,避免过于密集
|
||
if len(symbol_bar_data) > 30:
|
||
# 如果数据点较多,选择间隔显示,但确保第一条和最后一条始终显示
|
||
step = max(1, len(symbol_bar_data) // 30)
|
||
|
||
# 创建标签索引列表,确保包含首尾数据
|
||
label_indices = [0] # 第一条
|
||
|
||
# 添加中间间隔的标签
|
||
for i in range(step, len(symbol_bar_data) - 1, step):
|
||
label_indices.append(i)
|
||
|
||
# 添加最后一条(如果还没有包含的话)
|
||
if len(symbol_bar_data) - 1 not in label_indices:
|
||
label_indices.append(len(symbol_bar_data) - 1)
|
||
|
||
# 设置x轴标签
|
||
plt.xticks(
|
||
symbol_bar_data["end_date_time"].iloc[label_indices],
|
||
symbol_bar_data["end_date_time"]
|
||
.iloc[label_indices]
|
||
.dt.strftime("%Y%m%d %H:%M"),
|
||
rotation=45,
|
||
ha="right",
|
||
)
|
||
else:
|
||
# 如果数据点较少,全部显示
|
||
plt.xticks(
|
||
symbol_bar_data["end_date_time"],
|
||
symbol_bar_data["end_date_time"].dt.strftime("%Y%m%d %H:%M"),
|
||
rotation=45,
|
||
ha="right",
|
||
)
|
||
|
||
plt.tight_layout()
|
||
|
||
if self.commission_per_share > 0:
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{symbol}_{bar}_line_chart_{strategy_name}_with_commission.png",
|
||
)
|
||
else:
|
||
save_path = os.path.join(
|
||
self.stats_chart_dir,
|
||
f"{symbol}_{bar}_line_chart_{strategy_name}_without_commission.png",
|
||
)
|
||
|
||
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
||
plt.close()
|
||
|
||
sheet_name = f"{symbol}_{bar}_折线图_{strategy_name}"
|
||
chart_dict[sheet_name] = save_path
|
||
return chart_dict
|
||
|
||
def output_chart_to_excel(self, excel_file_path: str, charts_dict: dict):
|
||
"""
|
||
输出Excel文件,包含所有图表
|
||
charts_dict: 图表数据字典,格式为:
|
||
{
|
||
"sheet_name": {
|
||
"chart_name": "chart_path"
|
||
}
|
||
}
|
||
"""
|
||
logger.info(f"将图表输出到{excel_file_path}")
|
||
|
||
# 打开已经存在的Excel文件
|
||
wb = openpyxl.load_workbook(excel_file_path)
|
||
|
||
for sheet_name, chart_path in charts_dict.items():
|
||
try:
|
||
ws = wb.create_sheet(title=sheet_name)
|
||
row_offset = 1
|
||
# Insert chart image
|
||
img = Image(chart_path)
|
||
ws.add_image(img, f"A{row_offset}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"输出Excel Sheet {sheet_name} 失败: {e}")
|
||
continue
|
||
# Save Excel file
|
||
wb.save(excel_file_path)
|
||
print(f"Chart saved as {excel_file_path}")
|