continual-foragax 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.22.0
3
+ Version: 0.23.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=Q-96fMoA_51TIJky2JXApITQHXC-1QfdeB5VZvNwe0o,21362
4
- foragax/objects.py,sha256=8tBFMiquWCkhOpNndNmzovMjw7lE5P81OOlUvN2F65w,8301
5
- foragax/registry.py,sha256=pRBWGP18jd4NKl1H-rwDYaAJKUgRWVfENQ9pvTS0tAw,9462
3
+ foragax/env.py,sha256=7csniZeHv2ZX6-8CoYm0pMue7tXmmDwr2XU5msH0n0Q,23680
4
+ foragax/objects.py,sha256=CblI0NI7PQzeKk3MZa8sbaa9wB4pc_8CyOGbJOFWytE,9391
5
+ foragax/registry.py,sha256=EbvPn2IpEMZ7HRjkAN7_JWaTJfYgjLhHkqspcT88DiY,11339
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.22.0.dist-info/METADATA,sha256=WCVeg6996zpBsLWNghDpibCAs70CDaI8KStzCFnSPNM,4897
132
- continual_foragax-0.22.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.22.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.22.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.22.0.dist-info/RECORD,,
131
+ continual_foragax-0.23.0.dist-info/METADATA,sha256=EIaFemphhlG9KTITgvT_qTQfjKS7sQl7hdrxCceAsW4,4897
132
+ continual_foragax-0.23.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.23.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.23.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.23.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -59,6 +59,7 @@ class EnvParams(environment.EnvParams):
59
59
  class EnvState(environment.EnvState):
60
60
  pos: jax.Array
61
61
  object_grid: jax.Array
62
+ biome_grid: jax.Array
62
63
  time: int
63
64
 
64
65
 
@@ -101,6 +102,7 @@ class ForagaxEnv(environment.Environment):
101
102
  self.object_blocking = jnp.array([o.blocking for o in objects])
102
103
  self.object_collectable = jnp.array([o.collectable for o in objects])
103
104
  self.object_colors = jnp.array([o.color for o in objects])
105
+ self.object_random_respawn = jnp.array([o.random_respawn for o in objects])
104
106
 
105
107
  self.reward_fns = [o.reward for o in objects]
106
108
  self.regen_delay_fns = [o.regen_delay for o in objects]
@@ -179,7 +181,7 @@ class ForagaxEnv(environment.Environment):
179
181
  is_collectable = self.object_collectable[obj_at_pos]
180
182
 
181
183
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
182
- key, subkey = jax.random.split(key)
184
+ key, subkey, rand_key = jax.random.split(key, 3)
183
185
 
184
186
  # Decrement timers (stored as negative values)
185
187
  is_timer = state.object_grid < 0
@@ -193,10 +195,49 @@ class ForagaxEnv(environment.Environment):
193
195
  )
194
196
  encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
195
197
 
198
+ def place_at_current_pos(current_grid, timer_val):
199
+ return current_grid.at[pos[1], pos[0]].set(timer_val)
200
+
201
+ def place_at_random_pos(current_grid, timer_val):
202
+ # Set the collected position to empty temporarily
203
+ grid = current_grid.at[pos[1], pos[0]].set(0)
204
+
205
+ # Find all valid spawn locations (empty cells within the same biome)
206
+ biome_id = state.biome_grid[pos[1], pos[0]]
207
+ biome_mask = state.biome_grid == biome_id
208
+ empty_mask = grid == 0
209
+ valid_spawn_mask = biome_mask & empty_mask
210
+
211
+ num_valid_spawns = jnp.sum(valid_spawn_mask)
212
+
213
+ # Get indices of valid spawn locations, padded to a static size
214
+ y_indices, x_indices = jnp.nonzero(
215
+ valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
216
+ )
217
+ valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
218
+
219
+ # Select a random valid location
220
+ random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
221
+ new_spawn_pos = valid_spawn_indices[random_idx]
222
+
223
+ # Place the timer at the new random position
224
+ return grid.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
225
+
196
226
  # If collected, replace object with timer; otherwise, keep it
197
227
  val_at_pos = object_grid[pos[1], pos[0]]
198
- new_val_at_pos = jax.lax.select(is_collectable, encoded_timer, val_at_pos)
199
- object_grid = object_grid.at[pos[1], pos[0]].set(new_val_at_pos)
228
+ should_collect = is_collectable & (val_at_pos > 0)
229
+
230
+ # When not collecting, the value at the position remains unchanged.
231
+ # When collecting, we either place the timer at the current position or a random one.
232
+ object_grid = jax.lax.cond(
233
+ should_collect,
234
+ lambda: jax.lax.cond(
235
+ self.object_random_respawn[obj_at_pos],
236
+ lambda: place_at_random_pos(object_grid, encoded_timer),
237
+ lambda: place_at_current_pos(object_grid, encoded_timer),
238
+ ),
239
+ lambda: object_grid,
240
+ )
200
241
 
201
242
  info = {"discount": self.discount(state, params)}
202
243
  if self.weather_object is not None:
@@ -208,6 +249,7 @@ class ForagaxEnv(environment.Environment):
208
249
  state = EnvState(
209
250
  pos=pos,
210
251
  object_grid=object_grid,
252
+ biome_grid=state.biome_grid,
211
253
  time=state.time + 1,
212
254
  )
213
255
 
@@ -225,10 +267,12 @@ class ForagaxEnv(environment.Environment):
225
267
  ) -> Tuple[jax.Array, EnvState]:
226
268
  """Reset environment state."""
227
269
  object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
270
+ biome_grid = jnp.full((self.size[1], self.size[0]), -1, dtype=int)
228
271
  key, iter_key = jax.random.split(key)
229
272
  for i in range(self.biome_object_frequencies.shape[0]):
230
273
  iter_key, biome_key = jax.random.split(iter_key)
231
274
  mask = self.biome_masks[i]
275
+ biome_grid = jnp.where(mask, i, biome_grid)
232
276
 
233
277
  if self.deterministic_spawn:
234
278
  biome_objects = self.generate_biome_new(i, biome_key)
@@ -244,6 +288,7 @@ class ForagaxEnv(environment.Environment):
244
288
  state = EnvState(
245
289
  pos=agent_pos,
246
290
  object_grid=object_grid,
291
+ biome_grid=biome_grid,
247
292
  time=0,
248
293
  )
249
294
 
@@ -298,6 +343,12 @@ class ForagaxEnv(environment.Environment):
298
343
  (self.size[1], self.size[0]),
299
344
  int,
300
345
  ),
346
+ "biome_grid": spaces.Box(
347
+ 0,
348
+ self.biome_object_frequencies.shape[0],
349
+ (self.size[1], self.size[0]),
350
+ int,
351
+ ),
301
352
  "time": spaces.Discrete(params.max_steps_in_episode),
302
353
  }
303
354
  )
@@ -455,7 +506,13 @@ class ForagaxObjectEnv(ForagaxEnv):
455
506
  deterministic_spawn: bool = False,
456
507
  ):
457
508
  super().__init__(
458
- name, size, aperture_size, objects, biomes, nowrap, deterministic_spawn
509
+ name,
510
+ size,
511
+ aperture_size,
512
+ objects,
513
+ biomes,
514
+ nowrap,
515
+ deterministic_spawn,
459
516
  )
460
517
 
461
518
  # Compute unique colors and mapping for partial observability
foragax/objects.py CHANGED
@@ -16,11 +16,13 @@ class BaseForagaxObject:
16
16
  blocking: bool = False,
17
17
  collectable: bool = False,
18
18
  color: Tuple[int, int, int] = (0, 0, 0),
19
+ random_respawn: bool = False,
19
20
  ):
20
21
  self.name = name
21
22
  self.blocking = blocking
22
23
  self.collectable = collectable
23
24
  self.color = color
25
+ self.random_respawn = random_respawn
24
26
 
25
27
  @abc.abstractmethod
26
28
  def reward(self, clock: int, rng: jax.Array) -> float:
@@ -44,8 +46,9 @@ class DefaultForagaxObject(BaseForagaxObject):
44
46
  collectable: bool = False,
45
47
  regen_delay: Tuple[int, int] = (10, 100),
46
48
  color: Tuple[int, int, int] = (255, 255, 255),
49
+ random_respawn: bool = False,
47
50
  ):
48
- super().__init__(name, blocking, collectable, color)
51
+ super().__init__(name, blocking, collectable, color, random_respawn)
49
52
  self.reward_val = reward
50
53
  self.regen_delay_range = regen_delay
51
54
 
@@ -70,6 +73,7 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
70
73
  mean_regen_delay: int = 10,
71
74
  std_regen_delay: int = 1,
72
75
  color: Tuple[int, int, int] = (0, 0, 0),
76
+ random_respawn: bool = False,
73
77
  ):
74
78
  super().__init__(
75
79
  name=name,
@@ -77,6 +81,7 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
77
81
  collectable=collectable,
78
82
  regen_delay=(mean_regen_delay, mean_regen_delay),
79
83
  color=color,
84
+ random_respawn=random_respawn,
80
85
  )
81
86
  self.mean_regen_delay = mean_regen_delay
82
87
  self.std_regen_delay = std_regen_delay
@@ -99,6 +104,7 @@ class WeatherObject(NormalRegenForagaxObject):
99
104
  mean_regen_delay: int = 10,
100
105
  std_regen_delay: int = 1,
101
106
  color: Tuple[int, int, int] = (0, 0, 0),
107
+ random_respawn: bool = False,
102
108
  ):
103
109
  super().__init__(
104
110
  name=name,
@@ -106,6 +112,7 @@ class WeatherObject(NormalRegenForagaxObject):
106
112
  mean_regen_delay=mean_regen_delay,
107
113
  std_regen_delay=std_regen_delay,
108
114
  color=color,
115
+ random_respawn=random_respawn,
109
116
  )
110
117
  self.rewards = rewards
111
118
  self.repeat = repeat
@@ -271,6 +278,40 @@ GREEN_FAKE_UNIFORM = DefaultForagaxObject(
271
278
  regen_delay=(9, 11),
272
279
  )
273
280
 
281
+ # Random respawn variants
282
+ BROWN_MOREL_UNIFORM_RANDOM = DefaultForagaxObject(
283
+ name="brown_morel",
284
+ reward=10.0,
285
+ collectable=True,
286
+ color=(63, 30, 25),
287
+ regen_delay=(90, 110),
288
+ random_respawn=True,
289
+ )
290
+ BROWN_OYSTER_UNIFORM_RANDOM = DefaultForagaxObject(
291
+ name="brown_oyster",
292
+ reward=1.0,
293
+ collectable=True,
294
+ color=(63, 30, 25),
295
+ regen_delay=(9, 11),
296
+ random_respawn=True,
297
+ )
298
+ GREEN_DEATHCAP_UNIFORM_RANDOM = DefaultForagaxObject(
299
+ name="green_deathcap",
300
+ reward=-5.0,
301
+ collectable=True,
302
+ color=(0, 255, 0),
303
+ regen_delay=(9, 11),
304
+ random_respawn=True,
305
+ )
306
+ GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
307
+ name="green_fake",
308
+ reward=0.0,
309
+ collectable=True,
310
+ color=(0, 255, 0),
311
+ regen_delay=(9, 11),
312
+ random_respawn=True,
313
+ )
314
+
274
315
 
275
316
  def create_weather_objects(
276
317
  file_index: int = 0,
foragax/registry.py CHANGED
@@ -13,15 +13,19 @@ from foragax.objects import (
13
13
  BROWN_MOREL,
14
14
  BROWN_MOREL_2,
15
15
  BROWN_MOREL_UNIFORM,
16
+ BROWN_MOREL_UNIFORM_RANDOM,
16
17
  BROWN_OYSTER,
17
18
  BROWN_OYSTER_UNIFORM,
19
+ BROWN_OYSTER_UNIFORM_RANDOM,
18
20
  GREEN_DEATHCAP,
19
21
  GREEN_DEATHCAP_2,
20
22
  GREEN_DEATHCAP_3,
21
23
  GREEN_DEATHCAP_UNIFORM,
24
+ GREEN_DEATHCAP_UNIFORM_RANDOM,
22
25
  GREEN_FAKE,
23
26
  GREEN_FAKE_2,
24
27
  GREEN_FAKE_UNIFORM,
28
+ GREEN_FAKE_UNIFORM_RANDOM,
25
29
  LARGE_MOREL,
26
30
  LARGE_OYSTER,
27
31
  MEDIUM_MOREL,
@@ -176,6 +180,45 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
176
180
  "nowrap": True,
177
181
  "deterministic_spawn": True,
178
182
  },
183
+ "ForagaxTwoBiome-v10": {
184
+ "size": None,
185
+ "aperture_size": None,
186
+ "objects": (
187
+ BROWN_MOREL_UNIFORM_RANDOM,
188
+ BROWN_OYSTER_UNIFORM_RANDOM,
189
+ GREEN_DEATHCAP_UNIFORM_RANDOM,
190
+ GREEN_FAKE_UNIFORM_RANDOM,
191
+ ),
192
+ "biomes": None,
193
+ "nowrap": True,
194
+ "deterministic_spawn": True,
195
+ },
196
+ "ForagaxTwoBiome-v11": {
197
+ "size": None,
198
+ "aperture_size": None,
199
+ "objects": (
200
+ BROWN_MOREL_UNIFORM,
201
+ BROWN_OYSTER_UNIFORM,
202
+ GREEN_DEATHCAP_UNIFORM,
203
+ GREEN_FAKE_UNIFORM,
204
+ ),
205
+ "biomes": None,
206
+ "nowrap": True,
207
+ "deterministic_spawn": True,
208
+ },
209
+ "ForagaxTwoBiome-v12": {
210
+ "size": None,
211
+ "aperture_size": None,
212
+ "objects": (
213
+ BROWN_MOREL_UNIFORM_RANDOM,
214
+ BROWN_OYSTER_UNIFORM_RANDOM,
215
+ GREEN_DEATHCAP_UNIFORM_RANDOM,
216
+ GREEN_FAKE_UNIFORM_RANDOM,
217
+ ),
218
+ "biomes": None,
219
+ "nowrap": True,
220
+ "deterministic_spawn": True,
221
+ },
179
222
  "ForagaxTwoBiomeSmall-v1": {
180
223
  "size": (16, 8),
181
224
  "aperture_size": None,
@@ -246,7 +289,12 @@ def make(
246
289
  if nowrap is not None:
247
290
  config["nowrap"] = nowrap
248
291
 
249
- if env_id in ("ForagaxTwoBiome-v7", "ForagaxTwoBiome-v8", "ForagaxTwoBiome-v9"):
292
+ if env_id in (
293
+ "ForagaxTwoBiome-v7",
294
+ "ForagaxTwoBiome-v8",
295
+ "ForagaxTwoBiome-v9",
296
+ "ForagaxTwoBiome-v10",
297
+ ):
250
298
  margin = aperture_size[1] // 2 + 1
251
299
  width = 2 * margin + 9
252
300
  config["size"] = (width, 15)
@@ -265,6 +313,25 @@ def make(
265
313
  ),
266
314
  )
267
315
 
316
+ if env_id in ("ForagaxTwoBiome-v11", "ForagaxTwoBiome-v12"):
317
+ margin = aperture_size[1] // 2 + 1
318
+ width = 2 * margin + 9
319
+ config["size"] = (width, 15)
320
+ config["biomes"] = (
321
+ # Morel biome
322
+ Biome(
323
+ start=(margin, 0),
324
+ stop=(margin + 2, 15),
325
+ object_frequencies=(0.5, 0.0, 0.25, 0.0),
326
+ ),
327
+ # Oyster biome
328
+ Biome(
329
+ start=(margin + 7, 0),
330
+ stop=(margin + 9, 15),
331
+ object_frequencies=(0.0, 0.5, 0.0, 0.25),
332
+ ),
333
+ )
334
+
268
335
  if env_id == "ForagaxWeather-v3":
269
336
  margin = aperture_size[1] // 2 + 1
270
337
  width = 2 * margin + 9