""" ATR波动率突破 — 1h/2h/4h/6h 近半年横向对比回测 用法: source .venv/bin/activate && python example/vol_break_1h_6h.py """ import asyncio import json import sys import time from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Optional _project_root = Path(__file__).resolve().parent.parent.parent if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) from engine.common.base import BaseStrategy, Signal, StrategyConfig from engine.common.models import Kline from engine.common.config import config from engine.backtest.models import BacktestConfig from engine.data import DataService from engine.indicators.incremental import EmaInc, AtrInc, RsiInc, BbInc from engine.example.long_short import LongShortEngine # ── 全局常量 ── SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "SOLUSDT"] TIMEFRAMES = ["1h", "2h", "4h", "6h"] INITIAL = 10_000.0 WARMUP = 150 MAX_CONCURRENCY = 6 NOW = datetime.now(timezone.utc) PERIOD_START = NOW - timedelta(days=182) # 近半年 PERIOD_END = NOW # ════════════════════════════════════════════════════════ # ATR波动率突破策略 # ════════════════════════════════════════════════════════ class VolBreakConfig(StrategyConfig): atr_period: int = 14 squeeze_period: int = 20 squeeze_ratio: float = 0.7 atr_stop: float = 2.0 class VolBreakStrategy(BaseStrategy): strategy_type = "波动率突破" strategy_desc = "ATR(14)收缩至极低后扩张突破 + EMA(10/30)方向确认" def __init__(self, c: VolBreakConfig): super().__init__(c) self.cfg = c self._atr = AtrInc(c.atr_period) self._ema_fast = EmaInc(10) self._ema_slow = EmaInc(30) self._closes: list[float] = [] self._highs: list[float] = [] self._lows: list[float] = [] self._side: str = "" self._entry_price: float = 0.0 self._was_squeezed = False async def on_kline(self, k: Kline) -> Optional[Signal]: self._closes.append(k.close) self._highs.append(k.high) self._lows.append(k.low) self._atr.update(k.high, k.low, k.close) self._ema_fast.update(k.close) self._ema_slow.update(k.close) n = len(self._closes) if n < self.cfg.atr_period + self.cfg.squeeze_period: return None atr_now = self._atr[-1] atr_prev = self._atr[-2] if n >= 2 else 0 ca = atr_now if ca == 0: return None atr_window = [self._atr[i] for i in range(max(0, n - self.cfg.squeeze_period), n) if self._atr[i] > 0] if not atr_window: return None min_atr = min(atr_window) is_squeezed = atr_now < min_atr * (1 + (1 - self.cfg.squeeze_ratio)) atr_expanding = atr_now > atr_prev * 1.05 if atr_prev > 0 else False cf, cs = self._ema_fast[-1], self._ema_slow[-1] trend_up = cf > cs if self._side == "long": self._was_squeezed = False stop = self._entry_price - self.cfg.atr_stop * ca if k.close < stop or (cf < cs and not is_squeezed): self._side = "" return Signal(symbol=self.cfg.symbol, side="SELL", reason="ATR退出", timestamp=k.open_time) elif self._side == "short": self._was_squeezed = False stop = self._entry_price + self.cfg.atr_stop * ca if k.close > stop or (cf > cs and not is_squeezed): self._side = "" return Signal(symbol=self.cfg.symbol, side="BUY", reason="ATR退出", timestamp=k.open_time) else: if is_squeezed: self._was_squeezed = True elif self._was_squeezed and atr_expanding: self._was_squeezed = False if trend_up: self._side = "long"; self._entry_price = k.close return Signal(symbol=self.cfg.symbol, side="BUY", reason="ATR扩张突破做多", timestamp=k.open_time) else: self._side = "short"; self._entry_price = k.close return Signal(symbol=self.cfg.symbol, side="SELL", reason="ATR扩张突破做空", timestamp=k.open_time) return None async def run_one(symbol, interval, start, end): sc = VolBreakConfig(symbol=symbol) bt = BacktestConfig( symbol=symbol, interval=interval, start_time=start, end_time=end, initial_capital=INITIAL, warmup_bars=WARMUP, ) engine = LongShortEngine(bt, db_config=config.db) t0 = time.time() try: r = await engine.run(VolBreakStrategy, sc) elapsed = time.time() - t0 return r, elapsed, None except Exception as ex: elapsed = time.time() - t0 return None, elapsed, str(ex) async def main(): ds = DataService(config.db) await ds.connect() print("正在获取数据范围...") date_ranges = {} for symbol in SYMBOLS: for tf in TIMEFRAMES: try: s, e = await ds.fetch_symbol_date_range(symbol, tf) bar_ms = {"1h": 3_600_000, "2h": 7_200_000, "4h": 14_400_000, "6h": 21_600_000} estimated_bars = int((e - s).total_seconds() * 1000 / bar_ms[tf]) date_ranges[(symbol, tf)] = (s, e, estimated_bars) print(f" {symbol} {tf:<4}: {s.date()} ~ {e.date()} (约{estimated_bars:,}根)") except Exception as ex: print(f" {symbol} {tf:<4}: 获取失败 — {ex}") await ds.close() sem = asyncio.Semaphore(MAX_CONCURRENCY) tasks_info = [] for symbol in SYMBOLS: for tf in TIMEFRAMES: key = (symbol, tf) if key not in date_ranges: continue fs, fe, est_bars = date_ranges[key] actual_start = max(PERIOD_START, fs) actual_end = min(PERIOD_END, fe) if actual_start >= actual_end: continue tasks_info.append({"symbol": symbol, "tf": tf, "start": actual_start, "end": actual_end}) total = len(tasks_info) print(f"\n共 {total} 组回测任务 (ATR波动率突破 × 4币种 × 4时间 × 近半年)") results = [] completed = 0 errors = 0 async def run_one_safe(info): nonlocal completed, errors async with sem: r, elapsed, err = await run_one(info["symbol"], info["tf"], info["start"], info["end"]) completed += 1 if err: errors += 1 status = f"✗ {err[:40]}" elif r is None: errors += 1 status = "✗ 无结果" else: m = r.metrics status = f"✓ {m.annual_return_pct:+.1f}%/yr" print(f" [{completed}/{total}] {info['symbol']} {info['tf']} ({elapsed:.1f}s) {status}", flush=True) row = { "币种": info["symbol"], "时间级别": info["tf"], "日期范围": f"{info['start'].date()}~{info['end'].date()}", } if r is not None: m = r.metrics row.update({ "初始资金": INITIAL, "最终权益": round(m.final_equity, 2), "总收益%": round(m.total_return_pct, 2), "年化收益%": round(m.annual_return_pct, 2), "夏普比率": round(m.sharpe_ratio, 2), "最大回撤%": round(m.max_drawdown_pct, 2), "胜率%": round(m.win_rate * 100, 2), "盈亏比": round(m.profit_factor, 2), "交易次数": m.total_trades, "平均盈亏": round(m.avg_trade_pnl, 2), "最佳盈亏": round(m.best_trade_pnl, 2), "最差盈亏": round(m.worst_trade_pnl, 2), "耗时s": round(elapsed, 1), }) else: row.update({ "初始资金": INITIAL, "最终权益": 0, "总收益%": 0, "年化收益%": 0, "夏普比率": 0, "最大回撤%": 0, "胜率%": 0, "盈亏比": 0, "交易次数": 0, "平均盈亏": 0, "最佳盈亏": 0, "最差盈亏": 0, "耗时s": round(elapsed, 1), "错误": err or "未知错误", }) results.append(row) return row t_total = time.time() await asyncio.gather(*[run_one_safe(info) for info in tasks_info]) total_elapsed = time.time() - t_total print(f"\n全部完成!成功 {total - errors}/{total},错误 {errors},总耗时 {total_elapsed:.0f}s") # ── 打印 ── print() print("═" * 145) print(" ATR波动率突破 — 1h / 2h / 4h / 6h 近半年横向对比") print(" 策略: ATR(14)/squeeze=20x0.7/EMA(10,30) | 本金 $10,000 | 多空双向") print("═" * 145) print() # 按时间级别排序 results.sort(key=lambda x: TIMEFRAMES.index(x["时间级别"])) print(f" {'币种':<10} {'时间':<5} {'总收益%':>8} {'年化%':>8} {'夏普':>7} {'回撤%':>7} {'胜率%':>7} {'盈亏比':>7} {'交易':>6} {'最佳盈亏':>10} {'最差盈亏':>10} {'日期范围':<24}") print(" " + "─" * 140) for r in results: print(f" {r['币种']:<10} {r['时间级别']:<5} {r['总收益%']:>7.1f}% {r['年化收益%']:>7.1f}% {r['夏普比率']:>7.2f} {r['最大回撤%']:>7.1f}% {r['胜率%']:>6.1f}% {r['盈亏比']:>7.2f} {r['交易次数']:>6} {r['最佳盈亏']:>+9.0f} {r['最差盈亏']:>+9.0f} {r['日期范围']:<24}") print() # ── 各时间级别汇总 ── print("═" * 145) print(" ■ 各时间级别排名(按年化收益)") print("═" * 145) for tf in TIMEFRAMES: subset = [r for r in results if r["时间级别"] == tf] if not subset: continue subset.sort(key=lambda x: x.get("年化收益%", -9999), reverse=True) print(f"\n ▲ {tf} 近半年") print(f" {'排名':<5} {'币种':<10} {'总收益%':>8} {'年化%':>8} {'夏普':>7} {'回撤%':>7} {'胜率%':>7} {'盈亏比':>7} {'交易':>6}") print(" " + "─" * 100) for i, r in enumerate(subset): marker = ["🥇", "🥈", "🥉", " 4"][i] print(f" {marker:<5} {r['币种']:<10} {r['总收益%']:>7.1f}% {r['年化收益%']:>7.1f}% {r['夏普比率']:>7.2f} {r['最大回撤%']:>7.1f}% {r['胜率%']:>6.1f}% {r['盈亏比']:>7.2f} {r['交易次数']:>6}") print() print("═" * 145) # ── 保存 JSON ── output_file = _project_root / "engine" / "example" / "vol_break_1h_6h.json" with open(output_file, "w", encoding="utf-8") as f: json.dump({ "config": { "strategy": "ATR波动率突破", "symbols": SYMBOLS, "timeframes": TIMEFRAMES, "period": "近半年", "initial_capital": INITIAL, "warmup_bars": WARMUP, "elapsed_seconds": total_elapsed, "run_time": datetime.now(timezone.utc).isoformat(), }, "results": results, }, f, ensure_ascii=False, indent=2, default=str) print(f" 结果已保存至: {output_file}") if __name__ == "__main__": asyncio.run(main())