aimoon 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimoon/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """aimoon — A股量化筛选与交易建议系统"""
2
+ from __future__ import annotations
3
+
4
+ __all__ = [
5
+ "config", "models", "cache", "data", "indicators",
6
+ "scoring", "screener", "backtest", "output", "demo", "cli",
7
+ ]
aimoon/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from aimoon.cli import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
aimoon/backtest.py ADDED
@@ -0,0 +1,81 @@
1
+ """回测引擎"""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from dataclasses import dataclass
6
+
7
+ import pandas as pd
8
+
9
+ from aimoon.config import Config
10
+ from aimoon.screener import screen_stock
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class TradeRecord:
17
+ entry_date: str
18
+ exit_date: str
19
+ entry_price: float
20
+ exit_price: float
21
+ return_pct: float
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class BacktestResult:
26
+ code: str
27
+ total_return: float
28
+ win_rate: float
29
+ max_drawdown: float
30
+ trade_count: int
31
+ trades: tuple[TradeRecord, ...]
32
+
33
+
34
+ class BacktestEngine:
35
+ def __init__(self, cfg: Config, hold_days: int = 5) -> None:
36
+ self.cfg = cfg
37
+ self.hold_days = hold_days
38
+
39
+ def run(self, code: str, name: str, kline: pd.DataFrame) -> BacktestResult:
40
+ min_window = self.cfg.ma_long
41
+ if len(kline) < min_window + self.hold_days:
42
+ return BacktestResult(code, 0.0, 0.0, 0.0, 0, ())
43
+ trades: list[TradeRecord] = []
44
+ dates = kline.index.tolist()
45
+ in_trade = False
46
+ exit_idx = 0
47
+ for i in range(min_window, len(kline) - self.hold_days):
48
+ if in_trade and i < exit_idx:
49
+ continue
50
+ in_trade = False
51
+ window = kline.iloc[:i + 1]
52
+ scored = screen_stock(code, name, window)
53
+ if scored is None or scored.total_score < 2:
54
+ continue
55
+ entry_price = float(kline["close"].iloc[i])
56
+ exit_i = min(i + self.hold_days, len(kline) - 1)
57
+ exit_price = float(kline["close"].iloc[exit_i])
58
+ ret = (exit_price - entry_price) / entry_price * 100
59
+ trades.append(TradeRecord(str(dates[i].date()), str(dates[exit_i].date()), entry_price, exit_price, ret))
60
+ in_trade = True
61
+ exit_idx = exit_i + 1
62
+ return self._metrics(code, trades, kline)
63
+
64
+ def _metrics(self, code: str, trades: list[TradeRecord], kline: pd.DataFrame) -> BacktestResult:
65
+ if not trades:
66
+ return BacktestResult(code, 0.0, 0.0, 0.0, 0, ())
67
+ total_ret = sum(t.return_pct for t in trades)
68
+ win_rate = sum(1 for t in trades if t.return_pct > 0) / len(trades)
69
+ equity = [100.0]
70
+ trade_idx = 0
71
+ for i in range(1, len(kline)):
72
+ if trade_idx < len(trades) and str(kline.index[i].date()) == trades[trade_idx].exit_date:
73
+ equity.append(equity[-1] * (1 + trades[trade_idx].return_pct / 100))
74
+ trade_idx += 1
75
+ else:
76
+ equity.append(equity[-1])
77
+ peak = max_dd = 0.0
78
+ for v in equity:
79
+ peak = max(peak, v)
80
+ max_dd = max(max_dd, (peak - v) / peak if peak > 0 else 0)
81
+ return BacktestResult(code, total_ret, win_rate, max_dd, len(trades), tuple(trades))
aimoon/cache.py ADDED
@@ -0,0 +1,53 @@
1
+ """文件缓存层 - pickle 序列化 + TTL 过期"""
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ import os
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import pandas as pd
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DataCache:
15
+ """缓存 DataFrame 到磁盘,支持 TTL 过期。"""
16
+
17
+ def __init__(self, cache_dir: str = ".aimoon_cache", ttl_hours: int = 4) -> None:
18
+ self.cache_dir = Path(cache_dir)
19
+ self.ttl_seconds = ttl_hours * 3600
20
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
21
+
22
+ def _path_for(self, stock_code: str) -> Path:
23
+ return self.cache_dir / f"{stock_code}.pkl"
24
+
25
+ def get(self, stock_code: str) -> pd.DataFrame | None:
26
+ """返回缓存的 DataFrame,过期或不存在返回 None。"""
27
+ path = self._path_for(stock_code)
28
+ if not path.exists():
29
+ return None
30
+ age = time.time() - path.stat().st_mtime
31
+ if age > self.ttl_seconds:
32
+ logger.debug("Cache expired for %s (%.0fs old)", stock_code, age)
33
+ return None
34
+ try:
35
+ return pd.read_pickle(path)
36
+ except Exception as e:
37
+ logger.warning("Cache read failed for %s: %s", stock_code, e)
38
+ return None
39
+
40
+ def put(self, stock_code: str, df: pd.DataFrame) -> None:
41
+ """写入 DataFrame 到缓存。"""
42
+ try:
43
+ df.to_pickle(self._path_for(stock_code))
44
+ except Exception as e:
45
+ logger.warning("Cache write failed for %s: %s", stock_code, e)
46
+
47
+ def clear(self) -> int:
48
+ """清除所有缓存文件,返回删除数量。"""
49
+ count = 0
50
+ for p in self.cache_dir.glob("*.pkl"):
51
+ p.unlink()
52
+ count += 1
53
+ return count
aimoon/cli.py ADDED
@@ -0,0 +1,123 @@
1
+ """CLI 入口 — 薄管道"""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import logging
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+
10
+ from aimoon.cache import DataCache
11
+ from aimoon.config import Config, load_config
12
+ from aimoon.data import get_spot_for_codes, filter_universe, get_sector_context
13
+ from aimoon.data.filters import get_holdings_pool
14
+ from aimoon.output import OutputFormatter
15
+ from aimoon.scoring.rps import compute_rps
16
+ from aimoon.screener import screen_universe
17
+
18
+ logging.basicConfig(level=logging.WARNING)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_args() -> argparse.Namespace:
23
+ p = argparse.ArgumentParser(description="A-share quant screener")
24
+ p.add_argument("--config", type=str, default=None)
25
+ p.add_argument("--top", type=int, default=30)
26
+ p.add_argument("--workers", type=int, default=5)
27
+ p.add_argument("--no-csv", action="store_true")
28
+ p.add_argument("--demo", action="store_true")
29
+ p.add_argument("--refresh", action="store_true")
30
+ sub = p.add_subparsers(dest="command")
31
+ bt = sub.add_parser("backtest")
32
+ bt.add_argument("--stocks", type=str, default="000001")
33
+ bt.add_argument("--hold-days", type=int, default=5)
34
+ cp = sub.add_parser("cache")
35
+ cs = cp.add_subparsers(dest="cache_action")
36
+ cs.add_parser("clear")
37
+ sub.add_parser("update", help="Clear all caches and re-fetch data")
38
+ return p.parse_args()
39
+
40
+
41
+ def main() -> None:
42
+ args = parse_args()
43
+ cfg = load_config(args, path=getattr(args, "config", None))
44
+ fmt = OutputFormatter(cfg)
45
+
46
+ # 缓存管理
47
+ if cfg.command == "cache":
48
+ cache = DataCache(cfg.cache_dir, cfg.cache_ttl_hours)
49
+ print(f"Cleared {cache.clear()} cached files")
50
+ return
51
+
52
+ # update:清除所有缓存后重新运行
53
+ if cfg.command == "update":
54
+ import shutil
55
+ cache_dir = Path(cfg.cache_dir)
56
+ if cache_dir.exists():
57
+ shutil.rmtree(cache_dir)
58
+ print(f"Cleared cache dir: {cache_dir}")
59
+ # 继续执行正常管道(会重新获取所有数据)
60
+
61
+ # 回测
62
+ if cfg.command == "backtest":
63
+ from aimoon.backtest import BacktestEngine
64
+ from aimoon.data.history import get_kline
65
+ engine = BacktestEngine(cfg, hold_days=cfg.hold_days)
66
+ cache = DataCache(cfg.cache_dir, cfg.cache_ttl_hours)
67
+ fmt.console.print(f"[bold blue]=== Backtest (hold {cfg.hold_days}d) ===[/bold blue]")
68
+ for code in cfg.stocks.split(","):
69
+ code = code.strip()
70
+ r = get_kline(code, cfg.history_days, cache)
71
+ if r.is_ok():
72
+ result = engine.run(code, code, r.unwrap())
73
+ color = "green" if result.total_return > 0 else "red"
74
+ fmt.console.print(
75
+ f" {result.code}: [{color}]{result.total_return:+.2f}%[/{color}] "
76
+ f"胜率={result.win_rate:.0%} 交易={result.trade_count}次 "
77
+ f"最大回撤={result.max_drawdown:.2%}"
78
+ )
79
+ return
80
+
81
+ # Demo 模式
82
+ if cfg.demo:
83
+ from aimoon.demo import generate_demo
84
+ spot_df, klines = generate_demo()
85
+ cache = DataCache(cfg.cache_dir, cfg.cache_ttl_hours)
86
+ results, tails = screen_universe(spot_df, cfg, cache, klines=klines)
87
+ else:
88
+ # 实时筛选管道 — 持仓池先行,减少行情请求量
89
+ fmt.console.print("[dim]Loading holdings pool (cached)...[/dim]")
90
+ pool = get_holdings_pool(cfg)
91
+ fmt.console.print(f"[dim]Holdings pool: {len(pool)} stocks[/dim]")
92
+
93
+ sr = get_spot_for_codes(pool, cfg)
94
+ if sr.is_err():
95
+ fmt.console.print(f"[red]Failed: {sr.error}[/red]")
96
+ fmt.console.print("[yellow]Try: python -m aimoon --demo[/yellow]")
97
+ sys.exit(1)
98
+ spot = sr.unwrap()
99
+ fmt.console.print(f"[dim]Spot data for {len(spot)} stocks[/dim]")
100
+
101
+ fmt.console.print("[dim]Filtering universe...[/dim]")
102
+ universe = filter_universe(spot, cfg)
103
+
104
+ fmt.console.print("[dim]Building sector context...[/dim]")
105
+ ctx = get_sector_context(spot)
106
+
107
+ cache = DataCache(cfg.cache_dir, cfg.cache_ttl_hours)
108
+ fmt.console.print(f"[dim]Analyzing {len(universe)} stocks...[/dim]")
109
+ t0 = time.time()
110
+ results, tails = screen_universe(universe, cfg, cache, ctx)
111
+ fmt.console.print(f"[dim]Done in {time.time() - t0:.1f}s[/dim]")
112
+
113
+ # RPS + 排序 + 输出
114
+ results = compute_rps(results, tails)
115
+ top = sorted(results, key=lambda s: s.total_score, reverse=True)[:cfg.top_n]
116
+ fmt.display(top)
117
+ if not cfg.no_csv and top:
118
+ fmt.console.print(f"[dim]Exported: {fmt.export_csv(top)}[/dim]")
119
+ fmt.console.print(f"[dim]Exported: {fmt.export_markdown(top)}[/dim]")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
aimoon/config.py ADDED
@@ -0,0 +1,95 @@
1
+ """配置模块 — frozen dataclass,显式传递,无全局单例"""
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import logging
6
+ from dataclasses import dataclass, fields
7
+ from pathlib import Path
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class Config:
14
+ # 筛选参数
15
+ history_days: int = 250
16
+ min_market_cap_yi: float = 50.0
17
+ max_market_cap_yi: float = 2000.0
18
+ min_turnover_pct: float = 3.0
19
+ max_turnover_pct: float = 30.0
20
+ min_price: float = 5.0
21
+ max_price: float = 100.0
22
+ min_list_days: int = 250
23
+ top_n: int = 30
24
+ # 机构持仓
25
+ min_northbound_cap: float = 1.0
26
+ min_fund_pct: float = 5.0
27
+ # 技术指标参数
28
+ ma_short: int = 5
29
+ ma_mid: int = 20
30
+ ma_long: int = 60
31
+ rsi_period: int = 14
32
+ macd_fast: int = 12
33
+ macd_slow: int = 26
34
+ macd_signal: int = 9
35
+ kdj_period: int = 9
36
+ boll_period: int = 20
37
+ boll_std: float = 2.0
38
+ volume_ma_period: int = 20
39
+ # 缓存
40
+ cache_dir: str = ".aimoon_cache"
41
+ cache_ttl_hours: int = 24
42
+ # 输出
43
+ output_dir: str = "output"
44
+ # CLI 参数
45
+ no_csv: bool = False
46
+ workers: int = 5
47
+ demo: bool = False
48
+ refresh: bool = False
49
+ command: str | None = None
50
+ stocks: str = "000001"
51
+ hold_days: int = 5
52
+ # 排除规则
53
+ exclude_boards: tuple[str, ...] = ("ST", "退", "北交所")
54
+ exclude_prefixes: tuple[str, ...] = ("8", "4")
55
+
56
+
57
+ def load_config(args: argparse.Namespace | None = None, path: str | None = None) -> Config:
58
+ """合并配置:CLI 参数 > YAML 文件 > 默认值。"""
59
+ overrides: dict = {}
60
+
61
+ # YAML 文件
62
+ if path:
63
+ p = Path(path)
64
+ if p.exists():
65
+ try:
66
+ import yaml
67
+ with open(p, encoding="utf-8") as f:
68
+ data = yaml.safe_load(f) or {}
69
+ valid = {f.name for f in fields(Config)}
70
+ tuple_fields = {f.name for f in fields(Config) if isinstance(f.default, tuple)}
71
+ for k, v in data.items():
72
+ if k in valid:
73
+ overrides[k] = tuple(v) if k in tuple_fields and isinstance(v, list) else v
74
+ except Exception as e:
75
+ logger.warning("Failed to load config %s: %s", path, e)
76
+
77
+ # CLI 参数覆盖
78
+ if args:
79
+ cli_map = {
80
+ "top": "top_n", "workers": "workers", "no_csv": "no_csv",
81
+ "demo": "demo", "refresh": "refresh",
82
+ "hold_days": "hold_days", "stocks": "stocks",
83
+ }
84
+ for cli_key, cfg_key in cli_map.items():
85
+ val = getattr(args, cli_key, None)
86
+ if val is not None:
87
+ overrides[cfg_key] = val
88
+ if hasattr(args, "command") and args.command:
89
+ overrides["command"] = args.command
90
+
91
+ return Config(**overrides)
92
+
93
+
94
+ # 向后兼容别名 — 旧代码使用 CONFIG 全局变量,Task 11 清理时移除
95
+ CONFIG = Config()
@@ -0,0 +1,15 @@
1
+ """数据获取层"""
2
+ from aimoon.data.spot import get_spot, get_spot_for_codes
3
+ from aimoon.data.history import get_kline
4
+ from aimoon.data.filters import (
5
+ filter_universe,
6
+ filter_by_sectors,
7
+ get_sector_context,
8
+ get_holdings_pool,
9
+ )
10
+
11
+ __all__ = [
12
+ "get_spot", "get_spot_for_codes", "get_kline",
13
+ "filter_universe", "filter_by_sectors",
14
+ "get_sector_context", "get_holdings_pool",
15
+ ]