continual-foragax 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.17.0
3
+ Version: 0.19.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=EyT6KY0d0mXNh6yw10V-8SJVAdyPAGKtRdFV4wXq-JM,19836
3
+ foragax/env.py,sha256=APPu0r31A9SlvmJYkoZzsrh59YFVoLW3Rmyb3JT-dHw,20613
4
4
  foragax/objects.py,sha256=iDFo_2CjpgErAm3QdQt5ixmQ8jdIvl7siIPqzGwVcGk,7665
5
- foragax/registry.py,sha256=HDx820l1V3opeWfZ9altCwQgwZrseeVI9Rgev3ZAnRo,6525
5
+ foragax/registry.py,sha256=DbOCKcoE9bfLEAVxq4ts1MasYYdcQZVxCK_WuQvCVlc,7781
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.17.0.dist-info/METADATA,sha256=9ZfDpyM1n9Bp6SdL02kIGJkMM0HCgS_fpUjWZsU0aLE,4897
132
- continual_foragax-0.17.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.17.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.17.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.17.0.dist-info/RECORD,,
131
+ continual_foragax-0.19.0.dist-info/METADATA,sha256=y9A7sTHfOB84ycKWis5RsA9qkhyIpRdNw9t0fbuA8eo,4897
132
+ continual_foragax-0.19.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.19.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.19.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.19.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -355,17 +355,31 @@ class ForagaxEnv(environment.Environment):
355
355
  # Create indices for the aperture
356
356
  y_offsets = jnp.arange(ap_h)
357
357
  x_offsets = jnp.arange(ap_w)
358
- y_coords = jnp.mod(start_y + y_offsets[:, None], self.size[1])
359
- x_coords = jnp.mod(start_x + x_offsets, self.size[0])
360
-
361
- # Get original colors from the aperture area
362
- original_colors = img[y_coords, x_coords]
363
-
364
- # Calculate tinted colors
365
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
366
-
367
- # Update the image with tinted colors
368
- img = img.at[y_coords, x_coords].set(tinted_colors)
358
+ y_coords_original = start_y + y_offsets[:, None]
359
+ x_coords_original = start_x + x_offsets
360
+
361
+ if self.nowrap:
362
+ y_coords = jnp.clip(y_coords_original, 0, self.size[1] - 1)
363
+ x_coords = jnp.clip(x_coords_original, 0, self.size[0] - 1)
364
+ in_bounds = (
365
+ (y_coords_original >= 0)
366
+ & (y_coords_original < self.size[1])
367
+ & (x_coords_original >= 0)
368
+ & (x_coords_original < self.size[0])
369
+ )
370
+ original_colors = img[y_coords, x_coords]
371
+ tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
372
+ img = img.at[y_coords, x_coords].set(
373
+ jnp.where(
374
+ in_bounds[..., None], tinted_colors, img[y_coords, x_coords]
375
+ )
376
+ )
377
+ else:
378
+ y_coords = jnp.mod(y_coords_original, self.size[1])
379
+ x_coords = jnp.mod(x_coords_original, self.size[0])
380
+ original_colors = img[y_coords, x_coords]
381
+ tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
382
+ img = img.at[y_coords, x_coords].set(tinted_colors)
369
383
 
370
384
  # Agent color
371
385
  img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
foragax/registry.py CHANGED
@@ -119,6 +119,27 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
119
119
  ),
120
120
  "nowrap": True,
121
121
  },
122
+ "ForagaxTwoBiome-v6": {
123
+ "size": (15, 15),
124
+ "aperture_size": None,
125
+ "objects": (BROWN_MOREL_2, BROWN_OYSTER, GREEN_DEATHCAP_3, GREEN_FAKE_2),
126
+ "biomes": (
127
+ # Morel biome
128
+ Biome(start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.0, 0.5, 0.0)),
129
+ # Oyster biome
130
+ Biome(
131
+ start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.25, 0.0, 0.5)
132
+ ),
133
+ ),
134
+ "nowrap": True,
135
+ },
136
+ "ForagaxTwoBiome-v7": {
137
+ "size": None,
138
+ "aperture_size": None,
139
+ "objects": (BROWN_MOREL_2, BROWN_OYSTER, GREEN_DEATHCAP_3, GREEN_FAKE_2),
140
+ "biomes": None,
141
+ "nowrap": True,
142
+ },
122
143
  "ForagaxTwoBiomeSmall-v1": {
123
144
  "size": (16, 8),
124
145
  "aperture_size": None,
@@ -183,11 +204,23 @@ def make(
183
204
  raise ValueError(f"Unknown env_id: {env_id}")
184
205
 
185
206
  config = ENV_CONFIGS[env_id].copy()
186
-
207
+ if isinstance(aperture_size, int):
208
+ aperture_size = (aperture_size, aperture_size)
187
209
  config["aperture_size"] = aperture_size
188
210
  if nowrap is not None:
189
211
  config["nowrap"] = nowrap
190
212
 
213
+ if env_id == "ForagaxTwoBiome-v7":
214
+ margin = aperture_size[1] // 2 + 1
215
+ width = 2 * margin + 9
216
+ config["size"] = (width, 15)
217
+ config["biomes"] = (
218
+ # Morel biome
219
+ Biome(start=(margin, 0), stop=(margin + 2, 15), object_frequencies=(0.25, 0.0, 0.5, 0.0)),
220
+ # Oyster biome
221
+ Biome(start=(margin + 7, 0), stop=(margin + 9, 15), object_frequencies=(0.0, 0.25, 0.0, 0.5)),
222
+ )
223
+
191
224
  if env_id.startswith("ForagaxWeather"):
192
225
  same_color = env_id == "ForagaxWeather-v2"
193
226
  hot, cold = create_weather_objects(file_index=file_index, same_color=same_color)