PyAlgoEngine 0.8.0a16__tar.gz → 0.8.0.post2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/PKG-INFO +1 -1
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/PyAlgoEngine.egg-info/PKG-INFO +1 -1
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/PyAlgoEngine.egg-info/SOURCES.txt +1 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/__init__.py +1 -1
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/backtest/tester.py +11 -9
- pyalgoengine-0.8.0.post2/algo_engine/backtest/__init__.py +19 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/backtest/replay.py +249 -66
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/__init__.py +5 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/market_data.pyi +3 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/market_data_buffer.pyi +12 -2
- pyalgoengine-0.8.0.post2/algo_engine/base/trade_utils_native.py +693 -0
- pyalgoengine-0.8.0.post2/algo_engine/profile/__init__.py +236 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/strategy/strategy_engine.py +6 -5
- pyalgoengine-0.8.0a16/algo_engine/backtest/__init__.py +0 -19
- pyalgoengine-0.8.0a16/algo_engine/profile/__init__.py +0 -121
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/LICENSE +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/PyAlgoEngine.egg-info/dependency_links.txt +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/PyAlgoEngine.egg-info/requires.txt +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/PyAlgoEngine.egg-info/top_level.txt +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/README.md +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/backtest/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/backtest/doc_server.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/backtest/web_app.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/bokeh_server.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/demo/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/demo/test.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/sim_input/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/sim_input/client.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/sim_input/sim_keyboard.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/sim_input/sim_mouse.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/apps/sim_input/window.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/backtest/__main__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/backtest/metrics.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/backtest/sim_match.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/candlestick.pyi +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/console_utils.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/finance_decimal.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/market_utils_nt.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/market_utils_posix.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/technical_analysis.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/telemetrics.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/tick.pyi +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/trade_utils.pyi +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/base/transaction.pyi +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/engine/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/engine/algo_engine.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/engine/event_engine.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/engine/market_engine.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/engine/trade_engine.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/monitor/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/monitor/advanced_data_interface.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/profile/cn.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/strategy/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/utils/__init__.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/utils/commit_regularizer.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/algo_engine/utils/data_utils.py +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/setup.cfg +0 -0
- {pyalgoengine-0.8.0a16 → pyalgoengine-0.8.0.post2}/setup.py +0 -0
|
@@ -37,6 +37,7 @@ algo_engine/base/technical_analysis.py
|
|
|
37
37
|
algo_engine/base/telemetrics.py
|
|
38
38
|
algo_engine/base/tick.pyi
|
|
39
39
|
algo_engine/base/trade_utils.pyi
|
|
40
|
+
algo_engine/base/trade_utils_native.py
|
|
40
41
|
algo_engine/base/transaction.pyi
|
|
41
42
|
algo_engine/engine/__init__.py
|
|
42
43
|
algo_engine/engine/algo_engine.py
|
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
|
8
8
|
from algo_engine.backtest.metrics import TradeMetrics
|
|
9
9
|
from . import LOGGER
|
|
10
10
|
from .web_app import WebApp
|
|
11
|
-
from ...backtest import SimMatch,
|
|
11
|
+
from ...backtest import SimMatch, ProgressReplay
|
|
12
12
|
from ...base import MarketData, TradeReport, TradeInstruction
|
|
13
13
|
from ...profile import Profile, PROFILE
|
|
14
14
|
|
|
@@ -109,17 +109,18 @@ class Tester(object, metaclass=abc.ABCMeta):
|
|
|
109
109
|
pass
|
|
110
110
|
|
|
111
111
|
def run(self, **kwargs):
|
|
112
|
-
replay =
|
|
112
|
+
replay = ProgressReplay(
|
|
113
113
|
loader=self.load_data,
|
|
114
|
-
tickers=list(self.subscription),
|
|
115
|
-
dtype=['TickData', 'TradeData'],
|
|
116
114
|
start_date=self.start_date,
|
|
117
115
|
end_date=self.end_date,
|
|
118
116
|
bod=self.bod,
|
|
119
117
|
eod=self.eod,
|
|
120
|
-
tick_size=kwargs.get('progress_tick_size', 0.001),
|
|
121
118
|
)
|
|
122
119
|
|
|
120
|
+
for ticker in self.subscription:
|
|
121
|
+
replay.add_subscription(ticker, dtype='TickData')
|
|
122
|
+
replay.add_subscription(ticker, dtype='TradeData')
|
|
123
|
+
|
|
123
124
|
_start_ts = time.time()
|
|
124
125
|
|
|
125
126
|
for market_data in replay:
|
|
@@ -222,17 +223,18 @@ class StrategyTester(Tester):
|
|
|
222
223
|
if not self.event_engine.active:
|
|
223
224
|
self.event_engine.start()
|
|
224
225
|
|
|
225
|
-
replay =
|
|
226
|
+
replay = ProgressReplay(
|
|
226
227
|
loader=self.load_data,
|
|
227
|
-
tickers=list(self.subscription),
|
|
228
|
-
dtype=['TickData', 'TradeData'],
|
|
229
228
|
start_date=self.start_date,
|
|
230
229
|
end_date=self.end_date,
|
|
231
230
|
bod=self.bod,
|
|
232
231
|
eod=self.eod,
|
|
233
|
-
tick_size=kwargs.get('progress_tick_size', 0.001),
|
|
234
232
|
)
|
|
235
233
|
|
|
234
|
+
for ticker in self.subscription:
|
|
235
|
+
replay.add_subscription(ticker, dtype='TickData')
|
|
236
|
+
replay.add_subscription(ticker, dtype='TradeData')
|
|
237
|
+
|
|
236
238
|
_start_ts = time.time()
|
|
237
239
|
|
|
238
240
|
for market_data in replay:
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from .. import LOGGER
|
|
4
|
+
|
|
5
|
+
LOGGER = LOGGER.getChild('BackTest')
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def set_logger(logger: logging.Logger):
|
|
9
|
+
global LOGGER
|
|
10
|
+
LOGGER = logger
|
|
11
|
+
|
|
12
|
+
replay.LOGGER = LOGGER.getChild('Replay')
|
|
13
|
+
sim_match.LOGGER = LOGGER.getChild('SimMatch')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from .replay import PyDataScope, MarketDateCallable, MarketDataLoader, MarketDataBulkLoader, Replay, SimpleReplay, ProgressReplay
|
|
17
|
+
from .sim_match import SimMatch
|
|
18
|
+
|
|
19
|
+
__all__ = ['PyDataScope', 'MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'SimMatch']
|
|
@@ -1,16 +1,97 @@
|
|
|
1
1
|
import abc
|
|
2
2
|
import datetime
|
|
3
|
+
import enum
|
|
3
4
|
import inspect
|
|
5
|
+
import logging
|
|
4
6
|
import operator
|
|
5
7
|
import warnings
|
|
6
8
|
from collections.abc import Sequence, Mapping, Iterable, Callable
|
|
7
|
-
from typing import Literal, Protocol, runtime_checkable, get_type_hints
|
|
9
|
+
from typing import Literal, Protocol, runtime_checkable, get_type_hints, Self
|
|
8
10
|
|
|
9
11
|
from . import LOGGER
|
|
10
12
|
from ..base import MarketData, DataType, MarketDataBuffer
|
|
11
13
|
|
|
12
14
|
LOGGER = LOGGER.getChild('Replay')
|
|
13
|
-
__all__ = ['MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'ProgressiveReplay']
|
|
15
|
+
__all__ = ['PyDataScope', 'MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'ProgressiveReplay']
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PyDataScope(enum.Flag):
|
|
19
|
+
SCOPE_TRANSACTION = enum.auto()
|
|
20
|
+
SCOPE_ORDER = enum.auto()
|
|
21
|
+
SCOPE_TICK = enum.auto()
|
|
22
|
+
SCOPE_TICK_LITE = enum.auto()
|
|
23
|
+
|
|
24
|
+
SCOPE_ALL = SCOPE_TRANSACTION | SCOPE_ORDER | SCOPE_TICK
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def _missing_(cls, value: Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']):
|
|
28
|
+
if isinstance(value, int):
|
|
29
|
+
return super()._missing_(value)
|
|
30
|
+
|
|
31
|
+
if isinstance(value, str):
|
|
32
|
+
dtypes = value.split(',')
|
|
33
|
+
elif isinstance(value, Iterable):
|
|
34
|
+
dtypes = value
|
|
35
|
+
else:
|
|
36
|
+
raise TypeError(value)
|
|
37
|
+
|
|
38
|
+
_ = PyDataScope(0)
|
|
39
|
+
for dtype in dtypes:
|
|
40
|
+
_ = _.from_str(dtype)
|
|
41
|
+
return _
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def get_dtype(cls, dtype: DataType | str) -> str | Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']:
|
|
45
|
+
match dtype:
|
|
46
|
+
case 'TickData' | 'TickDataLite' | 'OrderData' | 'TransactionData':
|
|
47
|
+
return str(dtype)
|
|
48
|
+
case 'TradeData': # handle the alias
|
|
49
|
+
return 'TransactionData'
|
|
50
|
+
case DataType.DTYPE_TICK | DataType.DTYPE_ORDER | DataType.DTYPE_TRANSACTION:
|
|
51
|
+
return DataType(dtype).name.removeprefix('DTYPE_').capitalize() + 'Data'
|
|
52
|
+
case DataType.DTYPE_TICK_LITE:
|
|
53
|
+
return 'Data'.join(_.capitalize() for _ in DataType(dtype).name.removeprefix('DTYPE_').split('_'))
|
|
54
|
+
case _:
|
|
55
|
+
raise ValueError(f'Invalid dtype {dtype}, expect str or int.')
|
|
56
|
+
|
|
57
|
+
def __iter__(self):
|
|
58
|
+
return iter(self.to_dtype())
|
|
59
|
+
|
|
60
|
+
def to_dtype(self) -> list[DataType]:
|
|
61
|
+
scope = list(super().__iter__())
|
|
62
|
+
scope_dtype = set()
|
|
63
|
+
|
|
64
|
+
for dtype in scope:
|
|
65
|
+
|
|
66
|
+
if dtype is PyDataScope.SCOPE_TRANSACTION:
|
|
67
|
+
scope_dtype.add(DataType.DTYPE_TRANSACTION)
|
|
68
|
+
elif dtype is PyDataScope.SCOPE_ORDER:
|
|
69
|
+
scope_dtype.add(DataType.DTYPE_ORDER)
|
|
70
|
+
elif dtype is PyDataScope.SCOPE_TICK_LITE:
|
|
71
|
+
scope_dtype.add(DataType.DTYPE_TICK_LITE)
|
|
72
|
+
elif dtype is PyDataScope.SCOPE_TICK:
|
|
73
|
+
scope_dtype.add(DataType.DTYPE_TICK)
|
|
74
|
+
|
|
75
|
+
return list(scope_dtype)
|
|
76
|
+
|
|
77
|
+
def to_int(self) -> list[int]:
|
|
78
|
+
return [int(_) for _ in self.to_dtype()]
|
|
79
|
+
|
|
80
|
+
def to_str(self) -> list[str]:
|
|
81
|
+
return [self.get_dtype(_) for _ in self.to_dtype()]
|
|
82
|
+
|
|
83
|
+
def from_str(self, dtype: Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']) -> Self:
|
|
84
|
+
match dtype:
|
|
85
|
+
case 'TickData':
|
|
86
|
+
return self | self.SCOPE_TICK
|
|
87
|
+
case 'TickDataLite':
|
|
88
|
+
return self | self.SCOPE_TICK_LITE
|
|
89
|
+
case 'OrderData':
|
|
90
|
+
return self | self.SCOPE_ORDER
|
|
91
|
+
case 'TransactionData' | 'TradeData':
|
|
92
|
+
return self | self.SCOPE_TRANSACTION
|
|
93
|
+
case _:
|
|
94
|
+
raise ValueError(f'Invalid str {dtype}.')
|
|
14
95
|
|
|
15
96
|
|
|
16
97
|
@runtime_checkable
|
|
@@ -27,7 +108,7 @@ class MarketDataLoader(Protocol):
|
|
|
27
108
|
|
|
28
109
|
@runtime_checkable
|
|
29
110
|
class MarketDataBulkLoader(Protocol):
|
|
30
|
-
def __call__(self, market_date: datetime.date, tickers: Sequence[str], dtypes: Sequence[str | DataType]) -> Sequence[MarketData] | Mapping[float, MarketData] | MarketDataBuffer:
|
|
111
|
+
def __call__(self, market_date: datetime.date, tickers: Sequence[str], dtypes: Sequence[str | DataType] | PyDataScope) -> Sequence[MarketData] | Mapping[float, MarketData] | MarketDataBuffer:
|
|
31
112
|
pass
|
|
32
113
|
|
|
33
114
|
|
|
@@ -40,13 +121,16 @@ def check_protocol_signature(func: Callable, protocol: type) -> bool:
|
|
|
40
121
|
|
|
41
122
|
proto_params = list(proto_sig.parameters.values())[1:] # Skip 'self'
|
|
42
123
|
func_params = list(func_sig.parameters.values())
|
|
124
|
+
enable_keywords = False
|
|
43
125
|
|
|
44
126
|
# Check for *args (VAR_POSITIONAL) — not allowed
|
|
45
127
|
for p in func_params:
|
|
46
128
|
if p.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
47
129
|
raise TypeError(f"{func.__name__} uses *args, which is not allowed")
|
|
130
|
+
elif p.kind == inspect.Parameter.VAR_KEYWORD:
|
|
131
|
+
enable_keywords = True
|
|
48
132
|
|
|
49
|
-
#
|
|
133
|
+
# Extract positional args (POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD)
|
|
50
134
|
proto_arg_names = [p.name for p in proto_params if p.kind in (
|
|
51
135
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
52
136
|
inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
@@ -57,8 +141,13 @@ def check_protocol_signature(func: Callable, protocol: type) -> bool:
|
|
|
57
141
|
inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
58
142
|
)]
|
|
59
143
|
|
|
60
|
-
if
|
|
61
|
-
|
|
144
|
+
# Check if required positional args match (ignore **kwargs)
|
|
145
|
+
if not enable_keywords and sorted(proto_arg_names) != sorted(func_arg_names):
|
|
146
|
+
warnings.warn(
|
|
147
|
+
f"{func} argument names {func_arg_names} do not match protocol {proto_arg_names}",
|
|
148
|
+
stacklevel=2
|
|
149
|
+
)
|
|
150
|
+
return False
|
|
62
151
|
|
|
63
152
|
# Type hint comparison (warn if mismatched, but allow)
|
|
64
153
|
proto_hints = get_type_hints(protocol.__call__)
|
|
@@ -104,32 +193,26 @@ class Replay(object, metaclass=abc.ABCMeta):
|
|
|
104
193
|
if eod is not None:
|
|
105
194
|
self.add_eod(eod)
|
|
106
195
|
|
|
107
|
-
def add_bod(self, func: MarketDateCallable):
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
196
|
+
def add_bod(self, func: MarketDateCallable, priority: int = None) -> None:
|
|
197
|
+
if priority is None:
|
|
198
|
+
self.bod.append(func)
|
|
199
|
+
else:
|
|
200
|
+
self.bod.insert(priority, func)
|
|
112
201
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
case DataType.DTYPE_TICK | DataType.DTYPE_ORDER | DataType.DTYPE_TRANSACTION:
|
|
119
|
-
return DataType(dtype).name.removeprefix('DTYPE_').capitalize() + 'Data'
|
|
120
|
-
case DataType.DTYPE_TICK_LITE:
|
|
121
|
-
return 'Data'.join(_.capitalize() for _ in DataType(dtype).name.removeprefix('DTYPE_').split('_'))
|
|
122
|
-
case _:
|
|
123
|
-
raise ValueError(f'Invalid dtype {dtype}, expect str or int.')
|
|
202
|
+
def add_eod(self, func: MarketDateCallable, priority: int = None):
|
|
203
|
+
if priority is None:
|
|
204
|
+
self.eod.append(func)
|
|
205
|
+
else:
|
|
206
|
+
self.eod.insert(priority, func)
|
|
124
207
|
|
|
125
208
|
def add_subscription(self, ticker: str, dtype: DataType | str):
|
|
126
|
-
dtype =
|
|
209
|
+
dtype = PyDataScope.get_dtype(dtype)
|
|
127
210
|
topic = f'{ticker}.{dtype}'
|
|
128
211
|
|
|
129
212
|
self.subscription[topic] = (ticker, dtype)
|
|
130
213
|
|
|
131
214
|
def remove_subscription(self, ticker: str, dtype: DataType | str):
|
|
132
|
-
dtype =
|
|
215
|
+
dtype = PyDataScope.get_dtype(dtype)
|
|
133
216
|
topic = f'{ticker}.{dtype}'
|
|
134
217
|
|
|
135
218
|
try:
|
|
@@ -162,7 +245,7 @@ class SimpleReplay(Replay):
|
|
|
162
245
|
|
|
163
246
|
def __iter__(self):
|
|
164
247
|
self._calendar = self.calendar or [self.start_date + datetime.timedelta(days=i) for i in range((self.end_date - self.start_date).days + 1)]
|
|
165
|
-
self._market_date =
|
|
248
|
+
self._market_date = sorted(_ for _ in self._calendar if _ >= self.market_date)[0]
|
|
166
249
|
self._status = {market_date: 'skipped' if market_date < self.market_date else 'idle' for market_date in self._calendar}
|
|
167
250
|
self._idx_buffer = 0
|
|
168
251
|
self._idx_date = sum([1 for _ in self._calendar if _ < self.market_date])
|
|
@@ -208,7 +291,9 @@ class SimpleReplay(Replay):
|
|
|
208
291
|
return f'{self.__class__.__name__}{{id={id(self)}, from={self.start_date}, to={self.end_date}}}'
|
|
209
292
|
|
|
210
293
|
def _bulk_load_protocol(self):
|
|
294
|
+
LOGGER.info(f'{self} loading {self._market_date} {(', '.join(self.dtypes)) if self.dtypes else 'data'} for {len(self.tickers)} tickers...')
|
|
211
295
|
buffer = self.loader(market_date=self._market_date, tickers=self.tickers, dtypes=self.dtypes)
|
|
296
|
+
LOGGER.info(f'{self} sorting {self._market_date} data...')
|
|
212
297
|
buffer.sort()
|
|
213
298
|
|
|
214
299
|
if isinstance(buffer, MarketDataBuffer):
|
|
@@ -220,11 +305,12 @@ class SimpleReplay(Replay):
|
|
|
220
305
|
elif isinstance(buffer, Mapping):
|
|
221
306
|
self._buffer = iter(buffer.values())
|
|
222
307
|
self._buffer_size = len(buffer)
|
|
308
|
+
LOGGER.info(f'{self} {self._market_date} total {self._buffer_size:,} items loaded.')
|
|
223
309
|
|
|
224
310
|
def _individual_load_protocol(self):
|
|
225
311
|
buffer = []
|
|
226
312
|
for topic, (_ticker, _dtype) in self.subscription.items():
|
|
227
|
-
LOGGER.info(f'{self} loading {self._market_date} {_ticker} {_dtype}')
|
|
313
|
+
LOGGER.info(f'{self} loading {self._market_date} {_ticker} {_dtype}...')
|
|
228
314
|
data = self.loader(market_date=self._market_date, ticker=_ticker, dtype=_dtype)
|
|
229
315
|
if isinstance(data, Mapping):
|
|
230
316
|
buffer.extend(list(data.values()))
|
|
@@ -232,14 +318,16 @@ class SimpleReplay(Replay):
|
|
|
232
318
|
buffer.extend(data)
|
|
233
319
|
else:
|
|
234
320
|
raise TypeError(f'The loader {self.loader} returned {type(data)}. Expect a sequence or mapping of MarketData')
|
|
321
|
+
LOGGER.info(f'{self} sorting {self._market_date} data...')
|
|
235
322
|
buffer.sort(key=operator.attrgetter('timestamp', 'ticker', '_dtype'))
|
|
236
323
|
self._buffer = iter(buffer)
|
|
237
324
|
self._buffer_size = len(buffer)
|
|
325
|
+
LOGGER.info(f'{self} {self._market_date} total {self._buffer_size:,} items loaded.')
|
|
238
326
|
|
|
239
327
|
def _safe_load(self):
|
|
240
328
|
if self.loader is None:
|
|
241
329
|
assert hasattr(self, '_buffer') and isinstance(self._buffer, Iterable), f'Without assigning a data loader, the _buffer of {self.__class__.__name__} should be set in bod process.'
|
|
242
|
-
return
|
|
330
|
+
return None
|
|
243
331
|
|
|
244
332
|
is_bulk_loader = check_protocol_signature(self.loader, MarketDataBulkLoader)
|
|
245
333
|
is_individual_loader = check_protocol_signature(self.loader, MarketDataLoader)
|
|
@@ -266,7 +354,7 @@ class SimpleReplay(Replay):
|
|
|
266
354
|
if not hasattr(self, '_buffer'):
|
|
267
355
|
raise RuntimeError(f'{self.__class__.__name__} not started yet.')
|
|
268
356
|
|
|
269
|
-
return (self._idx_date +
|
|
357
|
+
return (self._idx_date + self._idx_buffer / self._buffer_size) / len(self._calendar)
|
|
270
358
|
|
|
271
359
|
@property
|
|
272
360
|
def tickers(self) -> list[str]:
|
|
@@ -300,7 +388,7 @@ class ProgressReplay(SimpleReplay):
|
|
|
300
388
|
calendar: Sequence[datetime.date] = None,
|
|
301
389
|
bod: MarketDateCallable = None,
|
|
302
390
|
eod: MarketDateCallable = None,
|
|
303
|
-
**
|
|
391
|
+
**pbar_config
|
|
304
392
|
):
|
|
305
393
|
super().__init__(
|
|
306
394
|
loader=loader,
|
|
@@ -312,42 +400,134 @@ class ProgressReplay(SimpleReplay):
|
|
|
312
400
|
eod=eod
|
|
313
401
|
)
|
|
314
402
|
|
|
315
|
-
self.
|
|
316
|
-
'
|
|
317
|
-
'
|
|
318
|
-
'unit': 'percent',
|
|
319
|
-
'mininterval': 0.1,
|
|
320
|
-
'miniters': 0.001,
|
|
321
|
-
**tqdm_kwargs
|
|
403
|
+
self.pbar_config = {
|
|
404
|
+
'backend': pbar_config.pop('backend', 'tqdm'), # tqdm or native
|
|
405
|
+
'config': pbar_config,
|
|
322
406
|
}
|
|
323
|
-
self.
|
|
407
|
+
self._pbar = None
|
|
408
|
+
self.add_bod(self._update_pbar_prefix, priority=0)
|
|
409
|
+
|
|
410
|
+
def _init_pbar(self):
|
|
411
|
+
pbar_backend = self.pbar_config['backend']
|
|
412
|
+
match pbar_backend:
|
|
413
|
+
case 'tqdm':
|
|
414
|
+
from tqdm.auto import tqdm
|
|
415
|
+
from tqdm.std import tqdm as tqdm_std
|
|
416
|
+
from tqdm.contrib.logging import _TqdmLoggingHandler, _get_first_found_console_logging_handler, _is_console_logging_handler
|
|
417
|
+
|
|
418
|
+
tqdm_secondary_config = {
|
|
419
|
+
'total': 1,
|
|
420
|
+
'unit_scale': True,
|
|
421
|
+
'unit': 'percent',
|
|
422
|
+
'mininterval': 0.1,
|
|
423
|
+
'miniters': 0.001,
|
|
424
|
+
**self.pbar_config['config'],
|
|
425
|
+
}
|
|
426
|
+
self._pbar_secondary = tqdm(**tqdm_secondary_config)
|
|
427
|
+
|
|
428
|
+
tqdm_config = {
|
|
429
|
+
'total': 1,
|
|
430
|
+
'unit_scale': True,
|
|
431
|
+
'unit': 'percent',
|
|
432
|
+
'mininterval': 0.1,
|
|
433
|
+
'miniters': 0.001,
|
|
434
|
+
**self.pbar_config['config'],
|
|
435
|
+
}
|
|
436
|
+
self._pbar = tqdm(**tqdm_config)
|
|
437
|
+
|
|
438
|
+
self._update_pbar_progress = self._update_tqdm_progress
|
|
439
|
+
self.pbar_config['loggers'] = loggers = [LOGGER.root] + [_ for _ in LOGGER.root.manager.loggerDict.values() if isinstance(_, logging.Logger) and _.handlers]
|
|
440
|
+
self.pbar_config['original_handlers_list'] = [logger.handlers for logger in loggers]
|
|
441
|
+
for logger in loggers:
|
|
442
|
+
tqdm_handler = _TqdmLoggingHandler(tqdm_std)
|
|
443
|
+
orig_handler = _get_first_found_console_logging_handler(logger.handlers)
|
|
444
|
+
if orig_handler is not None:
|
|
445
|
+
tqdm_handler.setFormatter(orig_handler.formatter)
|
|
446
|
+
tqdm_handler.stream = orig_handler.stream
|
|
447
|
+
logger.handlers = [handler for handler in logger.handlers if not _is_console_logging_handler(handler)] + [tqdm_handler]
|
|
448
|
+
case 'native':
|
|
449
|
+
from ..base import Progress
|
|
450
|
+
|
|
451
|
+
progress_config = dict(
|
|
452
|
+
tasks=1,
|
|
453
|
+
tick_size=0.001,
|
|
454
|
+
**self.pbar_config['config'],
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
self._pbar = Progress(**progress_config)
|
|
458
|
+
self._update_pbar_progress = self._update_native_progress
|
|
459
|
+
case _:
|
|
460
|
+
raise NotImplementedError(f'Invalid pbar backend {pbar_backend}')
|
|
461
|
+
|
|
462
|
+
def _update_pbar_prefix(self, market_date: datetime.date):
|
|
463
|
+
pbar_backend = self.pbar_config['backend']
|
|
464
|
+
match pbar_backend:
|
|
465
|
+
case 'tqdm':
|
|
466
|
+
prompt = f'Progress Total ({self._idx_date + 1} / {len(self._calendar)})'
|
|
467
|
+
prompt_secondary = f'Progress [{market_date:%Y-%m-%d}]'
|
|
468
|
+
prompt_length = max(len(prompt), len(prompt_secondary))
|
|
469
|
+
self._pbar.set_description(prompt.ljust(prompt_length))
|
|
470
|
+
self._pbar.refresh()
|
|
471
|
+
self._pbar_secondary.n = 0
|
|
472
|
+
self._pbar_secondary.set_description(prompt_secondary.ljust(prompt_length))
|
|
473
|
+
self._pbar_secondary.refresh()
|
|
474
|
+
case 'native':
|
|
475
|
+
self._pbar.prompt = f'Replay {market_date:%Y-%m-%d} ({self._idx_date + 1} / {len(self._calendar)}):'
|
|
476
|
+
self._pbar.output()
|
|
477
|
+
case _:
|
|
478
|
+
raise NotImplementedError(f'Invalid pbar backend {pbar_backend}')
|
|
479
|
+
|
|
480
|
+
def _close_pbar(self):
|
|
481
|
+
pbar_backend = self.pbar_config['backend']
|
|
482
|
+
match pbar_backend:
|
|
483
|
+
case 'tqdm':
|
|
484
|
+
for logger, original_handlers in zip(self.pbar_config['loggers'], self.pbar_config['original_handlers_list']):
|
|
485
|
+
logger.handlers = original_handlers
|
|
486
|
+
|
|
487
|
+
self._pbar_secondary.n = 1
|
|
488
|
+
# self._pbar_secondary.refresh()
|
|
489
|
+
self._pbar_secondary.close()
|
|
490
|
+
self._pbar_secondary = None
|
|
491
|
+
|
|
492
|
+
self._pbar.n = 1
|
|
493
|
+
# self._pbar.refresh()
|
|
494
|
+
self._pbar.close()
|
|
495
|
+
self._pbar = None
|
|
496
|
+
case 'native':
|
|
497
|
+
self._pbar.done_tasks = 1
|
|
498
|
+
self._pbar.output()
|
|
499
|
+
case _:
|
|
500
|
+
raise NotImplementedError(f'Invalid pbar backend {pbar_backend}')
|
|
501
|
+
|
|
502
|
+
def _update_tqdm_progress(self):
|
|
503
|
+
self._pbar.n = self.progress
|
|
504
|
+
self._pbar.update(0)
|
|
505
|
+
|
|
506
|
+
self._pbar_secondary.n = self._idx_buffer / self._buffer_size
|
|
507
|
+
self._pbar_secondary.update(0)
|
|
508
|
+
|
|
509
|
+
def _update_native_progress(self):
|
|
510
|
+
self._pbar.done_tasks = self.progress
|
|
511
|
+
|
|
512
|
+
if (not self._pbar.tick_size) \
|
|
513
|
+
or self._pbar.progress >= self._pbar.tick_size + self._pbar.last_output \
|
|
514
|
+
or self._pbar.is_done:
|
|
515
|
+
self._pbar.output()
|
|
324
516
|
|
|
325
517
|
def __iter__(self):
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
iterator = super().__iter__()
|
|
518
|
+
self._init_pbar()
|
|
519
|
+
return super().__iter__()
|
|
329
520
|
|
|
521
|
+
def __next__(self) -> MarketData:
|
|
330
522
|
try:
|
|
331
|
-
|
|
332
|
-
try:
|
|
333
|
-
result = next(iterator)
|
|
334
|
-
if self._pbar:
|
|
335
|
-
self._pbar.update(self.progress)
|
|
336
|
-
yield result
|
|
337
|
-
except StopIteration:
|
|
338
|
-
break
|
|
339
|
-
finally:
|
|
523
|
+
result = super().__next__()
|
|
340
524
|
if self._pbar is not None:
|
|
341
|
-
self.
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
def _update_progress_bar(self, market_date: datetime.date):
|
|
348
|
-
if self._pbar:
|
|
349
|
-
self._pbar.set_description(f'Replay {market_date:%Y-%m-%d} ({self._idx_date + 1} / {len(self._calendar)})')
|
|
350
|
-
self._pbar.refresh()
|
|
525
|
+
self._update_pbar_progress()
|
|
526
|
+
return result
|
|
527
|
+
except StopIteration:
|
|
528
|
+
if self._pbar is not None:
|
|
529
|
+
self._close_pbar()
|
|
530
|
+
raise
|
|
351
531
|
|
|
352
532
|
|
|
353
533
|
class ProgressiveReplay(SimpleReplay):
|
|
@@ -378,7 +558,7 @@ class ProgressiveReplay(SimpleReplay):
|
|
|
378
558
|
eod: MarketDateCallable = None,
|
|
379
559
|
**progress_config
|
|
380
560
|
) -> None:
|
|
381
|
-
warnings.
|
|
561
|
+
warnings.warn('User ProgressReplay instead!', DeprecationWarning, stacklevel=2)
|
|
382
562
|
self.loader = loader
|
|
383
563
|
super().__init__(loader=loader, market_date=market_date, start_date=start_date, end_date=end_date, calendar=calendar, bod=bod, eod=eod)
|
|
384
564
|
|
|
@@ -410,6 +590,8 @@ class ProgressiveReplay(SimpleReplay):
|
|
|
410
590
|
tasks=1,
|
|
411
591
|
**progress_config
|
|
412
592
|
)
|
|
593
|
+
self._pbar = None
|
|
594
|
+
self.add_bod(self._update_progress_bar, priority=0)
|
|
413
595
|
|
|
414
596
|
def __iter__(self):
|
|
415
597
|
from ..base import Progress
|
|
@@ -419,16 +601,17 @@ class ProgressiveReplay(SimpleReplay):
|
|
|
419
601
|
def __next__(self) -> MarketData:
|
|
420
602
|
try:
|
|
421
603
|
result = super().__next__()
|
|
422
|
-
self._pbar
|
|
604
|
+
if self._pbar:
|
|
605
|
+
self._pbar.done_tasks = self.progress
|
|
423
606
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
607
|
+
if (not self._pbar.tick_size) \
|
|
608
|
+
or self._pbar.progress >= self._pbar.tick_size + self._pbar.last_output \
|
|
609
|
+
or self._pbar.is_done:
|
|
610
|
+
self._pbar.output()
|
|
428
611
|
|
|
429
612
|
return result
|
|
430
613
|
except StopIteration:
|
|
431
|
-
if not self._pbar.is_done:
|
|
614
|
+
if self._pbar is not None and not self._pbar.is_done:
|
|
432
615
|
self.progress.done_tasks = 1
|
|
433
616
|
self._pbar.output()
|
|
434
617
|
raise
|
|
@@ -5,6 +5,8 @@ import pathlib
|
|
|
5
5
|
from .telemetrics import LOGGER
|
|
6
6
|
from ..profile import PROFILE
|
|
7
7
|
|
|
8
|
+
USE_CYTHON = True
|
|
9
|
+
|
|
8
10
|
|
|
9
11
|
def set_logger(logger: logging.Logger):
|
|
10
12
|
global LOGGER
|
|
@@ -15,6 +17,9 @@ def set_logger(logger: logging.Logger):
|
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
def check_cython_module(cython_module) -> bool:
|
|
20
|
+
if not USE_CYTHON:
|
|
21
|
+
return False
|
|
22
|
+
|
|
18
23
|
for name in cython_module:
|
|
19
24
|
cython_ext = '.pyd' if os.name == 'nt' else '.so'
|
|
20
25
|
for file in pathlib.Path(__file__).parent.glob(f'*{cython_ext}'):
|
|
@@ -31,7 +31,7 @@ class MarketDataBuffer:
|
|
|
31
31
|
|
|
32
32
|
def to_bytes(self) -> bytes: ...
|
|
33
33
|
|
|
34
|
-
def update(self, dtype: int, **kwargs:
|
|
34
|
+
def update(self, dtype: int, **kwargs: Any) -> None: ...
|
|
35
35
|
|
|
36
36
|
def __getitem__(self, idx: int) -> MarketData | TransactionData | OrderData | TickDataLite | TickData | BarData: ...
|
|
37
37
|
|
|
@@ -79,6 +79,10 @@ class MarketDataConcurrentBuffer:
|
|
|
79
79
|
capacity: int = ...
|
|
80
80
|
) -> None: ...
|
|
81
81
|
|
|
82
|
+
def get_head(self, worker_id: int) -> int: ...
|
|
83
|
+
|
|
84
|
+
def min_head(self) -> int: ...
|
|
85
|
+
|
|
82
86
|
def is_empty(self, worker_id: int) -> bool: ...
|
|
83
87
|
|
|
84
88
|
def is_empty_all(self) -> bool: ...
|
|
@@ -89,4 +93,10 @@ class MarketDataConcurrentBuffer:
|
|
|
89
93
|
|
|
90
94
|
def get(self, idx: int) -> MarketData | TransactionData | OrderData | TickDataLite | TickData | BarData: ...
|
|
91
95
|
|
|
92
|
-
def listen(self, worker_id: int, timeout: float = ...) -> MarketData | TransactionData | OrderData | TickDataLite | TickData | BarData: ...
|
|
96
|
+
def listen(self, worker_id: int, block: bool = True, timeout: float = ...) -> MarketData | TransactionData | OrderData | TickDataLite | TickData | BarData: ...
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def head(self) -> list[int]: ...
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def tail(self) -> int: ...
|