prosperity3bt 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prosperity3bt/data.py CHANGED
@@ -49,6 +49,18 @@ def get_column_values(columns: list[str], indices: list[int]) -> list[int]:
49
49
  return values
50
50
 
51
51
 
52
+ @dataclass
53
+ class ObservationRow:
54
+ timestamp: int
55
+ bidPrice: float
56
+ askPrice: float
57
+ transportFees: float
58
+ exportTariff: float
59
+ importTariff: float
60
+ sugarPrice: float
61
+ sunlightIndex: float
62
+
63
+
52
64
  @dataclass
53
65
  class BacktestData:
54
66
  round_num: int
@@ -56,11 +68,14 @@ class BacktestData:
56
68
 
57
69
  prices: dict[int, dict[Symbol, PriceRow]]
58
70
  trades: dict[int, dict[Symbol, list[Trade]]]
71
+ observations: dict[int, ObservationRow]
59
72
  products: list[Symbol]
60
73
  profit_loss: dict[Symbol, float]
61
74
 
62
75
 
63
- def create_backtest_data(round_num: int, day_num: int, prices: list[PriceRow], trades: list[Trade]) -> BacktestData:
76
+ def create_backtest_data(
77
+ round_num: int, day_num: int, prices: list[PriceRow], trades: list[Trade], observations: list[ObservationRow]
78
+ ) -> BacktestData:
64
79
  prices_by_timestamp: dict[int, dict[Symbol, PriceRow]] = defaultdict(dict)
65
80
  for row in prices:
66
81
  prices_by_timestamp[row.timestamp][row.product] = row
@@ -72,11 +87,14 @@ def create_backtest_data(round_num: int, day_num: int, prices: list[PriceRow], t
72
87
  products = sorted(set(row.product for row in prices))
73
88
  profit_loss = {product: 0.0 for product in products}
74
89
 
90
+ observations_by_timestamp = {row.timestamp: row for row in observations}
91
+
75
92
  return BacktestData(
76
93
  round_num=round_num,
77
94
  day_num=day_num,
78
95
  prices=prices_by_timestamp,
79
96
  trades=trades_by_timestamp,
97
+ observations=observations_by_timestamp,
80
98
  products=products,
81
99
  profit_loss=profit_loss,
82
100
  )
@@ -87,6 +105,31 @@ def has_day_data(file_reader: FileReader, round_num: int, day_num: int) -> bool:
87
105
  return file is not None
88
106
 
89
107
 
108
+ def read_observations(file_reader: FileReader, round_num: int, day_num: int) -> list[ObservationRow]:
109
+ observations = []
110
+ with file_reader.file([f"round{round_num}", f"observations_round_{round_num}_day_{day_num}.csv"]) as file:
111
+ if file is None:
112
+ return []
113
+
114
+ for line in file.read_text(encoding="utf-8").splitlines()[1:]:
115
+ columns = line.split(",")
116
+
117
+ observations.append(
118
+ ObservationRow(
119
+ timestamp=int(columns[0]),
120
+ bidPrice=float(columns[1]),
121
+ askPrice=float(columns[2]),
122
+ transportFees=float(columns[3]),
123
+ exportTariff=float(columns[4]),
124
+ importTariff=float(columns[5]),
125
+ sugarPrice=float(columns[6]),
126
+ sunlightIndex=float(columns[7]),
127
+ )
128
+ )
129
+
130
+ return observations
131
+
132
+
90
133
  def read_day_data(file_reader: FileReader, round_num: int, day_num: int, no_names: bool) -> BacktestData:
91
134
  prices = []
92
135
  with file_reader.file([f"round{round_num}", f"prices_round_{round_num}_day_{day_num}.csv"]) as file:
@@ -134,4 +177,6 @@ def read_day_data(file_reader: FileReader, round_num: int, day_num: int, no_name
134
177
 
135
178
  break
136
179
 
137
- return create_backtest_data(round_num, day_num, prices, trades)
180
+ observations = read_observations(file_reader, round_num, day_num)
181
+
182
+ return create_backtest_data(round_num, day_num, prices, trades, observations)