""" K 线数据服务 — 从 TimescaleDB 读取历史 K 线数据,为回测引擎提供数据源 用法: from engine.common.config import config from engine.data import DataService ds = DataService(config.db) await ds.connect() # 获取 BTCUSDT 1h K 线,按时间范围过滤 klines = await ds.fetch_klines( symbol="BTCUSDT", interval="1h", start_time=datetime(2025, 1, 1), end_time=datetime(2026, 1, 1), ) await ds.close() """ import asyncio from datetime import datetime, timezone from decimal import Decimal from typing import AsyncGenerator import asyncpg from ..common.config import DBConfig from ..common.models import Kline, KlineInterval # ── 周期 → 表名映射 ── INTERVAL_TO_TABLE: dict[KlineInterval, str] = { "1m": "klines", "5m": "klines_5m", "15m": "klines_15m", "30m": "klines_30m", "1h": "klines_1h", "4h": "klines_4h", "1d": "klines_1d", "1w": "klines_1w", } # ── 周期毫秒数 ── INTERVAL_MS: dict[KlineInterval, int] = { "1m": 60_000, "5m": 300_000, "15m": 900_000, "30m": 1_800_000, "1h": 3_600_000, "4h": 14_400_000, "1d": 86_400_000, "1w": 604_800_000, } DEFAULT_BATCH_SIZE = 5000 def dt_to_unix_ms(dt: datetime) -> float: """datetime → Unix 毫秒时间戳""" if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.timestamp() * 1000 def _to_float(val) -> float: """将 Decimal / int 等数值转为 float""" if val is None: return 0.0 if isinstance(val, Decimal): return float(val) return float(val) class DataService: """K 线数据服务 封装 TimescaleDB 中 K 线数据的查询逻辑, 将数据库行转换为 engine.common.models.Kline 模型, 供回测引擎消费。 """ def __init__(self, db_config: DBConfig, pool_size: int = 4): self._db_config = db_config self._pool_size = pool_size self._pool: asyncpg.Pool | None = None self._col_cache: dict[str, set[str]] = {} # ── 生命周期 ── @property def dsn(self) -> str: db = self._db_config return f"postgresql://{db.user}:{db.password}@{db.host}:{db.port}/{db.name}" async def connect(self) -> None: """建立数据库连接池""" self._pool = await asyncpg.create_pool( dsn=self.dsn, min_size=1, max_size=self._pool_size, ) async def close(self) -> None: """关闭连接池""" if self._pool: await self._pool.close() self._pool = None @property def is_connected(self) -> bool: return self._pool is not None # ── 元数据查询 ── async def fetch_available_symbols( self, interval: KlineInterval = "1m" ) -> list[str]: """获取指定周期下所有有数据的交易对""" table = INTERVAL_TO_TABLE[interval] async with self._pool.acquire() as conn: rows = await conn.fetch( f"SELECT DISTINCT symbol FROM {table} ORDER BY symbol" ) return [r["symbol"] for r in rows] async def fetch_symbol_date_range( self, symbol: str, interval: KlineInterval ) -> tuple[datetime, datetime]: """获取指定交易对 + 周期的数据起止时间""" table = INTERVAL_TO_TABLE[interval] async with self._pool.acquire() as conn: row = await conn.fetchrow( f""" SELECT MIN(time) AS min_time, MAX(time) AS max_time FROM {table} WHERE symbol = $1 """, symbol, ) if row is None or row["min_time"] is None: raise ValueError(f"无数据: {symbol} {interval}") return row["min_time"], row["max_time"] async def fetch_klines_count( self, symbol: str, interval: KlineInterval, start_time: datetime | None = None, end_time: datetime | None = None, ) -> int: """获取指定条件的 K 线条数(用于预判数据量)""" table = INTERVAL_TO_TABLE[interval] conditions = ["symbol = $1", "interval = $2"] params: list = [symbol, interval] idx = 3 if start_time is not None: conditions.append(f"time >= ${idx}") params.append(start_time) idx += 1 if end_time is not None: conditions.append(f"time < ${idx}") params.append(end_time) idx += 1 where = " AND ".join(conditions) async with self._pool.acquire() as conn: count = await conn.fetchval( f"SELECT COUNT(*) FROM {table} WHERE {where}", *params ) return count # ── 核心查询 ── async def fetch_klines( self, symbol: str, interval: KlineInterval, start_time: datetime | None = None, end_time: datetime | None = None, limit: int = 1000, offset: int = 0, ) -> list[Kline]: """获取 K 线数据,返回 Pydantic 模型列表 Args: symbol: 交易对(如 BTCUSDT) interval: K 线周期 start_time: 起始时间(包含) end_time: 结束时间(不包含) limit: 最大返回条数 offset: 分页偏移 Returns: 按时间升序排列的 Kline 列表 """ table = INTERVAL_TO_TABLE[interval] interval_ms = INTERVAL_MS[interval] conditions = ["symbol = $1", "interval = $2"] params: list = [symbol, interval] idx = 3 if start_time is not None: conditions.append(f"time >= ${idx}") params.append(start_time) idx += 1 if end_time is not None: conditions.append(f"time < ${idx}") params.append(end_time) idx += 1 where = " AND ".join(conditions) cols = await self._get_columns(table) select_cols = [ "time", "exchange", "symbol", "interval", "open", "high", "low", "close", "volume", ] for extra in ( "trade_count", "quote_volume", "taker_buy_base_vol", "taker_buy_quote_vol", "is_closed", ): if extra in cols: select_cols.append(extra) async with self._pool.acquire() as conn: rows = await conn.fetch( f""" SELECT {', '.join(select_cols)} FROM {table} WHERE {where} ORDER BY time ASC LIMIT ${idx} OFFSET ${idx + 1} """, *params, limit, offset, ) return [self._row_to_kline(r, interval, interval_ms) for r in rows] async def stream_klines( self, symbol: str, interval: KlineInterval, start_time: datetime | None = None, end_time: datetime | None = None, batch_size: int = DEFAULT_BATCH_SIZE, ) -> AsyncGenerator[list[Kline], None]: """流式获取 K 线数据,适合大数据集 每次返回一批 K 线列表,避免一次性加载过多数据到内存。 Yields: 每批 Kline 列表(按时间升序) """ offset = 0 while True: batch = await self.fetch_klines( symbol=symbol, interval=interval, start_time=start_time, end_time=end_time, limit=batch_size, offset=offset, ) if not batch: break yield batch offset += len(batch) async def fetch_multi_klines( self, symbols: list[str], interval: KlineInterval, start_time: datetime | None = None, end_time: datetime | None = None, limit: int = 1000, ) -> dict[str, list[Kline]]: """批量获取多个交易对的 K 线 Returns: {symbol: [Kline, ...]} 字典 """ tasks = [ self.fetch_klines( symbol=sym, interval=interval, start_time=start_time, end_time=end_time, limit=limit, ) for sym in symbols ] results = await asyncio.gather(*tasks) return dict(zip(symbols, results)) # ── 内部方法 ── async def _get_columns(self, table: str) -> set[str]: """获取表的列名集合(带缓存)""" if table not in self._col_cache: async with self._pool.acquire() as conn: rows = await conn.fetch( """ SELECT column_name FROM information_schema.columns WHERE table_name = $1 """, table, ) self._col_cache[table] = {r["column_name"] for r in rows} return self._col_cache[table] @staticmethod def _row_to_kline( row: asyncpg.Record, interval: KlineInterval, interval_ms: int ) -> Kline: """将数据库行转换为 Kline 模型""" open_time = dt_to_unix_ms(row["time"]) return Kline( exchange=row["exchange"], symbol=row["symbol"], interval=interval, openTime=open_time, closeTime=open_time + interval_ms, open=_to_float(row["open"]), high=_to_float(row["high"]), low=_to_float(row["low"]), close=_to_float(row["close"]), volume=_to_float(row["volume"]), quoteVolume=_to_float(row.get("quote_volume", 0)), takerBuyBaseVol=_to_float(row.get("taker_buy_base_vol", 0)), takerBuyQuoteVol=_to_float(row.get("taker_buy_quote_vol", 0)), tradeCount=int(row.get("trade_count") or 0), isClosed=bool(row.get("is_closed", True)), )