continual-foragax 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/objects.py CHANGED
@@ -18,6 +18,7 @@ class BaseForagaxObject:
18
18
  color: Tuple[int, int, int] = (0, 0, 0),
19
19
  random_respawn: bool = False,
20
20
  max_reward_delay: int = 0,
21
+ expiry_time: Optional[int] = None,
21
22
  ):
22
23
  self.name = name
23
24
  self.blocking = blocking
@@ -25,10 +26,38 @@ class BaseForagaxObject:
25
26
  self.color = color
26
27
  self.random_respawn = random_respawn
27
28
  self.max_reward_delay = max_reward_delay
29
+ self.expiry_time = expiry_time
30
+
31
+ def get_state(self, key: jax.Array) -> jax.Array:
32
+ """Generate per-object reward parameters to store in the environment state.
33
+
34
+ By default, objects don't use per-instance params. Override this method
35
+ to provide per-instance parameters that will be stored in object_state_grid.
36
+
37
+ Args:
38
+ key: JAX random key for parameter generation
39
+
40
+ Returns:
41
+ Array of parameters (can be empty array for objects without params)
42
+ """
43
+ return jnp.array([], dtype=jnp.float32)
28
44
 
29
45
  @abc.abstractmethod
30
- def reward(self, clock: int, rng: jax.Array) -> float:
31
- """Reward function."""
46
+ def reward(
47
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
48
+ ) -> float:
49
+ """Reward function.
50
+
51
+ Args:
52
+ clock: Current time step
53
+ rng: JAX random key
54
+ params: Optional per-object parameters from object_state_grid
55
+ """
56
+ raise NotImplementedError
57
+
58
+ @abc.abstractmethod
59
+ def regen_delay(self, clock: int, rng: jax.Array) -> int:
60
+ """Regeneration delay function."""
32
61
  raise NotImplementedError
33
62
 
34
63
  @abc.abstractmethod
@@ -36,6 +65,11 @@ class BaseForagaxObject:
36
65
  """Reward delay function."""
37
66
  raise NotImplementedError
38
67
 
68
+ @abc.abstractmethod
69
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
70
+ """Expiry regeneration delay function."""
71
+ raise NotImplementedError
72
+
39
73
 
40
74
  class DefaultForagaxObject(BaseForagaxObject):
41
75
  """Base class for default objects in the Foragax environment."""
@@ -51,17 +85,28 @@ class DefaultForagaxObject(BaseForagaxObject):
51
85
  random_respawn: bool = False,
52
86
  reward_delay: int = 0,
53
87
  max_reward_delay: Optional[int] = None,
88
+ expiry_time: Optional[int] = None,
89
+ expiry_regen_delay: Tuple[int, int] = (10, 100),
54
90
  ):
55
91
  if max_reward_delay is None:
56
92
  max_reward_delay = reward_delay
57
93
  super().__init__(
58
- name, blocking, collectable, color, random_respawn, max_reward_delay
94
+ name,
95
+ blocking,
96
+ collectable,
97
+ color,
98
+ random_respawn,
99
+ max_reward_delay,
100
+ expiry_time,
59
101
  )
60
102
  self.reward_val = reward
61
103
  self.regen_delay_range = regen_delay
62
104
  self.reward_delay_val = reward_delay
105
+ self.expiry_regen_delay_range = expiry_regen_delay
63
106
 
64
- def reward(self, clock: int, rng: jax.Array) -> float:
107
+ def reward(
108
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
109
+ ) -> float:
65
110
  """Default reward function."""
66
111
  return self.reward_val
67
112
 
@@ -74,6 +119,11 @@ class DefaultForagaxObject(BaseForagaxObject):
74
119
  """Default reward delay function."""
75
120
  return self.reward_delay_val
76
121
 
122
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
123
+ """Default expiry regeneration delay function."""
124
+ min_delay, max_delay = self.expiry_regen_delay_range
125
+ return jax.random.randint(rng, (), min_delay, max_delay)
126
+
77
127
 
78
128
  class NormalRegenForagaxObject(DefaultForagaxObject):
79
129
  """Object with regeneration delay from a normal distribution."""
@@ -89,7 +139,16 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
89
139
  random_respawn: bool = False,
90
140
  reward_delay: int = 0,
91
141
  max_reward_delay: Optional[int] = None,
142
+ expiry_time: Optional[int] = None,
143
+ mean_expiry_regen_delay: Optional[int] = None,
144
+ std_expiry_regen_delay: Optional[int] = None,
92
145
  ):
146
+ # If expiry regen delays not provided, use same as normal regen
147
+ if mean_expiry_regen_delay is None:
148
+ mean_expiry_regen_delay = mean_regen_delay
149
+ if std_expiry_regen_delay is None:
150
+ std_expiry_regen_delay = std_regen_delay
151
+
93
152
  super().__init__(
94
153
  name=name,
95
154
  reward=reward,
@@ -99,15 +158,27 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
99
158
  random_respawn=random_respawn,
100
159
  reward_delay=reward_delay,
101
160
  max_reward_delay=max_reward_delay,
161
+ expiry_time=expiry_time,
162
+ expiry_regen_delay=(mean_expiry_regen_delay, mean_expiry_regen_delay),
102
163
  )
103
164
  self.mean_regen_delay = mean_regen_delay
104
165
  self.std_regen_delay = std_regen_delay
166
+ self.mean_expiry_regen_delay = mean_expiry_regen_delay
167
+ self.std_expiry_regen_delay = std_expiry_regen_delay
105
168
 
106
169
  def regen_delay(self, clock: int, rng: jax.Array) -> int:
107
170
  """Regeneration delay from a normal distribution."""
108
171
  delay = self.mean_regen_delay + jax.random.normal(rng) * self.std_regen_delay
109
172
  return jnp.maximum(0, delay).astype(int)
110
173
 
174
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
175
+ """Expiry regeneration delay from a normal distribution."""
176
+ delay = (
177
+ self.mean_expiry_regen_delay
178
+ + jax.random.normal(rng) * self.std_expiry_regen_delay
179
+ )
180
+ return jnp.maximum(0, delay).astype(int)
181
+
111
182
 
112
183
  class WeatherObject(NormalRegenForagaxObject):
113
184
  """Object with reward based on temperature data."""
@@ -124,6 +195,9 @@ class WeatherObject(NormalRegenForagaxObject):
124
195
  random_respawn: bool = False,
125
196
  reward_delay: int = 0,
126
197
  max_reward_delay: Optional[int] = None,
198
+ expiry_time: Optional[int] = None,
199
+ mean_expiry_regen_delay: Optional[int] = None,
200
+ std_expiry_regen_delay: Optional[int] = None,
127
201
  ):
128
202
  super().__init__(
129
203
  name=name,
@@ -134,15 +208,158 @@ class WeatherObject(NormalRegenForagaxObject):
134
208
  random_respawn=random_respawn,
135
209
  reward_delay=reward_delay,
136
210
  max_reward_delay=max_reward_delay,
211
+ expiry_time=expiry_time,
212
+ mean_expiry_regen_delay=mean_expiry_regen_delay,
213
+ std_expiry_regen_delay=std_expiry_regen_delay,
137
214
  )
138
215
  self.rewards = rewards * multiplier
139
216
  self.repeat = repeat
140
217
 
141
- def reward(self, clock: int, rng: jax.Array) -> float:
218
+ def reward(
219
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
220
+ ) -> float:
142
221
  """Reward is based on temperature."""
143
222
  return get_temperature(self.rewards, clock, self.repeat)
144
223
 
145
224
 
225
+ class FourierObject(BaseForagaxObject):
226
+ """Object with reward based on Fourier series with per-instance parameters.
227
+
228
+ This object doesn't respawn on its own. Instead, objects are respawned
229
+ biome-wide when consumption threshold is reached, with new random parameters.
230
+
231
+ The reward function is a Fourier series with random period and harmonics,
232
+ with coefficients scaled by 1/n and normalized to [-1, 1].
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ name: str,
238
+ num_fourier_terms: int = 10,
239
+ base_magnitude: float = 1.0,
240
+ color: Tuple[int, int, int] = (0, 0, 0),
241
+ reward_delay: int = 0,
242
+ max_reward_delay: Optional[int] = None,
243
+ ):
244
+ if max_reward_delay is None:
245
+ max_reward_delay = reward_delay
246
+ super().__init__(
247
+ name=name,
248
+ blocking=False,
249
+ collectable=True,
250
+ color=color,
251
+ random_respawn=False, # Objects don't respawn individually
252
+ max_reward_delay=max_reward_delay,
253
+ expiry_time=None,
254
+ )
255
+ self.num_fourier_terms = num_fourier_terms
256
+ self.base_magnitude = base_magnitude
257
+ self.reward_delay_val = reward_delay
258
+
259
+ def get_state(self, key: jax.Array) -> jax.Array:
260
+ """Generate random Fourier series parameters.
261
+
262
+ Returns array of shape (3 + 2*num_fourier_terms,) containing:
263
+ [period, y_min, y_max, a1, b1, a2, b2, ...]
264
+ """
265
+ # Sample period uniformly from [10, 1000]
266
+ key, period_key = jax.random.split(key)
267
+ period = jax.random.randint(period_key, (), 10, 1001).astype(jnp.float32)
268
+
269
+ # Generate coefficients with 1/n scaling
270
+ key, a_key = jax.random.split(key)
271
+ n_values = jnp.arange(1, self.num_fourier_terms + 1, dtype=jnp.float32)
272
+ a_coeffs = jax.random.normal(a_key, (self.num_fourier_terms,)) / n_values
273
+
274
+ key, b_key = jax.random.split(key)
275
+ b_coeffs = jax.random.normal(b_key, (self.num_fourier_terms,)) / n_values
276
+
277
+ # Compute min-max values for normalization
278
+ num_samples = 1000
279
+ t_samples = jnp.linspace(0, 2 * jnp.pi, num_samples)
280
+ y_samples = jnp.zeros(num_samples)
281
+ for n in range(1, self.num_fourier_terms + 1):
282
+ y_samples += a_coeffs[n - 1] * jnp.cos(n * t_samples)
283
+ y_samples += b_coeffs[n - 1] * jnp.sin(n * t_samples)
284
+
285
+ # Store min and max for proper min-max normalization
286
+ y_min = jnp.min(y_samples)
287
+ y_max = jnp.max(y_samples)
288
+
289
+ # Combine into parameter vector: [period, y_min, y_max, a1, b1, a2, b2, ...]
290
+ ab_interleaved = jnp.empty(2 * self.num_fourier_terms, dtype=jnp.float32)
291
+ ab_interleaved = ab_interleaved.at[::2].set(a_coeffs)
292
+ ab_interleaved = ab_interleaved.at[1::2].set(b_coeffs)
293
+ params_vec = jnp.concatenate(
294
+ [jnp.array([period, y_min, y_max]), ab_interleaved]
295
+ )
296
+
297
+ return params_vec
298
+
299
+ def reward(
300
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
301
+ ) -> float:
302
+ """Compute reward from Fourier series parameters.
303
+
304
+ Args:
305
+ clock: Current timestep
306
+ rng: Random key (unused for Fourier objects)
307
+ params: Array of shape (3 + 2*num_fourier_terms,) containing
308
+ [period, y_min, y_max, a1, b1, a2, b2, ...]
309
+
310
+ Returns:
311
+ Reward value computed from Fourier series, normalized to [-base_magnitude, base_magnitude]
312
+ """
313
+ if params is None or len(params) == 0:
314
+ return 0.0
315
+
316
+ # Extract period and min-max values
317
+ period = params[0]
318
+ y_min = params[1]
319
+ y_max = params[2]
320
+
321
+ # Normalize time to [0, 2π] using the object's period
322
+ t = 2.0 * jnp.pi * (clock % period) / period
323
+
324
+ # Extract interleaved coefficients: [a1, b1, a2, b2, ...]
325
+ ab_coeffs = params[3:]
326
+ n_terms = len(ab_coeffs) // 2
327
+
328
+ # Compute Fourier series: sum(a_n*cos(n*t) + b_n*sin(n*t))
329
+ reward = 0.0
330
+ for i in range(n_terms):
331
+ freq = i + 1
332
+ a_i = ab_coeffs[2 * i] # a coefficient at index 2i
333
+ b_i = ab_coeffs[2 * i + 1] # b coefficient at index 2i+1
334
+ reward += a_i * jnp.cos(freq * t) + b_i * jnp.sin(freq * t)
335
+
336
+ # Apply min-max normalization to [-1, 1], then scale by base_magnitude
337
+ # Formula: 2 * (x - min) / (max - min) - 1
338
+ # If min == max (constant function), return 0
339
+ range_val = jnp.maximum(y_max - y_min, 1e-8) # Avoid division by zero
340
+ # Check if this is a constant function (min == max)
341
+ is_constant = jnp.abs(y_max - y_min) < 1e-8
342
+ reward = jnp.where(
343
+ is_constant,
344
+ 0.0,
345
+ (2.0 * (reward - y_min) / range_val - 1.0) * self.base_magnitude,
346
+ )
347
+
348
+ return reward
349
+
350
+ def reward_delay(self, clock: int, rng: jax.Array) -> int:
351
+ """Reward delay function."""
352
+ return self.reward_delay_val
353
+
354
+ def regen_delay(self, clock: int, rng: jax.Array) -> int:
355
+ """No individual regeneration - returns infinity."""
356
+ return jnp.iinfo(jnp.int32).max
357
+
358
+ def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
359
+ """No expiry regeneration."""
360
+ return jnp.iinfo(jnp.int32).max
361
+
362
+
146
363
  EMPTY = DefaultForagaxObject()
147
364
  WALL = DefaultForagaxObject(name="wall", blocking=True, color=(127, 127, 127))
148
365
  FLOWER = DefaultForagaxObject(
@@ -332,6 +549,44 @@ GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
332
549
  random_respawn=True,
333
550
  )
334
551
 
552
+ # Random respawn variants with expiry
553
+ BROWN_MOREL_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
554
+ name="brown_morel",
555
+ reward=10.0,
556
+ collectable=True,
557
+ color=(63, 30, 25),
558
+ regen_delay=(90, 110),
559
+ random_respawn=True,
560
+ expiry_time=500,
561
+ )
562
+ BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
563
+ name="brown_oyster",
564
+ reward=1.0,
565
+ collectable=True,
566
+ color=(63, 30, 25),
567
+ regen_delay=(9, 11),
568
+ random_respawn=True,
569
+ expiry_time=500,
570
+ )
571
+ GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
572
+ name="green_deathcap",
573
+ reward=-5.0,
574
+ collectable=True,
575
+ color=(0, 255, 0),
576
+ regen_delay=(9, 11),
577
+ random_respawn=True,
578
+ expiry_time=500,
579
+ )
580
+ GREEN_FAKE_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
581
+ name="green_fake",
582
+ reward=0.0,
583
+ collectable=True,
584
+ color=(0, 255, 0),
585
+ regen_delay=(9, 11),
586
+ random_respawn=True,
587
+ expiry_time=500,
588
+ )
589
+
335
590
 
336
591
  def create_weather_objects(
337
592
  file_index: int = 0,
@@ -340,6 +595,9 @@ def create_weather_objects(
340
595
  same_color: bool = False,
341
596
  random_respawn: bool = False,
342
597
  reward_delay: int = 0,
598
+ expiry_time: Optional[int] = None,
599
+ mean_expiry_regen_delay: Optional[int] = None,
600
+ std_expiry_regen_delay: Optional[int] = None,
343
601
  ):
344
602
  """Create HOT and COLD WeatherObject instances using the specified file.
345
603
 
@@ -348,6 +606,11 @@ def create_weather_objects(
348
606
  repeat: How many steps each temperature value repeats for.
349
607
  multiplier: Base multiplier applied to HOT; COLD will use -multiplier.
350
608
  same_color: If True, both HOT and COLD use the same color.
609
+ random_respawn: If True, objects respawn at random locations.
610
+ reward_delay: Number of steps before reward is delivered.
611
+ expiry_time: Time steps before object expires (None = no expiry).
612
+ mean_expiry_regen_delay: Mean delay for expiry respawn.
613
+ std_expiry_regen_delay: Standard deviation for expiry respawn delay.
351
614
 
352
615
  Returns:
353
616
  A tuple (HOT, COLD) of WeatherObject instances.
@@ -370,6 +633,9 @@ def create_weather_objects(
370
633
  color=hot_color,
371
634
  random_respawn=random_respawn,
372
635
  reward_delay=reward_delay,
636
+ expiry_time=expiry_time,
637
+ mean_expiry_regen_delay=mean_expiry_regen_delay,
638
+ std_expiry_regen_delay=std_expiry_regen_delay,
373
639
  )
374
640
 
375
641
  cold_color = hot_color if same_color else (0, 255, 255)
@@ -381,6 +647,187 @@ def create_weather_objects(
381
647
  color=cold_color,
382
648
  random_respawn=random_respawn,
383
649
  reward_delay=reward_delay,
650
+ expiry_time=expiry_time,
651
+ mean_expiry_regen_delay=mean_expiry_regen_delay,
652
+ std_expiry_regen_delay=std_expiry_regen_delay,
653
+ )
654
+
655
+ return hot, cold
656
+
657
+
658
+ class SineObject(DefaultForagaxObject):
659
+ """Object with reward based on sine wave with a base reward offset.
660
+
661
+ The total reward is: base_reward + amplitude * sin(2*pi * clock / period)
662
+ This allows for objects that have different base behaviors (positive/negative)
663
+ with an underlying sine wave that drives continual learning.
664
+
665
+ Uses uniform distribution for regeneration delays by default (from DefaultForagaxObject).
666
+ """
667
+
668
+ def __init__(
669
+ self,
670
+ name: str,
671
+ base_reward: float = 0.0,
672
+ amplitude: float = 1.0,
673
+ period: int = 1000,
674
+ phase: float = 0.0,
675
+ regen_delay: Tuple[int, int] = (9, 11),
676
+ color: Tuple[int, int, int] = (0, 0, 0),
677
+ random_respawn: bool = False,
678
+ reward_delay: int = 0,
679
+ max_reward_delay: Optional[int] = None,
680
+ expiry_time: Optional[int] = None,
681
+ expiry_regen_delay: Tuple[int, int] = (9, 11),
682
+ ):
683
+ super().__init__(
684
+ name=name,
685
+ reward=base_reward,
686
+ collectable=True,
687
+ regen_delay=regen_delay,
688
+ color=color,
689
+ random_respawn=random_respawn,
690
+ reward_delay=reward_delay,
691
+ max_reward_delay=max_reward_delay,
692
+ expiry_time=expiry_time,
693
+ expiry_regen_delay=expiry_regen_delay,
694
+ )
695
+ self.base_reward = base_reward
696
+ self.amplitude = amplitude
697
+ self.period = period
698
+ self.phase = phase
699
+
700
+ def reward(
701
+ self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
702
+ ) -> float:
703
+ """Reward is base_reward + amplitude * sin(2*pi * clock / period + phase)."""
704
+ sine_value = jnp.sin(2.0 * jnp.pi * clock / self.period + self.phase)
705
+ return self.base_reward + self.amplitude * sine_value
706
+
707
+
708
+ def create_fourier_objects(
709
+ num_fourier_terms: int = 10,
710
+ base_magnitude: float = 1.0,
711
+ reward_delay: int = 0,
712
+ ):
713
+ """Create HOT and COLD FourierObject instances.
714
+
715
+ Args:
716
+ num_fourier_terms: Number of Fourier terms in the reward function (default: 10).
717
+ base_magnitude: Base magnitude for Fourier coefficients.
718
+ reward_delay: Number of steps before reward is delivered.
719
+
720
+ Returns:
721
+ A tuple (HOT, COLD) of FourierObject instances.
722
+ """
723
+ hot = FourierObject(
724
+ name="hot_fourier",
725
+ num_fourier_terms=num_fourier_terms,
726
+ base_magnitude=base_magnitude,
727
+ color=(0, 0, 0),
728
+ reward_delay=reward_delay,
729
+ )
730
+
731
+ cold = FourierObject(
732
+ name="cold_fourier",
733
+ num_fourier_terms=num_fourier_terms,
734
+ base_magnitude=base_magnitude,
735
+ color=(0, 0, 0),
736
+ reward_delay=reward_delay,
384
737
  )
385
738
 
386
739
  return hot, cold
740
+
741
+
742
+ def create_sine_biome_objects(
743
+ period: int = 1000,
744
+ amplitude: float = 20.0,
745
+ base_oyster_reward: float = 10.0,
746
+ base_deathcap_reward: float = -10.0,
747
+ regen_delay: Tuple[int, int] = (9, 11),
748
+ reward_delay: int = 0,
749
+ expiry_time: int = 500,
750
+ expiry_regen_delay: Tuple[int, int] = (9, 11),
751
+ ):
752
+ """Create objects for the sine-based two-biome environment.
753
+
754
+ Biome 1 (Left): Oyster (+base_reward), Death Cap (-base_reward)
755
+ Biome 2 (Right): Oyster (-base_reward), Death Cap (+base_reward)
756
+
757
+ Both biomes have an underlying sine curve with the specified amplitude.
758
+ The sine curve of biome 2 is the negative of biome 1 (180 degree phase shift).
759
+
760
+ Objects use uniform respawn and random expiry by default.
761
+
762
+ Args:
763
+ period: Period of the sine wave in timesteps
764
+ amplitude: Amplitude of the sine wave
765
+ base_oyster_reward: Base reward for oyster in biome 1 (will be negated in biome 2)
766
+ base_deathcap_reward: Base reward for death cap in biome 1 (will be negated in biome 2)
767
+ regen_delay: Tuple of (min, max) for uniform regeneration delay
768
+ reward_delay: Number of steps before reward is delivered
769
+ expiry_time: Time steps before object expires (None = no expiry)
770
+ expiry_regen_delay: Tuple of (min, max) for uniform expiry regeneration delay
771
+
772
+ Returns:
773
+ A tuple of (biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap)
774
+ """
775
+ # Biome 1 objects (phase = 0)
776
+ biome1_oyster = SineObject(
777
+ name="oyster_sine_1",
778
+ base_reward=base_oyster_reward,
779
+ amplitude=amplitude,
780
+ period=period,
781
+ phase=0.0,
782
+ regen_delay=regen_delay,
783
+ color=(124, 61, 81), # Oyster color
784
+ random_respawn=True,
785
+ reward_delay=reward_delay,
786
+ expiry_time=expiry_time,
787
+ expiry_regen_delay=expiry_regen_delay,
788
+ )
789
+
790
+ biome1_deathcap = SineObject(
791
+ name="deathcap_sine_1",
792
+ base_reward=base_deathcap_reward,
793
+ amplitude=amplitude,
794
+ period=period,
795
+ phase=0.0,
796
+ regen_delay=regen_delay,
797
+ color=(0, 255, 0), # Green color
798
+ random_respawn=True,
799
+ reward_delay=reward_delay,
800
+ expiry_time=expiry_time,
801
+ expiry_regen_delay=expiry_regen_delay,
802
+ )
803
+
804
+ # Biome 2 objects (phase = pi for 180 degree shift)
805
+ biome2_oyster = SineObject(
806
+ name="oyster_sine_2",
807
+ base_reward=-base_oyster_reward, # Negated
808
+ amplitude=amplitude,
809
+ period=period,
810
+ phase=jnp.pi, # 180 degree phase shift (negative of biome 1)
811
+ regen_delay=regen_delay,
812
+ color=(124, 61, 81), # Same oyster color
813
+ random_respawn=True,
814
+ reward_delay=reward_delay,
815
+ expiry_time=expiry_time,
816
+ expiry_regen_delay=expiry_regen_delay,
817
+ )
818
+
819
+ biome2_deathcap = SineObject(
820
+ name="deathcap_sine_2",
821
+ base_reward=-base_deathcap_reward, # Negated
822
+ amplitude=amplitude,
823
+ period=period,
824
+ phase=jnp.pi, # 180 degree phase shift (negative of biome 1)
825
+ regen_delay=regen_delay,
826
+ color=(0, 255, 0), # Same green color
827
+ random_respawn=True,
828
+ reward_delay=reward_delay,
829
+ expiry_time=expiry_time,
830
+ expiry_regen_delay=expiry_regen_delay,
831
+ )
832
+
833
+ return biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap
foragax/registry.py CHANGED
@@ -12,21 +12,27 @@ from foragax.objects import (
12
12
  BROWN_MOREL_2,
13
13
  BROWN_MOREL_UNIFORM,
14
14
  BROWN_MOREL_UNIFORM_RANDOM,
15
+ BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
15
16
  BROWN_OYSTER,
16
17
  BROWN_OYSTER_UNIFORM,
17
18
  BROWN_OYSTER_UNIFORM_RANDOM,
19
+ BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
18
20
  GREEN_DEATHCAP,
19
21
  GREEN_DEATHCAP_2,
20
22
  GREEN_DEATHCAP_3,
21
23
  GREEN_DEATHCAP_UNIFORM,
22
24
  GREEN_DEATHCAP_UNIFORM_RANDOM,
25
+ GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
23
26
  GREEN_FAKE,
24
27
  GREEN_FAKE_2,
25
28
  GREEN_FAKE_UNIFORM,
26
29
  GREEN_FAKE_UNIFORM_RANDOM,
30
+ GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
27
31
  LARGE_MOREL,
28
32
  LARGE_OYSTER,
29
33
  MEDIUM_MOREL,
34
+ create_fourier_objects,
35
+ create_sine_biome_objects,
30
36
  create_weather_objects,
31
37
  )
32
38
 
@@ -83,6 +89,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
83
89
  "nowrap": False,
84
90
  "deterministic_spawn": True,
85
91
  },
92
+ "ForagaxDiwali-v1": {
93
+ "size": (15, 15),
94
+ "aperture_size": None,
95
+ "objects": None,
96
+ "biomes": (
97
+ # Hot biome
98
+ Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
99
+ # Cold biome
100
+ Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
101
+ ),
102
+ "nowrap": False,
103
+ "deterministic_spawn": True,
104
+ "dynamic_biomes": True,
105
+ "biome_consumption_threshold": 0.9,
106
+ },
86
107
  "ForagaxTwoBiome-v1": {
87
108
  "size": (15, 15),
88
109
  "aperture_size": None,
@@ -304,6 +325,26 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
304
325
  "nowrap": True,
305
326
  "deterministic_spawn": True,
306
327
  },
328
+ "ForagaxTwoBiome-v17": {
329
+ "size": (15, 15),
330
+ "aperture_size": None,
331
+ "objects": (
332
+ BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
333
+ BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
334
+ GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
335
+ GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
336
+ ),
337
+ "biomes": (
338
+ # Morel biome
339
+ Biome(start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.0, 0.5, 0.0)),
340
+ # Oyster biome
341
+ Biome(
342
+ start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.25, 0.0, 0.5)
343
+ ),
344
+ ),
345
+ "nowrap": False,
346
+ "deterministic_spawn": True,
347
+ },
307
348
  "ForagaxTwoBiomeSmall-v1": {
308
349
  "size": (16, 8),
309
350
  "aperture_size": None,
@@ -340,6 +381,22 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
340
381
  ),
341
382
  "nowrap": True,
342
383
  },
384
+ "ForagaxSineTwoBiome-v1": {
385
+ "size": (15, 15),
386
+ "aperture_size": None,
387
+ "objects": None,
388
+ "biomes": (
389
+ # Biome 1 (left): Oyster +10, Death Cap -10 with sine
390
+ Biome(
391
+ start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.25, 0.0, 0.0)
392
+ ),
393
+ # Biome 2 (right): Oyster -10, Death Cap +10 with inverted sine
394
+ Biome(
395
+ start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.0, 0.25, 0.25)
396
+ ),
397
+ ),
398
+ "nowrap": False,
399
+ },
343
400
  }
344
401
 
345
402
 
@@ -472,6 +529,32 @@ def make(
472
529
  )
473
530
  config["objects"] = (hot, cold)
474
531
 
532
+ if env_id == "ForagaxDiwali-v1":
533
+ config["objects"] = create_fourier_objects(
534
+ num_fourier_terms=10,
535
+ reward_delay=reward_delay,
536
+ )
537
+
538
+ if env_id == "ForagaxSineTwoBiome-v1":
539
+ biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (
540
+ create_sine_biome_objects(
541
+ period=1000,
542
+ amplitude=20.0,
543
+ base_oyster_reward=10.0,
544
+ base_deathcap_reward=-10.0,
545
+ regen_delay=(9, 11),
546
+ reward_delay=reward_delay,
547
+ expiry_time=500,
548
+ expiry_regen_delay=(9, 11),
549
+ )
550
+ )
551
+ config["objects"] = (
552
+ biome1_oyster,
553
+ biome1_deathcap,
554
+ biome2_oyster,
555
+ biome2_deathcap,
556
+ )
557
+
475
558
  if env_id == "ForagaxTwoBiome-v16":
476
559
  config["teleport_interval"] = 10000
477
560