continual-foragax 0.29.0__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.29.0
3
+ Version: 0.30.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=OFtDT8c5nskflnXmMNbRUvu5pVs9vhIezDgZaICXSyE,27535
4
- foragax/objects.py,sha256=j7FivgT4uz6N4FkOTmpM0t-YjTUkYUBLAznWpVqqjrU,10509
5
- foragax/registry.py,sha256=bFSTDo7XU4G0njHjLTgPdzWyStkAGckimDYcnGYLIIg,15529
3
+ foragax/env.py,sha256=4NZ5JsUGjAepmzw2uxu5_ikyVZnZ7vazy062Xzx22Zg,27481
4
+ foragax/objects.py,sha256=0vb_iyr62BKaIxiE3JwtRhZhFE3VFM6PdxDZTaDtv24,10410
5
+ foragax/registry.py,sha256=Dxg6cWIPwg91fNrCPxADJv35u6jFg_8dI5iTpCMFEFA,15229
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.29.0.dist-info/METADATA,sha256=vaiiCWr06OczH8LgsRotJXRt7q4KZMtnvuFj6up5v3U,4897
132
- continual_foragax-0.29.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.29.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.29.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.29.0.dist-info/RECORD,,
131
+ continual_foragax-0.30.0.dist-info/METADATA,sha256=d0xeSz0BvDVe1lOUGdhVyqnbkkYN7dNW4BPfCnDSZfQ,4897
132
+ continual_foragax-0.30.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.30.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.30.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.30.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -113,12 +113,12 @@ class ForagaxEnv(environment.Environment):
113
113
 
114
114
  self.reward_fns = [o.reward for o in objects]
115
115
  self.regen_delay_fns = [o.regen_delay for o in objects]
116
- self.digestion_steps_fns = [o.digestion_steps for o in objects]
116
+ self.reward_delay_fns = [o.reward_delay for o in objects]
117
117
 
118
- # Compute digestion steps per object (using max_digestion_steps attribute)
119
- object_max_digestion_steps = jnp.array([o.max_digestion_steps for o in objects])
120
- self.max_digestion_steps = (
121
- int(jnp.max(object_max_digestion_steps)) + 1 if len(objects) > 0 else 0
118
+ # Compute reward steps per object (using max_reward_delay attribute)
119
+ object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
120
+ self.max_reward_delay = (
121
+ int(jnp.max(object_max_reward_delay)) + 1 if len(objects) > 0 else 0
122
122
  )
123
123
 
124
124
  self.biome_object_frequencies = jnp.array(
@@ -250,21 +250,21 @@ class ForagaxEnv(environment.Environment):
250
250
  obj_at_pos, self.reward_fns, state.time, reward_subkey
251
251
  )
252
252
  key, digestion_subkey = jax.random.split(key)
253
- digestion_steps = jax.lax.switch(
254
- obj_at_pos, self.digestion_steps_fns, state.time, digestion_subkey
253
+ reward_delay = jax.lax.switch(
254
+ obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
255
255
  )
256
- reward = jnp.where(should_collect & (digestion_steps == 0), object_reward, 0.0)
257
- if self.max_digestion_steps > 0:
256
+ reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
257
+ if self.max_reward_delay > 0:
258
258
  # Add delayed rewards to buffer
259
259
  digestion_buffer = jax.lax.cond(
260
- should_collect & (digestion_steps > 0),
260
+ should_collect & (reward_delay > 0),
261
261
  lambda: digestion_buffer.at[
262
- (state.time + digestion_steps) % self.max_digestion_steps
262
+ (state.time + reward_delay) % self.max_reward_delay
263
263
  ].add(object_reward),
264
264
  lambda: digestion_buffer,
265
265
  )
266
266
  # Deliver current rewards
267
- current_index = state.time % self.max_digestion_steps
267
+ current_index = state.time % self.max_reward_delay
268
268
  reward += digestion_buffer[current_index]
269
269
  digestion_buffer = digestion_buffer.at[current_index].set(0.0)
270
270
 
@@ -383,7 +383,7 @@ class ForagaxEnv(environment.Environment):
383
383
  object_grid=object_grid,
384
384
  biome_grid=biome_grid,
385
385
  time=0,
386
- digestion_buffer=jnp.zeros((self.max_digestion_steps,)),
386
+ digestion_buffer=jnp.zeros((self.max_reward_delay,)),
387
387
  )
388
388
 
389
389
  return self.get_obs(state, params), state
@@ -447,7 +447,7 @@ class ForagaxEnv(environment.Environment):
447
447
  "digestion_buffer": spaces.Box(
448
448
  -jnp.inf,
449
449
  jnp.inf,
450
- (self.max_digestion_steps,),
450
+ (self.max_reward_delay,),
451
451
  float,
452
452
  ),
453
453
  }
foragax/objects.py CHANGED
@@ -17,14 +17,14 @@ class BaseForagaxObject:
17
17
  collectable: bool = False,
18
18
  color: Tuple[int, int, int] = (0, 0, 0),
19
19
  random_respawn: bool = False,
20
- max_digestion_steps: int = 0,
20
+ max_reward_delay: int = 0,
21
21
  ):
22
22
  self.name = name
23
23
  self.blocking = blocking
24
24
  self.collectable = collectable
25
25
  self.color = color
26
26
  self.random_respawn = random_respawn
27
- self.max_digestion_steps = max_digestion_steps
27
+ self.max_reward_delay = max_reward_delay
28
28
 
29
29
  @abc.abstractmethod
30
30
  def reward(self, clock: int, rng: jax.Array) -> float:
@@ -32,8 +32,8 @@ class BaseForagaxObject:
32
32
  raise NotImplementedError
33
33
 
34
34
  @abc.abstractmethod
35
- def digestion_steps(self, clock: int, rng: jax.Array) -> int:
36
- """Digestion steps function."""
35
+ def reward_delay(self, clock: int, rng: jax.Array) -> int:
36
+ """Reward delay function."""
37
37
  raise NotImplementedError
38
38
 
39
39
 
@@ -49,17 +49,17 @@ class DefaultForagaxObject(BaseForagaxObject):
49
49
  regen_delay: Tuple[int, int] = (10, 100),
50
50
  color: Tuple[int, int, int] = (255, 255, 255),
51
51
  random_respawn: bool = False,
52
- digestion_steps: int = 0,
53
- max_digestion_steps: Optional[int] = None,
52
+ reward_delay: int = 0,
53
+ max_reward_delay: Optional[int] = None,
54
54
  ):
55
- if max_digestion_steps is None:
56
- max_digestion_steps = digestion_steps
55
+ if max_reward_delay is None:
56
+ max_reward_delay = reward_delay
57
57
  super().__init__(
58
- name, blocking, collectable, color, random_respawn, max_digestion_steps
58
+ name, blocking, collectable, color, random_respawn, max_reward_delay
59
59
  )
60
60
  self.reward_val = reward
61
61
  self.regen_delay_range = regen_delay
62
- self.digestion_steps_val = digestion_steps
62
+ self.reward_delay_val = reward_delay
63
63
 
64
64
  def reward(self, clock: int, rng: jax.Array) -> float:
65
65
  """Default reward function."""
@@ -70,9 +70,9 @@ class DefaultForagaxObject(BaseForagaxObject):
70
70
  min_delay, max_delay = self.regen_delay_range
71
71
  return jax.random.randint(rng, (), min_delay, max_delay)
72
72
 
73
- def digestion_steps(self, clock: int, rng: jax.Array) -> int:
74
- """Default digestion steps function."""
75
- return self.digestion_steps_val
73
+ def reward_delay(self, clock: int, rng: jax.Array) -> int:
74
+ """Default reward delay function."""
75
+ return self.reward_delay_val
76
76
 
77
77
 
78
78
  class NormalRegenForagaxObject(DefaultForagaxObject):
@@ -87,8 +87,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
87
87
  std_regen_delay: int = 1,
88
88
  color: Tuple[int, int, int] = (0, 0, 0),
89
89
  random_respawn: bool = False,
90
- digestion_steps: int = 0,
91
- max_digestion_steps: Optional[int] = None,
90
+ reward_delay: int = 0,
91
+ max_reward_delay: Optional[int] = None,
92
92
  ):
93
93
  super().__init__(
94
94
  name=name,
@@ -97,8 +97,8 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
97
97
  regen_delay=(mean_regen_delay, mean_regen_delay),
98
98
  color=color,
99
99
  random_respawn=random_respawn,
100
- digestion_steps=digestion_steps,
101
- max_digestion_steps=max_digestion_steps,
100
+ reward_delay=reward_delay,
101
+ max_reward_delay=max_reward_delay,
102
102
  )
103
103
  self.mean_regen_delay = mean_regen_delay
104
104
  self.std_regen_delay = std_regen_delay
@@ -122,8 +122,8 @@ class WeatherObject(NormalRegenForagaxObject):
122
122
  std_regen_delay: int = 1,
123
123
  color: Tuple[int, int, int] = (0, 0, 0),
124
124
  random_respawn: bool = False,
125
- digestion_steps: int = 0,
126
- max_digestion_steps: Optional[int] = None,
125
+ reward_delay: int = 0,
126
+ max_reward_delay: Optional[int] = None,
127
127
  ):
128
128
  super().__init__(
129
129
  name=name,
@@ -132,8 +132,8 @@ class WeatherObject(NormalRegenForagaxObject):
132
132
  std_regen_delay=std_regen_delay,
133
133
  color=color,
134
134
  random_respawn=random_respawn,
135
- digestion_steps=digestion_steps,
136
- max_digestion_steps=max_digestion_steps,
135
+ reward_delay=reward_delay,
136
+ max_reward_delay=max_reward_delay,
137
137
  )
138
138
  self.rewards = rewards
139
139
  self.repeat = repeat
@@ -340,7 +340,7 @@ def create_weather_objects(
340
340
  multiplier: float = 1.0,
341
341
  same_color: bool = False,
342
342
  random_respawn: bool = False,
343
- digestion_steps: int = 0,
343
+ reward_delay: int = 0,
344
344
  ):
345
345
  """Create HOT and COLD WeatherObject instances using the specified file.
346
346
 
@@ -370,7 +370,7 @@ def create_weather_objects(
370
370
  multiplier=multiplier,
371
371
  color=hot_color,
372
372
  random_respawn=random_respawn,
373
- digestion_steps=digestion_steps,
373
+ reward_delay=reward_delay,
374
374
  )
375
375
 
376
376
  cold_color = hot_color if same_color else (0, 255, 255)
@@ -381,7 +381,7 @@ def create_weather_objects(
381
381
  multiplier=-multiplier,
382
382
  color=cold_color,
383
383
  random_respawn=random_respawn,
384
- digestion_steps=digestion_steps,
384
+ reward_delay=reward_delay,
385
385
  )
386
386
 
387
387
  return hot, cold
foragax/registry.py CHANGED
@@ -83,19 +83,6 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
83
83
  "nowrap": False,
84
84
  "deterministic_spawn": True,
85
85
  },
86
- "ForagaxWeather-v6": {
87
- "size": (15, 15),
88
- "aperture_size": None,
89
- "objects": None,
90
- "biomes": (
91
- # Hot biome
92
- Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
93
- # Cold biome
94
- Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
95
- ),
96
- "nowrap": False,
97
- "deterministic_spawn": True,
98
- },
99
86
  "ForagaxTwoBiome-v1": {
100
87
  "size": (15, 15),
101
88
  "aperture_size": None,
@@ -361,6 +348,8 @@ def make(
361
348
  observation_type: str = "color",
362
349
  aperture_size: Optional[Tuple[int, int]] = (5, 5),
363
350
  file_index: int = 0,
351
+ repeat: int = 500,
352
+ reward_delay: int = 0,
364
353
  **kwargs: Any,
365
354
  ) -> ForagaxEnv:
366
355
  """Create a Foragax environment.
@@ -371,6 +360,8 @@ def make(
371
360
  aperture_size: The size of the agent's observation aperture. If -1, full world observation.
372
361
  If None, the default for the environment is used.
373
362
  file_index: File index for weather objects.
363
+ repeat: How many steps each temperature value repeats for (weather environments).
364
+ reward_delay: Number of steps required to digest food items (weather environments).
374
365
  **kwargs: Additional keyword arguments to pass to the ForagaxEnv constructor.
375
366
 
376
367
  Returns:
@@ -467,19 +458,17 @@ def make(
467
458
  "ForagaxWeather-v3",
468
459
  "ForagaxWeather-v4",
469
460
  "ForagaxWeather-v5",
470
- "ForagaxWeather-v6",
471
461
  )
472
462
  random_respawn = env_id in (
473
463
  "ForagaxWeather-v4",
474
464
  "ForagaxWeather-v5",
475
- "ForagaxWeather-v6",
476
465
  )
477
- digestion_steps = 10 if env_id in ("ForagaxWeather-v6") else 0
478
466
  hot, cold = create_weather_objects(
479
467
  file_index=file_index,
468
+ repeat=repeat,
480
469
  same_color=same_color,
481
470
  random_respawn=random_respawn,
482
- digestion_steps=digestion_steps,
471
+ reward_delay=reward_delay,
483
472
  )
484
473
  config["objects"] = (hot, cold)
485
474