PyAlgoEngine 0.8.0a14__tar.gz → 0.8.0a17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/PKG-INFO +1 -1
  2. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/PyAlgoEngine.egg-info/PKG-INFO +1 -1
  3. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/__init__.py +1 -1
  4. pyalgoengine-0.8.0a17/algo_engine/backtest/__init__.py +19 -0
  5. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/backtest/replay.py +241 -65
  6. pyalgoengine-0.8.0a14/algo_engine/backtest/__init__.py +0 -19
  7. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/LICENSE +0 -0
  8. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/PyAlgoEngine.egg-info/SOURCES.txt +0 -0
  9. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/PyAlgoEngine.egg-info/dependency_links.txt +0 -0
  10. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/PyAlgoEngine.egg-info/requires.txt +0 -0
  11. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/PyAlgoEngine.egg-info/top_level.txt +0 -0
  12. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/README.md +0 -0
  13. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/__init__.py +0 -0
  14. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/backtest/__init__.py +0 -0
  15. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/backtest/doc_server.py +0 -0
  16. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/backtest/tester.py +0 -0
  17. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/backtest/web_app.py +0 -0
  18. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/bokeh_server.py +0 -0
  19. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/demo/__init__.py +0 -0
  20. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/demo/test.py +0 -0
  21. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/sim_input/__init__.py +0 -0
  22. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/sim_input/client.py +0 -0
  23. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/sim_input/sim_keyboard.py +0 -0
  24. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/sim_input/sim_mouse.py +0 -0
  25. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/apps/sim_input/window.py +0 -0
  26. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/backtest/__main__.py +0 -0
  27. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/backtest/metrics.py +0 -0
  28. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/backtest/sim_match.py +0 -0
  29. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/__init__.py +0 -0
  30. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/candlestick.pyi +0 -0
  31. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/console_utils.py +0 -0
  32. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/finance_decimal.py +0 -0
  33. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/market_data.pyi +0 -0
  34. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/market_data_buffer.pyi +0 -0
  35. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/market_utils_nt.py +0 -0
  36. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/market_utils_posix.py +0 -0
  37. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/technical_analysis.py +0 -0
  38. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/telemetrics.py +0 -0
  39. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/tick.pyi +0 -0
  40. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/trade_utils.pyi +0 -0
  41. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/base/transaction.pyi +0 -0
  42. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/engine/__init__.py +0 -0
  43. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/engine/algo_engine.py +0 -0
  44. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/engine/event_engine.py +0 -0
  45. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/engine/market_engine.py +0 -0
  46. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/engine/trade_engine.py +0 -0
  47. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/monitor/__init__.py +0 -0
  48. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/monitor/advanced_data_interface.py +0 -0
  49. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/profile/__init__.py +0 -0
  50. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/profile/cn.py +0 -0
  51. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/strategy/__init__.py +0 -0
  52. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/strategy/strategy_engine.py +0 -0
  53. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/utils/__init__.py +0 -0
  54. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/utils/commit_regularizer.py +0 -0
  55. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/algo_engine/utils/data_utils.py +0 -0
  56. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/setup.cfg +0 -0
  57. {pyalgoengine-0.8.0a14 → pyalgoengine-0.8.0a17}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PyAlgoEngine
3
- Version: 0.8.0a14
3
+ Version: 0.8.0a17
4
4
  Summary: Basic algo engine
5
5
  Home-page: https://github.com/BolunHan/PyAlgoEngine
6
6
  Author: Bolun.Han
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PyAlgoEngine
3
- Version: 0.8.0a14
3
+ Version: 0.8.0a17
4
4
  Summary: Basic algo engine
5
5
  Home-page: https://github.com/BolunHan/PyAlgoEngine
6
6
  Author: Bolun.Han
@@ -1,4 +1,4 @@
1
- __version__ = "0.8.0.alpha14"
1
+ __version__ = "0.8.0.alpha17"
2
2
 
3
3
  import logging
4
4
  import os
@@ -0,0 +1,19 @@
1
+ import logging
2
+
3
+ from .. import LOGGER
4
+
5
+ LOGGER = LOGGER.getChild('BackTest')
6
+
7
+
8
+ def set_logger(logger: logging.Logger):
9
+ global LOGGER
10
+ LOGGER = logger
11
+
12
+ replay.LOGGER = LOGGER.getChild('Replay')
13
+ sim_match.LOGGER = LOGGER.getChild('SimMatch')
14
+
15
+
16
+ from .replay import PyDataScope, MarketDateCallable, MarketDataLoader, MarketDataBulkLoader, Replay, SimpleReplay, ProgressReplay, ProgressiveReplay
17
+ from .sim_match import SimMatch
18
+
19
+ __all__ = ['PyDataScope', 'MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'ProgressiveReplay', 'SimMatch']
@@ -1,15 +1,96 @@
1
1
  import abc
2
2
  import datetime
3
+ import enum
4
+ import inspect
3
5
  import operator
4
6
  import warnings
5
- from collections.abc import Sequence, Mapping, Iterable
6
- from typing import Literal, Protocol, runtime_checkable
7
+ from collections.abc import Sequence, Mapping, Iterable, Callable
8
+ from typing import Literal, Protocol, runtime_checkable, get_type_hints, Self
7
9
 
8
10
  from . import LOGGER
9
11
  from ..base import MarketData, DataType, MarketDataBuffer
10
12
 
11
13
  LOGGER = LOGGER.getChild('Replay')
12
- __all__ = ['MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'ProgressiveReplay']
14
+ __all__ = ['PyDataScope', 'MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'ProgressiveReplay']
15
+
16
+
17
+ class PyDataScope(enum.Flag):
18
+ SCOPE_TRANSACTION = enum.auto()
19
+ SCOPE_ORDER = enum.auto()
20
+ SCOPE_TICK = enum.auto()
21
+ SCOPE_TICK_LITE = enum.auto()
22
+
23
+ SCOPE_ALL = SCOPE_TRANSACTION | SCOPE_ORDER | SCOPE_TICK
24
+
25
+ @classmethod
26
+ def _missing_(cls, value: Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']):
27
+ if isinstance(value, int):
28
+ return super()._missing_(value)
29
+
30
+ if isinstance(value, str):
31
+ dtypes = value.split(',')
32
+ elif isinstance(value, Iterable):
33
+ dtypes = value
34
+ else:
35
+ raise TypeError(value)
36
+
37
+ _ = PyDataScope(0)
38
+ for dtype in dtypes:
39
+ _ = _.from_str(dtype)
40
+ return _
41
+
42
+ @classmethod
43
+ def get_dtype(cls, dtype: DataType | str) -> str | Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']:
44
+ match dtype:
45
+ case 'TickData' | 'TickDataLite' | 'OrderData' | 'TransactionData':
46
+ return str(dtype)
47
+ case 'TradeData': # handle the alias
48
+ return 'TransactionData'
49
+ case DataType.DTYPE_TICK | DataType.DTYPE_ORDER | DataType.DTYPE_TRANSACTION:
50
+ return DataType(dtype).name.removeprefix('DTYPE_').capitalize() + 'Data'
51
+ case DataType.DTYPE_TICK_LITE:
52
+ return 'Data'.join(_.capitalize() for _ in DataType(dtype).name.removeprefix('DTYPE_').split('_'))
53
+ case _:
54
+ raise ValueError(f'Invalid dtype {dtype}, expect str or int.')
55
+
56
+ def __iter__(self):
57
+ return iter(self.to_dtype())
58
+
59
+ def to_dtype(self) -> list[DataType]:
60
+ scope = list(super().__iter__())
61
+ scope_dtype = set()
62
+
63
+ for dtype in scope:
64
+
65
+ if dtype is PyDataScope.SCOPE_TRANSACTION:
66
+ scope_dtype.add(DataType.DTYPE_TRANSACTION)
67
+ elif dtype is PyDataScope.SCOPE_ORDER:
68
+ scope_dtype.add(DataType.DTYPE_ORDER)
69
+ elif dtype is PyDataScope.SCOPE_TICK_LITE:
70
+ scope_dtype.add(DataType.DTYPE_TICK_LITE)
71
+ elif dtype is PyDataScope.SCOPE_TICK:
72
+ scope_dtype.add(DataType.DTYPE_TICK)
73
+
74
+ return list(scope_dtype)
75
+
76
+ def to_int(self) -> list[int]:
77
+ return [int(_) for _ in self.to_dtype()]
78
+
79
+ def to_str(self) -> list[str]:
80
+ return [self.get_dtype(_) for _ in self.to_dtype()]
81
+
82
+ def from_str(self, dtype: Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']) -> Self:
83
+ match dtype:
84
+ case 'TickData':
85
+ return self | self.SCOPE_TICK
86
+ case 'TickDataLite':
87
+ return self | self.SCOPE_TICK_LITE
88
+ case 'OrderData':
89
+ return self | self.SCOPE_ORDER
90
+ case 'TransactionData' | 'TradeData':
91
+ return self | self.SCOPE_TRANSACTION
92
+ case _:
93
+ raise ValueError(f'Invalid str {dtype}.')
13
94
 
14
95
 
15
96
  @runtime_checkable
@@ -26,10 +107,72 @@ class MarketDataLoader(Protocol):
26
107
 
27
108
  @runtime_checkable
28
109
  class MarketDataBulkLoader(Protocol):
29
- def __call__(self, market_date: datetime.date, tickers: Sequence[str], dtypes: Sequence[str | DataType]) -> Sequence[MarketData] | Mapping[float, MarketData] | MarketDataBuffer:
110
+ def __call__(self, market_date: datetime.date, tickers: Sequence[str], dtypes: Sequence[str | DataType] | PyDataScope) -> Sequence[MarketData] | Mapping[float, MarketData] | MarketDataBuffer:
30
111
  pass
31
112
 
32
113
 
114
+ def check_protocol_signature(func: Callable, protocol: type) -> bool:
115
+ if not callable(func):
116
+ raise TypeError(f"{func} is not callable")
117
+
118
+ proto_sig = inspect.signature(protocol.__call__)
119
+ func_sig = inspect.signature(func)
120
+
121
+ proto_params = list(proto_sig.parameters.values())[1:] # Skip 'self'
122
+ func_params = list(func_sig.parameters.values())
123
+ enable_keywords = False
124
+
125
+ # Check for *args (VAR_POSITIONAL) — not allowed
126
+ for p in func_params:
127
+ if p.kind == inspect.Parameter.VAR_POSITIONAL:
128
+ raise TypeError(f"{func.__name__} uses *args, which is not allowed")
129
+ elif p.kind == inspect.Parameter.VAR_KEYWORD:
130
+ enable_keywords = True
131
+
132
+ # Extract positional args (POSITIONAL_ONLY or POSITIONAL_OR_KEYWORD)
133
+ proto_arg_names = [p.name for p in proto_params if p.kind in (
134
+ inspect.Parameter.POSITIONAL_ONLY,
135
+ inspect.Parameter.POSITIONAL_OR_KEYWORD
136
+ )]
137
+
138
+ func_arg_names = [p.name for p in func_params if p.kind in (
139
+ inspect.Parameter.POSITIONAL_ONLY,
140
+ inspect.Parameter.POSITIONAL_OR_KEYWORD
141
+ )]
142
+
143
+ # Check if required positional args match (ignore **kwargs)
144
+ if not enable_keywords and sorted(proto_arg_names) != sorted(func_arg_names):
145
+ warnings.warn(
146
+ f"{func} argument names {func_arg_names} do not match protocol {proto_arg_names}",
147
+ stacklevel=2
148
+ )
149
+ return False
150
+
151
+ # Type hint comparison (warn if mismatched, but allow)
152
+ proto_hints = get_type_hints(protocol.__call__)
153
+ func_hints = get_type_hints(func)
154
+
155
+ for pname in proto_arg_names:
156
+ expected = proto_hints.get(pname)
157
+ actual = func_hints.get(pname)
158
+ if expected and actual and expected != actual:
159
+ warnings.warn(
160
+ f"Type hint mismatch for parameter '{pname}': expected {expected}, got {actual}",
161
+ stacklevel=2
162
+ )
163
+
164
+ # Optional: check return type
165
+ expected_ret = proto_hints.get("return")
166
+ actual_ret = func_hints.get("return")
167
+ if expected_ret and actual_ret and expected_ret != actual_ret:
168
+ warnings.warn(
169
+ f"Return type mismatch: expected {expected_ret}, got {actual_ret}",
170
+ stacklevel=2
171
+ )
172
+
173
+ return True
174
+
175
+
33
176
  class Replay(object, metaclass=abc.ABCMeta):
34
177
  # __slots__ = ('start_date', 'end_date', 'market_date', 'calendar', 'bod', 'eod', 'subscription', '_calendar', '_market_date', '_status', '_progress')
35
178
 
@@ -49,32 +192,26 @@ class Replay(object, metaclass=abc.ABCMeta):
49
192
  if eod is not None:
50
193
  self.add_eod(eod)
51
194
 
52
- def add_bod(self, func: MarketDateCallable):
53
- self.bod.append(func)
54
-
55
- def add_eod(self, func: MarketDateCallable):
56
- self.eod.append(func)
195
+ def add_bod(self, func: MarketDateCallable, priority: int = None) -> None:
196
+ if priority is None:
197
+ self.bod.append(func)
198
+ else:
199
+ self.bod.insert(priority, func)
57
200
 
58
- @classmethod
59
- def get_dtype(cls, dtype: DataType | str) -> str | Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']:
60
- match dtype:
61
- case 'TickData' | 'TickDataLite' | 'OrderData' | 'TransactionData':
62
- return str(dtype)
63
- case DataType.DTYPE_TICK | DataType.DTYPE_ORDER | DataType.DTYPE_TRANSACTION:
64
- return DataType(dtype).name.removeprefix('DTYPE_').capitalize() + 'Data'
65
- case DataType.DTYPE_TICK_LITE:
66
- return 'Data'.join(_.capitalize() for _ in DataType(dtype).name.removeprefix('DTYPE_').split('_'))
67
- case _:
68
- raise ValueError(f'Invalid dtype {dtype}, expect str or int.')
201
+ def add_eod(self, func: MarketDateCallable, priority: int = None):
202
+ if priority is None:
203
+ self.eod.append(func)
204
+ else:
205
+ self.eod.insert(priority, func)
69
206
 
70
207
  def add_subscription(self, ticker: str, dtype: DataType | str):
71
- dtype = self.get_dtype(dtype)
208
+ dtype = PyDataScope.get_dtype(dtype)
72
209
  topic = f'{ticker}.{dtype}'
73
210
 
74
211
  self.subscription[topic] = (ticker, dtype)
75
212
 
76
213
  def remove_subscription(self, ticker: str, dtype: DataType | str):
77
- dtype = self.get_dtype(dtype)
214
+ dtype = PyDataScope.get_dtype(dtype)
78
215
  topic = f'{ticker}.{dtype}'
79
216
 
80
217
  try:
@@ -115,27 +252,7 @@ class SimpleReplay(Replay):
115
252
  for func in self.bod:
116
253
  func(self._market_date)
117
254
 
118
- if self.loader is None:
119
- assert hasattr(self, '_buffer') and isinstance(self._buffer, Iterable), f'Without assigning a data loader, the _buffer of {self.__class__.__name__} should be set in bod process.'
120
- elif isinstance(self.loader, MarketDataLoader):
121
- md_list = []
122
- for topic, (_ticker, _dtype) in self.subscription.items():
123
- LOGGER.info(f'{self} loading {self._market_date} {_ticker} {_dtype}')
124
- data = self.loader(market_date=self._market_date, ticker=_ticker, dtype=_dtype)
125
- if isinstance(data, Mapping):
126
- md_list.extend(list(data.values()))
127
- elif isinstance(data, Sequence):
128
- md_list.extend(data)
129
- else:
130
- raise TypeError(f'The loader {self.loader} returned {type(data)}. Expect a sequence or mapping of MarketData')
131
- md_list.sort(key=operator.attrgetter('timestamp', 'ticker', '_dtype'))
132
- self._buffer = iter(md_list)
133
- self._buffer_size = len(md_list)
134
- elif isinstance(self.loader, MarketDataBulkLoader):
135
- self._buffer = self.loader(market_date=self._market_date, tickers=self.tickers, dtypes=self.dtypes)
136
- self._buffer_size = len(self._buffer)
137
- else:
138
- raise NotImplementedError()
255
+ self._safe_load()
139
256
 
140
257
  return self
141
258
 
@@ -166,12 +283,71 @@ class SimpleReplay(Replay):
166
283
  for func in self.bod:
167
284
  func(self._market_date)
168
285
 
169
- self._buffer = self.loader(market_date=self._market_date, tickers=self.tickers, dtypes=self.dtypes)
286
+ self._safe_load()
170
287
  return self.__next__()
171
288
 
172
289
  def __repr__(self):
173
290
  return f'{self.__class__.__name__}{{id={id(self)}, from={self.start_date}, to={self.end_date}}}'
174
291
 
292
+ def _bulk_load_protocol(self):
293
+ LOGGER.info(f'{self} loading {self._market_date} {(', '.join(self.dtypes)) if self.dtypes else 'data'} for {len(self.tickers)} tickers...')
294
+ buffer = self.loader(market_date=self._market_date, tickers=self.tickers, dtypes=self.dtypes)
295
+ LOGGER.info(f'{self} sorting {self._market_date} data...')
296
+ buffer.sort()
297
+
298
+ if isinstance(buffer, MarketDataBuffer):
299
+ self._buffer = buffer
300
+ self._buffer_size = len(self._buffer)
301
+ elif isinstance(buffer, Sequence):
302
+ self._buffer = iter(buffer)
303
+ self._buffer_size = len(buffer)
304
+ elif isinstance(buffer, Mapping):
305
+ self._buffer = iter(buffer.values())
306
+ self._buffer_size = len(buffer)
307
+ LOGGER.info(f'{self} {self._market_date} total {self._buffer_size:,} items loaded.')
308
+
309
+ def _individual_load_protocol(self):
310
+ buffer = []
311
+ for topic, (_ticker, _dtype) in self.subscription.items():
312
+ LOGGER.info(f'{self} loading {self._market_date} {_ticker} {_dtype}...')
313
+ data = self.loader(market_date=self._market_date, ticker=_ticker, dtype=_dtype)
314
+ if isinstance(data, Mapping):
315
+ buffer.extend(list(data.values()))
316
+ elif isinstance(data, Sequence):
317
+ buffer.extend(data)
318
+ else:
319
+ raise TypeError(f'The loader {self.loader} returned {type(data)}. Expect a sequence or mapping of MarketData')
320
+ LOGGER.info(f'{self} sorting {self._market_date} data...')
321
+ buffer.sort(key=operator.attrgetter('timestamp', 'ticker', '_dtype'))
322
+ self._buffer = iter(buffer)
323
+ self._buffer_size = len(buffer)
324
+ LOGGER.info(f'{self} {self._market_date} total {self._buffer_size:,} items loaded.')
325
+
326
+ def _safe_load(self):
327
+ if self.loader is None:
328
+ assert hasattr(self, '_buffer') and isinstance(self._buffer, Iterable), f'Without assigning a data loader, the _buffer of {self.__class__.__name__} should be set in bod process.'
329
+ return None
330
+
331
+ is_bulk_loader = check_protocol_signature(self.loader, MarketDataBulkLoader)
332
+ is_individual_loader = check_protocol_signature(self.loader, MarketDataLoader)
333
+
334
+ if (is_bulk_loader and is_individual_loader) or (not is_bulk_loader and not is_individual_loader):
335
+ try:
336
+ return self._bulk_load_protocol()
337
+ except Exception as e:
338
+ LOGGER.info('Failed to load data using MarketDataBulkLoader protocol!')
339
+
340
+ try:
341
+ return self._individual_load_protocol()
342
+ except Exception as e:
343
+ LOGGER.info('Failed to load data using MarketDataLoader protocol!')
344
+ raise
345
+
346
+ if is_bulk_loader:
347
+ return self._bulk_load_protocol()
348
+
349
+ return self._individual_load_protocol()
350
+
175
351
  @property
176
352
  def progress(self) -> float:
177
353
  if not hasattr(self, '_buffer'):
@@ -231,29 +407,26 @@ class ProgressReplay(SimpleReplay):
231
407
  'miniters': 0.001,
232
408
  **tqdm_kwargs
233
409
  }
234
- self.add_bod(self._update_progress_bar)
410
+ self._pbar = None
411
+ self.add_bod(self._update_progress_bar, priority=0)
235
412
 
236
413
  def __iter__(self):
237
414
  from tqdm.auto import tqdm
238
415
  self._pbar = tqdm(**self._tqdm_kwargs)
239
- iterator = super().__iter__()
416
+ return super().__iter__()
240
417
 
418
+ def __next__(self) -> MarketData:
241
419
  try:
242
- while True:
243
- try:
244
- result = next(iterator)
245
- if self._pbar:
246
- self._pbar.update(self.progress)
247
- yield result
248
- except StopIteration:
249
- break
250
- finally:
420
+ result = super().__next__()
421
+ if self._pbar:
422
+ self._pbar.update(self.progress)
423
+ self._pbar.refresh()
424
+ return result
425
+ except StopIteration:
251
426
  if self._pbar is not None:
252
427
  self._pbar.close()
253
428
  self._pbar = None
254
-
255
- def __next__(self) -> MarketData:
256
- raise RuntimeError("MarketDataBufferReplay should be used as an iterator context")
429
+ raise
257
430
 
258
431
  def _update_progress_bar(self, market_date: datetime.date):
259
432
  if self._pbar:
@@ -289,7 +462,7 @@ class ProgressiveReplay(SimpleReplay):
289
462
  eod: MarketDateCallable = None,
290
463
  **progress_config
291
464
  ) -> None:
292
- warnings.deprecated('User ProgressReplay instead!')
465
+ warnings.warn('User ProgressReplay instead!', DeprecationWarning, stacklevel=2)
293
466
  self.loader = loader
294
467
  super().__init__(loader=loader, market_date=market_date, start_date=start_date, end_date=end_date, calendar=calendar, bod=bod, eod=eod)
295
468
 
@@ -321,6 +494,8 @@ class ProgressiveReplay(SimpleReplay):
321
494
  tasks=1,
322
495
  **progress_config
323
496
  )
497
+ self._pbar = None
498
+ self.add_bod(self._update_progress_bar, priority=0)
324
499
 
325
500
  def __iter__(self):
326
501
  from ..base import Progress
@@ -330,16 +505,17 @@ class ProgressiveReplay(SimpleReplay):
330
505
  def __next__(self) -> MarketData:
331
506
  try:
332
507
  result = super().__next__()
333
- self._pbar.done_tasks = self.progress
508
+ if self._pbar:
509
+ self._pbar.done_tasks = self.progress
334
510
 
335
- if (not self._pbar.tick_size) \
336
- or self._pbar.progress >= self._pbar.tick_size + self._pbar.last_output \
337
- or self._pbar.is_done:
338
- self._pbar.output()
511
+ if (not self._pbar.tick_size) \
512
+ or self._pbar.progress >= self._pbar.tick_size + self._pbar.last_output \
513
+ or self._pbar.is_done:
514
+ self._pbar.output()
339
515
 
340
516
  return result
341
517
  except StopIteration:
342
- if not self._pbar.is_done:
518
+ if self._pbar is not None and not self._pbar.is_done:
343
519
  self.progress.done_tasks = 1
344
520
  self._pbar.output()
345
521
  raise
@@ -1,19 +0,0 @@
1
- import logging
2
-
3
- from .. import LOGGER
4
-
5
- LOGGER = LOGGER.getChild('BackTest')
6
-
7
-
8
- def set_logger(logger: logging.Logger):
9
- global LOGGER
10
- LOGGER = logger
11
-
12
- replay.LOGGER = LOGGER.getChild('Replay')
13
- sim_match.LOGGER = LOGGER.getChild('SimMatch')
14
-
15
-
16
- from .replay import Replay, SimpleReplay, ProgressiveReplay
17
- from .sim_match import SimMatch
18
-
19
- __all__ = ['Replay', 'SimpleReplay', 'ProgressiveReplay', 'SimMatch']
File without changes