continual-foragax 0.24.2__py3-none-any.whl → 0.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: continual-foragax
3
- Version: 0.24.2
3
+ Version: 0.26.0
4
4
  Summary: A continual reinforcement learning benchmark
5
5
  Author-email: Steven Tang <stang5@ualberta.ca>
6
6
  Requires-Python: >=3.8
@@ -1,8 +1,8 @@
1
1
  foragax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  foragax/colors.py,sha256=rqNPiywP4Nvr0POhsGpasRk-nMMTS3DOwFRUgperlUk,2065
3
- foragax/env.py,sha256=t0tb4IiDFiC3xJLLv7s_IFOYgoyRdje1PcnEvJNcyVM,23578
4
- foragax/objects.py,sha256=CblI0NI7PQzeKk3MZa8sbaa9wB4pc_8CyOGbJOFWytE,9391
5
- foragax/registry.py,sha256=iK-Pc9m7NGTds6UidE92LGNMGoMDwJHP09RcO5AwOZ0,12971
3
+ foragax/env.py,sha256=w9KDLRm0xvJjnRP0C4gynk1At6LVdrz5xBEtDC3ePvM,24958
4
+ foragax/objects.py,sha256=FCLZ-8d7qq9VMTG6G-TaRt842-sjgB0-DH0IoHwwngI,9503
5
+ foragax/registry.py,sha256=7-6VDN1MKVEvX_1u5G8NkSpv9BccEmtjJa77-OTNg3A,14324
6
6
  foragax/rendering.py,sha256=bms7wvBZTofoR-K-2QD2Ggeed7Viw8uwAEiEpEM3eSo,2768
7
7
  foragax/weather.py,sha256=KNAiwuFz8V__6G75vZIWQKPocLzXqxXn-Vt4TbHIpcA,1258
8
8
  foragax/data/ECA_non-blended_custom/TG_SOUID100897.txt,sha256=N7URbX6VlCZvCboUogYjMzy1I-0cfNPOn0QTLSHHfQ0,1776751
@@ -128,8 +128,8 @@ foragax/data/ECA_non-blended_custom/TG_SOUID156887.txt,sha256=juzTPgJoJxfqmZkorL
128
128
  foragax/data/ECA_non-blended_custom/elements.txt,sha256=OtcUBoDAHxuln79BPKGu0tsQxG_5G2BfAX3Ck130kEA,4507
129
129
  foragax/data/ECA_non-blended_custom/metadata.txt,sha256=nudnmOCy5cPJfSXt_IjyX0S5-T7NkCZREICZSimqeqc,48260
130
130
  foragax/data/ECA_non-blended_custom/sources.txt,sha256=1j3lSmINAoCMqPqFrHfZJriOz6sTYZNOhXzUwvTLas0,20857
131
- continual_foragax-0.24.2.dist-info/METADATA,sha256=UcZgs3ahcEv6m-4DIRv8E9giWBeR_g9adHROvoASUpI,4897
132
- continual_foragax-0.24.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
- continual_foragax-0.24.2.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
- continual_foragax-0.24.2.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
- continual_foragax-0.24.2.dist-info/RECORD,,
131
+ continual_foragax-0.26.0.dist-info/METADATA,sha256=9cCIfMlbC9CAiRZBOoeiZwO45j-v60ZHjn3kjRTu2IE,4897
132
+ continual_foragax-0.26.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
133
+ continual_foragax-0.26.0.dist-info/entry_points.txt,sha256=Qiu6iE_XudrDO_bVAMeA435h4PO9ourt8huvSHiuMPc,41
134
+ continual_foragax-0.26.0.dist-info/top_level.txt,sha256=-z3SDK6RfLIcLI24n8rdbeFzlVY3hunChzlu-v1Fncs,8
135
+ continual_foragax-0.26.0.dist-info/RECORD,,
foragax/env.py CHANGED
@@ -1,12 +1,12 @@
1
- """JAX implementation of Foragax environment.
1
+ """JAX implementation of Forager environment.
2
2
 
3
- Source: https://github.com/andnp/Foragax
3
+ Source: https://github.com/andnp/Forager
4
4
  """
5
5
 
6
6
  from dataclasses import dataclass
7
7
  from enum import IntEnum
8
8
  from functools import partial
9
- from typing import Any, Dict, Tuple, Union
9
+ from typing import Any, Dict, Optional, Tuple, Union
10
10
 
11
11
  import jax
12
12
  import jax.numpy as jnp
@@ -75,6 +75,7 @@ class ForagaxEnv(environment.Environment):
75
75
  biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
76
76
  nowrap: bool = False,
77
77
  deterministic_spawn: bool = False,
78
+ teleport_interval: Optional[int] = None,
78
79
  ):
79
80
  super().__init__()
80
81
  self._name = name
@@ -87,6 +88,7 @@ class ForagaxEnv(environment.Environment):
87
88
  self.aperture_size = aperture_size
88
89
  self.nowrap = nowrap
89
90
  self.deterministic_spawn = deterministic_spawn
91
+ self.teleport_interval = teleport_interval
90
92
  objects = (EMPTY,) + objects
91
93
  if self.nowrap:
92
94
  objects = objects + (PADDING,)
@@ -117,6 +119,16 @@ class ForagaxEnv(environment.Environment):
117
119
  [b.stop if b.stop is not None else (-1, -1) for b in biomes]
118
120
  )
119
121
  self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
122
+ self.biome_starts_jax = jnp.array(self.biome_starts)
123
+ self.biome_stops_jax = jnp.array(self.biome_stops)
124
+ biome_centers = []
125
+ for i in range(len(self.biome_starts)):
126
+ start = self.biome_starts[i]
127
+ stop = self.biome_stops[i]
128
+ center_x = (start[0] + stop[0] - 1) // 2
129
+ center_y = (start[1] + stop[1] - 1) // 2
130
+ biome_centers.append((center_x, center_y))
131
+ self.biome_centers_jax = jnp.array(biome_centers)
120
132
  self.biome_masks = []
121
133
  for i in range(self.biome_object_frequencies.shape[0]):
122
134
  # Create mask for the biome
@@ -174,6 +186,23 @@ class ForagaxEnv(environment.Environment):
174
186
  is_blocking = self.object_blocking[obj_at_new_pos]
175
187
  pos = jax.lax.select(is_blocking, state.pos, new_pos)
176
188
 
189
+ # Check for automatic teleport
190
+ if self.teleport_interval is not None:
191
+ should_teleport = jnp.mod(state.time + 1, self.teleport_interval) == 0
192
+ else:
193
+ should_teleport = False
194
+
195
+ def teleport_fn():
196
+ # Calculate squared distances from current position to each biome center
197
+ diffs = self.biome_centers_jax - pos
198
+ distances = jnp.sum(diffs**2, axis=1)
199
+ # Find the index of the furthest biome center
200
+ furthest_idx = jnp.argmax(distances)
201
+ new_pos = self.biome_centers_jax[furthest_idx]
202
+ return new_pos
203
+
204
+ pos = jax.lax.cond(should_teleport, teleport_fn, lambda: pos)
205
+
177
206
  # 2. HANDLE COLLISIONS AND REWARDS
178
207
  obj_at_pos = current_objects[pos[1], pos[0]]
179
208
  key, subkey = jax.random.split(key)
@@ -503,6 +532,7 @@ class ForagaxObjectEnv(ForagaxEnv):
503
532
  biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
504
533
  nowrap: bool = False,
505
534
  deterministic_spawn: bool = False,
535
+ teleport_interval: Optional[int] = None,
506
536
  ):
507
537
  super().__init__(
508
538
  name,
@@ -512,6 +542,7 @@ class ForagaxObjectEnv(ForagaxEnv):
512
542
  biomes,
513
543
  nowrap,
514
544
  deterministic_spawn,
545
+ teleport_interval,
515
546
  )
516
547
 
517
548
  # Compute unique colors and mapping for partial observability
foragax/objects.py CHANGED
@@ -318,6 +318,7 @@ def create_weather_objects(
318
318
  repeat: int = 500,
319
319
  multiplier: float = 1.0,
320
320
  same_color: bool = False,
321
+ random_respawn: bool = False,
321
322
  ):
322
323
  """Create HOT and COLD WeatherObject instances using the specified file.
323
324
 
@@ -346,6 +347,7 @@ def create_weather_objects(
346
347
  repeat=repeat,
347
348
  multiplier=multiplier,
348
349
  color=hot_color,
350
+ random_respawn=random_respawn,
349
351
  )
350
352
 
351
353
  cold_color = hot_color if same_color else (0, 255, 255)
@@ -355,6 +357,7 @@ def create_weather_objects(
355
357
  repeat=repeat,
356
358
  multiplier=-multiplier,
357
359
  color=cold_color,
360
+ random_respawn=random_respawn,
358
361
  )
359
362
 
360
363
  return hot, cold
foragax/registry.py CHANGED
@@ -64,6 +64,27 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
64
64
  "biomes": None,
65
65
  "nowrap": True,
66
66
  },
67
+ "ForagaxWeather-v4": {
68
+ "size": None,
69
+ "aperture_size": None,
70
+ "objects": None,
71
+ "biomes": None,
72
+ "nowrap": True,
73
+ "deterministic_spawn": True,
74
+ },
75
+ "ForagaxWeather-v5": {
76
+ "size": (15, 15),
77
+ "aperture_size": None,
78
+ "objects": None,
79
+ "biomes": (
80
+ # Hot biome
81
+ Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
82
+ # Cold biome
83
+ Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
84
+ ),
85
+ "nowrap": False,
86
+ "deterministic_spawn": True,
87
+ },
67
88
  "ForagaxTwoBiome-v1": {
68
89
  "size": (15, 15),
69
90
  "aperture_size": None,
@@ -272,6 +293,19 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
272
293
  "nowrap": True,
273
294
  "deterministic_spawn": True,
274
295
  },
296
+ "ForagaxTwoBiome-v16": {
297
+ "size": None,
298
+ "aperture_size": None,
299
+ "objects": (
300
+ BROWN_MOREL_UNIFORM_RANDOM,
301
+ BROWN_OYSTER_UNIFORM_RANDOM,
302
+ GREEN_DEATHCAP_UNIFORM_RANDOM,
303
+ GREEN_FAKE_UNIFORM_RANDOM,
304
+ ),
305
+ "biomes": None,
306
+ "nowrap": True,
307
+ "deterministic_spawn": True,
308
+ },
275
309
  "ForagaxTwoBiomeSmall-v1": {
276
310
  "size": (16, 8),
277
311
  "aperture_size": None,
@@ -348,6 +382,7 @@ def make(
348
382
  "ForagaxTwoBiome-v9",
349
383
  "ForagaxTwoBiome-v10",
350
384
  "ForagaxTwoBiome-v15",
385
+ "ForagaxTwoBiome-v16",
351
386
  ):
352
387
  margin = aperture_size[1] // 2 + 1
353
388
  width = 2 * margin + 9
@@ -386,7 +421,7 @@ def make(
386
421
  ),
387
422
  )
388
423
 
389
- if env_id == "ForagaxWeather-v3":
424
+ if env_id in ("ForagaxWeather-v3", "ForagaxWeather-v4"):
390
425
  margin = aperture_size[1] // 2 + 1
391
426
  width = 2 * margin + 9
392
427
  config["size"] = (15, width)
@@ -406,10 +441,21 @@ def make(
406
441
  )
407
442
 
408
443
  if env_id.startswith("ForagaxWeather"):
409
- same_color = env_id in ("ForagaxWeather-v2", "ForagaxWeather-v3")
410
- hot, cold = create_weather_objects(file_index=file_index, same_color=same_color)
444
+ same_color = env_id in (
445
+ "ForagaxWeather-v2",
446
+ "ForagaxWeather-v3",
447
+ "ForagaxWeather-v4",
448
+ "ForagaxWeather-v5",
449
+ )
450
+ random_respawn = env_id in ("ForagaxWeather-v4", "ForagaxWeather-v5")
451
+ hot, cold = create_weather_objects(
452
+ file_index=file_index, same_color=same_color, random_respawn=random_respawn
453
+ )
411
454
  config["objects"] = (hot, cold)
412
455
 
456
+ if env_id == "ForagaxTwoBiome-v16":
457
+ config["teleport_interval"] = 10000
458
+
413
459
  env_class_map = {
414
460
  "object": ForagaxObjectEnv,
415
461
  "rgb": ForagaxRGBEnv,