continual-foragax 0.10.3__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.10.3
3
+ Version: 0.12.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=EyWJ67GLCJl2Re8s4AoVRGXbccgOGn0J_XxuuWJaGrE,18702
4
- foragax/objects.py,sha256=CyBxrykTxpHCI_2hE9jE8mG4TU8R7VxzKdQ5mtxkEqU,6004
5
- foragax/registry.py,sha256=7_RDXvm_3RNO7culBLGkE0jH8Wk_q6jbMv72dZx4JO8,2722
3
+ foragax/env.py,sha256=EyT6KY0d0mXNh6yw10V-8SJVAdyPAGKtRdFV4wXq-JM,19836
4
+ foragax/objects.py,sha256=_TO7tBFCzH5L3JwzHK4bPIh090mtlWBSXcPsZ4y0gHg,6745
5
+ foragax/registry.py,sha256=CjD1eRlY5956royMUPLYuZ3twVcuI9pbRN9M0TEagmo,3437
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.10.3.dist-info/METADATA,sha256=km5hYBXYDVBJ5a6VhUN1Fjb92QHEQvjG-eThHhlPp84,4897
132
- continual_foragax-0.10.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.10.3.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.10.3.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.10.3.dist-info/RECORD,,
131
+ continual_foragax-0.12.0.dist-info/METADATA,sha256=yHZYcu0knPEtTUHbAcIGNx9EEqkZpCs_UT3TwAh6bOM,4897
132
+ continual_foragax-0.12.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.12.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.12.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.12.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -13,7 +13,13 @@ import jax.numpy as jnp
13
13
  from flax import struct
14
14
  from gymnax.environments import environment, spaces
15
15
 
16
- from foragax.objects import AGENT, EMPTY, BaseForagaxObject, WeatherObject
16
+ from foragax.objects import (
17
+ AGENT,
18
+ EMPTY,
19
+ PADDING,
20
+ BaseForagaxObject,
21
+ WeatherObject,
22
+ )
17
23
  from foragax.rendering import apply_true_borders
18
24
  from foragax.weather import get_temperature
19
25
 
@@ -64,6 +70,7 @@ class ForagaxEnv(environment.Environment):
64
70
  aperture_size: Union[Tuple[int, int], int] = (5, 5),
65
71
  objects: Tuple[BaseForagaxObject, ...] = (),
66
72
  biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
73
+ nowrap: bool = False,
67
74
  ):
68
75
  super().__init__()
69
76
  if isinstance(size, int):
@@ -73,7 +80,10 @@ class ForagaxEnv(environment.Environment):
73
80
  if isinstance(aperture_size, int):
74
81
  aperture_size = (aperture_size, aperture_size)
75
82
  self.aperture_size = aperture_size
83
+ self.nowrap = nowrap
76
84
  objects = (EMPTY,) + objects
85
+ if self.nowrap:
86
+ objects = objects + (PADDING,)
77
87
  self.objects = objects
78
88
  self.weather_object = None
79
89
  for o in objects:
@@ -122,8 +132,12 @@ class ForagaxEnv(environment.Environment):
122
132
  direction = DIRECTIONS[action]
123
133
  new_pos = state.pos + direction
124
134
 
125
- # Wrap around boundaries
126
- new_pos = jnp.mod(new_pos, jnp.array(self.size))
135
+ if self.nowrap:
136
+ in_bounds = jnp.all((new_pos >= 0) & (new_pos < jnp.array(self.size)))
137
+ new_pos = jnp.where(in_bounds, new_pos, state.pos)
138
+ else:
139
+ # Wrap around boundaries
140
+ new_pos = jnp.mod(new_pos, jnp.array(self.size))
127
141
 
128
142
  # Check for blocking objects
129
143
  obj_at_new_pos = current_objects[new_pos[1], new_pos[0]]
@@ -288,10 +302,26 @@ class ForagaxEnv(environment.Environment):
288
302
 
289
303
  y_offsets = jnp.arange(ap_h)
290
304
  x_offsets = jnp.arange(ap_w)
291
- y_coords = jnp.mod(start_y + y_offsets[:, None], self.size[1])
292
- x_coords = jnp.mod(start_x + x_offsets, self.size[0])
305
+ y_coords = start_y + y_offsets[:, None]
306
+ x_coords = start_x + x_offsets
307
+
308
+ if self.nowrap:
309
+ # Clamp coordinates to bounds
310
+ y_coords_clamped = jnp.clip(y_coords, 0, self.size[1] - 1)
311
+ x_coords_clamped = jnp.clip(x_coords, 0, self.size[0] - 1)
312
+ values = object_grid[y_coords_clamped, x_coords_clamped]
313
+ # Mark out-of-bounds positions with -1
314
+ y_out = (y_coords < 0) | (y_coords >= self.size[1])
315
+ x_out = (x_coords < 0) | (x_coords >= self.size[0])
316
+ out_of_bounds = y_out | x_out
317
+ padding_index = self.object_ids[-1]
318
+ aperture = jnp.where(out_of_bounds, padding_index, values)
319
+ else:
320
+ y_coords_mod = jnp.mod(y_coords, self.size[1])
321
+ x_coords_mod = jnp.mod(x_coords, self.size[0])
322
+ aperture = object_grid[y_coords_mod, x_coords_mod]
293
323
 
294
- return object_grid[y_coords, x_coords]
324
+ return aperture
295
325
 
296
326
  @partial(jax.jit, static_argnames=("self", "render_mode"))
297
327
  def render(self, state: EnvState, params: EnvParams, render_mode: str = "world"):
@@ -404,8 +434,9 @@ class ForagaxObjectEnv(ForagaxEnv):
404
434
  aperture_size: Union[Tuple[int, int], int] = (5, 5),
405
435
  objects: Tuple[BaseForagaxObject, ...] = (),
406
436
  biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
437
+ nowrap: bool = False,
407
438
  ):
408
- super().__init__(size, aperture_size, objects, biomes)
439
+ super().__init__(size, aperture_size, objects, biomes, nowrap)
409
440
 
410
441
  # Compute unique colors and mapping for partial observability
411
442
  # Exclude EMPTY (index 0) from color channels
foragax/objects.py CHANGED
@@ -179,6 +179,39 @@ DEATHCAP = DefaultForagaxObject(
179
179
  )
180
180
  AGENT = DefaultForagaxObject(name="agent", blocking=True, color=(0, 0, 255))
181
181
 
182
+ PADDING = DefaultForagaxObject(name="padding", blocking=True, color=(0, 0, 0))
183
+
184
+ BROWN_MOREL = NormalRegenForagaxObject(
185
+ name="brown_morel",
186
+ reward=30.0,
187
+ collectable=True,
188
+ color=(63, 30, 25),
189
+ mean_regen_delay=300,
190
+ std_regen_delay=30,
191
+ )
192
+ BROWN_OYSTER = NormalRegenForagaxObject(
193
+ name="brown_oyster",
194
+ reward=1.0,
195
+ collectable=True,
196
+ color=(63, 30, 25),
197
+ mean_regen_delay=10,
198
+ std_regen_delay=1,
199
+ )
200
+ GREEN_DEATHCAP = DefaultForagaxObject(
201
+ name="green_deathcap",
202
+ reward=-1.0,
203
+ collectable=True,
204
+ color=(0, 255, 0),
205
+ regen_delay=(10, 10),
206
+ )
207
+ GREEN_FAKE = DefaultForagaxObject(
208
+ name="green_fake",
209
+ reward=0.0,
210
+ collectable=True,
211
+ color=(0, 255, 0),
212
+ regen_delay=(10, 10),
213
+ )
214
+
182
215
 
183
216
  def create_weather_objects(
184
217
  file_index: int = 0, repeat: int = 500, multiplier: float = 1.0
foragax/registry.py CHANGED
@@ -10,6 +10,10 @@ from foragax.env import (
10
10
  ForagaxWorldEnv,
11
11
  )
12
12
  from foragax.objects import (
13
+ BROWN_MOREL,
14
+ BROWN_OYSTER,
15
+ GREEN_DEATHCAP,
16
+ GREEN_FAKE,
13
17
  LARGE_MOREL,
14
18
  LARGE_OYSTER,
15
19
  MEDIUM_MOREL,
@@ -28,6 +32,19 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
28
32
  Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
29
33
  ),
30
34
  },
35
+ "ForagaxTwoBiome-v1": {
36
+ "size": (15, 15),
37
+ "aperture_size": None,
38
+ "objects": (BROWN_MOREL, BROWN_OYSTER, GREEN_DEATHCAP, GREEN_FAKE),
39
+ "biomes": (
40
+ # Morel biome
41
+ Biome(start=(3, 0), stop=(5, 15), object_frequencies=(0.5, 0.0, 0.25, 0.0)),
42
+ # Oyster biome
43
+ Biome(
44
+ start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.5, 0.0, 0.25)
45
+ ),
46
+ ),
47
+ },
31
48
  "ForagaxTwoBiomeSmall-v1": {
32
49
  "size": (16, 8),
33
50
  "aperture_size": None,
@@ -58,6 +75,7 @@ def make(
58
75
  observation_type: str = "object",
59
76
  aperture_size: Optional[Tuple[int, int]] = (5, 5),
60
77
  file_index: int = 0,
78
+ nowrap: bool = False,
61
79
  ) -> ForagaxEnv:
62
80
  """Create a Foragax environment.
63
81
 
@@ -66,6 +84,8 @@ def make(
66
84
  observation_type: The type of observation to use. One of "object", "rgb", or "world".
67
85
  aperture_size: The size of the agent's observation aperture. If None, the default
68
86
  for the environment is used.
87
+ file_index: File index for weather objects.
88
+ nowrap: If True, disables wrapping around environment boundaries.
69
89
 
70
90
  Returns:
71
91
  A Foragax environment instance.
@@ -76,6 +96,7 @@ def make(
76
96
  config = ENV_CONFIGS[env_id].copy()
77
97
 
78
98
  config["aperture_size"] = aperture_size
99
+ config["nowrap"] = nowrap
79
100
 
80
101
  if env_id.startswith("ForagaxWeather"):
81
102
  hot, cold = create_weather_objects(file_index=file_index)