Files
trade/engine/data/service.py
T
Rekey edc50e8809 feat: 新增2h/6h时间框架支持,策略重构为增量指标计算
- 数据层: build_aggregates_sql 新增 2h/6h 聚合视图,默认起始时间调整为 2017-05
- 模型层: KlineInterval 类型扩展 2h/6h,DataService 新增对应表名和毫秒映射
- 指标层: 新增 incremental.py 增量指标模块 (EmaInc/AtrInc/RsiInc/BbInc),O(1) per bar
- 策略重构: long_short.py 和 regime_all.py 从批量 ema/atr 迁移至增量指标,避免每 bar 重复全量计算
- regime 探测器: RegimeDetector3 改为增量 EMA200,detect() 接口简化
- 回测扩展: regime_timeframe_comparison 从 4h/1d 扩展至 2h/4h/6h/1d
- 新增示例: multi_strategy_report, vol_break_compare/periods, intraday_explore, top3_trades 等分析脚本
2026-06-13 19:30:25 +08:00

345 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
K 线数据服务 — 从 TimescaleDB 读取历史 K 线数据,为回测引擎提供数据源
用法:
from engine.common.config import config
from engine.data import DataService
ds = DataService(config.db)
await ds.connect()
# 获取 BTCUSDT 1h K 线,按时间范围过滤
klines = await ds.fetch_klines(
symbol="BTCUSDT",
interval="1h",
start_time=datetime(2025, 1, 1),
end_time=datetime(2026, 1, 1),
)
await ds.close()
"""
import asyncio
from datetime import datetime, timezone
from decimal import Decimal
from typing import AsyncGenerator
import asyncpg
from ..common.config import DBConfig
from ..common.models import Kline, KlineInterval
# ── 周期 → 表名映射 ──
INTERVAL_TO_TABLE: dict[KlineInterval, str] = {
"1m": "klines",
"5m": "klines_5m",
"15m": "klines_15m",
"30m": "klines_30m",
"1h": "klines_1h",
"2h": "klines_2h",
"4h": "klines_4h",
"6h": "klines_6h",
"1d": "klines_1d",
"1w": "klines_1w",
}
# ── 周期毫秒数 ──
INTERVAL_MS: dict[KlineInterval, int] = {
"1m": 60_000,
"5m": 300_000,
"15m": 900_000,
"30m": 1_800_000,
"1h": 3_600_000,
"2h": 7_200_000,
"4h": 14_400_000,
"6h": 21_600_000,
"1d": 86_400_000,
"1w": 604_800_000,
}
DEFAULT_BATCH_SIZE = 5000
def dt_to_unix_ms(dt: datetime) -> float:
"""datetime → Unix 毫秒时间戳"""
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.timestamp() * 1000
def _to_float(val) -> float:
"""将 Decimal / int 等数值转为 float"""
if val is None:
return 0.0
if isinstance(val, Decimal):
return float(val)
return float(val)
class DataService:
"""K 线数据服务
封装 TimescaleDB 中 K 线数据的查询逻辑,
将数据库行转换为 engine.common.models.Kline 模型,
供回测引擎消费。
"""
def __init__(self, db_config: DBConfig, pool_size: int = 4):
self._db_config = db_config
self._pool_size = pool_size
self._pool: asyncpg.Pool | None = None
self._col_cache: dict[str, set[str]] = {}
# ── 生命周期 ──
@property
def dsn(self) -> str:
db = self._db_config
return f"postgresql://{db.user}:{db.password}@{db.host}:{db.port}/{db.name}"
async def connect(self) -> None:
"""建立数据库连接池"""
self._pool = await asyncpg.create_pool(
dsn=self.dsn,
min_size=1,
max_size=self._pool_size,
)
async def close(self) -> None:
"""关闭连接池"""
if self._pool:
await self._pool.close()
self._pool = None
@property
def is_connected(self) -> bool:
return self._pool is not None
# ── 元数据查询 ──
async def fetch_available_symbols(
self, interval: KlineInterval = "1m"
) -> list[str]:
"""获取指定周期下所有有数据的交易对"""
table = INTERVAL_TO_TABLE[interval]
async with self._pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT DISTINCT symbol FROM {table} ORDER BY symbol"
)
return [r["symbol"] for r in rows]
async def fetch_symbol_date_range(
self, symbol: str, interval: KlineInterval
) -> tuple[datetime, datetime]:
"""获取指定交易对 + 周期的数据起止时间"""
table = INTERVAL_TO_TABLE[interval]
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
f"""
SELECT MIN(time) AS min_time, MAX(time) AS max_time
FROM {table}
WHERE symbol = $1
""",
symbol,
)
if row is None or row["min_time"] is None:
raise ValueError(f"无数据: {symbol} {interval}")
return row["min_time"], row["max_time"]
async def fetch_klines_count(
self,
symbol: str,
interval: KlineInterval,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> int:
"""获取指定条件的 K 线条数(用于预判数据量)"""
table = INTERVAL_TO_TABLE[interval]
conditions = ["symbol = $1", "interval = $2"]
params: list = [symbol, interval]
idx = 3
if start_time is not None:
conditions.append(f"time >= ${idx}")
params.append(start_time)
idx += 1
if end_time is not None:
conditions.append(f"time < ${idx}")
params.append(end_time)
idx += 1
where = " AND ".join(conditions)
async with self._pool.acquire() as conn:
count = await conn.fetchval(
f"SELECT COUNT(*) FROM {table} WHERE {where}", *params
)
return count
# ── 核心查询 ──
async def fetch_klines(
self,
symbol: str,
interval: KlineInterval,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int = 1000,
offset: int = 0,
) -> list[Kline]:
"""获取 K 线数据,返回 Pydantic 模型列表
Args:
symbol: 交易对(如 BTCUSDT
interval: K 线周期
start_time: 起始时间(包含)
end_time: 结束时间(不包含)
limit: 最大返回条数
offset: 分页偏移
Returns:
按时间升序排列的 Kline 列表
"""
table = INTERVAL_TO_TABLE[interval]
interval_ms = INTERVAL_MS[interval]
conditions = ["symbol = $1", "interval = $2"]
params: list = [symbol, interval]
idx = 3
if start_time is not None:
conditions.append(f"time >= ${idx}")
params.append(start_time)
idx += 1
if end_time is not None:
conditions.append(f"time < ${idx}")
params.append(end_time)
idx += 1
where = " AND ".join(conditions)
cols = await self._get_columns(table)
select_cols = [
"time", "exchange", "symbol", "interval",
"open", "high", "low", "close", "volume",
]
for extra in (
"trade_count", "quote_volume", "taker_buy_base_vol",
"taker_buy_quote_vol", "is_closed",
):
if extra in cols:
select_cols.append(extra)
async with self._pool.acquire() as conn:
rows = await conn.fetch(
f"""
SELECT {', '.join(select_cols)}
FROM {table}
WHERE {where}
ORDER BY time ASC
LIMIT ${idx} OFFSET ${idx + 1}
""",
*params,
limit,
offset,
)
return [self._row_to_kline(r, interval, interval_ms) for r in rows]
async def stream_klines(
self,
symbol: str,
interval: KlineInterval,
start_time: datetime | None = None,
end_time: datetime | None = None,
batch_size: int = DEFAULT_BATCH_SIZE,
) -> AsyncGenerator[list[Kline], None]:
"""流式获取 K 线数据,适合大数据集
每次返回一批 K 线列表,避免一次性加载过多数据到内存。
Yields:
每批 Kline 列表(按时间升序)
"""
offset = 0
while True:
batch = await self.fetch_klines(
symbol=symbol,
interval=interval,
start_time=start_time,
end_time=end_time,
limit=batch_size,
offset=offset,
)
if not batch:
break
yield batch
offset += len(batch)
async def fetch_multi_klines(
self,
symbols: list[str],
interval: KlineInterval,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int = 1000,
) -> dict[str, list[Kline]]:
"""批量获取多个交易对的 K 线
Returns:
{symbol: [Kline, ...]} 字典
"""
tasks = [
self.fetch_klines(
symbol=sym,
interval=interval,
start_time=start_time,
end_time=end_time,
limit=limit,
)
for sym in symbols
]
results = await asyncio.gather(*tasks)
return dict(zip(symbols, results))
# ── 内部方法 ──
async def _get_columns(self, table: str) -> set[str]:
"""获取表的列名集合(带缓存)"""
if table not in self._col_cache:
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = $1
""",
table,
)
self._col_cache[table] = {r["column_name"] for r in rows}
return self._col_cache[table]
@staticmethod
def _row_to_kline(
row: asyncpg.Record, interval: KlineInterval, interval_ms: int
) -> Kline:
"""将数据库行转换为 Kline 模型"""
open_time = dt_to_unix_ms(row["time"])
return Kline(
exchange=row["exchange"],
symbol=row["symbol"],
interval=interval,
openTime=open_time,
closeTime=open_time + interval_ms,
open=_to_float(row["open"]),
high=_to_float(row["high"]),
low=_to_float(row["low"]),
close=_to_float(row["close"]),
volume=_to_float(row["volume"]),
quoteVolume=_to_float(row.get("quote_volume", 0)),
takerBuyBaseVol=_to_float(row.get("taker_buy_base_vol", 0)),
takerBuyQuoteVol=_to_float(row.get("taker_buy_quote_vol", 0)),
tradeCount=int(row.get("trade_count") or 0),
isClosed=bool(row.get("is_closed", True)),
)