feat(engine): 添加数据服务层与技术指标库
- data/service.py: 数据拉取服务,从 TimescaleDB 读取 K 线/Ticker 等行情数据 - indicators/momentum.py: 动量类指标(RSI/MACD/Stochastic 等) - indicators/trend.py: 趋势类指标(EMA/SMA/ADX/SuperTrend 等) - indicators/volatility.py: 波动率指标(Bollinger/ATR/Keltner 等) - indicators/volume.py: 成交量指标(OBV/VWAP/MFI 等)
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
# engine.data — K 线数据服务
|
||||
|
||||
from .service import DataService
|
||||
|
||||
__all__ = ["DataService"]
|
||||
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
K 线数据服务 — 从 TimescaleDB 读取历史 K 线数据,为回测引擎提供数据源
|
||||
|
||||
用法:
|
||||
from engine.common.config import config
|
||||
from engine.data import DataService
|
||||
|
||||
ds = DataService(config.db)
|
||||
await ds.connect()
|
||||
|
||||
# 获取 BTCUSDT 1h K 线,按时间范围过滤
|
||||
klines = await ds.fetch_klines(
|
||||
symbol="BTCUSDT",
|
||||
interval="1h",
|
||||
start_time=datetime(2025, 1, 1),
|
||||
end_time=datetime(2026, 1, 1),
|
||||
)
|
||||
|
||||
await ds.close()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
|
||||
from ..common.config import DBConfig
|
||||
from ..common.models import Kline, KlineInterval
|
||||
|
||||
# ── 周期 → 表名映射 ──
|
||||
INTERVAL_TO_TABLE: dict[KlineInterval, str] = {
|
||||
"1m": "klines",
|
||||
"5m": "klines_5m",
|
||||
"15m": "klines_15m",
|
||||
"30m": "klines_30m",
|
||||
"1h": "klines_1h",
|
||||
"4h": "klines_4h",
|
||||
"1d": "klines_1d",
|
||||
"1w": "klines_1w",
|
||||
}
|
||||
|
||||
# ── 周期毫秒数 ──
|
||||
INTERVAL_MS: dict[KlineInterval, int] = {
|
||||
"1m": 60_000,
|
||||
"5m": 300_000,
|
||||
"15m": 900_000,
|
||||
"30m": 1_800_000,
|
||||
"1h": 3_600_000,
|
||||
"4h": 14_400_000,
|
||||
"1d": 86_400_000,
|
||||
"1w": 604_800_000,
|
||||
}
|
||||
|
||||
DEFAULT_BATCH_SIZE = 5000
|
||||
|
||||
|
||||
def dt_to_unix_ms(dt: datetime) -> float:
|
||||
"""datetime → Unix 毫秒时间戳"""
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.timestamp() * 1000
|
||||
|
||||
|
||||
def _to_float(val) -> float:
|
||||
"""将 Decimal / int 等数值转为 float"""
|
||||
if val is None:
|
||||
return 0.0
|
||||
if isinstance(val, Decimal):
|
||||
return float(val)
|
||||
return float(val)
|
||||
|
||||
|
||||
class DataService:
|
||||
"""K 线数据服务
|
||||
|
||||
封装 TimescaleDB 中 K 线数据的查询逻辑,
|
||||
将数据库行转换为 engine.common.models.Kline 模型,
|
||||
供回测引擎消费。
|
||||
"""
|
||||
|
||||
def __init__(self, db_config: DBConfig, pool_size: int = 4):
|
||||
self._db_config = db_config
|
||||
self._pool_size = pool_size
|
||||
self._pool: asyncpg.Pool | None = None
|
||||
self._col_cache: dict[str, set[str]] = {}
|
||||
|
||||
# ── 生命周期 ──
|
||||
|
||||
@property
|
||||
def dsn(self) -> str:
|
||||
db = self._db_config
|
||||
return f"postgresql://{db.user}:{db.password}@{db.host}:{db.port}/{db.name}"
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""建立数据库连接池"""
|
||||
self._pool = await asyncpg.create_pool(
|
||||
dsn=self.dsn,
|
||||
min_size=1,
|
||||
max_size=self._pool_size,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭连接池"""
|
||||
if self._pool:
|
||||
await self._pool.close()
|
||||
self._pool = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._pool is not None
|
||||
|
||||
# ── 元数据查询 ──
|
||||
|
||||
async def fetch_available_symbols(
|
||||
self, interval: KlineInterval = "1m"
|
||||
) -> list[str]:
|
||||
"""获取指定周期下所有有数据的交易对"""
|
||||
table = INTERVAL_TO_TABLE[interval]
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"SELECT DISTINCT symbol FROM {table} ORDER BY symbol"
|
||||
)
|
||||
return [r["symbol"] for r in rows]
|
||||
|
||||
async def fetch_symbol_date_range(
|
||||
self, symbol: str, interval: KlineInterval
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""获取指定交易对 + 周期的数据起止时间"""
|
||||
table = INTERVAL_TO_TABLE[interval]
|
||||
async with self._pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
f"""
|
||||
SELECT MIN(time) AS min_time, MAX(time) AS max_time
|
||||
FROM {table}
|
||||
WHERE symbol = $1
|
||||
""",
|
||||
symbol,
|
||||
)
|
||||
if row is None or row["min_time"] is None:
|
||||
raise ValueError(f"无数据: {symbol} {interval}")
|
||||
return row["min_time"], row["max_time"]
|
||||
|
||||
async def fetch_klines_count(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: KlineInterval,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> int:
|
||||
"""获取指定条件的 K 线条数(用于预判数据量)"""
|
||||
table = INTERVAL_TO_TABLE[interval]
|
||||
conditions = ["symbol = $1", "interval = $2"]
|
||||
params: list = [symbol, interval]
|
||||
idx = 3
|
||||
|
||||
if start_time is not None:
|
||||
conditions.append(f"time >= ${idx}")
|
||||
params.append(start_time)
|
||||
idx += 1
|
||||
if end_time is not None:
|
||||
conditions.append(f"time < ${idx}")
|
||||
params.append(end_time)
|
||||
idx += 1
|
||||
|
||||
where = " AND ".join(conditions)
|
||||
async with self._pool.acquire() as conn:
|
||||
count = await conn.fetchval(
|
||||
f"SELECT COUNT(*) FROM {table} WHERE {where}", *params
|
||||
)
|
||||
return count
|
||||
|
||||
# ── 核心查询 ──
|
||||
|
||||
async def fetch_klines(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: KlineInterval,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
limit: int = 1000,
|
||||
offset: int = 0,
|
||||
) -> list[Kline]:
|
||||
"""获取 K 线数据,返回 Pydantic 模型列表
|
||||
|
||||
Args:
|
||||
symbol: 交易对(如 BTCUSDT)
|
||||
interval: K 线周期
|
||||
start_time: 起始时间(包含)
|
||||
end_time: 结束时间(不包含)
|
||||
limit: 最大返回条数
|
||||
offset: 分页偏移
|
||||
|
||||
Returns:
|
||||
按时间升序排列的 Kline 列表
|
||||
"""
|
||||
table = INTERVAL_TO_TABLE[interval]
|
||||
interval_ms = INTERVAL_MS[interval]
|
||||
|
||||
conditions = ["symbol = $1", "interval = $2"]
|
||||
params: list = [symbol, interval]
|
||||
idx = 3
|
||||
|
||||
if start_time is not None:
|
||||
conditions.append(f"time >= ${idx}")
|
||||
params.append(start_time)
|
||||
idx += 1
|
||||
if end_time is not None:
|
||||
conditions.append(f"time < ${idx}")
|
||||
params.append(end_time)
|
||||
idx += 1
|
||||
|
||||
where = " AND ".join(conditions)
|
||||
cols = await self._get_columns(table)
|
||||
|
||||
select_cols = [
|
||||
"time", "exchange", "symbol", "interval",
|
||||
"open", "high", "low", "close", "volume",
|
||||
]
|
||||
for extra in (
|
||||
"trade_count", "quote_volume", "taker_buy_base_vol",
|
||||
"taker_buy_quote_vol", "is_closed",
|
||||
):
|
||||
if extra in cols:
|
||||
select_cols.append(extra)
|
||||
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
f"""
|
||||
SELECT {', '.join(select_cols)}
|
||||
FROM {table}
|
||||
WHERE {where}
|
||||
ORDER BY time ASC
|
||||
LIMIT ${idx} OFFSET ${idx + 1}
|
||||
""",
|
||||
*params,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
|
||||
return [self._row_to_kline(r, interval, interval_ms) for r in rows]
|
||||
|
||||
async def stream_klines(
|
||||
self,
|
||||
symbol: str,
|
||||
interval: KlineInterval,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
) -> AsyncGenerator[list[Kline], None]:
|
||||
"""流式获取 K 线数据,适合大数据集
|
||||
|
||||
每次返回一批 K 线列表,避免一次性加载过多数据到内存。
|
||||
|
||||
Yields:
|
||||
每批 Kline 列表(按时间升序)
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
batch = await self.fetch_klines(
|
||||
symbol=symbol,
|
||||
interval=interval,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=batch_size,
|
||||
offset=offset,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
offset += len(batch)
|
||||
|
||||
async def fetch_multi_klines(
|
||||
self,
|
||||
symbols: list[str],
|
||||
interval: KlineInterval,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
limit: int = 1000,
|
||||
) -> dict[str, list[Kline]]:
|
||||
"""批量获取多个交易对的 K 线
|
||||
|
||||
Returns:
|
||||
{symbol: [Kline, ...]} 字典
|
||||
"""
|
||||
tasks = [
|
||||
self.fetch_klines(
|
||||
symbol=sym,
|
||||
interval=interval,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
limit=limit,
|
||||
)
|
||||
for sym in symbols
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return dict(zip(symbols, results))
|
||||
|
||||
# ── 内部方法 ──
|
||||
|
||||
async def _get_columns(self, table: str) -> set[str]:
|
||||
"""获取表的列名集合(带缓存)"""
|
||||
if table not in self._col_cache:
|
||||
async with self._pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = $1
|
||||
""",
|
||||
table,
|
||||
)
|
||||
self._col_cache[table] = {r["column_name"] for r in rows}
|
||||
return self._col_cache[table]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_kline(
|
||||
row: asyncpg.Record, interval: KlineInterval, interval_ms: int
|
||||
) -> Kline:
|
||||
"""将数据库行转换为 Kline 模型"""
|
||||
open_time = dt_to_unix_ms(row["time"])
|
||||
|
||||
return Kline(
|
||||
exchange=row["exchange"],
|
||||
symbol=row["symbol"],
|
||||
interval=interval,
|
||||
openTime=open_time,
|
||||
closeTime=open_time + interval_ms,
|
||||
open=_to_float(row["open"]),
|
||||
high=_to_float(row["high"]),
|
||||
low=_to_float(row["low"]),
|
||||
close=_to_float(row["close"]),
|
||||
volume=_to_float(row["volume"]),
|
||||
quoteVolume=_to_float(row.get("quote_volume", 0)),
|
||||
takerBuyBaseVol=_to_float(row.get("taker_buy_base_vol", 0)),
|
||||
takerBuyQuoteVol=_to_float(row.get("taker_buy_quote_vol", 0)),
|
||||
tradeCount=int(row.get("trade_count") or 0),
|
||||
isClosed=bool(row.get("is_closed", True)),
|
||||
)
|
||||
Reference in New Issue
Block a user