bbstrader 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bbstrader might be problematic. Click here for more details.

Files changed (45) hide show
  1. bbstrader/__init__.py +11 -2
  2. bbstrader/__main__.py +6 -1
  3. bbstrader/apps/_copier.py +43 -40
  4. bbstrader/btengine/backtest.py +33 -28
  5. bbstrader/btengine/data.py +105 -81
  6. bbstrader/btengine/event.py +21 -22
  7. bbstrader/btengine/execution.py +51 -24
  8. bbstrader/btengine/performance.py +23 -12
  9. bbstrader/btengine/portfolio.py +40 -30
  10. bbstrader/btengine/scripts.py +13 -12
  11. bbstrader/btengine/strategy.py +396 -134
  12. bbstrader/compat.py +4 -3
  13. bbstrader/config.py +20 -36
  14. bbstrader/core/data.py +76 -48
  15. bbstrader/core/scripts.py +22 -21
  16. bbstrader/core/utils.py +13 -12
  17. bbstrader/metatrader/account.py +51 -26
  18. bbstrader/metatrader/analysis.py +30 -16
  19. bbstrader/metatrader/copier.py +75 -40
  20. bbstrader/metatrader/trade.py +29 -39
  21. bbstrader/metatrader/utils.py +5 -4
  22. bbstrader/models/nlp.py +83 -66
  23. bbstrader/trading/execution.py +45 -22
  24. bbstrader/tseries.py +158 -166
  25. {bbstrader-0.3.5.dist-info → bbstrader-0.3.7.dist-info}/METADATA +7 -21
  26. bbstrader-0.3.7.dist-info/RECORD +62 -0
  27. bbstrader-0.3.7.dist-info/top_level.txt +3 -0
  28. docs/conf.py +56 -0
  29. tests/__init__.py +0 -0
  30. tests/engine/__init__.py +1 -0
  31. tests/engine/test_backtest.py +58 -0
  32. tests/engine/test_data.py +536 -0
  33. tests/engine/test_events.py +300 -0
  34. tests/engine/test_execution.py +219 -0
  35. tests/engine/test_portfolio.py +308 -0
  36. tests/metatrader/__init__.py +0 -0
  37. tests/metatrader/test_account.py +1769 -0
  38. tests/metatrader/test_rates.py +292 -0
  39. tests/metatrader/test_risk_management.py +700 -0
  40. tests/metatrader/test_trade.py +439 -0
  41. bbstrader-0.3.5.dist-info/RECORD +0 -49
  42. bbstrader-0.3.5.dist-info/top_level.txt +0 -1
  43. {bbstrader-0.3.5.dist-info → bbstrader-0.3.7.dist-info}/WHEEL +0 -0
  44. {bbstrader-0.3.5.dist-info → bbstrader-0.3.7.dist-info}/entry_points.txt +0 -0
  45. {bbstrader-0.3.5.dist-info → bbstrader-0.3.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,16 @@
1
1
  import os.path
2
2
  from abc import ABCMeta, abstractmethod
3
3
  from datetime import datetime
4
+ from pathlib import Path
4
5
  from queue import Queue
5
- from typing import Dict, List
6
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
6
7
 
7
8
  import numpy as np
8
- from numpy.typing import NDArray
9
9
  import pandas as pd
10
10
  import yfinance as yf
11
11
  from eodhd import APIClient
12
12
  from financetoolkit import Toolkit
13
+ from numpy.typing import NDArray
13
14
  from pytz import timezone
14
15
 
15
16
  from bbstrader.btengine.event import MarketEvent
@@ -59,54 +60,56 @@ class DataHandler(metaclass=ABCMeta):
59
60
  pass
60
61
 
61
62
  @property
62
- def index(self) -> str | List[str]:
63
+ def index(self) -> Union[str, List[str]]:
63
64
  pass
64
65
 
65
66
  @abstractmethod
66
- def get_latest_bar(self, symbol) -> pd.Series:
67
+ def get_latest_bar(self, symbol: str) -> pd.Series:
67
68
  """
68
69
  Returns the last bar updated.
69
70
  """
70
- pass
71
+ raise NotImplementedError("Should implement get_latest_bar()")
71
72
 
72
73
  @abstractmethod
73
- def get_latest_bars(self, symbol, N=1, df=True) -> pd.DataFrame | List[pd.Series]:
74
+ def get_latest_bars(
75
+ self, symbol: str, N: int = 1, df: bool = True
76
+ ) -> Union[pd.DataFrame, List[pd.Series]]:
74
77
  """
75
78
  Returns the last N bars updated.
76
79
  """
77
- pass
80
+ raise NotImplementedError("Should implement get_latest_bars()")
78
81
 
79
82
  @abstractmethod
80
- def get_latest_bar_datetime(self, symbol) -> datetime | pd.Timestamp:
83
+ def get_latest_bar_datetime(self, symbol: str) -> Union[datetime, pd.Timestamp]:
81
84
  """
82
85
  Returns a Python datetime object for the last bar.
83
86
  """
84
- pass
87
+ raise NotImplementedError("Should implement get_latest_bar_datetime()")
85
88
 
86
89
  @abstractmethod
87
- def get_latest_bar_value(self, symbol, val_type) -> float:
90
+ def get_latest_bar_value(self, symbol: str, val_type: str) -> float:
88
91
  """
89
92
  Returns one of the Open, High, Low, Close, Adj Close, Volume or Returns
90
93
  from the last bar.
91
94
  """
92
- pass
95
+ raise NotImplementedError("Should implement get_latest_bar_value()")
93
96
 
94
97
  @abstractmethod
95
- def get_latest_bars_values(self, symbol, val_type, N=1) -> NDArray:
98
+ def get_latest_bars_values(self, symbol: str, val_type: str, N: int = 1) -> NDArray:
96
99
  """
97
100
  Returns the last N bar values from the
98
101
  latest_symbol list, or N-k if less available.
99
102
  """
100
- pass
103
+ raise NotImplementedError("Should implement get_latest_bars_values()")
101
104
 
102
105
  @abstractmethod
103
- def update_bars(self):
106
+ def update_bars(self) -> None:
104
107
  """
105
108
  Pushes the latest bars to the bars_queue for each symbol
106
109
  in a tuple OHLCVI format: (datetime, Open, High, Low,
107
110
  Close, Adj Close, Volume, Retruns).
108
111
  """
109
- pass
112
+ raise NotImplementedError("Should implement update_bars()")
110
113
 
111
114
 
112
115
  class BaseCSVDataHandler(DataHandler):
@@ -117,12 +120,12 @@ class BaseCSVDataHandler(DataHandler):
117
120
 
118
121
  def __init__(
119
122
  self,
120
- events: Queue,
123
+ events: "Queue[MarketEvent]",
121
124
  symbol_list: List[str],
122
125
  csv_dir: str,
123
- columns: List[str] = None,
124
- index_col: str | int | List[str] | List[int] = 0,
125
- ):
126
+ columns: Optional[List[str]] = None,
127
+ index_col: Union[str, int, List[str], List[int]] = 0,
128
+ ) -> None:
126
129
  """
127
130
  Initialises the data handler by requesting the location of the CSV files
128
131
  and a list of symbols.
@@ -139,10 +142,10 @@ class BaseCSVDataHandler(DataHandler):
139
142
  self.csv_dir = csv_dir
140
143
  self.columns = columns
141
144
  self.index_col = index_col
142
- self.symbol_data = {}
143
- self.latest_symbol_data = {}
145
+ self.symbol_data: Dict[str, Union[pd.DataFrame, Generator]] = {}
146
+ self.latest_symbol_data: Dict[str, List[Any]] = {}
144
147
  self.continue_backtest = True
145
- self._index = None
148
+ self._index: Optional[Union[str, List[str]]] = None
146
149
  self._load_and_process_data()
147
150
 
148
151
  @property
@@ -151,7 +154,7 @@ class BaseCSVDataHandler(DataHandler):
151
154
 
152
155
  @property
153
156
  def data(self) -> Dict[str, pd.DataFrame]:
154
- return self.symbol_data
157
+ return self.symbol_data # type: ignore
155
158
 
156
159
  @property
157
160
  def datadir(self) -> str:
@@ -159,13 +162,13 @@ class BaseCSVDataHandler(DataHandler):
159
162
 
160
163
  @property
161
164
  def labels(self) -> List[str]:
162
- return self.columns
165
+ return self.columns # type: ignore
163
166
 
164
167
  @property
165
- def index(self) -> str | List[str]:
166
- return self._index
168
+ def index(self) -> Union[str, List[str]]:
169
+ return self._index # type: ignore
167
170
 
168
- def _load_and_process_data(self):
171
+ def _load_and_process_data(self) -> None:
169
172
  """
170
173
  Opens the CSV files from the data directory, converting
171
174
  them into pandas DataFrames within a symbol dictionary.
@@ -201,29 +204,29 @@ class BaseCSVDataHandler(DataHandler):
201
204
 
202
205
  # Reindex the dataframes
203
206
  for s in self.symbol_list:
204
- self.symbol_data[s] = self.symbol_data[s].reindex(
207
+ self.symbol_data[s] = self.symbol_data[s].reindex( # type: ignore
205
208
  index=comb_index, method="pad"
206
209
  )
207
210
  if "adj_close" not in new_names:
208
211
  self.columns.append("adj_close")
209
- self.symbol_data[s]["adj_close"] = self.symbol_data[s]["close"]
210
- self.symbol_data[s]["returns"] = (
211
- self.symbol_data[s][
212
+ self.symbol_data[s]["adj_close"] = self.symbol_data[s]["close"] # type: ignore
213
+ self.symbol_data[s]["returns"] = ( # type: ignore
214
+ self.symbol_data[s][ # type: ignore
212
215
  "adj_close" if "adj_close" in new_names else "close"
213
216
  ]
214
217
  .pct_change()
215
218
  .dropna()
216
219
  )
217
- self._index = self.symbol_data[s].index.name
218
- self.symbol_data[s].to_csv(os.path.join(self.csv_dir, f"{s}.csv"))
220
+ self._index = self.symbol_data[s].index.name # type: ignore
221
+ self.symbol_data[s].to_csv(os.path.join(self.csv_dir, f"{s}.csv")) # type: ignore
219
222
  if self.events is not None:
220
- self.symbol_data[s] = self.symbol_data[s].iterrows()
223
+ self.symbol_data[s] = self.symbol_data[s].iterrows() # type: ignore
221
224
 
222
- def _get_new_bar(self, symbol: str):
225
+ def _get_new_bar(self, symbol: str) -> Generator[Tuple[Any, Any], Any, None]:
223
226
  """
224
227
  Returns the latest bar from the data feed.
225
228
  """
226
- for b in self.symbol_data[symbol]:
229
+ for b in self.symbol_data[symbol]: # type: ignore
227
230
  yield b
228
231
 
229
232
  def get_latest_bar(self, symbol: str) -> pd.Series:
@@ -239,8 +242,8 @@ class BaseCSVDataHandler(DataHandler):
239
242
  return bars_list[-1]
240
243
 
241
244
  def get_latest_bars(
242
- self, symbol: str, N=1, df=True
243
- ) -> pd.DataFrame | List[pd.Series]:
245
+ self, symbol: str, N: int = 1, df: bool = True
246
+ ) -> Union[pd.DataFrame, List[pd.Series]]:
244
247
  """
245
248
  Returns the last N bars from the latest_symbol list,
246
249
  or N-k if less available.
@@ -252,12 +255,12 @@ class BaseCSVDataHandler(DataHandler):
252
255
  raise
253
256
  else:
254
257
  if df:
255
- df = pd.DataFrame([bar[1] for bar in bars_list[-N:]])
256
- df.index.name = self._index
257
- return df
258
+ df_ = pd.DataFrame([bar[1] for bar in bars_list[-N:]])
259
+ df_.index.name = self._index # type: ignore
260
+ return df_
258
261
  return bars_list[-N:]
259
262
 
260
- def get_latest_bar_datetime(self, symbol: str) -> datetime | pd.Timestamp:
263
+ def get_latest_bar_datetime(self, symbol: str) -> Union[datetime, pd.Timestamp]:
261
264
  """
262
265
  Returns a Python datetime object for the last bar.
263
266
  """
@@ -270,18 +273,18 @@ class BaseCSVDataHandler(DataHandler):
270
273
  return bars_list[-1][0]
271
274
 
272
275
  def get_latest_bars_datetime(
273
- self, symbol: str, N=1
274
- ) -> List[datetime | pd.Timestamp]:
276
+ self, symbol: str, N: int = 1
277
+ ) -> List[Union[datetime, pd.Timestamp]]:
275
278
  """
276
279
  Returns a list of Python datetime objects for the last N bars.
277
280
  """
278
281
  try:
279
- bars_list = self.get_latest_bars(symbol, N)
282
+ bars_list = self.get_latest_bars(symbol, N) # type: ignore
280
283
  except KeyError:
281
284
  print(f"{symbol} not available in the historical data set for .")
282
285
  raise
283
286
  else:
284
- return [b[0] for b in bars_list]
287
+ return [b[0] for b in bars_list] # type: ignore
285
288
 
286
289
  def get_latest_bar_value(self, symbol: str, val_type: str) -> float:
287
290
  """
@@ -302,7 +305,7 @@ class BaseCSVDataHandler(DataHandler):
302
305
  )
303
306
  raise
304
307
 
305
- def get_latest_bars_values(self, symbol: str, val_type: str, N=1) -> NDArray:
308
+ def get_latest_bars_values(self, symbol: str, val_type: str, N: int = 1) -> NDArray:
306
309
  """
307
310
  Returns the last N bar values from the
308
311
  latest_symbol list, or N-k if less available.
@@ -321,7 +324,7 @@ class BaseCSVDataHandler(DataHandler):
321
324
  )
322
325
  raise
323
326
 
324
- def update_bars(self):
327
+ def update_bars(self) -> None:
325
328
  """
326
329
  Pushes the latest bar to the latest_symbol_data structure
327
330
  for all symbols in the symbol list.
@@ -348,7 +351,9 @@ class CSVDataHandler(BaseCSVDataHandler):
348
351
  to cutomize specific data in some form based on your `Strategy()` .
349
352
  """
350
353
 
351
- def __init__(self, events: Queue, symbol_list: List[str], **kwargs):
354
+ def __init__(
355
+ self, events: "Queue[MarketEvent]", symbol_list: List[str], **kwargs: Any
356
+ ) -> None:
352
357
  """
353
358
  Initialises the historic data handler by requesting
354
359
  the location of the CSV files and a list of symbols.
@@ -369,7 +374,7 @@ class CSVDataHandler(BaseCSVDataHandler):
369
374
  super().__init__(
370
375
  events,
371
376
  symbol_list,
372
- csv_dir,
377
+ str(csv_dir),
373
378
  columns=kwargs.get("columns"),
374
379
  index_col=kwargs.get("index_col", 0),
375
380
  )
@@ -388,7 +393,9 @@ class MT5DataHandler(BaseCSVDataHandler):
388
393
  for different time frames.
389
394
  """
390
395
 
391
- def __init__(self, events: Queue, symbol_list: List[str], **kwargs):
396
+ def __init__(
397
+ self, events: "Queue[MarketEvent]", symbol_list: List[str], **kwargs: Any
398
+ ) -> None:
392
399
  """
393
400
  Args:
394
401
  events (Queue): The Event Queue for passing market events.
@@ -414,19 +421,23 @@ class MT5DataHandler(BaseCSVDataHandler):
414
421
  self.data_dir = kwargs.get("data_dir")
415
422
  self.symbol_list = symbol_list
416
423
  self.kwargs = kwargs
417
- self.kwargs["backtest"] = True # Ensure backtest mode is set to avoid InvalidBroker errors
424
+ self.kwargs["backtest"] = (
425
+ True # Ensure backtest mode is set to avoid InvalidBroker errors
426
+ )
418
427
 
419
428
  csv_dir = self._download_and_cache_data(self.data_dir)
420
429
  super().__init__(
421
430
  events,
422
431
  symbol_list,
423
- csv_dir,
432
+ str(csv_dir),
424
433
  columns=kwargs.get("columns"),
425
434
  index_col=kwargs.get("index_col", 0),
426
435
  )
427
436
 
428
- def _download_and_cache_data(self, cache_dir: str):
429
- data_dir = cache_dir or BBSTRADER_DIR / "data" / "mt5" / self.tf
437
+ def _download_and_cache_data(self, cache_dir: Optional[str]) -> Path:
438
+ data_dir = (
439
+ Path(cache_dir) if cache_dir else BBSTRADER_DIR / "data" / "mt5" / self.tf
440
+ )
430
441
  data_dir.mkdir(parents=True, exist_ok=True)
431
442
  for symbol in self.symbol_list:
432
443
  try:
@@ -461,7 +472,9 @@ class YFDataHandler(BaseCSVDataHandler):
461
472
  This class is useful when working with historical daily prices.
462
473
  """
463
474
 
464
- def __init__(self, events: Queue, symbol_list: List[str], **kwargs):
475
+ def __init__(
476
+ self, events: "Queue[MarketEvent]", symbol_list: List[str], **kwargs: Any
477
+ ) -> None:
465
478
  """
466
479
  Args:
467
480
  events (Queue): The Event Queue for passing market events.
@@ -483,17 +496,17 @@ class YFDataHandler(BaseCSVDataHandler):
483
496
  super().__init__(
484
497
  events,
485
498
  symbol_list,
486
- csv_dir,
499
+ str(csv_dir),
487
500
  columns=kwargs.get("columns"),
488
501
  index_col=kwargs.get("index_col", 0),
489
502
  )
490
503
 
491
- def _download_and_cache_data(self, cache_dir: str):
504
+ def _download_and_cache_data(self, cache_dir: Optional[str]) -> str:
492
505
  """Downloads and caches historical data as CSV files."""
493
- cache_dir = cache_dir or BBSTRADER_DIR / "data" / "yfinance" / "daily"
494
- os.makedirs(cache_dir, exist_ok=True)
506
+ _cache_dir = cache_dir or BBSTRADER_DIR / "data" / "yfinance" / "daily"
507
+ os.makedirs(_cache_dir, exist_ok=True)
495
508
  for symbol in self.symbol_list:
496
- filepath = os.path.join(cache_dir, f"{symbol}.csv")
509
+ filepath = os.path.join(_cache_dir, f"{symbol}.csv")
497
510
  try:
498
511
  data = yf.download(
499
512
  symbol,
@@ -510,7 +523,7 @@ class YFDataHandler(BaseCSVDataHandler):
510
523
  data.to_csv(filepath)
511
524
  except Exception as e:
512
525
  raise ValueError(f"Error downloading {symbol}: {e}")
513
- return cache_dir
526
+ return str(_cache_dir)
514
527
 
515
528
 
516
529
  class EODHDataHandler(BaseCSVDataHandler):
@@ -522,7 +535,9 @@ class EODHDataHandler(BaseCSVDataHandler):
522
535
  https://eodhistoricaldata.com/ and provide the key as an argument.
523
536
  """
524
537
 
525
- def __init__(self, events: Queue, symbol_list: List[str], **kwargs):
538
+ def __init__(
539
+ self, events: "Queue[MarketEvent]", symbol_list: List[str], **kwargs: Any
540
+ ) -> None:
526
541
  """
527
542
  Args:
528
543
  events (Queue): The Event Queue for passing market events.
@@ -548,12 +563,14 @@ class EODHDataHandler(BaseCSVDataHandler):
548
563
  super().__init__(
549
564
  events,
550
565
  symbol_list,
551
- csv_dir,
566
+ str(csv_dir),
552
567
  columns=kwargs.get("columns"),
553
568
  index_col=kwargs.get("index_col", 0),
554
569
  )
555
570
 
556
- def _get_data(self, symbol: str, period) -> pd.DataFrame | List[Dict]:
571
+ def _get_data(
572
+ self, symbol: str, period: str
573
+ ) -> Union[pd.DataFrame, List[Dict[str, Any]]]:
557
574
  if not self.__api_key:
558
575
  raise ValueError("API key is required for EODHD data.")
559
576
  client = APIClient(api_key=self.__api_key)
@@ -567,8 +584,8 @@ class EODHDataHandler(BaseCSVDataHandler):
567
584
  elif period in ["1m", "5m", "1h"]:
568
585
  hms = " 00:00:00"
569
586
  fmt = "%Y-%m-%d %H:%M:%S"
570
- startdt = datetime.strptime(self.start_date + hms, fmt)
571
- enddt = datetime.strptime(self.end_date + hms, fmt)
587
+ startdt = datetime.strptime(str(self.start_date) + hms, fmt)
588
+ enddt = datetime.strptime(str(self.end_date) + hms, fmt)
572
589
  startdt = startdt.replace(tzinfo=timezone("UTC"))
573
590
  enddt = enddt.replace(tzinfo=timezone("UTC"))
574
591
  unix_start = int(startdt.timestamp())
@@ -579,8 +596,11 @@ class EODHDataHandler(BaseCSVDataHandler):
579
596
  from_unix_time=unix_start,
580
597
  to_unix_time=unix_end,
581
598
  )
599
+ raise ValueError(f"Unsupported period: {period}")
582
600
 
583
- def _format_data(self, data: List[Dict] | pd.DataFrame) -> pd.DataFrame:
601
+ def _format_data(
602
+ self, data: Union[List[Dict[str, Any]], pd.DataFrame]
603
+ ) -> pd.DataFrame:
584
604
  if isinstance(data, pd.DataFrame):
585
605
  if data.empty or len(data) == 0:
586
606
  raise ValueError("No data found.")
@@ -599,20 +619,21 @@ class EODHDataHandler(BaseCSVDataHandler):
599
619
  df.date = pd.to_datetime(df.date)
600
620
  df = df.set_index("date")
601
621
  return df
622
+ raise TypeError(f"Unsupported data type: {type(data)}")
602
623
 
603
- def _download_and_cache_data(self, cache_dir: str):
624
+ def _download_and_cache_data(self, cache_dir: Optional[str]) -> str:
604
625
  """Downloads and caches historical data as CSV files."""
605
- cache_dir = cache_dir or BBSTRADER_DIR / "data" / "eodhd" / self.period
606
- os.makedirs(cache_dir, exist_ok=True)
626
+ _cache_dir = cache_dir or BBSTRADER_DIR / "data" / "eodhd" / self.period
627
+ os.makedirs(_cache_dir, exist_ok=True)
607
628
  for symbol in self.symbol_list:
608
- filepath = os.path.join(cache_dir, f"{symbol}.csv")
629
+ filepath = os.path.join(_cache_dir, f"{symbol}.csv")
609
630
  try:
610
631
  data = self._get_data(symbol, self.period)
611
632
  data = self._format_data(data)
612
633
  data.to_csv(filepath)
613
634
  except Exception as e:
614
635
  raise ValueError(f"Error downloading {symbol}: {e}")
615
- return cache_dir
636
+ return str(_cache_dir)
616
637
 
617
638
 
618
639
  class FMPDataHandler(BaseCSVDataHandler):
@@ -626,7 +647,9 @@ class FMPDataHandler(BaseCSVDataHandler):
626
647
 
627
648
  """
628
649
 
629
- def __init__(self, events: Queue, symbol_list: List[str], **kwargs):
650
+ def __init__(
651
+ self, events: "Queue[MarketEvent]", symbol_list: List[str], **kwargs: Any
652
+ ) -> None:
630
653
  """
631
654
  Args:
632
655
  events (Queue): The Event Queue for passing market events.
@@ -653,7 +676,7 @@ class FMPDataHandler(BaseCSVDataHandler):
653
676
  super().__init__(
654
677
  events,
655
678
  symbol_list,
656
- csv_dir,
679
+ str(csv_dir),
657
680
  columns=kwargs.get("columns"),
658
681
  index_col=kwargs.get("index_col", 0),
659
682
  )
@@ -673,6 +696,7 @@ class FMPDataHandler(BaseCSVDataHandler):
673
696
  return toolkit.get_historical_data(period=period, progress_bar=False)
674
697
  elif period in ["1min", "5min", "15min", "30min", "1hour"]:
675
698
  return toolkit.get_intraday_data(period=period, progress_bar=False)
699
+ raise ValueError(f"Unsupported period: {period}")
676
700
 
677
701
  def _format_data(self, data: pd.DataFrame, period: str) -> pd.DataFrame:
678
702
  if data.empty or len(data) == 0:
@@ -696,24 +720,24 @@ class FMPDataHandler(BaseCSVDataHandler):
696
720
  data = data.reset_index()
697
721
  if "Adj Close" not in data.columns:
698
722
  data["Adj Close"] = data["Close"]
699
- data["date"] = data["date"].dt.to_timestamp()
723
+ data["date"] = data["date"].dt.to_timestamp() # type: ignore
700
724
  data["date"] = pd.to_datetime(data["date"])
701
725
  data.set_index("date", inplace=True)
702
726
  return data
703
727
 
704
- def _download_and_cache_data(self, cache_dir: str):
728
+ def _download_and_cache_data(self, cache_dir: Optional[str]) -> str:
705
729
  """Downloads and caches historical data as CSV files."""
706
- cache_dir = cache_dir or BBSTRADER_DIR / "data" / "fmp" / self.period
707
- os.makedirs(cache_dir, exist_ok=True)
730
+ _cache_dir = cache_dir or BBSTRADER_DIR / "data" / "fmp" / self.period
731
+ os.makedirs(_cache_dir, exist_ok=True)
708
732
  for symbol in self.symbol_list:
709
- filepath = os.path.join(cache_dir, f"{symbol}.csv")
733
+ filepath = os.path.join(_cache_dir, f"{symbol}.csv")
710
734
  try:
711
735
  data = self._get_data(symbol, self.period)
712
736
  data = self._format_data(data, self.period)
713
737
  data.to_csv(filepath)
714
738
  except Exception as e:
715
739
  raise ValueError(f"Error downloading {symbol}: {e}")
716
- return cache_dir
740
+ return str(_cache_dir)
717
741
 
718
742
 
719
743
  # TODO Add data Handlers for Interactive Brokers
@@ -1,11 +1,11 @@
1
1
  from datetime import datetime
2
2
  from enum import Enum
3
- from typing import Literal
3
+ from typing import Literal, Optional, Union
4
4
 
5
- __all__ = ["Event", "Events", "MarketEvent", "SignalEvent", "OrderEvent", "FillEvent"]
5
+ __all__ = ["Event", "Events", "MarketEvent", "SignalEvent", "OrderEvent", "FillEvent"]
6
6
 
7
7
 
8
- class Event(object):
8
+ class Event:
9
9
  """
10
10
  Event is base class providing an interface for all subsequent
11
11
  (inherited) events, that will trigger further events in the
@@ -36,7 +36,7 @@ class MarketEvent(Event):
36
36
  that it is a market event, with no other structure.
37
37
  """
38
38
 
39
- def __init__(self):
39
+ def __init__(self) -> None:
40
40
  """
41
41
  Initialises the MarketEvent.
42
42
  """
@@ -59,11 +59,11 @@ class SignalEvent(Event):
59
59
  symbol: str,
60
60
  datetime: datetime,
61
61
  signal_type: Literal["LONG", "SHORT", "EXIT"],
62
- quantity: int | float = 100,
63
- strength: int | float = 1.0,
64
- price: int | float = None,
65
- stoplimit: int | float = None,
66
- ):
62
+ quantity: Union[int, float] = 100,
63
+ strength: Union[int, float] = 1.0,
64
+ price: Optional[Union[int, float]] = None,
65
+ stoplimit: Optional[Union[int, float]] = None,
66
+ ) -> None:
67
67
  """
68
68
  Initialises the SignalEvent.
69
69
 
@@ -108,11 +108,11 @@ class OrderEvent(Event):
108
108
  self,
109
109
  symbol: str,
110
110
  order_type: Literal["MKT", "LMT", "STP", "STPLMT"],
111
- quantity: int | float,
111
+ quantity: Union[int, float],
112
112
  direction: Literal["BUY", "SELL"],
113
- price: int | float = None,
114
- signal: str = None,
115
- ):
113
+ price: Optional[Union[int, float]] = None,
114
+ signal: Optional[str] = None,
115
+ ) -> None:
116
116
  """
117
117
  Initialises the order type, setting whether it is
118
118
  a Market order ('MKT') or Limit order ('LMT'), or Stop order ('STP').
@@ -134,7 +134,7 @@ class OrderEvent(Event):
134
134
  self.price = price
135
135
  self.signal = signal
136
136
 
137
- def print_order(self):
137
+ def print_order(self) -> None:
138
138
  """
139
139
  Outputs the values within the Order.
140
140
  """
@@ -150,7 +150,6 @@ class OrderEvent(Event):
150
150
  )
151
151
 
152
152
 
153
-
154
153
  class FillEvent(Event):
155
154
  """
156
155
  When an `ExecutionHandler` receives an `OrderEvent` it must transact the order.
@@ -175,12 +174,12 @@ class FillEvent(Event):
175
174
  timeindex: datetime,
176
175
  symbol: str,
177
176
  exchange: str,
178
- quantity: int | float,
177
+ quantity: Union[int, float],
179
178
  direction: Literal["BUY", "SELL"],
180
- fill_cost: int | float | None,
181
- commission: float | None = None,
182
- order: str = None,
183
- ):
179
+ fill_cost: Optional[Union[int, float]],
180
+ commission: Optional[float] = None,
181
+ order: Optional[str] = None,
182
+ ) -> None:
184
183
  """
185
184
  Initialises the FillEvent object. Sets the symbol, exchange,
186
185
  quantity, direction, cost of fill and an optional
@@ -209,12 +208,12 @@ class FillEvent(Event):
209
208
  self.fill_cost = fill_cost
210
209
  # Calculate commission
211
210
  if commission is None:
212
- self.commission = self.calculate_ib_commission()
211
+ self.commission: float = self.calculate_ib_commission()
213
212
  else:
214
213
  self.commission = commission
215
214
  self.order = order
216
215
 
217
- def calculate_ib_commission(self):
216
+ def calculate_ib_commission(self) -> float:
218
217
  """
219
218
  Calculates the fees of trading based on an Interactive
220
219
  Brokers fee structure for API, in USD.