Files
trade/engine/indicators/regime.py
T
Rekey 4294ec401d feat: 多周期牛熊判定模块 — 方案一矩阵展示 + 四法投票 + 多TF策略
- engine/indicators/regime.py: RegimeDetector(四法投票) + MultiTimeframeRegime(多周期并行)
  四法: EMA200斜率 / 价格vsEMA200 / ATH回撤 / 窄幅盘整(<3%振幅)
  全部 O(1)/bar 增量计算,适用于回测和实时
- engine/example/regime_display.py: 多周期牛熊矩阵展示脚本
  独立加载各周期数据 → 运行判定 → 日线对齐矩阵 + 详细拆解 + 统计
  输出 engine/backtest/REGIME_MATRIX_BTCUSDT.md
- engine/example/regime_mtf_strategy.py: 多周期共识策略 + 四策略对比回测
  MTF Consensus: 1w定方向 + 1d确认 + 4h EMA入场
  vs Old Regime(单TF基线) vs Long/Short(无过滤)
- engine/indicators/__init__.py: 导出 RegimeDetector, MultiTimeframeRegime
2026-06-17 11:30:19 +08:00

269 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
多周期牛熊判定模块
基于 EMA200 斜率、价格 vs EMA200、ATH 回撤三法投票,支持单周期判定
和多周期并行分析。每根 bar O(1) 增量更新,适用于回测和实时分析。
用法:
# 单周期
det = RegimeDetector()
for price in closes:
det.update(price)
regime = det.detect(closes[-1], len(closes) - 1)
# 多周期
mtf = MultiTimeframeRegime(['1h', '4h', '1d', '1w'])
for price in closes:
mtf.update(price)
matrix = mtf.detect_all(closes[-1], len(closes) - 1)
"""
from .incremental import EmaInc
class RegimeDetector:
"""单周期牛熊判定器 — 四法投票
四个独立判定维度:
1. EMA200 斜率:近 20 根 bar 的 EMA200 变化率 → bull/bear/sideways
2. 价格 vs EMA200:当前价在 EMA200 上方/下方 → bull/bear
3. ATH 回撤:当前价距历史高点的回撤幅度 → bull/bear/sideways
4. 窄幅盘整:近 20 根 bar 收盘价振幅 < 3% → sideways(否则 unknown,不投票)
最终判定:四选二投票(>=2 票即定论)。
"""
def __init__(self, range_period: int = 20, range_threshold: float = 0.03):
self._ath = 0.0
self._e200 = EmaInc(200)
self._range_period = range_period
self._range_threshold = range_threshold
self._recent_high: list[float] = []
self._recent_low: list[float] = []
def update(self, price: float) -> None:
"""每根 bar 调一次:更新 ATH + EMA200 + 价格区间窗口"""
if price > self._ath:
self._ath = price
self._e200.update(price)
# 滚动窗口:维持最近 range_period 根 bar 的最高/最低(只用 close 做代理)
self._recent_high.append(price)
self._recent_low.append(price)
if len(self._recent_high) > self._range_period:
self._recent_high.pop(0)
self._recent_low.pop(0)
def detect(self, price: float, idx: int) -> str:
"""返回当前 bar 的牛熊判定:bull / bear / sideways"""
r1 = self._ema200_slope(idx)
r2 = self._price_vs_ema200(price, idx)
r3 = self._ath_drawdown(price)
r4 = self._price_range()
b = sum(1 for r in [r1, r2, r3, r4] if r == "bull")
br = sum(1 for r in [r1, r2, r3, r4] if r == "bear")
sx = sum(1 for r in [r1, r2, r3, r4] if r == "sideways")
if b >= 2:
return "bull"
if br >= 2:
return "bear"
if sx >= 2:
return "sideways"
return "sideways"
def detect_detail(self, price: float, idx: int) -> dict:
"""返回详细判定(含各子法的结果 + 三态票数)"""
r1 = self._ema200_slope(idx)
r2 = self._price_vs_ema200(price, idx)
r3 = self._ath_drawdown(price)
r4 = self._price_range()
b = sum(1 for r in [r1, r2, r3, r4] if r == "bull")
br = sum(1 for r in [r1, r2, r3, r4] if r == "bear")
sx = sum(1 for r in [r1, r2, r3, r4] if r == "sideways")
if b >= 2:
final = "bull"
elif br >= 2:
final = "bear"
elif sx >= 2:
final = "sideways"
else:
final = "sideways"
return {
"final": final,
"ema200_slope": r1,
"price_vs_ema200": r2,
"ath_drawdown": r3,
"price_range": r4,
"bull_votes": b,
"bear_votes": br,
"sideways_votes": sx,
"range_pct": self._range_pct(),
}
def _ema200_slope(self, idx: int) -> str:
if idx < 220:
return "unknown"
e200 = self._e200
if e200[idx - 20] == 0:
return "unknown"
slope = (e200[idx] - e200[idx - 20]) / e200[idx - 20]
if slope > 0.002:
return "bull"
if slope < -0.002:
return "bear"
return "sideways"
def _price_vs_ema200(self, price: float, idx: int) -> str:
if idx < 210:
return "unknown"
e = self._e200[idx]
if e == 0:
return "unknown"
return "bull" if price > e else "bear"
def _ath_drawdown(self, price: float) -> str:
if self._ath == 0:
return "unknown"
dd = (price - self._ath) / self._ath
if dd > -0.20:
return "bull"
if dd < -0.40:
return "bear"
return "sideways"
def _price_range(self) -> str:
"""窄幅盘整检测:近 range_period 根 bar 振幅 < 阈值 → sideways"""
if len(self._recent_high) < self._range_period:
return "unknown"
highest = max(self._recent_high)
lowest = min(self._recent_low)
if highest == 0:
return "unknown"
range_pct = (highest - lowest) / highest
return "sideways" if range_pct < self._range_threshold else "unknown"
def _range_pct(self) -> float:
"""当前波动区间百分比(用于展示)"""
if len(self._recent_high) < self._range_period:
return 0.0
highest = max(self._recent_high)
if highest == 0:
return 0.0
return (highest - min(self._recent_low)) / highest
@property
def ema200(self) -> float:
"""最新 EMA200 值"""
return self._e200.current
@property
def ath(self) -> float:
"""历史最高价"""
return self._ath
@property
def ready(self) -> bool:
"""是否已有足够数据做判定"""
return len(self._e200) >= 220
class MultiTimeframeRegime:
"""多周期牛熊并行分析
管理多个 RegimeDetector 实例,每个对应一个周期。
所有检测器共用同一价格序列更新,保证 O(1)/bar 的增量计算。
Attributes:
timeframes: 周期列表(如 ['1h', '4h', '1d', '1w']
detectors: 周期 → RegimeDetector 映射
_prices: 价格序列(所有检测器共享)
"""
def __init__(self, timeframes: list[str]):
self.timeframes = timeframes
self.detectors: dict[str, RegimeDetector] = {
tf: RegimeDetector() for tf in timeframes
}
self._prices: list[float] = []
def update(self, price: float) -> None:
"""更新所有周期检测器"""
self._prices.append(price)
for det in self.detectors.values():
det.update(price)
def detect_all(self) -> dict[str, str]:
"""返回各周期当前判定(需至少 220 根 bar)"""
n = len(self._prices)
if n < 220:
return {tf: "unknown" for tf in self.timeframes}
price = self._prices[-1]
return {
tf: self.detectors[tf].detect(price, n - 1)
for tf in self.timeframes
}
def detect_all_detail(self) -> dict[str, dict]:
"""返回各周期详细判定"""
n = len(self._prices)
if n < 220:
return {tf: {"final": "unknown"} for tf in self.timeframes}
price = self._prices[-1]
return {
tf: self.detectors[tf].detect_detail(price, n - 1)
for tf in self.timeframes
}
def detect_at(self, idx: int) -> dict[str, str]:
"""返回指定 bar 位置各周期的判定"""
if idx < 0 or idx >= len(self._prices):
raise IndexError(f"idx {idx} out of range [0, {len(self._prices)})")
price = self._prices[idx]
result = {}
for tf in self.timeframes:
if idx < 220:
result[tf] = "unknown"
else:
result[tf] = self.detectors[tf].detect(price, idx)
return result
def score(self) -> float:
"""加权综合评分 [-1, 1]
各周期赋分:bull=+1, sideways=0, bear=-1
权重:周期越长权重越大
Returns:
-1.0 (强熊) ~ +1.0 (强牛)
"""
n = len(self._prices)
if n < 220:
return 0.0
# 周期权重:1h=1, 4h=2, 1d=3, 1w=4, 1mon=5
tf_weights = {
"1m": 1, "3m": 1, "5m": 1, "15m": 2, "30m": 2,
"1h": 3, "2h": 3, "4h": 4, "6h": 4, "8h": 4,
"1d": 5, "1w": 6, "1mon": 7,
}
score_map = {"bull": 1.0, "sideways": 0.0, "bear": -1.0, "unknown": 0.0}
regimes = self.detect_all()
total_weight = 0.0
weighted_score = 0.0
for tf, regime in regimes.items():
w = tf_weights.get(tf, 1)
weighted_score += score_map[regime] * w
total_weight += w
if total_weight == 0:
return 0.0
return weighted_score / total_weight
@property
def ready(self) -> bool:
"""所有检测器是否都已就绪"""
return len(self._prices) >= 220