Files
trade/engine/example/vol_break_1h_6h.py
Rekey 0cd2cbbb79 feat(engine): 新增 2h/6h 与 1h 策略对比回测
- comparison_2h_6h: 9 策略 × 4 币种 × 2 周期 × 4 数据量 = 288 次回测
  - 包含海龟、超级趋势、MACD、布林收缩、三均线、RSI 回归、
    ATR 波动率突破、EMA 多空、牛熊自适应
  - 结论:6h 夏普显著优于 2h(69% 组合),ATR 策略霸榜
  - 自动生成 Markdown 回测报告

- vol_break_1h_6h: ATR 波动率突破 × 1h/2h/4h/6h 近半年对比
2026-06-14 00:15:16 +08:00

286 lines
11 KiB
Python

"""
ATR波动率突破 — 1h/2h/4h/6h 近半年横向对比回测
用法:
source .venv/bin/activate && python example/vol_break_1h_6h.py
"""
import asyncio
import json
import sys
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
_project_root = Path(__file__).resolve().parent.parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
from engine.common.base import BaseStrategy, Signal, StrategyConfig
from engine.common.models import Kline
from engine.common.config import config
from engine.backtest.models import BacktestConfig
from engine.data import DataService
from engine.indicators.incremental import EmaInc, AtrInc, RsiInc, BbInc
from engine.example.long_short import LongShortEngine
# ── 全局常量 ──
SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "SOLUSDT"]
TIMEFRAMES = ["1h", "2h", "4h", "6h"]
INITIAL = 10_000.0
WARMUP = 150
MAX_CONCURRENCY = 6
NOW = datetime.now(timezone.utc)
PERIOD_START = NOW - timedelta(days=182) # 近半年
PERIOD_END = NOW
# ════════════════════════════════════════════════════════
# ATR波动率突破策略
# ════════════════════════════════════════════════════════
class VolBreakConfig(StrategyConfig):
atr_period: int = 14
squeeze_period: int = 20
squeeze_ratio: float = 0.7
atr_stop: float = 2.0
class VolBreakStrategy(BaseStrategy):
strategy_type = "波动率突破"
strategy_desc = "ATR(14)收缩至极低后扩张突破 + EMA(10/30)方向确认"
def __init__(self, c: VolBreakConfig):
super().__init__(c)
self.cfg = c
self._atr = AtrInc(c.atr_period)
self._ema_fast = EmaInc(10)
self._ema_slow = EmaInc(30)
self._closes: list[float] = []
self._highs: list[float] = []
self._lows: list[float] = []
self._side: str = ""
self._entry_price: float = 0.0
self._was_squeezed = False
async def on_kline(self, k: Kline) -> Optional[Signal]:
self._closes.append(k.close)
self._highs.append(k.high)
self._lows.append(k.low)
self._atr.update(k.high, k.low, k.close)
self._ema_fast.update(k.close)
self._ema_slow.update(k.close)
n = len(self._closes)
if n < self.cfg.atr_period + self.cfg.squeeze_period:
return None
atr_now = self._atr[-1]
atr_prev = self._atr[-2] if n >= 2 else 0
ca = atr_now
if ca == 0:
return None
atr_window = [self._atr[i] for i in range(max(0, n - self.cfg.squeeze_period), n) if self._atr[i] > 0]
if not atr_window:
return None
min_atr = min(atr_window)
is_squeezed = atr_now < min_atr * (1 + (1 - self.cfg.squeeze_ratio))
atr_expanding = atr_now > atr_prev * 1.05 if atr_prev > 0 else False
cf, cs = self._ema_fast[-1], self._ema_slow[-1]
trend_up = cf > cs
if self._side == "long":
self._was_squeezed = False
stop = self._entry_price - self.cfg.atr_stop * ca
if k.close < stop or (cf < cs and not is_squeezed):
self._side = ""
return Signal(symbol=self.cfg.symbol, side="SELL", reason="ATR退出", timestamp=k.open_time)
elif self._side == "short":
self._was_squeezed = False
stop = self._entry_price + self.cfg.atr_stop * ca
if k.close > stop or (cf > cs and not is_squeezed):
self._side = ""
return Signal(symbol=self.cfg.symbol, side="BUY", reason="ATR退出", timestamp=k.open_time)
else:
if is_squeezed:
self._was_squeezed = True
elif self._was_squeezed and atr_expanding:
self._was_squeezed = False
if trend_up:
self._side = "long"; self._entry_price = k.close
return Signal(symbol=self.cfg.symbol, side="BUY", reason="ATR扩张突破做多", timestamp=k.open_time)
else:
self._side = "short"; self._entry_price = k.close
return Signal(symbol=self.cfg.symbol, side="SELL", reason="ATR扩张突破做空", timestamp=k.open_time)
return None
async def run_one(symbol, interval, start, end):
sc = VolBreakConfig(symbol=symbol)
bt = BacktestConfig(
symbol=symbol, interval=interval,
start_time=start, end_time=end,
initial_capital=INITIAL, warmup_bars=WARMUP,
)
engine = LongShortEngine(bt, db_config=config.db)
t0 = time.time()
try:
r = await engine.run(VolBreakStrategy, sc)
elapsed = time.time() - t0
return r, elapsed, None
except Exception as ex:
elapsed = time.time() - t0
return None, elapsed, str(ex)
async def main():
ds = DataService(config.db)
await ds.connect()
print("正在获取数据范围...")
date_ranges = {}
for symbol in SYMBOLS:
for tf in TIMEFRAMES:
try:
s, e = await ds.fetch_symbol_date_range(symbol, tf)
bar_ms = {"1h": 3_600_000, "2h": 7_200_000, "4h": 14_400_000, "6h": 21_600_000}
estimated_bars = int((e - s).total_seconds() * 1000 / bar_ms[tf])
date_ranges[(symbol, tf)] = (s, e, estimated_bars)
print(f" {symbol} {tf:<4}: {s.date()} ~ {e.date()} (约{estimated_bars:,}根)")
except Exception as ex:
print(f" {symbol} {tf:<4}: 获取失败 — {ex}")
await ds.close()
sem = asyncio.Semaphore(MAX_CONCURRENCY)
tasks_info = []
for symbol in SYMBOLS:
for tf in TIMEFRAMES:
key = (symbol, tf)
if key not in date_ranges:
continue
fs, fe, est_bars = date_ranges[key]
actual_start = max(PERIOD_START, fs)
actual_end = min(PERIOD_END, fe)
if actual_start >= actual_end:
continue
tasks_info.append({"symbol": symbol, "tf": tf, "start": actual_start, "end": actual_end})
total = len(tasks_info)
print(f"\n{total} 组回测任务 (ATR波动率突破 × 4币种 × 4时间 × 近半年)")
results = []
completed = 0
errors = 0
async def run_one_safe(info):
nonlocal completed, errors
async with sem:
r, elapsed, err = await run_one(info["symbol"], info["tf"], info["start"], info["end"])
completed += 1
if err:
errors += 1
status = f"{err[:40]}"
elif r is None:
errors += 1
status = "✗ 无结果"
else:
m = r.metrics
status = f"{m.annual_return_pct:+.1f}%/yr"
print(f" [{completed}/{total}] {info['symbol']} {info['tf']} ({elapsed:.1f}s) {status}", flush=True)
row = {
"币种": info["symbol"],
"时间级别": info["tf"],
"日期范围": f"{info['start'].date()}~{info['end'].date()}",
}
if r is not None:
m = r.metrics
row.update({
"初始资金": INITIAL,
"最终权益": round(m.final_equity, 2),
"总收益%": round(m.total_return_pct, 2),
"年化收益%": round(m.annual_return_pct, 2),
"夏普比率": round(m.sharpe_ratio, 2),
"最大回撤%": round(m.max_drawdown_pct, 2),
"胜率%": round(m.win_rate * 100, 2),
"盈亏比": round(m.profit_factor, 2),
"交易次数": m.total_trades,
"平均盈亏": round(m.avg_trade_pnl, 2),
"最佳盈亏": round(m.best_trade_pnl, 2),
"最差盈亏": round(m.worst_trade_pnl, 2),
"耗时s": round(elapsed, 1),
})
else:
row.update({
"初始资金": INITIAL, "最终权益": 0, "总收益%": 0, "年化收益%": 0,
"夏普比率": 0, "最大回撤%": 0, "胜率%": 0, "盈亏比": 0,
"交易次数": 0, "平均盈亏": 0, "最佳盈亏": 0, "最差盈亏": 0,
"耗时s": round(elapsed, 1), "错误": err or "未知错误",
})
results.append(row)
return row
t_total = time.time()
await asyncio.gather(*[run_one_safe(info) for info in tasks_info])
total_elapsed = time.time() - t_total
print(f"\n全部完成!成功 {total - errors}/{total},错误 {errors},总耗时 {total_elapsed:.0f}s")
# ── 打印 ──
print()
print("" * 145)
print(" ATR波动率突破 — 1h / 2h / 4h / 6h 近半年横向对比")
print(" 策略: ATR(14)/squeeze=20x0.7/EMA(10,30) | 本金 $10,000 | 多空双向")
print("" * 145)
print()
# 按时间级别排序
results.sort(key=lambda x: TIMEFRAMES.index(x["时间级别"]))
print(f" {'币种':<10} {'时间':<5} {'总收益%':>8} {'年化%':>8} {'夏普':>7} {'回撤%':>7} {'胜率%':>7} {'盈亏比':>7} {'交易':>6} {'最佳盈亏':>10} {'最差盈亏':>10} {'日期范围':<24}")
print(" " + "" * 140)
for r in results:
print(f" {r['币种']:<10} {r['时间级别']:<5} {r['总收益%']:>7.1f}% {r['年化收益%']:>7.1f}% {r['夏普比率']:>7.2f} {r['最大回撤%']:>7.1f}% {r['胜率%']:>6.1f}% {r['盈亏比']:>7.2f} {r['交易次数']:>6} {r['最佳盈亏']:>+9.0f} {r['最差盈亏']:>+9.0f} {r['日期范围']:<24}")
print()
# ── 各时间级别汇总 ──
print("" * 145)
print(" ■ 各时间级别排名(按年化收益)")
print("" * 145)
for tf in TIMEFRAMES:
subset = [r for r in results if r["时间级别"] == tf]
if not subset:
continue
subset.sort(key=lambda x: x.get("年化收益%", -9999), reverse=True)
print(f"\n{tf} 近半年")
print(f" {'排名':<5} {'币种':<10} {'总收益%':>8} {'年化%':>8} {'夏普':>7} {'回撤%':>7} {'胜率%':>7} {'盈亏比':>7} {'交易':>6}")
print(" " + "" * 100)
for i, r in enumerate(subset):
marker = ["🥇", "🥈", "🥉", " 4"][i]
print(f" {marker:<5} {r['币种']:<10} {r['总收益%']:>7.1f}% {r['年化收益%']:>7.1f}% {r['夏普比率']:>7.2f} {r['最大回撤%']:>7.1f}% {r['胜率%']:>6.1f}% {r['盈亏比']:>7.2f} {r['交易次数']:>6}")
print()
print("" * 145)
# ── 保存 JSON ──
output_file = _project_root / "engine" / "example" / "vol_break_1h_6h.json"
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"config": {
"strategy": "ATR波动率突破",
"symbols": SYMBOLS,
"timeframes": TIMEFRAMES,
"period": "近半年",
"initial_capital": INITIAL,
"warmup_bars": WARMUP,
"elapsed_seconds": total_elapsed,
"run_time": datetime.now(timezone.utc).isoformat(),
},
"results": results,
}, f, ensure_ascii=False, indent=2, default=str)
print(f" 结果已保存至: {output_file}")
if __name__ == "__main__":
asyncio.run(main())