Files
trade/engine/example/regime_all.py
T
Rekey edc50e8809 feat: 新增2h/6h时间框架支持,策略重构为增量指标计算
- 数据层: build_aggregates_sql 新增 2h/6h 聚合视图,默认起始时间调整为 2017-05
- 模型层: KlineInterval 类型扩展 2h/6h,DataService 新增对应表名和毫秒映射
- 指标层: 新增 incremental.py 增量指标模块 (EmaInc/AtrInc/RsiInc/BbInc),O(1) per bar
- 策略重构: long_short.py 和 regime_all.py 从批量 ema/atr 迁移至增量指标,避免每 bar 重复全量计算
- regime 探测器: RegimeDetector3 改为增量 EMA200,detect() 接口简化
- 回测扩展: regime_timeframe_comparison 从 4h/1d 扩展至 2h/4h/6h/1d
- 新增示例: multi_strategy_report, vol_break_compare/periods, intraday_explore, top3_trades 等分析脚本
2026-06-13 19:30:25 +08:00

248 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
最佳牛熊判定 — 全币种全周期回测
方法:EMA200斜率 + 价格vs EMA200 + ATH回撤,3选2投票
策略:牛市只做多 / 熊市只做空 / 震荡空仓
币种:BTC / ETH / BNB / SOL,各自最早有数据到2026
用法:
source .venv/bin/activate && python example/regime_all.py
"""
import asyncio
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
_project_root = Path(__file__).resolve().parent.parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
from engine.common.base import BaseStrategy, Signal, StrategyConfig
from engine.common.models import Kline
from engine.common.config import config
from engine.backtest import BacktestConfig
from engine.indicators.incremental import EmaInc, AtrInc
from engine.data import DataService
from engine.example.long_short import LongShortEngine
# ════════════════════════════════════════════════════════
# 3法判定器(增量 EMA200O(1) per bar
# ════════════════════════════════════════════════════════
class RegimeDetector3:
"""牛熊判定器,内部维护增量 EMA(200),避免每次从头重算"""
def __init__(self):
self._ath = 0.0
self._e200 = EmaInc(200)
def update(self, price: float):
"""每根 bar 调一次:更新 ATH + EMA(200)"""
if price > self._ath:
self._ath = price
self._e200.update(price)
def _ema200_slope(self, idx: int) -> str:
if idx < 220: return "unknown"
e200 = self._e200
if e200[idx - 20] == 0: return "unknown"
slope = (e200[idx] - e200[idx - 20]) / e200[idx - 20]
if slope > 0.002: return "bull"
if slope < -0.002: return "bear"
return "sideways"
def _price_vs_ema200(self, price: float, idx: int) -> str:
if idx < 210: return "unknown"
e = self._e200[idx]
if e == 0: return "unknown"
return "bull" if price > e else "bear"
def _ath_drawdown(self, price: float) -> str:
if self._ath == 0: return "unknown"
dd = (price - self._ath) / self._ath
if dd > -0.15: return "bull"
if dd < -0.35: return "bear"
return "sideways"
def detect(self, price: float, idx: int) -> str:
r1 = self._ema200_slope(idx)
r2 = self._price_vs_ema200(price, idx)
r3 = self._ath_drawdown(price)
b = sum(1 for r in [r1, r2, r3] if r == "bull")
br = sum(1 for r in [r1, r2, r3] if r == "bear")
if b >= 2: return "bull"
if br >= 2: return "bear"
return "sideways"
# ════════════════════════════════════════════════════════
# 自适应策略
# ════════════════════════════════════════════════════════
class RegimeEmaConfig(StrategyConfig):
fast: int = 10; slow: int = 50; atr_stop: float = 2.5
class RegimeEmaStrategy(BaseStrategy):
"""按市场状态自适应做多/做空 — 全部指标增量计算,O(1) per bar"""
strategy_type = "regime_ema"
def __init__(self, c: RegimeEmaConfig):
super().__init__(c)
self.cfg = c
self._c: list[float] = []; self._h: list[float] = []; self._l: list[float] = []
self._detector = RegimeDetector3()
self._ema_fast = EmaInc(c.fast)
self._ema_slow = EmaInc(c.slow)
self._atr = AtrInc(14)
self._side: str = ""; self._hp: float = 0.0; self._lp: float = float('inf')
async def on_kline(self, k: Kline) -> Optional[Signal]:
self._c.append(k.close); self._h.append(k.high); self._l.append(k.low)
# 增量更新所有指标(O(1) each)
self._detector.update(k.close)
self._ema_fast.update(k.close)
self._ema_slow.update(k.close)
self._atr.update(k.high, k.low, k.close)
n = len(self._c)
if n < 220: return None
regime = self._detector.detect(k.close, n - 1)
cf, cs = self._ema_fast[-1], self._ema_slow[-1]
ca = self._atr[-1]
pf, ps = self._ema_fast[-2], self._ema_slow[-2]
if cf == 0 or cs == 0 or ca == 0: return None
golden = pf <= ps and cf > cs; death = pf >= ps and cf < cs
# 多头持仓
if self._side == "long":
self._hp = max(self._hp, k.high); stop = self._hp - self.cfg.atr_stop * ca
if death or k.close < stop or regime == "bear":
self._side = ""
reason = "死叉" if death else ("ATR止损" if k.close < stop else "转熊")
return Signal(symbol=self.cfg.symbol, side="SELL", reason=reason, timestamp=k.open_time)
# 空头持仓
elif self._side == "short":
self._lp = min(self._lp, k.low); stop = self._lp + self.cfg.atr_stop * ca
if golden or k.close > stop or regime == "bull":
self._side = ""
reason = "金叉" if golden else ("ATR止损" if k.close > stop else "转牛")
return Signal(symbol=self.cfg.symbol, side="BUY", reason=reason, timestamp=k.open_time)
# 空仓等信号
else:
if regime == "bull" and golden:
self._side = "long"; self._hp = k.close
return Signal(symbol=self.cfg.symbol, side="BUY", reason=f"牛市金叉", timestamp=k.open_time)
elif regime == "bear" and death:
self._side = "short"; self._lp = k.close
return Signal(symbol=self.cfg.symbol, side="SELL", reason=f"熊市死叉", timestamp=k.open_time)
return None
# ════════════════════════════════════════════════════════
SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "SOLUSDT"]
PARAMS = {
"BTCUSDT": (10, 50),
"ETHUSDT": (10, 75),
"BNBUSDT": (20, 50),
"SOLUSDT": (30, 50),
}
DATE_START = datetime(2017, 1, 1)
DATE_END = datetime(2026, 1, 1)
async def get_actual_range(symbol: str) -> tuple[datetime, datetime]:
"""获取币种实际数据范围"""
ds = DataService(config.db)
await ds.connect()
try:
start, end = await ds.fetch_symbol_date_range(symbol, "4h")
return start, end
except:
return DATE_START, DATE_END
finally:
await ds.close()
async def main():
print()
print("" * 125)
print(" 牛熊自适应策略 — 全币种全周期 | 牛市只多/熊市只空/震荡空仓")
print("" * 125)
print(f"\n ■ 全周期汇总")
print(f" {'币种':<10} {'数据范围':<22} {'总收益%':>7} {'年化%':>7} {'夏普':>6} {'回撤%':>7} {'交易':>5} {'多头P&L':>10} {'空头P&L':>10}")
print(" " + "" * 115)
for symbol in SYMBOLS:
fast, slow = PARAMS[symbol]
sc = RegimeEmaConfig(symbol=symbol, fast=fast, slow=slow)
# 获取实际数据范围
try:
act_start, act_end = await get_actual_range(symbol)
range_str = f"{act_start.date()}~{act_end.date()}"
except:
act_start, act_end = DATE_START, DATE_END
range_str = "2017-2026"
bt = BacktestConfig(symbol=symbol, interval="4h",
start_time=act_start, end_time=act_end, initial_capital=10_000.0)
engine = LongShortEngine(bt, db_config=config.db)
r = await engine.run(RegimeEmaStrategy, sc)
m = r.metrics
long_t = [t for t in r.trades if t.pnl is not None and t.side == "SELL"]
short_t = [t for t in r.trades if t.pnl is not None and t.side == "BUY"]
long_pnl = sum(t.pnl for t in long_t) if long_t else 0
short_pnl = sum(t.pnl for t in short_t) if short_t else 0
print(f" {symbol:<10} {range_str:<22} {m.total_return_pct:>6.1f}% {m.annual_return_pct:>6.1f}% {m.sharpe_ratio:>6.2f} {m.max_drawdown_pct:>6.1f}% {m.total_trades:>5} {long_pnl:>+9.0f} {short_pnl:>+9.0f}")
# ── BTC 分段 ──
PERIODS = [
("2017 牛市", datetime(2017,1,1), datetime(2018,1,1)),
("2018 熊市", datetime(2018,1,1), datetime(2019,1,1)),
("2019 反弹", datetime(2019,1,1), datetime(2020,1,1)),
("2020 牛初", datetime(2020,1,1), datetime(2021,1,1)),
("2021 牛市", datetime(2021,1,1), datetime(2022,1,1)),
("2022 熊市", datetime(2022,1,1), datetime(2023,1,1)),
("2023 震荡", datetime(2023,1,1), datetime(2024,1,1)),
("2024-25牛", datetime(2024,1,1), datetime(2026,1,1)),
]
print(f"\n ■ BTC 分段表现")
print(f" {'阶段':<16} {'总收益%':>7} {'夏普':>6} {'多头P&L':>9} {'空头P&L':>9}")
print(" " + "" * 65)
for name, s, e in PERIODS:
try:
sc = RegimeEmaConfig(symbol="BTCUSDT", fast=10, slow=50)
bt = BacktestConfig(symbol="BTCUSDT", interval="4h", start_time=s, end_time=e, initial_capital=10_000.0)
eng = LongShortEngine(bt, db_config=config.db)
r = await eng.run(RegimeEmaStrategy, sc)
lt = [t for t in r.trades if t.pnl is not None and t.side == "SELL"]
st = [t for t in r.trades if t.pnl is not None and t.side == "BUY"]
print(f" {name:<16} {r.metrics.total_return_pct:>+6.1f}% {r.metrics.sharpe_ratio:>6.2f} {sum(t.pnl for t in lt) if lt else 0:>+8.0f} {sum(t.pnl for t in st) if st else 0:>+8.0f}")
except Exception as ex:
print(f" {name:<16} 数据不足")
print("\n" * 125)
if __name__ == "__main__":
asyncio.run(main())