continual-foragax 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/objects.py CHANGED
@@ -28,9 +28,36 @@ class BaseForagaxObject:
28
28
  self.max_reward_delay = max_reward_delay
29
29
  self.expiry_time = expiry_time
30
30
 
31
+ def get_state(self, key: jax.Array) -> jax.Array:
32
+ """Generate per-object reward parameters to store in the environment state.
33
+
34
+ By default, objects don't use per-instance params. Override this method
35
+ to provide per-instance parameters that will be stored in object_state_grid.
36
+
37
+ Args:
38
+ key: JAX random key for parameter generation
39
+
40
+ Returns:
41
+ Array of parameters (can be empty array for objects without params)
42
+ """
43
+ return jnp.array([], dtype=jnp.float32)
44
+
45
+ @abc.abstractmethod
46
+ def reward(
47
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
48
+ ) -> float:
49
+ """Reward function.
50
+
51
+ Args:
52
+ clock: Current time step
53
+ rng: JAX random key
54
+ params: Optional per-object parameters from object_state_grid
55
+ """
56
+ raise NotImplementedError
57
+
31
58
  @abc.abstractmethod
32
- def reward(self, clock: int, rng: jax.Array) -> float:
33
- """Reward function."""
59
+ def regen_delay(self, clock: int, rng: jax.Array) -> int:
60
+ """Regeneration delay function."""
34
61
  raise NotImplementedError
35
62
 
36
63
  @abc.abstractmethod
@@ -77,7 +104,9 @@ class DefaultForagaxObject(BaseForagaxObject):
77
104
  self.reward_delay_val = reward_delay
78
105
  self.expiry_regen_delay_range = expiry_regen_delay
79
106
 
80
- def reward(self, clock: int, rng: jax.Array) -> float:
107
+ def reward(
108
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
109
+ ) -> float:
81
110
  """Default reward function."""
82
111
  return self.reward_val
83
112
 
@@ -186,11 +215,151 @@ class WeatherObject(NormalRegenForagaxObject):
186
215
  self.rewards = rewards * multiplier
187
216
  self.repeat = repeat
188
217
 
189
- def reward(self, clock: int, rng: jax.Array) -> float:
218
+ def reward(
219
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
220
+ ) -> float:
190
221
  """Reward is based on temperature."""
191
222
  return get_temperature(self.rewards, clock, self.repeat)
192
223
 
193
224
 
225
+ class FourierObject(BaseForagaxObject):
226
+ """Object with reward based on Fourier series with per-instance parameters.
227
+
228
+ This object doesn't respawn on its own. Instead, objects are respawned
229
+ biome-wide when consumption threshold is reached, with new random parameters.
230
+
231
+ The reward function is a Fourier series with random period and harmonics,
232
+ with coefficients scaled by 1/n and normalized to [-1, 1].
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ name: str,
238
+ num_fourier_terms: int = 10,
239
+ base_magnitude: float = 1.0,
240
+ color: Tuple[int, int, int] = (0, 0, 0),
241
+ reward_delay: int = 0,
242
+ max_reward_delay: Optional[int] = None,
243
+ ):
244
+ if max_reward_delay is None:
245
+ max_reward_delay = reward_delay
246
+ super().__init__(
247
+ name=name,
248
+ blocking=False,
249
+ collectable=True,
250
+ color=color,
251
+ random_respawn=False, # Objects don't respawn individually
252
+ max_reward_delay=max_reward_delay,
253
+ expiry_time=None,
254
+ )
255
+ self.num_fourier_terms = num_fourier_terms
256
+ self.base_magnitude = base_magnitude
257
+ self.reward_delay_val = reward_delay
258
+
259
+ def get_state(self, key: jax.Array) -> jax.Array:
260
+ """Generate random Fourier series parameters.
261
+
262
+ Returns array of shape (3 + 2*num_fourier_terms,) containing:
263
+ [period, y_min, y_max, a1, b1, a2, b2, ...]
264
+ """
265
+ # Sample period uniformly from [10, 1000]
266
+ key, period_key = jax.random.split(key)
267
+ period = jax.random.randint(period_key, (), 10, 1001).astype(jnp.float32)
268
+
269
+ # Generate coefficients with 1/n scaling
270
+ key, a_key = jax.random.split(key)
271
+ n_values = jnp.arange(1, self.num_fourier_terms + 1, dtype=jnp.float32)
272
+ a_coeffs = jax.random.normal(a_key, (self.num_fourier_terms,)) / n_values
273
+
274
+ key, b_key = jax.random.split(key)
275
+ b_coeffs = jax.random.normal(b_key, (self.num_fourier_terms,)) / n_values
276
+
277
+ # Compute min-max values for normalization
278
+ num_samples = 1000
279
+ t_samples = jnp.linspace(0, 2 * jnp.pi, num_samples)
280
+ y_samples = jnp.zeros(num_samples)
281
+ for n in range(1, self.num_fourier_terms + 1):
282
+ y_samples += a_coeffs[n - 1] * jnp.cos(n * t_samples)
283
+ y_samples += b_coeffs[n - 1] * jnp.sin(n * t_samples)
284
+
285
+ # Store min and max for proper min-max normalization
286
+ y_min = jnp.min(y_samples)
287
+ y_max = jnp.max(y_samples)
288
+
289
+ # Combine into parameter vector: [period, y_min, y_max, a1, b1, a2, b2, ...]
290
+ ab_interleaved = jnp.empty(2 * self.num_fourier_terms, dtype=jnp.float32)
291
+ ab_interleaved = ab_interleaved.at[::2].set(a_coeffs)
292
+ ab_interleaved = ab_interleaved.at[1::2].set(b_coeffs)
293
+ params_vec = jnp.concatenate(
294
+ [jnp.array([period, y_min, y_max]), ab_interleaved]
295
+ )
296
+
297
+ return params_vec
298
+
299
+ def reward(
300
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
301
+ ) -> float:
302
+ """Compute reward from Fourier series parameters.
303
+
304
+ Args:
305
+ clock: Current timestep
306
+ rng: Random key (unused for Fourier objects)
307
+ params: Array of shape (3 + 2*num_fourier_terms,) containing
308
+ [period, y_min, y_max, a1, b1, a2, b2, ...]
309
+
310
+ Returns:
311
+ Reward value computed from Fourier series, normalized to [-base_magnitude, base_magnitude]
312
+ """
313
+ if params is None or len(params) == 0:
314
+ return 0.0
315
+
316
+ # Extract period and min-max values
317
+ period = params[0]
318
+ y_min = params[1]
319
+ y_max = params[2]
320
+
321
+ # Normalize time to [0, 2π] using the object's period
322
+ t = 2.0 * jnp.pi * (clock % period) / period
323
+
324
+ # Extract interleaved coefficients: [a1, b1, a2, b2, ...]
325
+ ab_coeffs = params[3:]
326
+ n_terms = len(ab_coeffs) // 2
327
+
328
+ # Compute Fourier series: sum(a_n*cos(n*t) + b_n*sin(n*t))
329
+ reward = 0.0
330
+ for i in range(n_terms):
331
+ freq = i + 1
332
+ a_i = ab_coeffs[2 * i] # a coefficient at index 2i
333
+ b_i = ab_coeffs[2 * i + 1] # b coefficient at index 2i+1
334
+ reward += a_i * jnp.cos(freq * t) + b_i * jnp.sin(freq * t)
335
+
336
+ # Apply min-max normalization to [-1, 1], then scale by base_magnitude
337
+ # Formula: 2 * (x - min) / (max - min) - 1
338
+ # If min == max (constant function), return 0
339
+ range_val = jnp.maximum(y_max - y_min, 1e-8) # Avoid division by zero
340
+ # Check if this is a constant function (min == max)
341
+ is_constant = jnp.abs(y_max - y_min) < 1e-8
342
+ reward = jnp.where(
343
+ is_constant,
344
+ 0.0,
345
+ (2.0 * (reward - y_min) / range_val - 1.0) * self.base_magnitude,
346
+ )
347
+
348
+ return reward
349
+
350
+ def reward_delay(self, clock: int, rng: jax.Array) -> int:
351
+ """Reward delay function."""
352
+ return self.reward_delay_val
353
+
354
+ def regen_delay(self, clock: int, rng: jax.Array) -> int:
355
+ """No individual regeneration - returns infinity."""
356
+ return jnp.iinfo(jnp.int32).max
357
+
358
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
359
+ """No expiry regeneration."""
360
+ return jnp.iinfo(jnp.int32).max
361
+
362
+
194
363
  EMPTY = DefaultForagaxObject()
195
364
  WALL = DefaultForagaxObject(name="wall", blocking=True, color=(127, 127, 127))
196
365
  FLOWER = DefaultForagaxObject(
@@ -484,3 +653,181 @@ def create_weather_objects(
484
653
  )
485
654
 
486
655
  return hot, cold
656
+
657
+
658
+ class SineObject(DefaultForagaxObject):
659
+ """Object with reward based on sine wave with a base reward offset.
660
+
661
+ The total reward is: base_reward + amplitude * sin(2*pi * clock / period)
662
+ This allows for objects that have different base behaviors (positive/negative)
663
+ with an underlying sine wave that drives continual learning.
664
+
665
+ Uses uniform distribution for regeneration delays by default (from DefaultForagaxObject).
666
+ """
667
+
668
+ def __init__(
669
+ self,
670
+ name: str,
671
+ base_reward: float = 0.0,
672
+ amplitude: float = 1.0,
673
+ period: int = 1000,
674
+ phase: float = 0.0,
675
+ regen_delay: Tuple[int, int] = (9, 11),
676
+ color: Tuple[int, int, int] = (0, 0, 0),
677
+ random_respawn: bool = False,
678
+ reward_delay: int = 0,
679
+ max_reward_delay: Optional[int] = None,
680
+ expiry_time: Optional[int] = None,
681
+ expiry_regen_delay: Tuple[int, int] = (9, 11),
682
+ ):
683
+ super().__init__(
684
+ name=name,
685
+ reward=base_reward,
686
+ collectable=True,
687
+ regen_delay=regen_delay,
688
+ color=color,
689
+ random_respawn=random_respawn,
690
+ reward_delay=reward_delay,
691
+ max_reward_delay=max_reward_delay,
692
+ expiry_time=expiry_time,
693
+ expiry_regen_delay=expiry_regen_delay,
694
+ )
695
+ self.base_reward = base_reward
696
+ self.amplitude = amplitude
697
+ self.period = period
698
+ self.phase = phase
699
+
700
+ def reward(
701
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
702
+ ) -> float:
703
+ """Reward is base_reward + amplitude * sin(2*pi * clock / period + phase)."""
704
+ sine_value = jnp.sin(2.0 * jnp.pi * clock / self.period + self.phase)
705
+ return self.base_reward + self.amplitude * sine_value
706
+
707
+
708
+ def create_fourier_objects(
709
+ num_fourier_terms: int = 10,
710
+ base_magnitude: float = 1.0,
711
+ reward_delay: int = 0,
712
+ ):
713
+ """Create HOT and COLD FourierObject instances.
714
+
715
+ Args:
716
+ num_fourier_terms: Number of Fourier terms in the reward function (default: 10).
717
+ base_magnitude: Base magnitude for Fourier coefficients.
718
+ reward_delay: Number of steps before reward is delivered.
719
+
720
+ Returns:
721
+ A tuple (HOT, COLD) of FourierObject instances.
722
+ """
723
+ hot = FourierObject(
724
+ name="hot_fourier",
725
+ num_fourier_terms=num_fourier_terms,
726
+ base_magnitude=base_magnitude,
727
+ color=(0, 0, 0),
728
+ reward_delay=reward_delay,
729
+ )
730
+
731
+ cold = FourierObject(
732
+ name="cold_fourier",
733
+ num_fourier_terms=num_fourier_terms,
734
+ base_magnitude=base_magnitude,
735
+ color=(0, 0, 0),
736
+ reward_delay=reward_delay,
737
+ )
738
+
739
+ return hot, cold
740
+
741
+
742
+ def create_sine_biome_objects(
743
+ period: int = 1000,
744
+ amplitude: float = 20.0,
745
+ base_oyster_reward: float = 10.0,
746
+ base_deathcap_reward: float = -10.0,
747
+ regen_delay: Tuple[int, int] = (9, 11),
748
+ reward_delay: int = 0,
749
+ expiry_time: int = 500,
750
+ expiry_regen_delay: Tuple[int, int] = (9, 11),
751
+ ):
752
+ """Create objects for the sine-based two-biome environment.
753
+
754
+ Biome 1 (Left): Oyster (+base_reward), Death Cap (-base_reward)
755
+ Biome 2 (Right): Oyster (-base_reward), Death Cap (+base_reward)
756
+
757
+ Both biomes have an underlying sine curve with the specified amplitude.
758
+ The sine curve of biome 2 is the negative of biome 1 (180 degree phase shift).
759
+
760
+ Objects use uniform respawn and random expiry by default.
761
+
762
+ Args:
763
+ period: Period of the sine wave in timesteps
764
+ amplitude: Amplitude of the sine wave
765
+ base_oyster_reward: Base reward for oyster in biome 1 (will be negated in biome 2)
766
+ base_deathcap_reward: Base reward for death cap in biome 1 (will be negated in biome 2)
767
+ regen_delay: Tuple of (min, max) for uniform regeneration delay
768
+ reward_delay: Number of steps before reward is delivered
769
+ expiry_time: Time steps before object expires (None = no expiry)
770
+ expiry_regen_delay: Tuple of (min, max) for uniform expiry regeneration delay
771
+
772
+ Returns:
773
+ A tuple of (biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap)
774
+ """
775
+ # Biome 1 objects (phase = 0)
776
+ biome1_oyster = SineObject(
777
+ name="oyster_sine_1",
778
+ base_reward=base_oyster_reward,
779
+ amplitude=amplitude,
780
+ period=period,
781
+ phase=0.0,
782
+ regen_delay=regen_delay,
783
+ color=(124, 61, 81), # Oyster color
784
+ random_respawn=True,
785
+ reward_delay=reward_delay,
786
+ expiry_time=expiry_time,
787
+ expiry_regen_delay=expiry_regen_delay,
788
+ )
789
+
790
+ biome1_deathcap = SineObject(
791
+ name="deathcap_sine_1",
792
+ base_reward=base_deathcap_reward,
793
+ amplitude=amplitude,
794
+ period=period,
795
+ phase=0.0,
796
+ regen_delay=regen_delay,
797
+ color=(0, 255, 0), # Green color
798
+ random_respawn=True,
799
+ reward_delay=reward_delay,
800
+ expiry_time=expiry_time,
801
+ expiry_regen_delay=expiry_regen_delay,
802
+ )
803
+
804
+ # Biome 2 objects (phase = pi for 180 degree shift)
805
+ biome2_oyster = SineObject(
806
+ name="oyster_sine_2",
807
+ base_reward=-base_oyster_reward, # Negated
808
+ amplitude=amplitude,
809
+ period=period,
810
+ phase=jnp.pi, # 180 degree phase shift (negative of biome 1)
811
+ regen_delay=regen_delay,
812
+ color=(124, 61, 81), # Same oyster color
813
+ random_respawn=True,
814
+ reward_delay=reward_delay,
815
+ expiry_time=expiry_time,
816
+ expiry_regen_delay=expiry_regen_delay,
817
+ )
818
+
819
+ biome2_deathcap = SineObject(
820
+ name="deathcap_sine_2",
821
+ base_reward=-base_deathcap_reward, # Negated
822
+ amplitude=amplitude,
823
+ period=period,
824
+ phase=jnp.pi, # 180 degree phase shift (negative of biome 1)
825
+ regen_delay=regen_delay,
826
+ color=(0, 255, 0), # Same green color
827
+ random_respawn=True,
828
+ reward_delay=reward_delay,
829
+ expiry_time=expiry_time,
830
+ expiry_regen_delay=expiry_regen_delay,
831
+ )
832
+
833
+ return biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap
foragax/registry.py CHANGED
@@ -31,6 +31,8 @@ from foragax.objects import (
31
31
  LARGE_MOREL,
32
32
  LARGE_OYSTER,
33
33
  MEDIUM_MOREL,
34
+ create_fourier_objects,
35
+ create_sine_biome_objects,
34
36
  create_weather_objects,
35
37
  )
36
38
 
@@ -87,6 +89,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
87
89
  "nowrap": False,
88
90
  "deterministic_spawn": True,
89
91
  },
92
+ "ForagaxDiwali-v1": {
93
+ "size": (15, 15),
94
+ "aperture_size": None,
95
+ "objects": None,
96
+ "biomes": (
97
+ # Hot biome
98
+ Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
99
+ # Cold biome
100
+ Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
101
+ ),
102
+ "nowrap": False,
103
+ "deterministic_spawn": True,
104
+ "dynamic_biomes": True,
105
+ "biome_consumption_threshold": 0.9,
106
+ },
90
107
  "ForagaxTwoBiome-v1": {
91
108
  "size": (15, 15),
92
109
  "aperture_size": None,
@@ -364,6 +381,22 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
364
381
  ),
365
382
  "nowrap": True,
366
383
  },
384
+ "ForagaxSineTwoBiome-v1": {
385
+ "size": (15, 15),
386
+ "aperture_size": None,
387
+ "objects": None,
388
+ "biomes": (
389
+ # Biome 1 (left): Oyster +10, Death Cap -10 with sine
390
+ Biome(
391
+ start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.25, 0.0, 0.0)
392
+ ),
393
+ # Biome 2 (right): Oyster -10, Death Cap +10 with inverted sine
394
+ Biome(
395
+ start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.0, 0.25, 0.25)
396
+ ),
397
+ ),
398
+ "nowrap": False,
399
+ },
367
400
  }
368
401
 
369
402
 
@@ -496,6 +529,32 @@ def make(
496
529
  )
497
530
  config["objects"] = (hot, cold)
498
531
 
532
+ if env_id == "ForagaxDiwali-v1":
533
+ config["objects"] = create_fourier_objects(
534
+ num_fourier_terms=10,
535
+ reward_delay=reward_delay,
536
+ )
537
+
538
+ if env_id == "ForagaxSineTwoBiome-v1":
539
+ biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (
540
+ create_sine_biome_objects(
541
+ period=1000,
542
+ amplitude=20.0,
543
+ base_oyster_reward=10.0,
544
+ base_deathcap_reward=-10.0,
545
+ regen_delay=(9, 11),
546
+ reward_delay=reward_delay,
547
+ expiry_time=500,
548
+ expiry_regen_delay=(9, 11),
549
+ )
550
+ )
551
+ config["objects"] = (
552
+ biome1_oyster,
553
+ biome1_deathcap,
554
+ biome2_oyster,
555
+ biome2_deathcap,
556
+ )
557
+
499
558
  if env_id == "ForagaxTwoBiome-v16":
500
559
  config["teleport_interval"] = 10000
501
560