continual-foragax 0.28.1__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.28.1
3
+ Version: 0.30.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=8V8HHfOBBB6adW3pGqEoO4YFWl2CzeKbgRyZ2W9Rpl4,25682
4
- foragax/objects.py,sha256=FCLZ-8d7qq9VMTG6G-TaRt842-sjgB0-DH0IoHwwngI,9503
5
- foragax/registry.py,sha256=HysNaZs1tcbAcr53l8Cb2NeZ-_FmE6OpUe_zIks-ObM,15089
3
+ foragax/env.py,sha256=4NZ5JsUGjAepmzw2uxu5_ikyVZnZ7vazy062Xzx22Zg,27481
4
+ foragax/objects.py,sha256=0vb_iyr62BKaIxiE3JwtRhZhFE3VFM6PdxDZTaDtv24,10410
5
+ foragax/registry.py,sha256=Dxg6cWIPwg91fNrCPxADJv35u6jFg_8dI5iTpCMFEFA,15229
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.28.1.dist-info/METADATA,sha256=dO9WXb8d6s6PWMklUwrb_EGYsKzeU3rT-FyqRZtQRkQ,4897
132
- continual_foragax-0.28.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.28.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.28.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.28.1.dist-info/RECORD,,
131
+ continual_foragax-0.30.0.dist-info/METADATA,sha256=d0xeSz0BvDVe1lOUGdhVyqnbkkYN7dNW4BPfCnDSZfQ,4897
132
+ continual_foragax-0.30.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.30.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.30.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.30.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -61,6 +61,7 @@ class EnvState(environment.EnvState):
61
61
  object_grid: jax.Array
62
62
  biome_grid: jax.Array
63
63
  time: int
64
+ digestion_buffer: jax.Array
64
65
 
65
66
 
66
67
  class ForagaxEnv(environment.Environment):
@@ -102,11 +103,6 @@ class ForagaxEnv(environment.Environment):
102
103
  if self.nowrap and not self.full_world:
103
104
  objects = objects + (PADDING,)
104
105
  self.objects = objects
105
- self.weather_object = None
106
- for o in objects:
107
- if isinstance(o, WeatherObject):
108
- self.weather_object = o
109
- break
110
106
 
111
107
  # JIT-compatible versions of object and biome properties
112
108
  self.object_ids = jnp.arange(len(objects))
@@ -117,6 +113,13 @@ class ForagaxEnv(environment.Environment):
117
113
 
118
114
  self.reward_fns = [o.reward for o in objects]
119
115
  self.regen_delay_fns = [o.regen_delay for o in objects]
116
+ self.reward_delay_fns = [o.reward_delay for o in objects]
117
+
118
+ # Compute reward steps per object (using max_reward_delay attribute)
119
+ object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
120
+ self.max_reward_delay = (
121
+ int(jnp.max(object_max_reward_delay)) + 1 if len(objects) > 0 else 0
122
+ )
120
123
 
121
124
  self.biome_object_frequencies = jnp.array(
122
125
  [b.object_frequencies for b in biomes]
@@ -237,12 +240,36 @@ class ForagaxEnv(environment.Environment):
237
240
 
238
241
  # 2. HANDLE COLLISIONS AND REWARDS
239
242
  obj_at_pos = current_objects[pos[1], pos[0]]
240
- key, subkey = jax.random.split(key)
241
- reward = jax.lax.switch(obj_at_pos, self.reward_fns, state.time, subkey)
242
243
  is_collectable = self.object_collectable[obj_at_pos]
244
+ should_collect = is_collectable & (obj_at_pos > 0)
245
+
246
+ # Handle digestion: add reward to buffer if collected
247
+ digestion_buffer = state.digestion_buffer
248
+ key, reward_subkey = jax.random.split(key)
249
+ object_reward = jax.lax.switch(
250
+ obj_at_pos, self.reward_fns, state.time, reward_subkey
251
+ )
252
+ key, digestion_subkey = jax.random.split(key)
253
+ reward_delay = jax.lax.switch(
254
+ obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
255
+ )
256
+ reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
257
+ if self.max_reward_delay > 0:
258
+ # Add delayed rewards to buffer
259
+ digestion_buffer = jax.lax.cond(
260
+ should_collect & (reward_delay > 0),
261
+ lambda: digestion_buffer.at[
262
+ (state.time + reward_delay) % self.max_reward_delay
263
+ ].add(object_reward),
264
+ lambda: digestion_buffer,
265
+ )
266
+ # Deliver current rewards
267
+ current_index = state.time % self.max_reward_delay
268
+ reward += digestion_buffer[current_index]
269
+ digestion_buffer = digestion_buffer.at[current_index].set(0.0)
243
270
 
244
271
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
245
- key, subkey, rand_key = jax.random.split(key, 3)
272
+ key, regen_subkey, rand_key = jax.random.split(key, 3)
246
273
 
247
274
  # Decrement timers (stored as negative values)
248
275
  is_timer = state.object_grid < 0
@@ -252,7 +279,7 @@ class ForagaxEnv(environment.Environment):
252
279
 
253
280
  # Collect object: set a timer
254
281
  regen_delay = jax.lax.switch(
255
- obj_at_pos, self.regen_delay_fns, state.time, subkey
282
+ obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
256
283
  )
257
284
  encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
258
285
 
@@ -301,10 +328,13 @@ class ForagaxEnv(environment.Environment):
301
328
  )
302
329
 
303
330
  info = {"discount": self.discount(state, params)}
304
- if self.weather_object is not None:
305
- info["temperature"] = get_temperature(
306
- self.weather_object.rewards, state.time, self.weather_object.repeat
307
- )
331
+ temperatures = jnp.zeros(len(self.objects))
332
+ for obj_index, obj in enumerate(self.objects):
333
+ if isinstance(obj, WeatherObject):
334
+ temperatures = temperatures.at[obj_index].set(
335
+ get_temperature(obj.rewards, state.time, obj.repeat)
336
+ )
337
+ info["temperatures"] = temperatures
308
338
  info["biome_id"] = state.biome_grid[pos[1], pos[0]]
309
339
  info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
310
340
 
@@ -314,6 +344,7 @@ class ForagaxEnv(environment.Environment):
314
344
  object_grid=object_grid,
315
345
  biome_grid=state.biome_grid,
316
346
  time=state.time + 1,
347
+ digestion_buffer=digestion_buffer,
317
348
  )
318
349
 
319
350
  done = self.is_terminal(state, params)
@@ -352,6 +383,7 @@ class ForagaxEnv(environment.Environment):
352
383
  object_grid=object_grid,
353
384
  biome_grid=biome_grid,
354
385
  time=0,
386
+ digestion_buffer=jnp.zeros((self.max_reward_delay,)),
355
387
  )
356
388
 
357
389
  return self.get_obs(state, params), state
@@ -412,6 +444,12 @@ class ForagaxEnv(environment.Environment):
412
444
  int,
413
445
  ),
414
446
  "time": spaces.Discrete(params.max_steps_in_episode),
447
+ "digestion_buffer": spaces.Box(
448
+ -jnp.inf,
449
+ jnp.inf,
450
+ (self.max_reward_delay,),
451
+ float,
452
+ ),
415
453
  }
416
454
  )
417
455
 
foragax/objects.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Tuple
2
+ from typing import Optional, Tuple
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
@@ -17,12 +17,14 @@ class BaseForagaxObject:
17
17
  collectable: bool = False,
18
18
  color: Tuple[int, int, int] = (0, 0, 0),
19
19
  random_respawn: bool = False,
20
+ max_reward_delay: int = 0,
20
21
  ):
21
22
  self.name = name
22
23
  self.blocking = blocking
23
24
  self.collectable = collectable
24
25
  self.color = color
25
26
  self.random_respawn = random_respawn
27
+ self.max_reward_delay = max_reward_delay
26
28
 
27
29
  @abc.abstractmethod
28
30
  def reward(self, clock: int, rng: jax.Array) -> float:
@@ -30,8 +32,8 @@ class BaseForagaxObject:
30
32
  raise NotImplementedError
31
33
 
32
34
  @abc.abstractmethod
33
- def regen_delay(self, clock: int, rng: jax.Array) -> int:
34
- """Regeneration delay function."""
35
+ def reward_delay(self, clock: int, rng: jax.Array) -> int:
36
+ """Reward delay function."""
35
37
  raise NotImplementedError
36
38
 
37
39
 
@@ -47,10 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
47
49
  regen_delay: Tuple[int, int] = (10, 100),
48
50
  color: Tuple[int, int, int] = (255, 255, 255),
49
51
  random_respawn: bool = False,
52
+ reward_delay: int = 0,
53
+ max_reward_delay: Optional[int] = None,
50
54
  ):
51
- super().__init__(name, blocking, collectable, color, random_respawn)
55
+ if max_reward_delay is None:
56
+ max_reward_delay = reward_delay
57
+ super().__init__(
58
+ name, blocking, collectable, color, random_respawn, max_reward_delay
59
+ )
52
60
  self.reward_val = reward
53
61
  self.regen_delay_range = regen_delay
62
+ self.reward_delay_val = reward_delay
54
63
 
55
64
  def reward(self, clock: int, rng: jax.Array) -> float:
56
65
  """Default reward function."""
@@ -61,6 +70,10 @@ class DefaultForagaxObject(BaseForagaxObject):
61
70
  min_delay, max_delay = self.regen_delay_range
62
71
  return jax.random.randint(rng, (), min_delay, max_delay)
63
72
 
73
+ def reward_delay(self, clock: int, rng: jax.Array) -> int:
74
+ """Default reward delay function."""
75
+ return self.reward_delay_val
76
+
64
77
 
65
78
  class NormalRegenForagaxObject(DefaultForagaxObject):
66
79
  """Object with regeneration delay from a normal distribution."""
@@ -74,6 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
74
87
  std_regen_delay: int = 1,
75
88
  color: Tuple[int, int, int] = (0, 0, 0),
76
89
  random_respawn: bool = False,
90
+ reward_delay: int = 0,
91
+ max_reward_delay: Optional[int] = None,
77
92
  ):
78
93
  super().__init__(
79
94
  name=name,
@@ -82,6 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
82
97
  regen_delay=(mean_regen_delay, mean_regen_delay),
83
98
  color=color,
84
99
  random_respawn=random_respawn,
100
+ reward_delay=reward_delay,
101
+ max_reward_delay=max_reward_delay,
85
102
  )
86
103
  self.mean_regen_delay = mean_regen_delay
87
104
  self.std_regen_delay = std_regen_delay
@@ -105,6 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
105
122
  std_regen_delay: int = 1,
106
123
  color: Tuple[int, int, int] = (0, 0, 0),
107
124
  random_respawn: bool = False,
125
+ reward_delay: int = 0,
126
+ max_reward_delay: Optional[int] = None,
108
127
  ):
109
128
  super().__init__(
110
129
  name=name,
@@ -113,6 +132,8 @@ class WeatherObject(NormalRegenForagaxObject):
113
132
  std_regen_delay=std_regen_delay,
114
133
  color=color,
115
134
  random_respawn=random_respawn,
135
+ reward_delay=reward_delay,
136
+ max_reward_delay=max_reward_delay,
116
137
  )
117
138
  self.rewards = rewards
118
139
  self.repeat = repeat
@@ -319,6 +340,7 @@ def create_weather_objects(
319
340
  multiplier: float = 1.0,
320
341
  same_color: bool = False,
321
342
  random_respawn: bool = False,
343
+ reward_delay: int = 0,
322
344
  ):
323
345
  """Create HOT and COLD WeatherObject instances using the specified file.
324
346
 
@@ -348,6 +370,7 @@ def create_weather_objects(
348
370
  multiplier=multiplier,
349
371
  color=hot_color,
350
372
  random_respawn=random_respawn,
373
+ reward_delay=reward_delay,
351
374
  )
352
375
 
353
376
  cold_color = hot_color if same_color else (0, 255, 255)
@@ -358,6 +381,7 @@ def create_weather_objects(
358
381
  multiplier=-multiplier,
359
382
  color=cold_color,
360
383
  random_respawn=random_respawn,
384
+ reward_delay=reward_delay,
361
385
  )
362
386
 
363
387
  return hot, cold
foragax/registry.py CHANGED
@@ -348,7 +348,8 @@ def make(
348
348
  observation_type: str = "color",
349
349
  aperture_size: Optional[Tuple[int, int]] = (5, 5),
350
350
  file_index: int = 0,
351
- nowrap: Optional[bool] = None,
351
+ repeat: int = 500,
352
+ reward_delay: int = 0,
352
353
  **kwargs: Any,
353
354
  ) -> ForagaxEnv:
354
355
  """Create a Foragax environment.
@@ -358,9 +359,9 @@ def make(
358
359
  observation_type: The type of observation to use. One of "object", "rgb", or "color".
359
360
  aperture_size: The size of the agent's observation aperture. If -1, full world observation.
360
361
  If None, the default for the environment is used.
361
- file_index: File index for weather objects. nowrap: If True, disables
362
- wrapping around environment boundaries. If None, uses defaults per
363
- environment.
362
+ file_index: File index for weather objects.
363
+ repeat: How many steps each temperature value repeats for (weather environments).
364
+ reward_delay: Number of steps required to digest food items (weather environments).
364
365
  **kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
365
366
 
366
367
  Returns:
@@ -376,8 +377,6 @@ def make(
376
377
  else:
377
378
  aperture_size = (aperture_size, aperture_size)
378
379
  config["aperture_size"] = aperture_size
379
- if nowrap is not None:
380
- config["nowrap"] = nowrap
381
380
 
382
381
  # Handle special size and biome configurations
383
382
  if env_id in (
@@ -460,9 +459,16 @@ def make(
460
459
  "ForagaxWeather-v4",
461
460
  "ForagaxWeather-v5",
462
461
  )
463
- random_respawn = env_id in ("ForagaxWeather-v4", "ForagaxWeather-v5")
462
+ random_respawn = env_id in (
463
+ "ForagaxWeather-v4",
464
+ "ForagaxWeather-v5",
465
+ )
464
466
  hot, cold = create_weather_objects(
465
- file_index=file_index, same_color=same_color, random_respawn=random_respawn
467
+ file_index=file_index,
468
+ repeat=repeat,
469
+ same_color=same_color,
470
+ random_respawn=random_respawn,
471
+ reward_delay=reward_delay,
466
472
  )
467
473
  config["objects"] = (hot, cold)
468
474