continual-foragax 0.28.1__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.28.1
3
+ Version: 0.29.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=8V8HHfOBBB6adW3pGqEoO4YFWl2CzeKbgRyZ2W9Rpl4,25682
4
- foragax/objects.py,sha256=FCLZ-8d7qq9VMTG6G-TaRt842-sjgB0-DH0IoHwwngI,9503
5
- foragax/registry.py,sha256=HysNaZs1tcbAcr53l8Cb2NeZ-_FmE6OpUe_zIks-ObM,15089
3
+ foragax/env.py,sha256=OFtDT8c5nskflnXmMNbRUvu5pVs9vhIezDgZaICXSyE,27535
4
+ foragax/objects.py,sha256=j7FivgT4uz6N4FkOTmpM0t-YjTUkYUBLAznWpVqqjrU,10509
5
+ foragax/registry.py,sha256=bFSTDo7XU4G0njHjLTgPdzWyStkAGckimDYcnGYLIIg,15529
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.28.1.dist-info/METADATA,sha256=dO9WXb8d6s6PWMklUwrb_EGYsKzeU3rT-FyqRZtQRkQ,4897
132
- continual_foragax-0.28.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.28.1.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.28.1.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.28.1.dist-info/RECORD,,
131
+ continual_foragax-0.29.0.dist-info/METADATA,sha256=vaiiCWr06OczH8LgsRotJXRt7q4KZMtnvuFj6up5v3U,4897
132
+ continual_foragax-0.29.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.29.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.29.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.29.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -61,6 +61,7 @@ class EnvState(environment.EnvState):
61
61
  object_grid: jax.Array
62
62
  biome_grid: jax.Array
63
63
  time: int
64
+ digestion_buffer: jax.Array
64
65
 
65
66
 
66
67
  class ForagaxEnv(environment.Environment):
@@ -102,11 +103,6 @@ class ForagaxEnv(environment.Environment):
102
103
  if self.nowrap and not self.full_world:
103
104
  objects = objects + (PADDING,)
104
105
  self.objects = objects
105
- self.weather_object = None
106
- for o in objects:
107
- if isinstance(o, WeatherObject):
108
- self.weather_object = o
109
- break
110
106
 
111
107
  # JIT-compatible versions of object and biome properties
112
108
  self.object_ids = jnp.arange(len(objects))
@@ -117,6 +113,13 @@ class ForagaxEnv(environment.Environment):
117
113
 
118
114
  self.reward_fns = [o.reward for o in objects]
119
115
  self.regen_delay_fns = [o.regen_delay for o in objects]
116
+ self.digestion_steps_fns = [o.digestion_steps for o in objects]
117
+
118
+ # Compute digestion steps per object (using max_digestion_steps attribute)
119
+ object_max_digestion_steps = jnp.array([o.max_digestion_steps for o in objects])
120
+ self.max_digestion_steps = (
121
+ int(jnp.max(object_max_digestion_steps)) + 1 if len(objects) > 0 else 0
122
+ )
120
123
 
121
124
  self.biome_object_frequencies = jnp.array(
122
125
  [b.object_frequencies for b in biomes]
@@ -237,12 +240,36 @@ class ForagaxEnv(environment.Environment):
237
240
 
238
241
  # 2. HANDLE COLLISIONS AND REWARDS
239
242
  obj_at_pos = current_objects[pos[1], pos[0]]
240
- key, subkey = jax.random.split(key)
241
- reward = jax.lax.switch(obj_at_pos, self.reward_fns, state.time, subkey)
242
243
  is_collectable = self.object_collectable[obj_at_pos]
244
+ should_collect = is_collectable & (obj_at_pos > 0)
245
+
246
+ # Handle digestion: add reward to buffer if collected
247
+ digestion_buffer = state.digestion_buffer
248
+ key, reward_subkey = jax.random.split(key)
249
+ object_reward = jax.lax.switch(
250
+ obj_at_pos, self.reward_fns, state.time, reward_subkey
251
+ )
252
+ key, digestion_subkey = jax.random.split(key)
253
+ digestion_steps = jax.lax.switch(
254
+ obj_at_pos, self.digestion_steps_fns, state.time, digestion_subkey
255
+ )
256
+ reward = jnp.where(should_collect & (digestion_steps == 0), object_reward, 0.0)
257
+ if self.max_digestion_steps > 0:
258
+ # Add delayed rewards to buffer
259
+ digestion_buffer = jax.lax.cond(
260
+ should_collect & (digestion_steps > 0),
261
+ lambda: digestion_buffer.at[
262
+ (state.time + digestion_steps) % self.max_digestion_steps
263
+ ].add(object_reward),
264
+ lambda: digestion_buffer,
265
+ )
266
+ # Deliver current rewards
267
+ current_index = state.time % self.max_digestion_steps
268
+ reward += digestion_buffer[current_index]
269
+ digestion_buffer = digestion_buffer.at[current_index].set(0.0)
243
270
 
244
271
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
245
- key, subkey, rand_key = jax.random.split(key, 3)
272
+ key, regen_subkey, rand_key = jax.random.split(key, 3)
246
273
 
247
274
  # Decrement timers (stored as negative values)
248
275
  is_timer = state.object_grid < 0
@@ -252,7 +279,7 @@ class ForagaxEnv(environment.Environment):
252
279
 
253
280
  # Collect object: set a timer
254
281
  regen_delay = jax.lax.switch(
255
- obj_at_pos, self.regen_delay_fns, state.time, subkey
282
+ obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
256
283
  )
257
284
  encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
258
285
 
@@ -301,10 +328,13 @@ class ForagaxEnv(environment.Environment):
301
328
  )
302
329
 
303
330
  info = {"discount": self.discount(state, params)}
304
- if self.weather_object is not None:
305
- info["temperature"] = get_temperature(
306
- self.weather_object.rewards, state.time, self.weather_object.repeat
307
- )
331
+ temperatures = jnp.zeros(len(self.objects))
332
+ for obj_index, obj in enumerate(self.objects):
333
+ if isinstance(obj, WeatherObject):
334
+ temperatures = temperatures.at[obj_index].set(
335
+ get_temperature(obj.rewards, state.time, obj.repeat)
336
+ )
337
+ info["temperatures"] = temperatures
308
338
  info["biome_id"] = state.biome_grid[pos[1], pos[0]]
309
339
  info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
310
340
 
@@ -314,6 +344,7 @@ class ForagaxEnv(environment.Environment):
314
344
  object_grid=object_grid,
315
345
  biome_grid=state.biome_grid,
316
346
  time=state.time + 1,
347
+ digestion_buffer=digestion_buffer,
317
348
  )
318
349
 
319
350
  done = self.is_terminal(state, params)
@@ -352,6 +383,7 @@ class ForagaxEnv(environment.Environment):
352
383
  object_grid=object_grid,
353
384
  biome_grid=biome_grid,
354
385
  time=0,
386
+ digestion_buffer=jnp.zeros((self.max_digestion_steps,)),
355
387
  )
356
388
 
357
389
  return self.get_obs(state, params), state
@@ -412,6 +444,12 @@ class ForagaxEnv(environment.Environment):
412
444
  int,
413
445
  ),
414
446
  "time": spaces.Discrete(params.max_steps_in_episode),
447
+ "digestion_buffer": spaces.Box(
448
+ -jnp.inf,
449
+ jnp.inf,
450
+ (self.max_digestion_steps,),
451
+ float,
452
+ ),
415
453
  }
416
454
  )
417
455
 
foragax/objects.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Tuple
2
+ from typing import Optional, Tuple
3
3
 
4
4
  import jax
5
5
  import jax.numpy as jnp
@@ -17,12 +17,14 @@ class BaseForagaxObject:
17
17
  collectable: bool = False,
18
18
  color: Tuple[int, int, int] = (0, 0, 0),
19
19
  random_respawn: bool = False,
20
+ max_digestion_steps: int = 0,
20
21
  ):
21
22
  self.name = name
22
23
  self.blocking = blocking
23
24
  self.collectable = collectable
24
25
  self.color = color
25
26
  self.random_respawn = random_respawn
27
+ self.max_digestion_steps = max_digestion_steps
26
28
 
27
29
  @abc.abstractmethod
28
30
  def reward(self, clock: int, rng: jax.Array) -> float:
@@ -30,8 +32,8 @@ class BaseForagaxObject:
30
32
  raise NotImplementedError
31
33
 
32
34
  @abc.abstractmethod
33
- def regen_delay(self, clock: int, rng: jax.Array) -> int:
34
- """Regeneration delay function."""
35
+ def digestion_steps(self, clock: int, rng: jax.Array) -> int:
36
+ """Digestion steps function."""
35
37
  raise NotImplementedError
36
38
 
37
39
 
@@ -47,10 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
47
49
  regen_delay: Tuple[int, int] = (10, 100),
48
50
  color: Tuple[int, int, int] = (255, 255, 255),
49
51
  random_respawn: bool = False,
52
+ digestion_steps: int = 0,
53
+ max_digestion_steps: Optional[int] = None,
50
54
  ):
51
- super().__init__(name, blocking, collectable, color, random_respawn)
55
+ if max_digestion_steps is None:
56
+ max_digestion_steps = digestion_steps
57
+ super().__init__(
58
+ name, blocking, collectable, color, random_respawn, max_digestion_steps
59
+ )
52
60
  self.reward_val = reward
53
61
  self.regen_delay_range = regen_delay
62
+ self.digestion_steps_val = digestion_steps
54
63
 
55
64
  def reward(self, clock: int, rng: jax.Array) -> float:
56
65
  """Default reward function."""
@@ -61,6 +70,10 @@ class DefaultForagaxObject(BaseForagaxObject):
61
70
  min_delay, max_delay = self.regen_delay_range
62
71
  return jax.random.randint(rng, (), min_delay, max_delay)
63
72
 
73
+ def digestion_steps(self, clock: int, rng: jax.Array) -> int:
74
+ """Default digestion steps function."""
75
+ return self.digestion_steps_val
76
+
64
77
 
65
78
  class NormalRegenForagaxObject(DefaultForagaxObject):
66
79
  """Object with regeneration delay from a normal distribution."""
@@ -74,6 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
74
87
  std_regen_delay: int = 1,
75
88
  color: Tuple[int, int, int] = (0, 0, 0),
76
89
  random_respawn: bool = False,
90
+ digestion_steps: int = 0,
91
+ max_digestion_steps: Optional[int] = None,
77
92
  ):
78
93
  super().__init__(
79
94
  name=name,
@@ -82,6 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
82
97
  regen_delay=(mean_regen_delay, mean_regen_delay),
83
98
  color=color,
84
99
  random_respawn=random_respawn,
100
+ digestion_steps=digestion_steps,
101
+ max_digestion_steps=max_digestion_steps,
85
102
  )
86
103
  self.mean_regen_delay = mean_regen_delay
87
104
  self.std_regen_delay = std_regen_delay
@@ -105,6 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
105
122
  std_regen_delay: int = 1,
106
123
  color: Tuple[int, int, int] = (0, 0, 0),
107
124
  random_respawn: bool = False,
125
+ digestion_steps: int = 0,
126
+ max_digestion_steps: Optional[int] = None,
108
127
  ):
109
128
  super().__init__(
110
129
  name=name,
@@ -113,6 +132,8 @@ class WeatherObject(NormalRegenForagaxObject):
113
132
  std_regen_delay=std_regen_delay,
114
133
  color=color,
115
134
  random_respawn=random_respawn,
135
+ digestion_steps=digestion_steps,
136
+ max_digestion_steps=max_digestion_steps,
116
137
  )
117
138
  self.rewards = rewards
118
139
  self.repeat = repeat
@@ -319,6 +340,7 @@ def create_weather_objects(
319
340
  multiplier: float = 1.0,
320
341
  same_color: bool = False,
321
342
  random_respawn: bool = False,
343
+ digestion_steps: int = 0,
322
344
  ):
323
345
  """Create HOT and COLD WeatherObject instances using the specified file.
324
346
 
@@ -348,6 +370,7 @@ def create_weather_objects(
348
370
  multiplier=multiplier,
349
371
  color=hot_color,
350
372
  random_respawn=random_respawn,
373
+ digestion_steps=digestion_steps,
351
374
  )
352
375
 
353
376
  cold_color = hot_color if same_color else (0, 255, 255)
@@ -358,6 +381,7 @@ def create_weather_objects(
358
381
  multiplier=-multiplier,
359
382
  color=cold_color,
360
383
  random_respawn=random_respawn,
384
+ digestion_steps=digestion_steps,
361
385
  )
362
386
 
363
387
  return hot, cold
foragax/registry.py CHANGED
@@ -83,6 +83,19 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
83
83
  "nowrap": False,
84
84
  "deterministic_spawn": True,
85
85
  },
86
+ "ForagaxWeather-v6": {
87
+ "size": (15, 15),
88
+ "aperture_size": None,
89
+ "objects": None,
90
+ "biomes": (
91
+ # Hot biome
92
+ Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
93
+ # Cold biome
94
+ Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
95
+ ),
96
+ "nowrap": False,
97
+ "deterministic_spawn": True,
98
+ },
86
99
  "ForagaxTwoBiome-v1": {
87
100
  "size": (15, 15),
88
101
  "aperture_size": None,
@@ -348,7 +361,6 @@ def make(
348
361
  observation_type: str = "color",
349
362
  aperture_size: Optional[Tuple[int, int]] = (5, 5),
350
363
  file_index: int = 0,
351
- nowrap: Optional[bool] = None,
352
364
  **kwargs: Any,
353
365
  ) -> ForagaxEnv:
354
366
  """Create a Foragax environment.
@@ -358,9 +370,7 @@ def make(
358
370
  observation_type: The type of observation to use. One of "object", "rgb", or "color".
359
371
  aperture_size: The size of the agent's observation aperture. If -1, full world observation.
360
372
  If None, the default for the environment is used.
361
- file_index: File index for weather objects. nowrap: If True, disables
362
- wrapping around environment boundaries. If None, uses defaults per
363
- environment.
373
+ file_index: File index for weather objects.
364
374
  **kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
365
375
 
366
376
  Returns:
@@ -376,8 +386,6 @@ def make(
376
386
  else:
377
387
  aperture_size = (aperture_size, aperture_size)
378
388
  config["aperture_size"] = aperture_size
379
- if nowrap is not None:
380
- config["nowrap"] = nowrap
381
389
 
382
390
  # Handle special size and biome configurations
383
391
  if env_id in (
@@ -459,10 +467,19 @@ def make(
459
467
  "ForagaxWeather-v3",
460
468
  "ForagaxWeather-v4",
461
469
  "ForagaxWeather-v5",
470
+ "ForagaxWeather-v6",
471
+ )
472
+ random_respawn = env_id in (
473
+ "ForagaxWeather-v4",
474
+ "ForagaxWeather-v5",
475
+ "ForagaxWeather-v6",
462
476
  )
463
- random_respawn = env_id in ("ForagaxWeather-v4", "ForagaxWeather-v5")
477
+ digestion_steps = 10 if env_id in ("ForagaxWeather-v6") else 0
464
478
  hot, cold = create_weather_objects(
465
- file_index=file_index, same_color=same_color, random_respawn=random_respawn
479
+ file_index=file_index,
480
+ same_color=same_color,
481
+ random_respawn=random_respawn,
482
+ digestion_steps=digestion_steps,
466
483
  )
467
484
  config["objects"] = (hot, cold)
468
485