Files
trade/engine/example/strategy_optimize.py
Rekey 515e61c517 feat(engine): 添加策略示例集(18 个 Demo)
- backtest_demo.py: 回测基础演示
- strategy_simple.py / three_ema.py / long_short.py: 基础策略(双均线/三均线/多空)
- strategy_optimize*.py (3 版本): 参数优化示例(网格搜索/贝叶斯/遗传算法)
- multi_tf_*.py (4 版本): 多时间框架策略(EMA200/多周期共振/混合信号)
- regime_*.py (4 版本): 市场状态检测(趋势/震荡/波动率区间/全状态)
- cross_section.py: 截面多品种策略
- factor_demo.py: 多因子模型演示
- strategy_battle.py / strategy_more.py: 策略对比与组合
- full_cycle.py: 全流程演示(数据→回测→分析)
- data.py: 数据读取示例
2026-06-12 10:27:04 +08:00

370 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
策略优化对比 — 原始 vs 优化版本
优化点:
EMA v2: 增加 ATR 动态止损 + 趋势过滤(EMA200)
RSI v2: 趋势确认(只在 EMA50 上方做多)+ 放宽入场到 RSI<40
MACD v2: 零轴过滤(MACD>0 时才做多)+ 信号连续性确认
COMBO: 多因子组合(EMA趋势 + RSI回调 + ATR风控)
用法:
source .venv/bin/activate && python example/strategy_optimize.py
"""
import asyncio
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
_project_root = Path(__file__).resolve().parent.parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
from engine.common.base import BaseStrategy, Signal, StrategyConfig
from engine.common.models import Kline
from engine.common.config import config
from engine.backtest import BacktestEngine, BacktestConfig, BacktestResult
from engine.indicators import macd, ema, rsi, bollinger, atr
# ════════════════════════════════════════════════════════
# EMA v2: 双均线 + ATR动态止损 + EMA200趋势过滤
# ════════════════════════════════════════════════════════
class EmaV2Config(StrategyConfig):
fast: int = 20
slow: int = 50
trend: int = 100 # 长期趋势均线
atr_period: int = 14
atr_stop_mult: float = 3.0 # 止损倍率
class EmaV2Strategy(BaseStrategy):
"""EMA双均线优化版:EMA200过滤只做多 + ATR动态止损"""
strategy_type = "ema_v2"
def __init__(self, c: EmaV2Config):
super().__init__(c)
self.cfg = c
self._closes: list[float] = []
self._highs: list[float] = []
self._lows: list[float] = []
self._entry_price: float = 0.0
self._highest_since_entry: float = 0.0
self._in_position = False
async def on_kline(self, k: Kline) -> Optional[Signal]:
self._closes.append(k.close)
self._highs.append(k.high)
self._lows.append(k.low)
n = len(self._closes)
if n < self.cfg.slow + 5:
return None
fast = ema(self._closes, self.cfg.fast)
slow = ema(self._closes, self.cfg.slow)
trd = ema(self._closes, self.cfg.trend)
atr_vals = atr(self._highs, self._lows, self._closes, self.cfg.atr_period)
cur_f, cur_s, cur_trd, cur_atr = fast[-1], slow[-1], trd[-1], atr_vals[-1]
if cur_f == 0 or cur_s == 0 or cur_atr == 0:
return None
is_bull_market = cur_trd > 0 and k.close > cur_trd
# ── 出场:ATR 动态止损 或 EMA死叉 ──
if self._in_position:
self._highest_since_entry = max(self._highest_since_entry, k.high)
stop_price = self._highest_since_entry - self.cfg.atr_stop_mult * cur_atr
death_cross = fast[-2] >= slow[-2] and cur_f < cur_s
if k.close < stop_price or death_cross:
self._in_position = False
reason = f"ATR止损" if k.close < stop_price else "EMA死叉"
return Signal(symbol=self.cfg.symbol, side="SELL", reason=reason, timestamp=k.open_time)
# ── 入场:EMA金叉 + 多头趋势 ──
if not self._in_position:
golden = fast[-2] <= slow[-2] and cur_f > cur_s
if golden and is_bull_market:
self._in_position = True
self._entry_price = k.close
self._highest_since_entry = k.close
return Signal(symbol=self.cfg.symbol, side="BUY",
reason=f"EMA金叉+多头 P={k.close:.0f}>EMA{self.cfg.trend}={cur_trd:.0f}",
timestamp=k.open_time)
return None
# ════════════════════════════════════════════════════════
# RSI v2: 趋势过滤 + 放宽入场
# ════════════════════════════════════════════════════════
class RsiV2Config(StrategyConfig):
period: int = 14
entry_rsi: float = 40.0 # 放宽入场(原 30
exit_rsi: float = 75.0 # 放宽出场(原 70
trend_ema: int = 50 # 趋势过滤
class RsiV2Strategy(BaseStrategy):
"""RSI优化版:EMA50只做多 + RSI<40入场 + RSI>75出场"""
strategy_type = "rsi_v2"
def __init__(self, c: RsiV2Config):
super().__init__(c)
self.cfg = c
self._closes: list[float] = []
self._in_position = False
async def on_kline(self, k: Kline) -> Optional[Signal]:
self._closes.append(k.close)
n = len(self._closes)
if n < self.cfg.trend_ema + 5:
return None
vals = rsi(self._closes, self.cfg.period)
trd = ema(self._closes, self.cfg.trend_ema)
v, cur_trd = vals[-1], trd[-1]
if v == 0 or cur_trd == 0:
return None
is_bull = k.close > cur_trd
if self._in_position:
if v > self.cfg.exit_rsi or not is_bull:
self._in_position = False
reason = f"RSI过热({v:.0f})" if v > self.cfg.exit_rsi else f"跌破EMA{self.cfg.trend_ema}"
return Signal(symbol=self.cfg.symbol, side="SELL", reason=reason, timestamp=k.open_time)
if not self._in_position:
if v < self.cfg.entry_rsi and is_bull:
self._in_position = True
return Signal(symbol=self.cfg.symbol, side="BUY",
reason=f"RSI回调({v:.0f}) 多头确认 P>{cur_trd:.0f}",
timestamp=k.open_time)
return None
# ════════════════════════════════════════════════════════
# MACD v2: 零轴过滤 + 信号线确认
# ════════════════════════════════════════════════════════
class MacdV2Config(StrategyConfig):
fast: int = 12
slow: int = 26
signal: int = 9
class MacdV2Strategy(BaseStrategy):
"""MACD优化版:只做MACD>0时的金叉,过滤零轴下方假信号"""
strategy_type = "macd_v2"
def __init__(self, c: MacdV2Config):
super().__init__(c)
self.cfg = c
self._closes: list[float] = []
async def on_kline(self, k: Kline) -> Optional[Signal]:
self._closes.append(k.close)
mline, sline, _ = macd(self._closes, self.cfg.fast, self.cfg.slow, self.cfg.signal)
if len(mline) < 4:
return None
cur_m, cur_s = mline[-1], sline[-1]
prev_m, prev_s = mline[-2], sline[-2]
if cur_m == 0:
return None
# 金叉 + MACD线在零轴上方(多头确认)→ 买入
golden = prev_m <= prev_s and cur_m > cur_s
if golden and cur_m > 0:
return Signal(symbol=self.cfg.symbol, side="BUY",
reason=f"零轴上金叉 MACD={cur_m:.1f}", timestamp=k.open_time)
# 死叉 → 卖出
death = prev_m >= prev_s and cur_m < cur_s
if death:
return Signal(symbol=self.cfg.symbol, side="SELL",
reason=f"MACD死叉", timestamp=k.open_time)
return None
# ════════════════════════════════════════════════════════
# COMBO: 多因子组合
# ════════════════════════════════════════════════════════
class ComboConfig(StrategyConfig):
ema_trend: int = 50 # 趋势过滤
rsi_period: int = 14
rsi_entry: float = 45.0
rsi_exit: float = 72.0
class ComboStrategy(BaseStrategy):
"""多因子组合:EMA50趋势 + RSI入场 + 趋势反转出场"""
strategy_type = "combo"
def __init__(self, c: ComboConfig):
super().__init__(c)
self.cfg = c
self._closes: list[float] = []
self._in_position = False
async def on_kline(self, k: Kline) -> Optional[Signal]:
self._closes.append(k.close)
n = len(self._closes)
if n < self.cfg.ema_trend + 5:
return None
vals = rsi(self._closes, self.cfg.rsi_period)
trd = ema(self._closes, self.cfg.ema_trend)
v, cur_trd, prev_trd = vals[-1], trd[-1], trd[-2]
if v == 0 or cur_trd == 0:
return None
trend_up = cur_trd > prev_trd # EMA上行
price_above_trend = k.close > cur_trd
if self._in_position:
if v > self.cfg.rsi_exit or not price_above_trend:
self._in_position = False
reason = f"RSI过热{v:.0f}" if v > self.cfg.rsi_exit else "趋势转弱"
return Signal(symbol=self.cfg.symbol, side="SELL", reason=reason, timestamp=k.open_time)
if not self._in_position:
if v < self.cfg.rsi_entry and trend_up and price_above_trend:
self._in_position = True
return Signal(symbol=self.cfg.symbol, side="BUY",
reason=f"多头共振 RSI={v:.0f} EMA↑ P>{cur_trd:.0f}",
timestamp=k.open_time)
return None
# ════════════════════════════════════════════════════════
# 运行
# ════════════════════════════════════════════════════════
OPT_STRATEGIES = [
("EMA v2 趋势+止损", EmaV2Strategy, EmaV2Config()),
("RSI v2 趋势过滤", RsiV2Strategy, RsiV2Config()),
("MACD v2 零轴过滤", MacdV2Strategy, MacdV2Config()),
("COMBO 多因子", ComboStrategy, ComboConfig()),
]
SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "SOLUSDT"]
# 原始策略结果(从上一次运行提取,用于对比)
ORIGINAL = {
("BTCUSDT", "EMA双均线"): (45.5, 0.74, -31.6, 42, 26.2),
("BTCUSDT", "RSI超卖反弹"): (45.4, 0.74, -26.0, 20, 70.0),
("BTCUSDT", "MACD金叉死叉"): (-21.3, -0.16, -41.7, 169, 32.5),
("ETHUSDT", "EMA双均线"): (24.4, 0.47, -54.8, 41, 24.4),
("ETHUSDT", "RSI超卖反弹"): (-42.8, -0.28, -66.1, 18, 61.1),
("ETHUSDT", "MACD金叉死叉"): (47.6, 0.64, -41.5, 162, 34.0),
("BNBUSDT", "EMA双均线"): (52.0, 0.71, -39.8, 41, 39.0),
("BNBUSDT", "RSI超卖反弹"): (67.4, 0.93, -34.2, 18, 77.8),
("BNBUSDT", "MACD金叉死叉"): (4.4, 0.24, -38.1, 177, 35.0),
("SOLUSDT", "EMA双均线"): (27.8, 0.49, -39.5, 45, 40.0),
("SOLUSDT", "RSI超卖反弹"): (-5.3, 0.24, -42.8, 16, 56.2),
("SOLUSDT", "MACD金叉死叉"): (-15.9, 0.17, -58.6, 169, 34.9),
}
async def run_one(symbol, s_name, s_cls, s_cfg):
bt = BacktestConfig(
symbol=symbol, interval="4h",
start_time=datetime(2024, 1, 1), end_time=datetime(2026, 1, 1),
initial_capital=10_000.0,
)
s_cfg.symbol = symbol
s_cfg.name = f"{s_name}_{symbol}"
engine = BacktestEngine(bt, db_config=config.db)
return await engine.run(s_cls, s_cfg)
async def main():
print()
print("" * 115)
print(" 策略优化对比 — 原始 vs 优化版 | 4h 周期 | 2024-2026")
print("" * 115)
opt_results: dict[tuple[str, str], BacktestResult] = {}
for symbol in SYMBOLS:
for s_name, s_cls, s_cfg in OPT_STRATEGIES:
cfg = s_cfg.model_copy()
r = await run_one(symbol, s_name, s_cls, cfg)
opt_results[(symbol, s_name)] = r
# ── 打印对比表 ──
print()
print(f" {'币种':<10} {'策略':<20} {'类型':<10} {'收益%':>7} {'夏普':>6} {'回撤%':>7} {'交易':>5} {'胜率%':>6} Δ收益")
print("" * 115)
mapping = {
"EMA v2 趋势+止损": "EMA双均线",
"RSI v2 趋势过滤": "RSI超卖反弹",
"MACD v2 零轴过滤": "MACD金叉死叉",
}
for symbol in SYMBOLS:
for opt_name, orig_name in mapping.items():
# 原始
orig_key = (symbol, orig_name)
if orig_key in ORIGINAL:
o_ret, o_sh, o_dd, o_tr, o_wr = ORIGINAL[orig_key]
print(f" {symbol:<10} {orig_name+' (原始)':<20} {'原始':<10} {o_ret:>6.1f}% {o_sh:>6.2f} {o_dd:>6.1f}% {o_tr:>5} {o_wr:>5.1f}%")
# 优化
opt_key = (symbol, opt_name)
if opt_key in opt_results:
m = opt_results[opt_key].metrics
delta = m.total_return_pct - o_ret if orig_key in ORIGINAL else 0
print(f" {symbol:<10} {opt_name+' (优化)':<20} {'优化':<10} {m.total_return_pct:>6.1f}% {m.sharpe_ratio:>6.2f} {m.max_drawdown_pct:>6.1f}% {m.total_trades:>5} {m.win_rate*100:>5.1f}% {delta:+.1f}%")
print()
# COMBO
combo_key = (symbol, "COMBO 多因子")
if combo_key in opt_results:
m = opt_results[combo_key].metrics
print(f" {symbol:<10} {'COMBO 多因子':<20} {'新增':<10} {m.total_return_pct:>6.1f}% {m.sharpe_ratio:>6.2f} {m.max_drawdown_pct:>6.1f}% {m.total_trades:>5} {m.win_rate*100:>5.1f}%")
print()
# ── 优化效果汇总 ──
print("" * 115)
print("\n ■ 优化效果汇总 (平均 Δ收益):")
improvements = []
for (symbol, opt_name), r in opt_results.items():
orig_name = mapping.get(opt_name)
if orig_name and (symbol, orig_name) in ORIGINAL:
delta = r.metrics.total_return_pct - ORIGINAL[(symbol, orig_name)][0]
improvements.append((f"{symbol} {opt_name}", delta, r.metrics.sharpe_ratio))
improvements.sort(key=lambda x: x[1], reverse=True)
for name, delta, sh in improvements:
print(f" {name:<30} Δ收益={delta:+.1f}% 夏普={sh:.2f}")
print("\n ■ 最佳组合 TOP 5:")
all_results = [(f"{s} {n}", r) for (s, n), r in opt_results.items()]
all_results.sort(key=lambda x: x[1].metrics.sharpe_ratio, reverse=True)
for i, (name, r) in enumerate(all_results[:5]):
m = r.metrics
print(f" {i+1}. {name:<30} 夏普={m.sharpe_ratio:.2f} 收益={m.total_return_pct:+.1f}% 回撤={m.max_drawdown_pct:.1f}% 胜率={m.win_rate*100:.0f}%")
print("\n" * 115)
if __name__ == "__main__":
asyncio.run(main())