continual-foragax 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/RECORD +8 -8
- foragax/env.py +783 -132
- foragax/objects.py +452 -5
- foragax/registry.py +83 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.30.1.dist-info → continual_foragax-0.32.0.dist-info}/top_level.txt +0 -0
foragax/objects.py
CHANGED
|
@@ -18,6 +18,7 @@ class BaseForagaxObject:
|
|
|
18
18
|
color: Tuple[int, int, int] = (0, 0, 0),
|
|
19
19
|
random_respawn: bool = False,
|
|
20
20
|
max_reward_delay: int = 0,
|
|
21
|
+
expiry_time: Optional[int] = None,
|
|
21
22
|
):
|
|
22
23
|
self.name = name
|
|
23
24
|
self.blocking = blocking
|
|
@@ -25,10 +26,38 @@ class BaseForagaxObject:
|
|
|
25
26
|
self.color = color
|
|
26
27
|
self.random_respawn = random_respawn
|
|
27
28
|
self.max_reward_delay = max_reward_delay
|
|
29
|
+
self.expiry_time = expiry_time
|
|
30
|
+
|
|
31
|
+
def get_state(self, key: jax.Array) -> jax.Array:
|
|
32
|
+
"""Generate per-object reward parameters to store in the environment state.
|
|
33
|
+
|
|
34
|
+
By default, objects don't use per-instance params. Override this method
|
|
35
|
+
to provide per-instance parameters that will be stored in object_state_grid.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
key: JAX random key for parameter generation
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Array of parameters (can be empty array for objects without params)
|
|
42
|
+
"""
|
|
43
|
+
return jnp.array([], dtype=jnp.float32)
|
|
28
44
|
|
|
29
45
|
@abc.abstractmethod
|
|
30
|
-
def reward(
|
|
31
|
-
|
|
46
|
+
def reward(
|
|
47
|
+
self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
|
|
48
|
+
) -> float:
|
|
49
|
+
"""Reward function.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
clock: Current time step
|
|
53
|
+
rng: JAX random key
|
|
54
|
+
params: Optional per-object parameters from object_state_grid
|
|
55
|
+
"""
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
@abc.abstractmethod
|
|
59
|
+
def regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
60
|
+
"""Regeneration delay function."""
|
|
32
61
|
raise NotImplementedError
|
|
33
62
|
|
|
34
63
|
@abc.abstractmethod
|
|
@@ -36,6 +65,11 @@ class BaseForagaxObject:
|
|
|
36
65
|
"""Reward delay function."""
|
|
37
66
|
raise NotImplementedError
|
|
38
67
|
|
|
68
|
+
@abc.abstractmethod
|
|
69
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
70
|
+
"""Expiry regeneration delay function."""
|
|
71
|
+
raise NotImplementedError
|
|
72
|
+
|
|
39
73
|
|
|
40
74
|
class DefaultForagaxObject(BaseForagaxObject):
|
|
41
75
|
"""Base class for default objects in the Foragax environment."""
|
|
@@ -51,17 +85,28 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
|
51
85
|
random_respawn: bool = False,
|
|
52
86
|
reward_delay: int = 0,
|
|
53
87
|
max_reward_delay: Optional[int] = None,
|
|
88
|
+
expiry_time: Optional[int] = None,
|
|
89
|
+
expiry_regen_delay: Tuple[int, int] = (10, 100),
|
|
54
90
|
):
|
|
55
91
|
if max_reward_delay is None:
|
|
56
92
|
max_reward_delay = reward_delay
|
|
57
93
|
super().__init__(
|
|
58
|
-
name,
|
|
94
|
+
name,
|
|
95
|
+
blocking,
|
|
96
|
+
collectable,
|
|
97
|
+
color,
|
|
98
|
+
random_respawn,
|
|
99
|
+
max_reward_delay,
|
|
100
|
+
expiry_time,
|
|
59
101
|
)
|
|
60
102
|
self.reward_val = reward
|
|
61
103
|
self.regen_delay_range = regen_delay
|
|
62
104
|
self.reward_delay_val = reward_delay
|
|
105
|
+
self.expiry_regen_delay_range = expiry_regen_delay
|
|
63
106
|
|
|
64
|
-
def reward(
|
|
107
|
+
def reward(
|
|
108
|
+
self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
|
|
109
|
+
) -> float:
|
|
65
110
|
"""Default reward function."""
|
|
66
111
|
return self.reward_val
|
|
67
112
|
|
|
@@ -74,6 +119,11 @@ class DefaultForagaxObject(BaseForagaxObject):
|
|
|
74
119
|
"""Default reward delay function."""
|
|
75
120
|
return self.reward_delay_val
|
|
76
121
|
|
|
122
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
123
|
+
"""Default expiry regeneration delay function."""
|
|
124
|
+
min_delay, max_delay = self.expiry_regen_delay_range
|
|
125
|
+
return jax.random.randint(rng, (), min_delay, max_delay)
|
|
126
|
+
|
|
77
127
|
|
|
78
128
|
class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
79
129
|
"""Object with regeneration delay from a normal distribution."""
|
|
@@ -89,7 +139,16 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
|
89
139
|
random_respawn: bool = False,
|
|
90
140
|
reward_delay: int = 0,
|
|
91
141
|
max_reward_delay: Optional[int] = None,
|
|
142
|
+
expiry_time: Optional[int] = None,
|
|
143
|
+
mean_expiry_regen_delay: Optional[int] = None,
|
|
144
|
+
std_expiry_regen_delay: Optional[int] = None,
|
|
92
145
|
):
|
|
146
|
+
# If expiry regen delays not provided, use same as normal regen
|
|
147
|
+
if mean_expiry_regen_delay is None:
|
|
148
|
+
mean_expiry_regen_delay = mean_regen_delay
|
|
149
|
+
if std_expiry_regen_delay is None:
|
|
150
|
+
std_expiry_regen_delay = std_regen_delay
|
|
151
|
+
|
|
93
152
|
super().__init__(
|
|
94
153
|
name=name,
|
|
95
154
|
reward=reward,
|
|
@@ -99,15 +158,27 @@ class NormalRegenForagaxObject(DefaultForagaxObject):
|
|
|
99
158
|
random_respawn=random_respawn,
|
|
100
159
|
reward_delay=reward_delay,
|
|
101
160
|
max_reward_delay=max_reward_delay,
|
|
161
|
+
expiry_time=expiry_time,
|
|
162
|
+
expiry_regen_delay=(mean_expiry_regen_delay, mean_expiry_regen_delay),
|
|
102
163
|
)
|
|
103
164
|
self.mean_regen_delay = mean_regen_delay
|
|
104
165
|
self.std_regen_delay = std_regen_delay
|
|
166
|
+
self.mean_expiry_regen_delay = mean_expiry_regen_delay
|
|
167
|
+
self.std_expiry_regen_delay = std_expiry_regen_delay
|
|
105
168
|
|
|
106
169
|
def regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
107
170
|
"""Regeneration delay from a normal distribution."""
|
|
108
171
|
delay = self.mean_regen_delay + jax.random.normal(rng) * self.std_regen_delay
|
|
109
172
|
return jnp.maximum(0, delay).astype(int)
|
|
110
173
|
|
|
174
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
175
|
+
"""Expiry regeneration delay from a normal distribution."""
|
|
176
|
+
delay = (
|
|
177
|
+
self.mean_expiry_regen_delay
|
|
178
|
+
+ jax.random.normal(rng) * self.std_expiry_regen_delay
|
|
179
|
+
)
|
|
180
|
+
return jnp.maximum(0, delay).astype(int)
|
|
181
|
+
|
|
111
182
|
|
|
112
183
|
class WeatherObject(NormalRegenForagaxObject):
|
|
113
184
|
"""Object with reward based on temperature data."""
|
|
@@ -124,6 +195,9 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
|
124
195
|
random_respawn: bool = False,
|
|
125
196
|
reward_delay: int = 0,
|
|
126
197
|
max_reward_delay: Optional[int] = None,
|
|
198
|
+
expiry_time: Optional[int] = None,
|
|
199
|
+
mean_expiry_regen_delay: Optional[int] = None,
|
|
200
|
+
std_expiry_regen_delay: Optional[int] = None,
|
|
127
201
|
):
|
|
128
202
|
super().__init__(
|
|
129
203
|
name=name,
|
|
@@ -134,15 +208,158 @@ class WeatherObject(NormalRegenForagaxObject):
|
|
|
134
208
|
random_respawn=random_respawn,
|
|
135
209
|
reward_delay=reward_delay,
|
|
136
210
|
max_reward_delay=max_reward_delay,
|
|
211
|
+
expiry_time=expiry_time,
|
|
212
|
+
mean_expiry_regen_delay=mean_expiry_regen_delay,
|
|
213
|
+
std_expiry_regen_delay=std_expiry_regen_delay,
|
|
137
214
|
)
|
|
138
215
|
self.rewards = rewards * multiplier
|
|
139
216
|
self.repeat = repeat
|
|
140
217
|
|
|
141
|
-
def reward(
|
|
218
|
+
def reward(
|
|
219
|
+
self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
|
|
220
|
+
) -> float:
|
|
142
221
|
"""Reward is based on temperature."""
|
|
143
222
|
return get_temperature(self.rewards, clock, self.repeat)
|
|
144
223
|
|
|
145
224
|
|
|
225
|
+
class FourierObject(BaseForagaxObject):
|
|
226
|
+
"""Object with reward based on Fourier series with per-instance parameters.
|
|
227
|
+
|
|
228
|
+
This object doesn't respawn on its own. Instead, objects are respawned
|
|
229
|
+
biome-wide when consumption threshold is reached, with new random parameters.
|
|
230
|
+
|
|
231
|
+
The reward function is a Fourier series with random period and harmonics,
|
|
232
|
+
with coefficients scaled by 1/n and normalized to [-1, 1].
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
name: str,
|
|
238
|
+
num_fourier_terms: int = 10,
|
|
239
|
+
base_magnitude: float = 1.0,
|
|
240
|
+
color: Tuple[int, int, int] = (0, 0, 0),
|
|
241
|
+
reward_delay: int = 0,
|
|
242
|
+
max_reward_delay: Optional[int] = None,
|
|
243
|
+
):
|
|
244
|
+
if max_reward_delay is None:
|
|
245
|
+
max_reward_delay = reward_delay
|
|
246
|
+
super().__init__(
|
|
247
|
+
name=name,
|
|
248
|
+
blocking=False,
|
|
249
|
+
collectable=True,
|
|
250
|
+
color=color,
|
|
251
|
+
random_respawn=False, # Objects don't respawn individually
|
|
252
|
+
max_reward_delay=max_reward_delay,
|
|
253
|
+
expiry_time=None,
|
|
254
|
+
)
|
|
255
|
+
self.num_fourier_terms = num_fourier_terms
|
|
256
|
+
self.base_magnitude = base_magnitude
|
|
257
|
+
self.reward_delay_val = reward_delay
|
|
258
|
+
|
|
259
|
+
def get_state(self, key: jax.Array) -> jax.Array:
|
|
260
|
+
"""Generate random Fourier series parameters.
|
|
261
|
+
|
|
262
|
+
Returns array of shape (3 + 2*num_fourier_terms,) containing:
|
|
263
|
+
[period, y_min, y_max, a1, b1, a2, b2, ...]
|
|
264
|
+
"""
|
|
265
|
+
# Sample period uniformly from [10, 1000]
|
|
266
|
+
key, period_key = jax.random.split(key)
|
|
267
|
+
period = jax.random.randint(period_key, (), 10, 1001).astype(jnp.float32)
|
|
268
|
+
|
|
269
|
+
# Generate coefficients with 1/n scaling
|
|
270
|
+
key, a_key = jax.random.split(key)
|
|
271
|
+
n_values = jnp.arange(1, self.num_fourier_terms + 1, dtype=jnp.float32)
|
|
272
|
+
a_coeffs = jax.random.normal(a_key, (self.num_fourier_terms,)) / n_values
|
|
273
|
+
|
|
274
|
+
key, b_key = jax.random.split(key)
|
|
275
|
+
b_coeffs = jax.random.normal(b_key, (self.num_fourier_terms,)) / n_values
|
|
276
|
+
|
|
277
|
+
# Compute min-max values for normalization
|
|
278
|
+
num_samples = 1000
|
|
279
|
+
t_samples = jnp.linspace(0, 2 * jnp.pi, num_samples)
|
|
280
|
+
y_samples = jnp.zeros(num_samples)
|
|
281
|
+
for n in range(1, self.num_fourier_terms + 1):
|
|
282
|
+
y_samples += a_coeffs[n - 1] * jnp.cos(n * t_samples)
|
|
283
|
+
y_samples += b_coeffs[n - 1] * jnp.sin(n * t_samples)
|
|
284
|
+
|
|
285
|
+
# Store min and max for proper min-max normalization
|
|
286
|
+
y_min = jnp.min(y_samples)
|
|
287
|
+
y_max = jnp.max(y_samples)
|
|
288
|
+
|
|
289
|
+
# Combine into parameter vector: [period, y_min, y_max, a1, b1, a2, b2, ...]
|
|
290
|
+
ab_interleaved = jnp.empty(2 * self.num_fourier_terms, dtype=jnp.float32)
|
|
291
|
+
ab_interleaved = ab_interleaved.at[::2].set(a_coeffs)
|
|
292
|
+
ab_interleaved = ab_interleaved.at[1::2].set(b_coeffs)
|
|
293
|
+
params_vec = jnp.concatenate(
|
|
294
|
+
[jnp.array([period, y_min, y_max]), ab_interleaved]
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
return params_vec
|
|
298
|
+
|
|
299
|
+
def reward(
|
|
300
|
+
self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
|
|
301
|
+
) -> float:
|
|
302
|
+
"""Compute reward from Fourier series parameters.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
clock: Current timestep
|
|
306
|
+
rng: Random key (unused for Fourier objects)
|
|
307
|
+
params: Array of shape (3 + 2*num_fourier_terms,) containing
|
|
308
|
+
[period, y_min, y_max, a1, b1, a2, b2, ...]
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
Reward value computed from Fourier series, normalized to [-base_magnitude, base_magnitude]
|
|
312
|
+
"""
|
|
313
|
+
if params is None or len(params) == 0:
|
|
314
|
+
return 0.0
|
|
315
|
+
|
|
316
|
+
# Extract period and min-max values
|
|
317
|
+
period = params[0]
|
|
318
|
+
y_min = params[1]
|
|
319
|
+
y_max = params[2]
|
|
320
|
+
|
|
321
|
+
# Normalize time to [0, 2π] using the object's period
|
|
322
|
+
t = 2.0 * jnp.pi * (clock % period) / period
|
|
323
|
+
|
|
324
|
+
# Extract interleaved coefficients: [a1, b1, a2, b2, ...]
|
|
325
|
+
ab_coeffs = params[3:]
|
|
326
|
+
n_terms = len(ab_coeffs) // 2
|
|
327
|
+
|
|
328
|
+
# Compute Fourier series: sum(a_n*cos(n*t) + b_n*sin(n*t))
|
|
329
|
+
reward = 0.0
|
|
330
|
+
for i in range(n_terms):
|
|
331
|
+
freq = i + 1
|
|
332
|
+
a_i = ab_coeffs[2 * i] # a coefficient at index 2i
|
|
333
|
+
b_i = ab_coeffs[2 * i + 1] # b coefficient at index 2i+1
|
|
334
|
+
reward += a_i * jnp.cos(freq * t) + b_i * jnp.sin(freq * t)
|
|
335
|
+
|
|
336
|
+
# Apply min-max normalization to [-1, 1], then scale by base_magnitude
|
|
337
|
+
# Formula: 2 * (x - min) / (max - min) - 1
|
|
338
|
+
# If min == max (constant function), return 0
|
|
339
|
+
range_val = jnp.maximum(y_max - y_min, 1e-8) # Avoid division by zero
|
|
340
|
+
# Check if this is a constant function (min == max)
|
|
341
|
+
is_constant = jnp.abs(y_max - y_min) < 1e-8
|
|
342
|
+
reward = jnp.where(
|
|
343
|
+
is_constant,
|
|
344
|
+
0.0,
|
|
345
|
+
(2.0 * (reward - y_min) / range_val - 1.0) * self.base_magnitude,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return reward
|
|
349
|
+
|
|
350
|
+
def reward_delay(self, clock: int, rng: jax.Array) -> int:
|
|
351
|
+
"""Reward delay function."""
|
|
352
|
+
return self.reward_delay_val
|
|
353
|
+
|
|
354
|
+
def regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
355
|
+
"""No individual regeneration - returns infinity."""
|
|
356
|
+
return jnp.iinfo(jnp.int32).max
|
|
357
|
+
|
|
358
|
+
def expiry_regen_delay(self, clock: int, rng: jax.Array) -> int:
|
|
359
|
+
"""No expiry regeneration."""
|
|
360
|
+
return jnp.iinfo(jnp.int32).max
|
|
361
|
+
|
|
362
|
+
|
|
146
363
|
EMPTY = DefaultForagaxObject()
|
|
147
364
|
WALL = DefaultForagaxObject(name="wall", blocking=True, color=(127, 127, 127))
|
|
148
365
|
FLOWER = DefaultForagaxObject(
|
|
@@ -332,6 +549,44 @@ GREEN_FAKE_UNIFORM_RANDOM = DefaultForagaxObject(
|
|
|
332
549
|
random_respawn=True,
|
|
333
550
|
)
|
|
334
551
|
|
|
552
|
+
# Random respawn variants with expiry
|
|
553
|
+
BROWN_MOREL_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
554
|
+
name="brown_morel",
|
|
555
|
+
reward=10.0,
|
|
556
|
+
collectable=True,
|
|
557
|
+
color=(63, 30, 25),
|
|
558
|
+
regen_delay=(90, 110),
|
|
559
|
+
random_respawn=True,
|
|
560
|
+
expiry_time=500,
|
|
561
|
+
)
|
|
562
|
+
BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
563
|
+
name="brown_oyster",
|
|
564
|
+
reward=1.0,
|
|
565
|
+
collectable=True,
|
|
566
|
+
color=(63, 30, 25),
|
|
567
|
+
regen_delay=(9, 11),
|
|
568
|
+
random_respawn=True,
|
|
569
|
+
expiry_time=500,
|
|
570
|
+
)
|
|
571
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
572
|
+
name="green_deathcap",
|
|
573
|
+
reward=-5.0,
|
|
574
|
+
collectable=True,
|
|
575
|
+
color=(0, 255, 0),
|
|
576
|
+
regen_delay=(9, 11),
|
|
577
|
+
random_respawn=True,
|
|
578
|
+
expiry_time=500,
|
|
579
|
+
)
|
|
580
|
+
GREEN_FAKE_UNIFORM_RANDOM_EXPIRY = DefaultForagaxObject(
|
|
581
|
+
name="green_fake",
|
|
582
|
+
reward=0.0,
|
|
583
|
+
collectable=True,
|
|
584
|
+
color=(0, 255, 0),
|
|
585
|
+
regen_delay=(9, 11),
|
|
586
|
+
random_respawn=True,
|
|
587
|
+
expiry_time=500,
|
|
588
|
+
)
|
|
589
|
+
|
|
335
590
|
|
|
336
591
|
def create_weather_objects(
|
|
337
592
|
file_index: int = 0,
|
|
@@ -340,6 +595,9 @@ def create_weather_objects(
|
|
|
340
595
|
same_color: bool = False,
|
|
341
596
|
random_respawn: bool = False,
|
|
342
597
|
reward_delay: int = 0,
|
|
598
|
+
expiry_time: Optional[int] = None,
|
|
599
|
+
mean_expiry_regen_delay: Optional[int] = None,
|
|
600
|
+
std_expiry_regen_delay: Optional[int] = None,
|
|
343
601
|
):
|
|
344
602
|
"""Create HOT and COLD WeatherObject instances using the specified file.
|
|
345
603
|
|
|
@@ -348,6 +606,11 @@ def create_weather_objects(
|
|
|
348
606
|
repeat: How many steps each temperature value repeats for.
|
|
349
607
|
multiplier: Base multiplier applied to HOT; COLD will use -multiplier.
|
|
350
608
|
same_color: If True, both HOT and COLD use the same color.
|
|
609
|
+
random_respawn: If True, objects respawn at random locations.
|
|
610
|
+
reward_delay: Number of steps before reward is delivered.
|
|
611
|
+
expiry_time: Time steps before object expires (None = no expiry).
|
|
612
|
+
mean_expiry_regen_delay: Mean delay for expiry respawn.
|
|
613
|
+
std_expiry_regen_delay: Standard deviation for expiry respawn delay.
|
|
351
614
|
|
|
352
615
|
Returns:
|
|
353
616
|
A tuple (HOT, COLD) of WeatherObject instances.
|
|
@@ -370,6 +633,9 @@ def create_weather_objects(
|
|
|
370
633
|
color=hot_color,
|
|
371
634
|
random_respawn=random_respawn,
|
|
372
635
|
reward_delay=reward_delay,
|
|
636
|
+
expiry_time=expiry_time,
|
|
637
|
+
mean_expiry_regen_delay=mean_expiry_regen_delay,
|
|
638
|
+
std_expiry_regen_delay=std_expiry_regen_delay,
|
|
373
639
|
)
|
|
374
640
|
|
|
375
641
|
cold_color = hot_color if same_color else (0, 255, 255)
|
|
@@ -381,6 +647,187 @@ def create_weather_objects(
|
|
|
381
647
|
color=cold_color,
|
|
382
648
|
random_respawn=random_respawn,
|
|
383
649
|
reward_delay=reward_delay,
|
|
650
|
+
expiry_time=expiry_time,
|
|
651
|
+
mean_expiry_regen_delay=mean_expiry_regen_delay,
|
|
652
|
+
std_expiry_regen_delay=std_expiry_regen_delay,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
return hot, cold
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
class SineObject(DefaultForagaxObject):
|
|
659
|
+
"""Object with reward based on sine wave with a base reward offset.
|
|
660
|
+
|
|
661
|
+
The total reward is: base_reward + amplitude * sin(2*pi * clock / period)
|
|
662
|
+
This allows for objects that have different base behaviors (positive/negative)
|
|
663
|
+
with an underlying sine wave that drives continual learning.
|
|
664
|
+
|
|
665
|
+
Uses uniform distribution for regeneration delays by default (from DefaultForagaxObject).
|
|
666
|
+
"""
|
|
667
|
+
|
|
668
|
+
def __init__(
|
|
669
|
+
self,
|
|
670
|
+
name: str,
|
|
671
|
+
base_reward: float = 0.0,
|
|
672
|
+
amplitude: float = 1.0,
|
|
673
|
+
period: int = 1000,
|
|
674
|
+
phase: float = 0.0,
|
|
675
|
+
regen_delay: Tuple[int, int] = (9, 11),
|
|
676
|
+
color: Tuple[int, int, int] = (0, 0, 0),
|
|
677
|
+
random_respawn: bool = False,
|
|
678
|
+
reward_delay: int = 0,
|
|
679
|
+
max_reward_delay: Optional[int] = None,
|
|
680
|
+
expiry_time: Optional[int] = None,
|
|
681
|
+
expiry_regen_delay: Tuple[int, int] = (9, 11),
|
|
682
|
+
):
|
|
683
|
+
super().__init__(
|
|
684
|
+
name=name,
|
|
685
|
+
reward=base_reward,
|
|
686
|
+
collectable=True,
|
|
687
|
+
regen_delay=regen_delay,
|
|
688
|
+
color=color,
|
|
689
|
+
random_respawn=random_respawn,
|
|
690
|
+
reward_delay=reward_delay,
|
|
691
|
+
max_reward_delay=max_reward_delay,
|
|
692
|
+
expiry_time=expiry_time,
|
|
693
|
+
expiry_regen_delay=expiry_regen_delay,
|
|
694
|
+
)
|
|
695
|
+
self.base_reward = base_reward
|
|
696
|
+
self.amplitude = amplitude
|
|
697
|
+
self.period = period
|
|
698
|
+
self.phase = phase
|
|
699
|
+
|
|
700
|
+
def reward(
|
|
701
|
+
self, clock: int, rng: jax.Array, params: Optional[jax.Array] = None
|
|
702
|
+
) -> float:
|
|
703
|
+
"""Reward is base_reward + amplitude * sin(2*pi * clock / period + phase)."""
|
|
704
|
+
sine_value = jnp.sin(2.0 * jnp.pi * clock / self.period + self.phase)
|
|
705
|
+
return self.base_reward + self.amplitude * sine_value
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def create_fourier_objects(
|
|
709
|
+
num_fourier_terms: int = 10,
|
|
710
|
+
base_magnitude: float = 1.0,
|
|
711
|
+
reward_delay: int = 0,
|
|
712
|
+
):
|
|
713
|
+
"""Create HOT and COLD FourierObject instances.
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
num_fourier_terms: Number of Fourier terms in the reward function (default: 10).
|
|
717
|
+
base_magnitude: Base magnitude for Fourier coefficients.
|
|
718
|
+
reward_delay: Number of steps before reward is delivered.
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
A tuple (HOT, COLD) of FourierObject instances.
|
|
722
|
+
"""
|
|
723
|
+
hot = FourierObject(
|
|
724
|
+
name="hot_fourier",
|
|
725
|
+
num_fourier_terms=num_fourier_terms,
|
|
726
|
+
base_magnitude=base_magnitude,
|
|
727
|
+
color=(0, 0, 0),
|
|
728
|
+
reward_delay=reward_delay,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
cold = FourierObject(
|
|
732
|
+
name="cold_fourier",
|
|
733
|
+
num_fourier_terms=num_fourier_terms,
|
|
734
|
+
base_magnitude=base_magnitude,
|
|
735
|
+
color=(0, 0, 0),
|
|
736
|
+
reward_delay=reward_delay,
|
|
384
737
|
)
|
|
385
738
|
|
|
386
739
|
return hot, cold
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def create_sine_biome_objects(
|
|
743
|
+
period: int = 1000,
|
|
744
|
+
amplitude: float = 20.0,
|
|
745
|
+
base_oyster_reward: float = 10.0,
|
|
746
|
+
base_deathcap_reward: float = -10.0,
|
|
747
|
+
regen_delay: Tuple[int, int] = (9, 11),
|
|
748
|
+
reward_delay: int = 0,
|
|
749
|
+
expiry_time: int = 500,
|
|
750
|
+
expiry_regen_delay: Tuple[int, int] = (9, 11),
|
|
751
|
+
):
|
|
752
|
+
"""Create objects for the sine-based two-biome environment.
|
|
753
|
+
|
|
754
|
+
Biome 1 (Left): Oyster (+base_reward), Death Cap (-base_reward)
|
|
755
|
+
Biome 2 (Right): Oyster (-base_reward), Death Cap (+base_reward)
|
|
756
|
+
|
|
757
|
+
Both biomes have an underlying sine curve with the specified amplitude.
|
|
758
|
+
The sine curve of biome 2 is the negative of biome 1 (180 degree phase shift).
|
|
759
|
+
|
|
760
|
+
Objects use uniform respawn and random expiry by default.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
period: Period of the sine wave in timesteps
|
|
764
|
+
amplitude: Amplitude of the sine wave
|
|
765
|
+
base_oyster_reward: Base reward for oyster in biome 1 (will be negated in biome 2)
|
|
766
|
+
base_deathcap_reward: Base reward for death cap in biome 1 (will be negated in biome 2)
|
|
767
|
+
regen_delay: Tuple of (min, max) for uniform regeneration delay
|
|
768
|
+
reward_delay: Number of steps before reward is delivered
|
|
769
|
+
expiry_time: Time steps before object expires (None = no expiry)
|
|
770
|
+
expiry_regen_delay: Tuple of (min, max) for uniform expiry regeneration delay
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
A tuple of (biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap)
|
|
774
|
+
"""
|
|
775
|
+
# Biome 1 objects (phase = 0)
|
|
776
|
+
biome1_oyster = SineObject(
|
|
777
|
+
name="oyster_sine_1",
|
|
778
|
+
base_reward=base_oyster_reward,
|
|
779
|
+
amplitude=amplitude,
|
|
780
|
+
period=period,
|
|
781
|
+
phase=0.0,
|
|
782
|
+
regen_delay=regen_delay,
|
|
783
|
+
color=(124, 61, 81), # Oyster color
|
|
784
|
+
random_respawn=True,
|
|
785
|
+
reward_delay=reward_delay,
|
|
786
|
+
expiry_time=expiry_time,
|
|
787
|
+
expiry_regen_delay=expiry_regen_delay,
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
biome1_deathcap = SineObject(
|
|
791
|
+
name="deathcap_sine_1",
|
|
792
|
+
base_reward=base_deathcap_reward,
|
|
793
|
+
amplitude=amplitude,
|
|
794
|
+
period=period,
|
|
795
|
+
phase=0.0,
|
|
796
|
+
regen_delay=regen_delay,
|
|
797
|
+
color=(0, 255, 0), # Green color
|
|
798
|
+
random_respawn=True,
|
|
799
|
+
reward_delay=reward_delay,
|
|
800
|
+
expiry_time=expiry_time,
|
|
801
|
+
expiry_regen_delay=expiry_regen_delay,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
# Biome 2 objects (phase = pi for 180 degree shift)
|
|
805
|
+
biome2_oyster = SineObject(
|
|
806
|
+
name="oyster_sine_2",
|
|
807
|
+
base_reward=-base_oyster_reward, # Negated
|
|
808
|
+
amplitude=amplitude,
|
|
809
|
+
period=period,
|
|
810
|
+
phase=jnp.pi, # 180 degree phase shift (negative of biome 1)
|
|
811
|
+
regen_delay=regen_delay,
|
|
812
|
+
color=(124, 61, 81), # Same oyster color
|
|
813
|
+
random_respawn=True,
|
|
814
|
+
reward_delay=reward_delay,
|
|
815
|
+
expiry_time=expiry_time,
|
|
816
|
+
expiry_regen_delay=expiry_regen_delay,
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
biome2_deathcap = SineObject(
|
|
820
|
+
name="deathcap_sine_2",
|
|
821
|
+
base_reward=-base_deathcap_reward, # Negated
|
|
822
|
+
amplitude=amplitude,
|
|
823
|
+
period=period,
|
|
824
|
+
phase=jnp.pi, # 180 degree phase shift (negative of biome 1)
|
|
825
|
+
regen_delay=regen_delay,
|
|
826
|
+
color=(0, 255, 0), # Same green color
|
|
827
|
+
random_respawn=True,
|
|
828
|
+
reward_delay=reward_delay,
|
|
829
|
+
expiry_time=expiry_time,
|
|
830
|
+
expiry_regen_delay=expiry_regen_delay,
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
return biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap
|
foragax/registry.py
CHANGED
|
@@ -12,21 +12,27 @@ from foragax.objects import (
|
|
|
12
12
|
BROWN_MOREL_2,
|
|
13
13
|
BROWN_MOREL_UNIFORM,
|
|
14
14
|
BROWN_MOREL_UNIFORM_RANDOM,
|
|
15
|
+
BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
|
|
15
16
|
BROWN_OYSTER,
|
|
16
17
|
BROWN_OYSTER_UNIFORM,
|
|
17
18
|
BROWN_OYSTER_UNIFORM_RANDOM,
|
|
19
|
+
BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
|
|
18
20
|
GREEN_DEATHCAP,
|
|
19
21
|
GREEN_DEATHCAP_2,
|
|
20
22
|
GREEN_DEATHCAP_3,
|
|
21
23
|
GREEN_DEATHCAP_UNIFORM,
|
|
22
24
|
GREEN_DEATHCAP_UNIFORM_RANDOM,
|
|
25
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
|
|
23
26
|
GREEN_FAKE,
|
|
24
27
|
GREEN_FAKE_2,
|
|
25
28
|
GREEN_FAKE_UNIFORM,
|
|
26
29
|
GREEN_FAKE_UNIFORM_RANDOM,
|
|
30
|
+
GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
|
|
27
31
|
LARGE_MOREL,
|
|
28
32
|
LARGE_OYSTER,
|
|
29
33
|
MEDIUM_MOREL,
|
|
34
|
+
create_fourier_objects,
|
|
35
|
+
create_sine_biome_objects,
|
|
30
36
|
create_weather_objects,
|
|
31
37
|
)
|
|
32
38
|
|
|
@@ -83,6 +89,21 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
83
89
|
"nowrap": False,
|
|
84
90
|
"deterministic_spawn": True,
|
|
85
91
|
},
|
|
92
|
+
"ForagaxDiwali-v1": {
|
|
93
|
+
"size": (15, 15),
|
|
94
|
+
"aperture_size": None,
|
|
95
|
+
"objects": None,
|
|
96
|
+
"biomes": (
|
|
97
|
+
# Hot biome
|
|
98
|
+
Biome(start=(0, 3), stop=(15, 5), object_frequencies=(0.5, 0.0)),
|
|
99
|
+
# Cold biome
|
|
100
|
+
Biome(start=(0, 10), stop=(15, 12), object_frequencies=(0.0, 0.5)),
|
|
101
|
+
),
|
|
102
|
+
"nowrap": False,
|
|
103
|
+
"deterministic_spawn": True,
|
|
104
|
+
"dynamic_biomes": True,
|
|
105
|
+
"biome_consumption_threshold": 0.9,
|
|
106
|
+
},
|
|
86
107
|
"ForagaxTwoBiome-v1": {
|
|
87
108
|
"size": (15, 15),
|
|
88
109
|
"aperture_size": None,
|
|
@@ -304,6 +325,26 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
304
325
|
"nowrap": True,
|
|
305
326
|
"deterministic_spawn": True,
|
|
306
327
|
},
|
|
328
|
+
"ForagaxTwoBiome-v17": {
|
|
329
|
+
"size": (15, 15),
|
|
330
|
+
"aperture_size": None,
|
|
331
|
+
"objects": (
|
|
332
|
+
BROWN_MOREL_UNIFORM_RANDOM_EXPIRY,
|
|
333
|
+
BROWN_OYSTER_UNIFORM_RANDOM_EXPIRY,
|
|
334
|
+
GREEN_DEATHCAP_UNIFORM_RANDOM_EXPIRY,
|
|
335
|
+
GREEN_FAKE_UNIFORM_RANDOM_EXPIRY,
|
|
336
|
+
),
|
|
337
|
+
"biomes": (
|
|
338
|
+
# Morel biome
|
|
339
|
+
Biome(start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.0, 0.5, 0.0)),
|
|
340
|
+
# Oyster biome
|
|
341
|
+
Biome(
|
|
342
|
+
start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.25, 0.0, 0.5)
|
|
343
|
+
),
|
|
344
|
+
),
|
|
345
|
+
"nowrap": False,
|
|
346
|
+
"deterministic_spawn": True,
|
|
347
|
+
},
|
|
307
348
|
"ForagaxTwoBiomeSmall-v1": {
|
|
308
349
|
"size": (16, 8),
|
|
309
350
|
"aperture_size": None,
|
|
@@ -340,6 +381,22 @@ ENV_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
|
340
381
|
),
|
|
341
382
|
"nowrap": True,
|
|
342
383
|
},
|
|
384
|
+
"ForagaxSineTwoBiome-v1": {
|
|
385
|
+
"size": (15, 15),
|
|
386
|
+
"aperture_size": None,
|
|
387
|
+
"objects": None,
|
|
388
|
+
"biomes": (
|
|
389
|
+
# Biome 1 (left): Oyster +10, Death Cap -10 with sine
|
|
390
|
+
Biome(
|
|
391
|
+
start=(3, 0), stop=(5, 15), object_frequencies=(0.25, 0.25, 0.0, 0.0)
|
|
392
|
+
),
|
|
393
|
+
# Biome 2 (right): Oyster -10, Death Cap +10 with inverted sine
|
|
394
|
+
Biome(
|
|
395
|
+
start=(10, 0), stop=(12, 15), object_frequencies=(0.0, 0.0, 0.25, 0.25)
|
|
396
|
+
),
|
|
397
|
+
),
|
|
398
|
+
"nowrap": False,
|
|
399
|
+
},
|
|
343
400
|
}
|
|
344
401
|
|
|
345
402
|
|
|
@@ -472,6 +529,32 @@ def make(
|
|
|
472
529
|
)
|
|
473
530
|
config["objects"] = (hot, cold)
|
|
474
531
|
|
|
532
|
+
if env_id == "ForagaxDiwali-v1":
|
|
533
|
+
config["objects"] = create_fourier_objects(
|
|
534
|
+
num_fourier_terms=10,
|
|
535
|
+
reward_delay=reward_delay,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
if env_id == "ForagaxSineTwoBiome-v1":
|
|
539
|
+
biome1_oyster, biome1_deathcap, biome2_oyster, biome2_deathcap = (
|
|
540
|
+
create_sine_biome_objects(
|
|
541
|
+
period=1000,
|
|
542
|
+
amplitude=20.0,
|
|
543
|
+
base_oyster_reward=10.0,
|
|
544
|
+
base_deathcap_reward=-10.0,
|
|
545
|
+
regen_delay=(9, 11),
|
|
546
|
+
reward_delay=reward_delay,
|
|
547
|
+
expiry_time=500,
|
|
548
|
+
expiry_regen_delay=(9, 11),
|
|
549
|
+
)
|
|
550
|
+
)
|
|
551
|
+
config["objects"] = (
|
|
552
|
+
biome1_oyster,
|
|
553
|
+
biome1_deathcap,
|
|
554
|
+
biome2_oyster,
|
|
555
|
+
biome2_deathcap,
|
|
556
|
+
)
|
|
557
|
+
|
|
475
558
|
if env_id == "ForagaxTwoBiome-v16":
|
|
476
559
|
config["teleport_interval"] = 10000
|
|
477
560
|
|