akquant 0.1.4__cp310-abi3-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of akquant might be problematic. Click here for more details.
- akquant/__init__.py +98 -0
- akquant/akquant.pyd +0 -0
- akquant/akquant.pyi +683 -0
- akquant/backtest.py +659 -0
- akquant/config.py +65 -0
- akquant/data.py +136 -0
- akquant/indicator.py +81 -0
- akquant/log.py +135 -0
- akquant/ml/__init__.py +3 -0
- akquant/ml/model.py +234 -0
- akquant/py.typed +0 -0
- akquant/risk.py +40 -0
- akquant/sizer.py +96 -0
- akquant/strategy.py +824 -0
- akquant/utils.py +386 -0
- akquant-0.1.4.dist-info/METADATA +219 -0
- akquant-0.1.4.dist-info/RECORD +19 -0
- akquant-0.1.4.dist-info/WHEEL +4 -0
- akquant-0.1.4.dist-info/licenses/LICENSE +21 -0
akquant/config.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass
|
|
6
|
+
class RiskConfig:
|
|
7
|
+
"""Configuration for Risk Management."""
|
|
8
|
+
|
|
9
|
+
active: bool = True
|
|
10
|
+
max_order_size: Optional[float] = None
|
|
11
|
+
max_order_value: Optional[float] = None
|
|
12
|
+
max_position_size: Optional[float] = None
|
|
13
|
+
restricted_list: Optional[List[str]] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class StrategyConfig:
|
|
18
|
+
"""
|
|
19
|
+
Global configuration for strategies and backtesting.
|
|
20
|
+
|
|
21
|
+
Inspired by PyBroker's configuration system.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# Capital Management
|
|
25
|
+
initial_cash: float = 100000.0
|
|
26
|
+
|
|
27
|
+
# Fees & Commission
|
|
28
|
+
fee_mode: str = "per_order" # 'per_order', 'per_share', 'percent'
|
|
29
|
+
fee_amount: float = 0.0 # Fixed amount or percentage
|
|
30
|
+
|
|
31
|
+
# Execution
|
|
32
|
+
enable_fractional_shares: bool = False
|
|
33
|
+
round_fill_price: bool = True
|
|
34
|
+
|
|
35
|
+
# Position Sizing Constraints
|
|
36
|
+
max_long_positions: Optional[int] = None
|
|
37
|
+
max_short_positions: Optional[int] = None
|
|
38
|
+
|
|
39
|
+
# Bootstrap Metrics
|
|
40
|
+
bootstrap_samples: int = 1000
|
|
41
|
+
bootstrap_sample_size: Optional[int] = None
|
|
42
|
+
|
|
43
|
+
# Other
|
|
44
|
+
exit_on_last_bar: bool = True
|
|
45
|
+
|
|
46
|
+
# Risk Config
|
|
47
|
+
risk: Optional[RiskConfig] = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class BacktestConfig:
|
|
52
|
+
"""Configuration specifically for running backtests."""
|
|
53
|
+
|
|
54
|
+
strategy_config: StrategyConfig
|
|
55
|
+
start_date: Optional[str] = None
|
|
56
|
+
end_date: Optional[str] = None
|
|
57
|
+
instruments: Optional[List[str]] = None
|
|
58
|
+
benchmark: Optional[str] = None
|
|
59
|
+
timezone: str = "Asia/Shanghai"
|
|
60
|
+
show_progress: bool = True
|
|
61
|
+
history_depth: int = 0
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# Global instance
|
|
65
|
+
strategy_config = StrategyConfig()
|
akquant/data.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from .akquant import Bar
|
|
9
|
+
from .utils import load_bar_from_df
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ParquetDataCatalog:
|
|
15
|
+
"""
|
|
16
|
+
Data Catalog using Parquet files for storage.
|
|
17
|
+
|
|
18
|
+
Optimized for performance using PyArrow/FastParquet.
|
|
19
|
+
Structure: {root}/{symbol}.parquet (Simplest for now)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, root_path: Optional[str] = None):
|
|
23
|
+
"""
|
|
24
|
+
Initialize the DataCatalog.
|
|
25
|
+
|
|
26
|
+
:param root_path: Root directory for the catalog.
|
|
27
|
+
"""
|
|
28
|
+
if root_path:
|
|
29
|
+
self.root = Path(root_path)
|
|
30
|
+
else:
|
|
31
|
+
self.root = Path.home() / ".akquant" / "catalog"
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
if not self.root.exists():
|
|
35
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
36
|
+
except PermissionError:
|
|
37
|
+
self.root = Path.cwd() / ".akquant_catalog"
|
|
38
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
|
|
40
|
+
def write(self, symbol: str, df: pd.DataFrame) -> Path:
|
|
41
|
+
"""
|
|
42
|
+
Write DataFrame to Parquet catalog.
|
|
43
|
+
|
|
44
|
+
:param symbol: Instrument symbol.
|
|
45
|
+
:param df: DataFrame with DatetimeIndex.
|
|
46
|
+
:return: Path to the written file.
|
|
47
|
+
"""
|
|
48
|
+
symbol_path = self.root / symbol
|
|
49
|
+
symbol_path.mkdir(exist_ok=True)
|
|
50
|
+
file_path = symbol_path / "data.parquet"
|
|
51
|
+
|
|
52
|
+
# Ensure index is standard
|
|
53
|
+
if not isinstance(df.index, pd.DatetimeIndex):
|
|
54
|
+
# Try to convert date column if exists
|
|
55
|
+
if "date" in df.columns:
|
|
56
|
+
df = df.set_index("date")
|
|
57
|
+
df.index = pd.to_datetime(df.index)
|
|
58
|
+
|
|
59
|
+
df.to_parquet(file_path, compression="snappy")
|
|
60
|
+
return file_path
|
|
61
|
+
|
|
62
|
+
def read(
|
|
63
|
+
self,
|
|
64
|
+
symbol: str,
|
|
65
|
+
start_date: Optional[str] = None,
|
|
66
|
+
end_date: Optional[str] = None,
|
|
67
|
+
columns: Optional[List[str]] = None,
|
|
68
|
+
) -> pd.DataFrame:
|
|
69
|
+
"""
|
|
70
|
+
Read DataFrame from Parquet catalog.
|
|
71
|
+
|
|
72
|
+
:param symbol: Instrument symbol.
|
|
73
|
+
:param start_date: Filter start date (YYYYMMDD or YYYY-MM-DD).
|
|
74
|
+
:param end_date: Filter end date.
|
|
75
|
+
:param columns: Specific columns to read.
|
|
76
|
+
:return: DataFrame.
|
|
77
|
+
"""
|
|
78
|
+
symbol_path = self.root / symbol
|
|
79
|
+
file_path = symbol_path / "data.parquet"
|
|
80
|
+
|
|
81
|
+
if not file_path.exists():
|
|
82
|
+
return pd.DataFrame()
|
|
83
|
+
|
|
84
|
+
# Read with projection (columns)
|
|
85
|
+
df = pd.read_parquet(file_path, columns=columns)
|
|
86
|
+
|
|
87
|
+
# Filter by date
|
|
88
|
+
if start_date:
|
|
89
|
+
df = df[df.index >= pd.to_datetime(start_date)]
|
|
90
|
+
if end_date:
|
|
91
|
+
df = df[df.index <= pd.to_datetime(end_date)]
|
|
92
|
+
|
|
93
|
+
return df
|
|
94
|
+
|
|
95
|
+
def list_symbols(self) -> List[str]:
|
|
96
|
+
"""List all symbols in the catalog."""
|
|
97
|
+
return [p.name for p in self.root.iterdir() if p.is_dir()]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DataLoader:
|
|
101
|
+
"""Data Loader with caching capabilities, inspired by PyBroker."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, cache_dir: Optional[str] = None):
|
|
104
|
+
"""
|
|
105
|
+
Initialize DataLoader.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
cache_dir (str, optional): Directory to store cache files.
|
|
109
|
+
Defaults to ~/.akquant/cache.
|
|
110
|
+
"""
|
|
111
|
+
if cache_dir:
|
|
112
|
+
self.cache_dir = Path(cache_dir)
|
|
113
|
+
else:
|
|
114
|
+
self.cache_dir = Path.home() / ".akquant" / "cache"
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
if not self.cache_dir.exists():
|
|
118
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
119
|
+
except PermissionError:
|
|
120
|
+
logger.warning(
|
|
121
|
+
f"Permission denied for {self.cache_dir}, "
|
|
122
|
+
"falling back to local .akquant_cache"
|
|
123
|
+
)
|
|
124
|
+
self.cache_dir = Path.cwd() / ".akquant_cache"
|
|
125
|
+
if not self.cache_dir.exists():
|
|
126
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
127
|
+
|
|
128
|
+
def _get_cache_path(self, key: str) -> Path:
|
|
129
|
+
"""Generate cache file path based on a unique key."""
|
|
130
|
+
# Use a hash of the key to avoid filesystem issues with long/invalid filenames
|
|
131
|
+
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
|
132
|
+
return self.cache_dir / f"{hashed_key}.pkl"
|
|
133
|
+
|
|
134
|
+
def df_to_bars(self, df: pd.DataFrame, symbol: Optional[str] = None) -> List[Bar]:
|
|
135
|
+
"""Convert DataFrame to list of Bar objects."""
|
|
136
|
+
return load_bar_from_df(df, symbol)
|
akquant/indicator.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict
|
|
2
|
+
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Indicator:
|
|
7
|
+
"""
|
|
8
|
+
Helper class for defining and calculating indicators.
|
|
9
|
+
|
|
10
|
+
Inspired by PyBroker's indicator system.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, name: str, fn: Callable, **kwargs: Any) -> None:
|
|
14
|
+
"""Initialize the Indicator."""
|
|
15
|
+
self.name = name
|
|
16
|
+
self.fn = fn
|
|
17
|
+
self.kwargs = kwargs
|
|
18
|
+
self._data: Dict[str, pd.Series] = {} # symbol -> series
|
|
19
|
+
|
|
20
|
+
def __call__(self, df: pd.DataFrame, symbol: str) -> pd.Series:
|
|
21
|
+
"""Calculate indicator on a DataFrame."""
|
|
22
|
+
if symbol in self._data:
|
|
23
|
+
return self._data[symbol]
|
|
24
|
+
|
|
25
|
+
# Assume fn takes a series/df and returns a series
|
|
26
|
+
# If kwargs contains column names, extract them
|
|
27
|
+
# This is a simplified version of PyBroker's powerful DSL
|
|
28
|
+
try:
|
|
29
|
+
result = self.fn(df, **self.kwargs)
|
|
30
|
+
except Exception:
|
|
31
|
+
# Try passing column if specified in kwargs
|
|
32
|
+
# e.g. rolling_mean(df['close'], window=5)
|
|
33
|
+
# This part is tricky to generalize without a full DSL,
|
|
34
|
+
# so we start simple: user passes a lambda or function that takes df
|
|
35
|
+
result = self.fn(df)
|
|
36
|
+
|
|
37
|
+
if not isinstance(result, pd.Series):
|
|
38
|
+
# Try to convert if it's not a Series (e.g. numpy array)
|
|
39
|
+
result = pd.Series(result, index=df.index)
|
|
40
|
+
|
|
41
|
+
self._data[symbol] = result
|
|
42
|
+
return result
|
|
43
|
+
|
|
44
|
+
def get_value(self, symbol: str, timestamp: Any) -> float:
|
|
45
|
+
"""
|
|
46
|
+
Get indicator value at specific timestamp (or latest before it).
|
|
47
|
+
|
|
48
|
+
Uses asof lookup which is efficient for sorted time series.
|
|
49
|
+
"""
|
|
50
|
+
if symbol not in self._data:
|
|
51
|
+
return float("nan")
|
|
52
|
+
|
|
53
|
+
series = self._data[symbol]
|
|
54
|
+
# Assuming series index is datetime
|
|
55
|
+
try:
|
|
56
|
+
return float(series.asof(timestamp)) # type: ignore[arg-type]
|
|
57
|
+
except Exception:
|
|
58
|
+
return float("nan")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class IndicatorSet:
|
|
62
|
+
"""Collection of indicators for easy management."""
|
|
63
|
+
|
|
64
|
+
def __init__(self) -> None:
|
|
65
|
+
"""Initialize the IndicatorSet."""
|
|
66
|
+
self._indicators: Dict[str, Indicator] = {}
|
|
67
|
+
|
|
68
|
+
def add(self, name: str, fn: Callable, **kwargs: Any) -> None:
|
|
69
|
+
"""Add an indicator to the set."""
|
|
70
|
+
self._indicators[name] = Indicator(name, fn, **kwargs)
|
|
71
|
+
|
|
72
|
+
def get(self, name: str) -> Indicator:
|
|
73
|
+
"""Get an indicator by name."""
|
|
74
|
+
return self._indicators[name]
|
|
75
|
+
|
|
76
|
+
def calculate_all(self, df: pd.DataFrame, symbol: str) -> Dict[str, pd.Series]:
|
|
77
|
+
"""Calculate all indicators for the given dataframe."""
|
|
78
|
+
results = {}
|
|
79
|
+
for name, ind in self._indicators.items():
|
|
80
|
+
results[name] = ind(df, symbol)
|
|
81
|
+
return results
|
akquant/log.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
# Default format: Time | Level | Message
|
|
6
|
+
DEFAULT_FORMAT = "%(asctime)s | %(levelname)s | %(message)s"
|
|
7
|
+
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Logger:
|
|
11
|
+
r"""
|
|
12
|
+
akquant 日志封装.
|
|
13
|
+
|
|
14
|
+
:description: 提供控制台与文件日志的快捷配置
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
_instance = None
|
|
18
|
+
|
|
19
|
+
def __init__(self) -> None:
|
|
20
|
+
"""Initialize the Logger."""
|
|
21
|
+
self._logger = logging.getLogger("akquant")
|
|
22
|
+
self._logger.setLevel(logging.INFO)
|
|
23
|
+
self._handlers: dict[str, logging.Handler] = {} # key -> handler
|
|
24
|
+
|
|
25
|
+
# Add default console handler if not present
|
|
26
|
+
if not self._logger.handlers:
|
|
27
|
+
self.enable_console()
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def get_logger(cls) -> logging.Logger:
|
|
31
|
+
"""Get the singleton logger instance."""
|
|
32
|
+
if cls._instance is None:
|
|
33
|
+
cls._instance = Logger()
|
|
34
|
+
return cls._instance._logger
|
|
35
|
+
|
|
36
|
+
def set_level(self, level: Union[str, int]) -> None:
|
|
37
|
+
r"""
|
|
38
|
+
设置日志等级.
|
|
39
|
+
|
|
40
|
+
:param level: 日志等级字符串或整数 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
|
|
41
|
+
:type level: str | int
|
|
42
|
+
"""
|
|
43
|
+
self._logger.setLevel(level)
|
|
44
|
+
|
|
45
|
+
def enable_console(self, format_str: str = DEFAULT_FORMAT) -> None:
|
|
46
|
+
r"""
|
|
47
|
+
启用控制台日志.
|
|
48
|
+
|
|
49
|
+
:param format_str: 日志格式字符串
|
|
50
|
+
:type format_str: str
|
|
51
|
+
"""
|
|
52
|
+
if "console" in self._handlers:
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
56
|
+
handler.setFormatter(logging.Formatter(format_str, datefmt=DATE_FORMAT))
|
|
57
|
+
self._logger.addHandler(handler)
|
|
58
|
+
self._handlers["console"] = handler
|
|
59
|
+
|
|
60
|
+
def disable_console(self) -> None:
|
|
61
|
+
r"""禁用控制台日志."""
|
|
62
|
+
if "console" in self._handlers:
|
|
63
|
+
self._logger.removeHandler(self._handlers["console"])
|
|
64
|
+
del self._handlers["console"]
|
|
65
|
+
|
|
66
|
+
def enable_file(
|
|
67
|
+
self, filename: str, format_str: str = DEFAULT_FORMAT, mode: str = "a"
|
|
68
|
+
) -> None:
|
|
69
|
+
r"""
|
|
70
|
+
启用文件日志.
|
|
71
|
+
|
|
72
|
+
:param filename: 日志文件路径
|
|
73
|
+
:type filename: str
|
|
74
|
+
:param format_str: 日志格式字符串
|
|
75
|
+
:type format_str: str
|
|
76
|
+
:param mode: 文件打开模式 ('a' 追加 或 'w' 覆写)
|
|
77
|
+
:type mode: str
|
|
78
|
+
"""
|
|
79
|
+
# Remove existing file handler if path matches (simple check)
|
|
80
|
+
key = f"file_{filename}"
|
|
81
|
+
if key in self._handlers:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
handler = logging.FileHandler(filename, mode=mode, encoding="utf-8")
|
|
85
|
+
handler.setFormatter(logging.Formatter(format_str, datefmt=DATE_FORMAT))
|
|
86
|
+
self._logger.addHandler(handler)
|
|
87
|
+
self._handlers[key] = handler
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Global helper functions
|
|
91
|
+
def get_logger() -> logging.Logger:
|
|
92
|
+
r"""
|
|
93
|
+
获取全局 logger 实例.
|
|
94
|
+
|
|
95
|
+
:return: 已初始化的 logger
|
|
96
|
+
:rtype: logging.Logger
|
|
97
|
+
"""
|
|
98
|
+
return Logger.get_logger()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def set_log_level(level: Union[str, int]) -> None:
|
|
102
|
+
r"""
|
|
103
|
+
设置全局日志等级.
|
|
104
|
+
|
|
105
|
+
:param level: 日志等级字符串或整数
|
|
106
|
+
:type level: str | int
|
|
107
|
+
"""
|
|
108
|
+
Logger.get_logger().setLevel(level)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def register_logger(
|
|
112
|
+
filename: Optional[str] = None, console: bool = True, level: str = "INFO"
|
|
113
|
+
) -> None:
|
|
114
|
+
r"""
|
|
115
|
+
日志一体化配置.
|
|
116
|
+
|
|
117
|
+
:param filename: 日志文件路径,提供则写入文件
|
|
118
|
+
:type filename: str, optional
|
|
119
|
+
:param console: 是否输出到控制台
|
|
120
|
+
:type console: bool
|
|
121
|
+
:param level: 日志等级 ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
|
|
122
|
+
:type level: str
|
|
123
|
+
"""
|
|
124
|
+
logger_manager = Logger._instance or Logger()
|
|
125
|
+
Logger._instance = logger_manager
|
|
126
|
+
|
|
127
|
+
logger_manager.set_level(level.upper())
|
|
128
|
+
|
|
129
|
+
if console:
|
|
130
|
+
logger_manager.enable_console()
|
|
131
|
+
else:
|
|
132
|
+
logger_manager.disable_console()
|
|
133
|
+
|
|
134
|
+
if filename:
|
|
135
|
+
logger_manager.enable_file(filename)
|
akquant/ml/__init__.py
ADDED
akquant/ml/model.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
# Define unified data input type
|
|
9
|
+
DataType = Union[np.ndarray, pd.DataFrame]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ValidationConfig:
|
|
14
|
+
"""Configuration for model validation."""
|
|
15
|
+
|
|
16
|
+
method: Literal["walk_forward"] = "walk_forward"
|
|
17
|
+
train_window: Union[str, int] = "1y"
|
|
18
|
+
test_window: Union[str, int] = (
|
|
19
|
+
"3m" # Not strictly used in rolling execution, but useful for evaluation
|
|
20
|
+
)
|
|
21
|
+
rolling_step: Union[str, int] = "3m"
|
|
22
|
+
frequency: str = "1d"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class QuantModel(ABC):
|
|
26
|
+
"""
|
|
27
|
+
Abstract base class for all quantitative models.
|
|
28
|
+
|
|
29
|
+
The strategy layer only interacts with this class, not directly with sklearn or
|
|
30
|
+
torch.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self) -> None:
|
|
34
|
+
"""Initialize the model."""
|
|
35
|
+
self.validation_config: Optional[ValidationConfig] = None
|
|
36
|
+
|
|
37
|
+
def set_validation(
|
|
38
|
+
self,
|
|
39
|
+
method: Literal["walk_forward"] = "walk_forward",
|
|
40
|
+
train_window: Union[str, int] = "1y",
|
|
41
|
+
test_window: Union[str, int] = "3m",
|
|
42
|
+
rolling_step: Union[str, int] = "3m",
|
|
43
|
+
frequency: str = "1d",
|
|
44
|
+
) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Configure validation method (e.g., Walk-forward).
|
|
47
|
+
|
|
48
|
+
:param method: Validation method (currently only 'walk_forward').
|
|
49
|
+
:param train_window: Training data duration (e.g., '1y', '250d') or bar count.
|
|
50
|
+
:param test_window: Testing/Prediction duration (e.g., '3m') or bar count.
|
|
51
|
+
:param rolling_step: How often to retrain (e.g., '3m') or bar count.
|
|
52
|
+
:param frequency: Data frequency ('1d', '1h', '1m') used for parsing time
|
|
53
|
+
strings.
|
|
54
|
+
"""
|
|
55
|
+
self.validation_config = ValidationConfig(
|
|
56
|
+
method=method,
|
|
57
|
+
train_window=train_window,
|
|
58
|
+
test_window=test_window,
|
|
59
|
+
rolling_step=rolling_step,
|
|
60
|
+
frequency=frequency,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def fit(self, X: DataType, y: DataType) -> None:
|
|
65
|
+
"""
|
|
66
|
+
Train the model.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
X: Training features
|
|
70
|
+
y: Training labels
|
|
71
|
+
"""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def predict(self, X: DataType) -> np.ndarray:
|
|
76
|
+
"""
|
|
77
|
+
Predict using the model.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
X: Input features
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
np.ndarray: Prediction results (numpy array)
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def save(self, path: str) -> None:
|
|
89
|
+
"""Save the model to the specified path."""
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def load(self, path: str) -> None:
|
|
94
|
+
"""Load the model from the specified path."""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class SklearnAdapter(QuantModel):
|
|
99
|
+
"""Adapter for Scikit-learn style models."""
|
|
100
|
+
|
|
101
|
+
def __init__(self, estimator: Any):
|
|
102
|
+
"""
|
|
103
|
+
Initialize the adapter.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
estimator: A sklearn-style estimator instance (e.g., XGBClassifier,
|
|
107
|
+
LGBMRegressor)
|
|
108
|
+
"""
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.model = estimator
|
|
111
|
+
|
|
112
|
+
def fit(self, X: DataType, y: DataType) -> None:
|
|
113
|
+
"""Train the sklearn model."""
|
|
114
|
+
# Convert DataFrame to numpy if necessary, or let sklearn handle it
|
|
115
|
+
self.model.fit(X, y)
|
|
116
|
+
|
|
117
|
+
def predict(self, X: DataType) -> np.ndarray:
|
|
118
|
+
"""Predict using the sklearn model."""
|
|
119
|
+
# For classification, we usually care about the probability of class 1
|
|
120
|
+
if hasattr(self.model, "predict_proba"):
|
|
121
|
+
# Return probability of class 1
|
|
122
|
+
# Note: This assumes binary classification. For multi-class, this might
|
|
123
|
+
# need adjustment.
|
|
124
|
+
proba = self.model.predict_proba(X)
|
|
125
|
+
if proba.shape[1] > 1:
|
|
126
|
+
return proba[:, 1] # type: ignore
|
|
127
|
+
return proba # type: ignore
|
|
128
|
+
else:
|
|
129
|
+
return self.model.predict(X) # type: ignore
|
|
130
|
+
|
|
131
|
+
def save(self, path: str) -> None:
|
|
132
|
+
"""Save the sklearn model using joblib."""
|
|
133
|
+
import joblib # type: ignore
|
|
134
|
+
|
|
135
|
+
joblib.dump(self.model, path)
|
|
136
|
+
|
|
137
|
+
def load(self, path: str) -> None:
|
|
138
|
+
"""Load the sklearn model using joblib."""
|
|
139
|
+
import joblib # type: ignore
|
|
140
|
+
|
|
141
|
+
self.model = joblib.load(path)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class PyTorchAdapter(QuantModel):
|
|
145
|
+
"""Adapter for PyTorch models."""
|
|
146
|
+
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
network: Any,
|
|
150
|
+
criterion: Any,
|
|
151
|
+
optimizer_cls: Any,
|
|
152
|
+
lr: float = 0.001,
|
|
153
|
+
epochs: int = 10,
|
|
154
|
+
batch_size: int = 64,
|
|
155
|
+
device: str = "cpu",
|
|
156
|
+
):
|
|
157
|
+
"""
|
|
158
|
+
Initialize the PyTorch adapter.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
network: PyTorch neural network module (nn.Module)
|
|
162
|
+
criterion: Loss function (nn.Module)
|
|
163
|
+
optimizer_cls: Optimizer class (torch.optim.Optimizer)
|
|
164
|
+
lr: Learning rate
|
|
165
|
+
epochs: Number of training epochs
|
|
166
|
+
batch_size: Batch size
|
|
167
|
+
device: Device to run on ('cpu' or 'cuda')
|
|
168
|
+
"""
|
|
169
|
+
super().__init__()
|
|
170
|
+
import torch
|
|
171
|
+
|
|
172
|
+
self.device = torch.device(device)
|
|
173
|
+
self.network = network.to(self.device)
|
|
174
|
+
self.criterion = criterion
|
|
175
|
+
self.optimizer = optimizer_cls(self.network.parameters(), lr=lr)
|
|
176
|
+
self.epochs = epochs
|
|
177
|
+
self.batch_size = batch_size
|
|
178
|
+
|
|
179
|
+
def fit(self, X: DataType, y: DataType) -> None:
|
|
180
|
+
"""Train the PyTorch model."""
|
|
181
|
+
import torch
|
|
182
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
183
|
+
|
|
184
|
+
# 1. Data conversion: Numpy/Pandas -> Tensor
|
|
185
|
+
X_array = X.values if isinstance(X, pd.DataFrame) else X
|
|
186
|
+
y_array = (
|
|
187
|
+
y.values if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series) else y
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
X_tensor = torch.tensor(X_array, dtype=torch.float32).to(self.device)
|
|
191
|
+
y_tensor = torch.tensor(y_array, dtype=torch.float32).to(self.device)
|
|
192
|
+
|
|
193
|
+
# 2. Wrap in DataLoader
|
|
194
|
+
dataset = TensorDataset(X_tensor, y_tensor)
|
|
195
|
+
loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
|
196
|
+
|
|
197
|
+
# 3. Standard training loop
|
|
198
|
+
self.network.train()
|
|
199
|
+
for epoch in range(self.epochs):
|
|
200
|
+
for batch_X, batch_y in loader:
|
|
201
|
+
self.optimizer.zero_grad()
|
|
202
|
+
outputs = self.network(batch_X)
|
|
203
|
+
|
|
204
|
+
# Note: Adjust loss calculation dimensions based on task
|
|
205
|
+
# (regression/classification)
|
|
206
|
+
# Squeeze outputs to match batch_y shape if necessary
|
|
207
|
+
loss = self.criterion(outputs.squeeze(), batch_y)
|
|
208
|
+
|
|
209
|
+
loss.backward()
|
|
210
|
+
self.optimizer.step()
|
|
211
|
+
|
|
212
|
+
def predict(self, X: DataType) -> np.ndarray:
|
|
213
|
+
"""Predict using the PyTorch model."""
|
|
214
|
+
import torch
|
|
215
|
+
|
|
216
|
+
self.network.eval()
|
|
217
|
+
with torch.no_grad():
|
|
218
|
+
X_array = X.values if isinstance(X, pd.DataFrame) else X
|
|
219
|
+
X_tensor = torch.tensor(X_array, dtype=torch.float32).to(self.device)
|
|
220
|
+
outputs = self.network(X_tensor)
|
|
221
|
+
# Convert back to Numpy for strategy layer
|
|
222
|
+
return outputs.cpu().numpy().flatten() # type: ignore
|
|
223
|
+
|
|
224
|
+
def save(self, path: str) -> None:
|
|
225
|
+
"""Save the PyTorch model state dict."""
|
|
226
|
+
import torch
|
|
227
|
+
|
|
228
|
+
torch.save(self.network.state_dict(), path)
|
|
229
|
+
|
|
230
|
+
def load(self, path: str) -> None:
|
|
231
|
+
"""Load the PyTorch model state dict."""
|
|
232
|
+
import torch
|
|
233
|
+
|
|
234
|
+
self.network.load_state_dict(torch.load(path))
|
akquant/py.typed
ADDED
|
File without changes
|
akquant/risk.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
|
|
3
|
+
from .config import RiskConfig as PyRiskConfig
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from .akquant import Engine
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def apply_risk_config(engine: "Engine", config: Optional[PyRiskConfig]) -> None:
|
|
10
|
+
"""
|
|
11
|
+
Apply Python-side RiskConfig to the Rust Engine's RiskManager.
|
|
12
|
+
|
|
13
|
+
:param engine: The backtest engine instance.
|
|
14
|
+
:param config: The Python RiskConfig object.
|
|
15
|
+
"""
|
|
16
|
+
if config is None:
|
|
17
|
+
return
|
|
18
|
+
|
|
19
|
+
# Get the Rust RiskConfig object from the engine's risk manager
|
|
20
|
+
# Assuming engine.risk_manager.config is accessible and mutable
|
|
21
|
+
# Or we can create a new one and assign it
|
|
22
|
+
|
|
23
|
+
rust_config = engine.risk_manager.config
|
|
24
|
+
|
|
25
|
+
if config.max_order_size is not None:
|
|
26
|
+
rust_config.max_order_size = config.max_order_size
|
|
27
|
+
|
|
28
|
+
if config.max_order_value is not None:
|
|
29
|
+
rust_config.max_order_value = config.max_order_value
|
|
30
|
+
|
|
31
|
+
if config.max_position_size is not None:
|
|
32
|
+
rust_config.max_position_size = config.max_position_size
|
|
33
|
+
|
|
34
|
+
if config.restricted_list is not None:
|
|
35
|
+
rust_config.restricted_list = config.restricted_list
|
|
36
|
+
|
|
37
|
+
rust_config.active = config.active
|
|
38
|
+
|
|
39
|
+
# Re-assign to ensure it updates (in case it was a copy)
|
|
40
|
+
engine.risk_manager.config = rust_config
|