PyAlgoEngine 0.8.0a12__tar.gz → 0.8.0a14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/PKG-INFO +1 -1
  2. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/PyAlgoEngine.egg-info/PKG-INFO +1 -1
  3. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/__init__.py +1 -1
  4. pyalgoengine-0.8.0a14/algo_engine/backtest/replay.py +350 -0
  5. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/market_data_buffer.pyi +2 -0
  6. pyalgoengine-0.8.0a12/algo_engine/backtest/replay.py +0 -292
  7. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/LICENSE +0 -0
  8. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/PyAlgoEngine.egg-info/SOURCES.txt +0 -0
  9. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/PyAlgoEngine.egg-info/dependency_links.txt +0 -0
  10. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/PyAlgoEngine.egg-info/requires.txt +0 -0
  11. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/PyAlgoEngine.egg-info/top_level.txt +0 -0
  12. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/README.md +0 -0
  13. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/__init__.py +0 -0
  14. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/backtest/__init__.py +0 -0
  15. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/backtest/doc_server.py +0 -0
  16. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/backtest/tester.py +0 -0
  17. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/backtest/web_app.py +0 -0
  18. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/bokeh_server.py +0 -0
  19. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/demo/__init__.py +0 -0
  20. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/demo/test.py +0 -0
  21. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/sim_input/__init__.py +0 -0
  22. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/sim_input/client.py +0 -0
  23. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/sim_input/sim_keyboard.py +0 -0
  24. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/sim_input/sim_mouse.py +0 -0
  25. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/apps/sim_input/window.py +0 -0
  26. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/backtest/__init__.py +0 -0
  27. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/backtest/__main__.py +0 -0
  28. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/backtest/metrics.py +0 -0
  29. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/backtest/sim_match.py +0 -0
  30. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/__init__.py +0 -0
  31. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/candlestick.pyi +0 -0
  32. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/console_utils.py +0 -0
  33. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/finance_decimal.py +0 -0
  34. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/market_data.pyi +0 -0
  35. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/market_utils_nt.py +0 -0
  36. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/market_utils_posix.py +0 -0
  37. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/technical_analysis.py +0 -0
  38. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/telemetrics.py +0 -0
  39. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/tick.pyi +0 -0
  40. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/trade_utils.pyi +0 -0
  41. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/base/transaction.pyi +0 -0
  42. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/engine/__init__.py +0 -0
  43. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/engine/algo_engine.py +0 -0
  44. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/engine/event_engine.py +0 -0
  45. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/engine/market_engine.py +0 -0
  46. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/engine/trade_engine.py +0 -0
  47. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/monitor/__init__.py +0 -0
  48. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/monitor/advanced_data_interface.py +0 -0
  49. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/profile/__init__.py +0 -0
  50. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/profile/cn.py +0 -0
  51. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/strategy/__init__.py +0 -0
  52. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/strategy/strategy_engine.py +0 -0
  53. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/utils/__init__.py +0 -0
  54. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/utils/commit_regularizer.py +0 -0
  55. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/algo_engine/utils/data_utils.py +0 -0
  56. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/setup.cfg +0 -0
  57. {pyalgoengine-0.8.0a12 → pyalgoengine-0.8.0a14}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PyAlgoEngine
3
- Version: 0.8.0a12
3
+ Version: 0.8.0a14
4
4
  Summary: Basic algo engine
5
5
  Home-page: https://github.com/BolunHan/PyAlgoEngine
6
6
  Author: Bolun.Han
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PyAlgoEngine
3
- Version: 0.8.0a12
3
+ Version: 0.8.0a14
4
4
  Summary: Basic algo engine
5
5
  Home-page: https://github.com/BolunHan/PyAlgoEngine
6
6
  Author: Bolun.Han
@@ -1,4 +1,4 @@
1
- __version__ = "0.8.0.alpha12"
1
+ __version__ = "0.8.0.alpha14"
2
2
 
3
3
  import logging
4
4
  import os
@@ -0,0 +1,350 @@
1
+ import abc
2
+ import datetime
3
+ import operator
4
+ import warnings
5
+ from collections.abc import Sequence, Mapping, Iterable
6
+ from typing import Literal, Protocol, runtime_checkable
7
+
8
+ from . import LOGGER
9
+ from ..base import MarketData, DataType, MarketDataBuffer
10
+
11
+ LOGGER = LOGGER.getChild('Replay')
12
+ __all__ = ['MarketDateCallable', 'MarketDataLoader', 'MarketDataBulkLoader', 'Replay', 'SimpleReplay', 'ProgressReplay', 'ProgressiveReplay']
13
+
14
+
15
+ @runtime_checkable
16
+ class MarketDateCallable(Protocol):
17
+ def __call__(self, market_date: datetime.date) -> None:
18
+ ...
19
+
20
+
21
+ @runtime_checkable
22
+ class MarketDataLoader(Protocol):
23
+ def __call__(self, market_date: datetime.date, ticker: str, dtype: str | DataType) -> Sequence[MarketData] | Mapping[float, MarketData]:
24
+ pass
25
+
26
+
27
+ @runtime_checkable
28
+ class MarketDataBulkLoader(Protocol):
29
+ def __call__(self, market_date: datetime.date, tickers: Sequence[str], dtypes: Sequence[str | DataType]) -> Sequence[MarketData] | Mapping[float, MarketData] | MarketDataBuffer:
30
+ pass
31
+
32
+
33
+ class Replay(object, metaclass=abc.ABCMeta):
34
+ # __slots__ = ('start_date', 'end_date', 'market_date', 'calendar', 'bod', 'eod', 'subscription', '_calendar', '_market_date', '_status', '_progress')
35
+
36
+ def __init__(self, start_date: datetime.date = None, end_date: datetime.date = None, market_date: datetime.date = None, calendar: Sequence[datetime.date] = None, bod: MarketDateCallable = None, eod: MarketDateCallable = None) -> None:
37
+ self.start_date = start_date or market_date or calendar[0]
38
+ self.end_date = end_date or calendar[-1]
39
+ self.market_date = market_date or start_date
40
+ self.calendar = calendar or []
41
+
42
+ self.bod = []
43
+ self.eod = []
44
+ self.subscription = {}
45
+
46
+ if bod is not None:
47
+ self.add_bod(bod)
48
+
49
+ if eod is not None:
50
+ self.add_eod(eod)
51
+
52
+ def add_bod(self, func: MarketDateCallable):
53
+ self.bod.append(func)
54
+
55
+ def add_eod(self, func: MarketDateCallable):
56
+ self.eod.append(func)
57
+
58
+ @classmethod
59
+ def get_dtype(cls, dtype: DataType | str) -> str | Literal['TickData', 'TickDataLite', 'OrderData', 'TransactionData']:
60
+ match dtype:
61
+ case 'TickData' | 'TickDataLite' | 'OrderData' | 'TransactionData':
62
+ return str(dtype)
63
+ case DataType.DTYPE_TICK | DataType.DTYPE_ORDER | DataType.DTYPE_TRANSACTION:
64
+ return DataType(dtype).name.removeprefix('DTYPE_').capitalize() + 'Data'
65
+ case DataType.DTYPE_TICK_LITE:
66
+ return 'Data'.join(_.capitalize() for _ in DataType(dtype).name.removeprefix('DTYPE_').split('_'))
67
+ case _:
68
+ raise ValueError(f'Invalid dtype {dtype}, expect str or int.')
69
+
70
+ def add_subscription(self, ticker: str, dtype: DataType | str):
71
+ dtype = self.get_dtype(dtype)
72
+ topic = f'{ticker}.{dtype}'
73
+
74
+ self.subscription[topic] = (ticker, dtype)
75
+
76
+ def remove_subscription(self, ticker: str, dtype: DataType | str):
77
+ dtype = self.get_dtype(dtype)
78
+ topic = f'{ticker}.{dtype}'
79
+
80
+ try:
81
+ self.subscription.pop(topic)
82
+ except KeyError as _:
83
+ LOGGER.info(f'{topic} not in {self.subscription}')
84
+
85
+ @abc.abstractmethod
86
+ def __next__(self):
87
+ ...
88
+
89
+ @abc.abstractmethod
90
+ def __iter__(self):
91
+ ...
92
+
93
+
94
+ class SimpleReplay(Replay):
95
+ def __init__(
96
+ self,
97
+ loader: MarketDataBulkLoader | MarketDataLoader = None,
98
+ market_date: datetime.date = None,
99
+ start_date: datetime.date = None,
100
+ end_date: datetime.date = None,
101
+ calendar: Sequence[datetime.date] = None,
102
+ bod: MarketDateCallable = None,
103
+ eod: MarketDateCallable = None
104
+ ):
105
+ super().__init__(market_date=market_date, start_date=start_date, end_date=end_date, calendar=calendar, bod=bod, eod=eod)
106
+ self.loader = loader
107
+
108
+ def __iter__(self):
109
+ self._calendar = self.calendar or [self.start_date + datetime.timedelta(days=i) for i in range((self.end_date - self.start_date).days + 1)]
110
+ self._market_date = self.market_date or sorted(_ for _ in self._calendar if _ >= self.market_date)[0]
111
+ self._status = {market_date: 'skipped' if market_date < self.market_date else 'idle' for market_date in self._calendar}
112
+ self._idx_buffer = 0
113
+ self._idx_date = sum([1 for _ in self._calendar if _ < self.market_date])
114
+
115
+ for func in self.bod:
116
+ func(self._market_date)
117
+
118
+ if self.loader is None:
119
+ assert hasattr(self, '_buffer') and isinstance(self._buffer, Iterable), f'Without assigning a data loader, the _buffer of {self.__class__.__name__} should be set in bod process.'
120
+ elif isinstance(self.loader, MarketDataLoader):
121
+ md_list = []
122
+ for topic, (_ticker, _dtype) in self.subscription.items():
123
+ LOGGER.info(f'{self} loading {self._market_date} {_ticker} {_dtype}')
124
+ data = self.loader(market_date=self._market_date, ticker=_ticker, dtype=_dtype)
125
+ if isinstance(data, Mapping):
126
+ md_list.extend(list(data.values()))
127
+ elif isinstance(data, Sequence):
128
+ md_list.extend(data)
129
+ else:
130
+ raise TypeError(f'The loader {self.loader} returned {type(data)}. Expect a sequence or mapping of MarketData')
131
+ md_list.sort(key=operator.attrgetter('timestamp', 'ticker', '_dtype'))
132
+ self._buffer = iter(md_list)
133
+ self._buffer_size = len(md_list)
134
+ elif isinstance(self.loader, MarketDataBulkLoader):
135
+ self._buffer = self.loader(market_date=self._market_date, tickers=self.tickers, dtypes=self.dtypes)
136
+ self._buffer_size = len(self._buffer)
137
+ else:
138
+ raise NotImplementedError()
139
+
140
+ return self
141
+
142
+ def __next__(self) -> MarketData:
143
+ if self._idx_buffer < self._buffer_size:
144
+ self._idx_buffer += 1
145
+ return next(self._buffer)
146
+
147
+ for func in self.eod:
148
+ func(self._market_date)
149
+
150
+ self._idx_buffer = 0
151
+ self._idx_date += 1
152
+
153
+ if self._idx_date >= len(self._calendar):
154
+ self._calendar.clear()
155
+ del self._calendar
156
+ del self._market_date
157
+ del self._status
158
+ del self._idx_buffer
159
+ del self._idx_date
160
+ del self._buffer
161
+ del self._buffer_size
162
+ raise StopIteration()
163
+
164
+ self._market_date = self._calendar[self._idx_date]
165
+
166
+ for func in self.bod:
167
+ func(self._market_date)
168
+
169
+ self._buffer = self.loader(market_date=self._market_date, tickers=self.tickers, dtypes=self.dtypes)
170
+ return self.__next__()
171
+
172
+ def __repr__(self):
173
+ return f'{self.__class__.__name__}{{id={id(self)}, from={self.start_date}, to={self.end_date}}}'
174
+
175
+ @property
176
+ def progress(self) -> float:
177
+ if not hasattr(self, '_buffer'):
178
+ raise RuntimeError(f'{self.__class__.__name__} not started yet.')
179
+
180
+ return (self._idx_date + (self._idx_buffer / self._buffer_size - 1)) / len(self._calendar)
181
+
182
+ @property
183
+ def tickers(self) -> list[str]:
184
+ tickers = set()
185
+ for _, (ticker, dtype) in self.subscription.items():
186
+ tickers.add(ticker)
187
+ return list(tickers)
188
+
189
+ @property
190
+ def dtypes(self) -> list[str]:
191
+ dtypes = set()
192
+ for _, (ticker, dtype) in self.subscription.items():
193
+ dtypes.add(dtype)
194
+ return list(dtypes)
195
+
196
+ @property
197
+ def status(self) -> dict[datetime.date, str]:
198
+ if not hasattr(self, '_status'):
199
+ raise RuntimeError(f'{self.__class__.__name__} not started yet.')
200
+
201
+ return self._status
202
+
203
+
204
+ class ProgressReplay(SimpleReplay):
205
+ def __init__(
206
+ self,
207
+ loader: MarketDataBulkLoader | MarketDataLoader = None,
208
+ market_date: datetime.date = None,
209
+ start_date: datetime.date = None,
210
+ end_date: datetime.date = None,
211
+ calendar: Sequence[datetime.date] = None,
212
+ bod: MarketDateCallable = None,
213
+ eod: MarketDateCallable = None,
214
+ **tqdm_kwargs
215
+ ):
216
+ super().__init__(
217
+ loader=loader,
218
+ market_date=market_date,
219
+ start_date=start_date,
220
+ end_date=end_date,
221
+ calendar=calendar,
222
+ bod=bod,
223
+ eod=eod
224
+ )
225
+
226
+ self._tqdm_kwargs = {
227
+ 'total': 1,
228
+ 'unit_scale': True,
229
+ 'unit': 'percent',
230
+ 'mininterval': 0.1,
231
+ 'miniters': 0.001,
232
+ **tqdm_kwargs
233
+ }
234
+ self.add_bod(self._update_progress_bar)
235
+
236
+ def __iter__(self):
237
+ from tqdm.auto import tqdm
238
+ self._pbar = tqdm(**self._tqdm_kwargs)
239
+ iterator = super().__iter__()
240
+
241
+ try:
242
+ while True:
243
+ try:
244
+ result = next(iterator)
245
+ if self._pbar:
246
+ self._pbar.update(self.progress)
247
+ yield result
248
+ except StopIteration:
249
+ break
250
+ finally:
251
+ if self._pbar is not None:
252
+ self._pbar.close()
253
+ self._pbar = None
254
+
255
+ def __next__(self) -> MarketData:
256
+ raise RuntimeError("MarketDataBufferReplay should be used as an iterator context")
257
+
258
+ def _update_progress_bar(self, market_date: datetime.date):
259
+ if self._pbar:
260
+ self._pbar.set_description(f'Replay {market_date:%Y-%m-%d} ({self._idx_date + 1} / {len(self._calendar)})')
261
+ self._pbar.refresh()
262
+
263
+
264
+ class ProgressiveReplay(SimpleReplay):
265
+ """
266
+ progressively loading and replaying market data
267
+
268
+ requires arguments
269
+ loader: a data loading function. Expect loader = Callable(market_date: datetime.date, ticker: str, dtype: str| type) -> dict[any, MarketData]
270
+ start_date & end_date: the given replay period
271
+ or calendar: the given replay calendar.
272
+
273
+ accepts kwargs:
274
+ ticker / tickers: the given symbols to replay, expect a str| list[str]
275
+ dtype / dtypes: the given dtype(s) of symbol to replay, expect a str | type, list[str | type]. default = all, which is (TradeData, TickData, OrderBook)
276
+ subscription / subscribe: the given ticker-dtype pair to replay, expect a list[dict[str, str | type]]
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ loader: MarketDataLoader,
282
+ tickers: str | Sequence[str] = None,
283
+ dtypes: str | DataType | Sequence[str] | Sequence[DataType] = None,
284
+ market_date: datetime.date = None,
285
+ start_date: datetime.date = None,
286
+ end_date: datetime.date = None,
287
+ calendar: Sequence[datetime.date] = None,
288
+ bod: MarketDateCallable = None,
289
+ eod: MarketDateCallable = None,
290
+ **progress_config
291
+ ) -> None:
292
+ warnings.deprecated('User ProgressReplay instead!')
293
+ self.loader = loader
294
+ super().__init__(loader=loader, market_date=market_date, start_date=start_date, end_date=end_date, calendar=calendar, bod=bod, eod=eod)
295
+
296
+ tickers = tickers or []
297
+ dtypes = dtypes or ['TransactionData', 'TickData', 'OrderData']
298
+
299
+ if not isinstance(loader, MarketDataLoader):
300
+ raise TypeError('loader function has 3 requires args, market_date, ticker and dtype.')
301
+
302
+ if isinstance(tickers, str):
303
+ tickers = [tickers]
304
+ elif isinstance(tickers, Iterable):
305
+ tickers = list(tickers)
306
+ else:
307
+ raise TypeError(f'Invalid ticker {tickers}, expect str or list[str]')
308
+
309
+ if isinstance(dtypes, (str, int, DataType)):
310
+ dtypes = [dtypes]
311
+ elif isinstance(dtypes, Iterable):
312
+ dtypes = list(dtypes)
313
+ else:
314
+ raise TypeError(f'Invalid dtype {dtypes}, expect str or list[str]')
315
+
316
+ for ticker in tickers:
317
+ for dtype in dtypes:
318
+ self.add_subscription(ticker=ticker, dtype=dtype)
319
+
320
+ self.progress_config = dict(
321
+ tasks=1,
322
+ **progress_config
323
+ )
324
+
325
+ def __iter__(self):
326
+ from ..base import Progress
327
+ self._pbar = Progress(**self.progress_config)
328
+ return super().__iter__()
329
+
330
+ def __next__(self) -> MarketData:
331
+ try:
332
+ result = super().__next__()
333
+ self._pbar.done_tasks = self.progress
334
+
335
+ if (not self._pbar.tick_size) \
336
+ or self._pbar.progress >= self._pbar.tick_size + self._pbar.last_output \
337
+ or self._pbar.is_done:
338
+ self._pbar.output()
339
+
340
+ return result
341
+ except StopIteration:
342
+ if not self._pbar.is_done:
343
+ self.progress.done_tasks = 1
344
+ self._pbar.output()
345
+ raise
346
+
347
+ def _update_progress_bar(self, market_date: datetime.date):
348
+ if self._pbar:
349
+ self.progress.prompt = f'Replay {market_date:%Y-%m-%d} ({self._idx_date + 1} / {len(self._calendar)}):'
350
+ self._pbar.output()
@@ -33,6 +33,8 @@ class MarketDataBuffer:
33
33
 
34
34
  def update(self, dtype: int, **kwargs: dict[str, Any]) -> None: ...
35
35
 
36
+ def __getitem__(self, idx: int) -> MarketData | TransactionData | OrderData | TickDataLite | TickData | BarData: ...
37
+
36
38
  def __iter__(self) -> MarketDataBuffer: ...
37
39
 
38
40
  def __next__(self) -> MarketData | TransactionData | OrderData | TickDataLite | TickData | BarData: ...
@@ -1,292 +0,0 @@
1
- import abc
2
- import datetime
3
- import inspect
4
- import operator
5
- from collections.abc import Mapping, Sequence, Iterator
6
- from typing import Iterable, Protocol
7
-
8
- from . import LOGGER
9
- from ..base import Progress, TickData, TransactionData, TradeData, OrderData, MarketData, MarketDataBuffer
10
-
11
- LOGGER = LOGGER.getChild('Replay')
12
-
13
-
14
- class Replay(object, metaclass=abc.ABCMeta):
15
- @abc.abstractmethod
16
- def __next__(self): ...
17
-
18
- @abc.abstractmethod
19
- def __iter__(self): ...
20
-
21
-
22
- class SimpleReplay(Replay):
23
- def __init__(self, **kwargs):
24
- self.eod = kwargs.pop('eod', None)
25
- self.bod = kwargs.pop('bod', None)
26
-
27
- self.replay_task = []
28
- self.task_progress = 0
29
- self.task_date = None
30
- self.progress = Progress(tasks=1, **kwargs)
31
-
32
- def load(self, data):
33
- if isinstance(data, dict):
34
- self.replay_task.extend(list(data.values()))
35
- else:
36
- self.replay_task.extend(data)
37
-
38
- def reset(self):
39
- self.replay_task.clear()
40
- self.task_progress = 0
41
- self.task_date = None
42
- self.progress.reset()
43
-
44
- def next_task(self):
45
- if self.task_progress < len(self.replay_task):
46
- market_data = self.replay_task[self.task_progress]
47
- market_time = market_data.market_time
48
-
49
- if isinstance(market_time, datetime.datetime):
50
- market_date = market_time.date()
51
- else:
52
- market_date = market_time
53
-
54
- if market_date != self.task_date:
55
- if callable(self.eod) and self.task_date:
56
- self.eod(self.task_date)
57
-
58
- self.task_date = market_date
59
- self.progress.prompt = f'Replay {market_date:%Y-%m-%d}:'
60
-
61
- if callable(self.bod):
62
- self.bod(market_date)
63
-
64
- self.progress.done_tasks = self.task_progress / len(self.replay_task)
65
-
66
- if (not self.progress.tick_size) or self.progress.progress >= self.progress.tick_size + self.progress.last_output:
67
- self.progress.output()
68
-
69
- self.task_progress += 1
70
- else:
71
- raise StopIteration()
72
-
73
- return market_data
74
-
75
- def __next__(self):
76
- try:
77
- return self.next_task()
78
- except StopIteration:
79
- if not self.progress.is_done:
80
- self.progress.done_tasks = 1
81
- self.progress.output()
82
-
83
- self.reset()
84
- raise StopIteration()
85
-
86
- def __iter__(self):
87
- return self
88
-
89
-
90
- class DataLoader(Protocol):
91
- def __call__(self, market_date: datetime.date, ticker: str, dtype: str) -> Mapping[float, MarketData] | Sequence[MarketData] | MarketDataBuffer:
92
- ...
93
-
94
-
95
- class ProgressiveReplay(Replay):
96
- """
97
- progressively loading and replaying market data
98
-
99
- requires arguments
100
- loader: a data loading function. Expect loader = Callable(market_date: datetime.date, ticker: str, dtype: str| type) -> dict[any, MarketData]
101
- start_date & end_date: the given replay period
102
- or calendar: the given replay calendar.
103
-
104
- accepts kwargs:
105
- ticker / tickers: the given symbols to replay, expect a str| list[str]
106
- dtype / dtypes: the given dtype(s) of symbol to replay, expect a str | type, list[str | type]. default = all, which is (TradeData, TickData, OrderBook)
107
- subscription / subscribe: the given ticker-dtype pair to replay, expect a list[dict[str, str | type]]
108
- """
109
-
110
- def __init__(
111
- self,
112
- loader: DataLoader,
113
- **kwargs
114
- ):
115
- self.loader = loader
116
- self.market_date: datetime.date | None = kwargs.pop('market_date', None)
117
- self.start_date: datetime.date | None = kwargs.pop('start_date', None)
118
- self.end_date: datetime.date | None = kwargs.pop('end_date', None)
119
- self.calendar: list[datetime.date] | None = kwargs.pop('calendar', None)
120
-
121
- self.eod = kwargs.pop('eod', None)
122
- self.bod = kwargs.pop('bod', None)
123
-
124
- self.replay_subscription = {}
125
- self.replay_calendar = []
126
- self.replay_task: Iterator | None = None
127
- self.replay_task_length: int = 0
128
- self.replay_status = {}
129
-
130
- self.date_progress = 0
131
- self.task_progress = 0
132
- self.progress = Progress(tasks=1, **kwargs)
133
-
134
- tickers: list[str] = kwargs.pop('ticker', kwargs.pop('tickers', []))
135
- dtypes: list[str | type] = kwargs.pop('dtype', kwargs.pop('dtypes', [TradeData, TransactionData, OrderData, TickData]))
136
-
137
- if not all([arg_name in inspect.getfullargspec(loader).args for arg_name in ['market_date', 'ticker', 'dtype']]):
138
- raise TypeError('loader function has 3 requires args, market_date, ticker and dtype.')
139
-
140
- if isinstance(tickers, str):
141
- tickers = [tickers]
142
- elif isinstance(tickers, Iterable):
143
- tickers = list(tickers)
144
- else:
145
- raise TypeError(f'Invalid ticker {tickers}, expect str or list[str]')
146
-
147
- if isinstance(dtypes, str) or inspect.isclass(dtypes):
148
- dtypes = [dtypes]
149
- elif isinstance(dtypes, Iterable):
150
- dtypes = list(dtypes)
151
- else:
152
- raise TypeError(f'Invalid dtype {dtypes}, expect str or list[str]')
153
-
154
- for ticker in tickers:
155
- for dtype in dtypes:
156
- self.add_subscription(ticker=ticker, dtype=dtype)
157
-
158
- subscription = kwargs.pop('subscription', kwargs.pop('subscribe', []))
159
-
160
- if isinstance(subscription, dict):
161
- subscription = [subscription]
162
-
163
- for sub in subscription:
164
- self.add_subscription(**sub)
165
-
166
- self.reset()
167
-
168
- def add_subscription(self, ticker: str, dtype: type | str):
169
- if isinstance(dtype, str):
170
- pass
171
- elif inspect.isclass(dtype):
172
- dtype = dtype.__name__
173
- else:
174
- raise ValueError(f'Invalid dtype {dtype}, expect str or class.')
175
-
176
- topic = f'{ticker}.{dtype}'
177
- self.replay_subscription[topic] = (ticker, dtype)
178
-
179
- def remove_subscription(self, ticker: str, dtype: type | str):
180
- if isinstance(dtype, str):
181
- pass
182
- else:
183
- dtype = dtype.__name__
184
-
185
- topic = f'{ticker}.{dtype}'
186
- self.replay_subscription.pop(topic, None)
187
-
188
- def reset(self):
189
- if self.calendar is None:
190
- self.replay_calendar = [self.start_date + datetime.timedelta(days=i) for i in range((self.end_date - self.start_date).days + 1)]
191
- else:
192
- self.replay_calendar = self.calendar
193
-
194
- if self.market_date is None:
195
- self.market_date = self.replay_calendar[0] if self.replay_calendar else self.start_date
196
- else:
197
- date_to_replay = [_ for _ in self.replay_calendar if _ >= self.market_date]
198
- self.market_date = date_to_replay[0] if date_to_replay else self.end_date
199
-
200
- self.replay_status = {market_date: 'skipped' if market_date < self.market_date else 'idle' for market_date in self.replay_calendar}
201
-
202
- self.task_progress = 0
203
- self.replay_task_length = 0
204
- self.replay_task = None
205
- self.date_progress = sum([1 for _ in self.replay_calendar if _ < self.market_date])
206
- self.progress.reset()
207
-
208
- if self.date_progress:
209
- self.progress.done_tasks = self.date_progress / len(self.replay_calendar)
210
-
211
- def next_trade_day(self):
212
- if self.date_progress >= len(self.replay_calendar):
213
- raise StopIteration()
214
-
215
- self.market_date = market_date = self.replay_calendar[self.date_progress]
216
- self.replay_status[market_date] = 'started'
217
- self.progress.prompt = f'Replay {market_date:%Y-%m-%d} ({self.date_progress + 1} / {len(self.replay_calendar)}):'
218
-
219
- for topic in self.replay_subscription:
220
- ticker, dtype = self.replay_subscription[topic]
221
- LOGGER.info(f'{self} loading {market_date} {ticker} {dtype}...')
222
- data = self.loader(market_date=market_date, ticker=ticker, dtype=dtype)
223
- if isinstance(data, Mapping):
224
- data = [data[ts] for ts in sorted(data)] # expect to be a mapping of ts and data
225
- self.replay_task = iter(data)
226
- self.replay_task_length = len(data)
227
- elif isinstance(data, Sequence):
228
- data = sorted(data, key=operator.attrgetter('timestamp', 'ticker', '__class__.__name__'))
229
- self.replay_task = iter(data)
230
- self.replay_task_length = len(data)
231
- elif isinstance(data, MarketDataBuffer):
232
- data.sort()
233
- self.replay_task = iter(data)
234
- self.replay_task_length = len(data)
235
- else:
236
- raise TypeError(f'Invalid return type of dataloader, expect list, tuple, dict or MarketDataBuffer, got {type(data)}.')
237
-
238
- LOGGER.info(f'{market_date} data loaded! {self.replay_task_length:,} entries.')
239
- self.date_progress += 1
240
-
241
- def next_task(self):
242
- try:
243
- data = next(self.replay_task)
244
- self.task_progress += 1
245
- except StopIteration:
246
- if self.eod is not None and self.replay_status[self.market_date] == 'started':
247
- self.eod(market_date=self.market_date, replay=self)
248
- self.replay_status[self.market_date] = 'done'
249
-
250
- self.replay_task = None
251
- self.task_progress = 0
252
-
253
- if self.bod is not None and self.date_progress < len(self.replay_calendar):
254
- self.bod(market_date=self.replay_calendar[self.date_progress], replay=self)
255
-
256
- # this is by designed, to load the new data after the bod is done.
257
- self.next_trade_day()
258
-
259
- # the bod process should be moved here!
260
-
261
- data = self.next_task()
262
-
263
- if self.replay_task_length and self.replay_calendar:
264
- current_progress = (self.date_progress - 1 + (self.task_progress / self.replay_task_length)) / len(self.replay_calendar)
265
- self.progress.done_tasks = current_progress
266
- else:
267
- self.progress.done_tasks = 1
268
-
269
- if (not self.progress.tick_size) \
270
- or self.progress.progress >= self.progress.tick_size + self.progress.last_output \
271
- or self.progress.is_done:
272
- self.progress.output()
273
-
274
- return data
275
-
276
- def __next__(self) -> MarketData:
277
- try:
278
- return self.next_task()
279
- except StopIteration:
280
- if not self.progress.is_done:
281
- self.progress.done_tasks = 1
282
- self.progress.output()
283
-
284
- self.reset()
285
- raise StopIteration()
286
-
287
- def __iter__(self):
288
- self.reset()
289
- return self
290
-
291
- def __repr__(self):
292
- return f'{self.__class__.__name__}{{id={id(self)}, from={self.start_date}, to={self.end_date}}}'
File without changes