continual-foragax 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.0.dist-info}/RECORD +8 -8
- foragax/env.py +759 -242
- foragax/objects.py +351 -4
- foragax/registry.py +59 -0
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.0.dist-info}/top_level.txt +0 -0
foragax/env.py
CHANGED
|
@@ -19,6 +19,7 @@ from foragax.objects import (
|
|
|
19
19
|
EMPTY,
|
|
20
20
|
PADDING,
|
|
21
21
|
BaseForagaxObject,
|
|
22
|
+
FourierObject,
|
|
22
23
|
WeatherObject,
|
|
23
24
|
)
|
|
24
25
|
from foragax.rendering import apply_true_borders
|
|
@@ -42,6 +43,48 @@ DIRECTIONS = jnp.array(
|
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
|
|
46
|
+
@struct.dataclass
|
|
47
|
+
class ObjectState:
|
|
48
|
+
"""Per-cell object state information.
|
|
49
|
+
|
|
50
|
+
This struct encapsulates all state information about objects in the grid:
|
|
51
|
+
- object_id: (H, W) The object type (0 for empty, positive for object type)
|
|
52
|
+
- respawn_timer: (H, W) Countdown timer for respawning (0 = no timer, positive = countdown remaining)
|
|
53
|
+
- respawn_object_id: (H, W) What object type will spawn when timer reaches 0
|
|
54
|
+
- spawn_time: (H, W) When each object was spawned (for expiry tracking)
|
|
55
|
+
- color: (H, W, 3) RGB color for each object instance (for dynamic biomes)
|
|
56
|
+
- generation: (H, W) Which biome generation each object belongs to
|
|
57
|
+
- state_params: (H, W, N) Per-instance parameters (e.g., Fourier coefficients)
|
|
58
|
+
- biome_id: (H, W) Which biome each cell belongs to
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
object_id: jax.Array # (H, W) - Object type ID (0 = empty, >0 = object type)
|
|
62
|
+
respawn_timer: (
|
|
63
|
+
jax.Array
|
|
64
|
+
) # (H, W) - Respawn countdown (0 = no timer, >0 = countdown)
|
|
65
|
+
respawn_object_id: jax.Array # (H, W) - Object type to spawn when timer reaches 0
|
|
66
|
+
spawn_time: jax.Array # (H, W) - Timestep when object spawned
|
|
67
|
+
color: jax.Array # (H, W, 3) - RGB color per instance
|
|
68
|
+
generation: jax.Array # (H, W) - Biome generation number
|
|
69
|
+
state_params: jax.Array # (H, W, N) - Per-instance parameters
|
|
70
|
+
biome_id: jax.Array # (H, W) - Biome assignment for each cell
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def create_empty(cls, size: Tuple[int, int], num_params: int) -> "ObjectState":
|
|
74
|
+
"""Create an empty ObjectState for the given grid size."""
|
|
75
|
+
h, w = size[1], size[0]
|
|
76
|
+
return cls(
|
|
77
|
+
object_id=jnp.zeros((h, w), dtype=int),
|
|
78
|
+
respawn_timer=jnp.zeros((h, w), dtype=int),
|
|
79
|
+
respawn_object_id=jnp.zeros((h, w), dtype=int),
|
|
80
|
+
spawn_time=jnp.zeros((h, w), dtype=int),
|
|
81
|
+
color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
|
|
82
|
+
generation=jnp.zeros((h, w), dtype=int),
|
|
83
|
+
state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
|
|
84
|
+
biome_id=jnp.full((h, w), -1, dtype=int),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
45
88
|
@dataclass
|
|
46
89
|
class Biome:
|
|
47
90
|
# Object generation frequencies for this biome
|
|
@@ -55,14 +98,22 @@ class EnvParams(environment.EnvParams):
|
|
|
55
98
|
max_steps_in_episode: Union[int, None]
|
|
56
99
|
|
|
57
100
|
|
|
101
|
+
@struct.dataclass
|
|
102
|
+
class BiomeState:
|
|
103
|
+
"""Biome-level tracking state (num_biomes,)."""
|
|
104
|
+
|
|
105
|
+
consumption_count: jax.Array # objects consumed per biome
|
|
106
|
+
total_objects: jax.Array # total objects spawned per biome
|
|
107
|
+
generation: jax.Array # current generation per biome
|
|
108
|
+
|
|
109
|
+
|
|
58
110
|
@struct.dataclass
|
|
59
111
|
class EnvState(environment.EnvState):
|
|
60
112
|
pos: jax.Array
|
|
61
|
-
object_grid: jax.Array
|
|
62
|
-
biome_grid: jax.Array
|
|
63
113
|
time: int
|
|
64
114
|
digestion_buffer: jax.Array
|
|
65
|
-
|
|
115
|
+
object_state: ObjectState
|
|
116
|
+
biome_state: BiomeState
|
|
66
117
|
|
|
67
118
|
|
|
68
119
|
class ForagaxEnv(environment.Environment):
|
|
@@ -79,6 +130,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
79
130
|
deterministic_spawn: bool = False,
|
|
80
131
|
teleport_interval: Optional[int] = None,
|
|
81
132
|
observation_type: str = "object",
|
|
133
|
+
dynamic_biomes: bool = False,
|
|
134
|
+
biome_consumption_threshold: float = 0.9,
|
|
82
135
|
):
|
|
83
136
|
super().__init__()
|
|
84
137
|
self._name = name
|
|
@@ -100,11 +153,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
100
153
|
self.nowrap = nowrap
|
|
101
154
|
self.deterministic_spawn = deterministic_spawn
|
|
102
155
|
self.teleport_interval = teleport_interval
|
|
156
|
+
self.dynamic_biomes = dynamic_biomes
|
|
157
|
+
self.biome_consumption_threshold = biome_consumption_threshold
|
|
158
|
+
|
|
103
159
|
objects = (EMPTY,) + objects
|
|
104
160
|
if self.nowrap and not self.full_world:
|
|
105
161
|
objects = objects + (PADDING,)
|
|
106
162
|
self.objects = objects
|
|
107
163
|
|
|
164
|
+
# Infer num_fourier_terms from objects
|
|
165
|
+
self.num_fourier_terms = max(
|
|
166
|
+
(
|
|
167
|
+
obj.num_fourier_terms
|
|
168
|
+
for obj in self.objects
|
|
169
|
+
if isinstance(obj, FourierObject)
|
|
170
|
+
),
|
|
171
|
+
default=0,
|
|
172
|
+
)
|
|
173
|
+
|
|
108
174
|
# JIT-compatible versions of object and biome properties
|
|
109
175
|
self.object_ids = jnp.arange(len(objects))
|
|
110
176
|
self.object_blocking = jnp.array([o.blocking for o in objects])
|
|
@@ -122,6 +188,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
122
188
|
[o.expiry_time if o.expiry_time is not None else -1 for o in objects]
|
|
123
189
|
)
|
|
124
190
|
|
|
191
|
+
# Check if any objects can expire
|
|
192
|
+
self.has_expiring_objects = jnp.any(self.object_expiry_time >= 0)
|
|
193
|
+
|
|
125
194
|
# Compute reward steps per object (using max_reward_delay attribute)
|
|
126
195
|
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
|
127
196
|
self.max_reward_delay = (
|
|
@@ -138,6 +207,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
138
207
|
[b.stop if b.stop is not None else (-1, -1) for b in biomes]
|
|
139
208
|
)
|
|
140
209
|
self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
|
|
210
|
+
self.biome_sizes_jax = jnp.array(self.biome_sizes) # JAX version for indexing
|
|
141
211
|
self.biome_starts_jax = jnp.array(self.biome_starts)
|
|
142
212
|
self.biome_stops_jax = jnp.array(self.biome_stops)
|
|
143
213
|
biome_centers = []
|
|
@@ -149,6 +219,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
149
219
|
biome_centers.append((center_x, center_y))
|
|
150
220
|
self.biome_centers_jax = jnp.array(biome_centers)
|
|
151
221
|
self.biome_masks = []
|
|
222
|
+
biome_masks_array = []
|
|
152
223
|
for i in range(self.biome_object_frequencies.shape[0]):
|
|
153
224
|
# Create mask for the biome
|
|
154
225
|
start = jax.lax.select(
|
|
@@ -170,6 +241,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
170
241
|
& (cols < stop[0])
|
|
171
242
|
)
|
|
172
243
|
self.biome_masks.append(mask)
|
|
244
|
+
biome_masks_array.append(mask)
|
|
245
|
+
|
|
246
|
+
# Convert to JAX array for indexing in JIT-compiled code
|
|
247
|
+
self.biome_masks_array = jnp.array(biome_masks_array)
|
|
173
248
|
|
|
174
249
|
# Compute unique colors and mapping for partial observability (for 'color' observation_type)
|
|
175
250
|
# Exclude EMPTY (index 0) from color channels
|
|
@@ -200,46 +275,89 @@ class ForagaxEnv(environment.Environment):
|
|
|
200
275
|
max_steps_in_episode=None,
|
|
201
276
|
)
|
|
202
277
|
|
|
203
|
-
def
|
|
204
|
-
self, grid: jax.Array, y: int, x: int, timer_val: int
|
|
205
|
-
) -> jax.Array:
|
|
206
|
-
"""Place a timer at a specific position."""
|
|
207
|
-
return grid.at[y, x].set(timer_val)
|
|
208
|
-
|
|
209
|
-
def _place_timer_at_random_position(
|
|
278
|
+
def _place_timer(
|
|
210
279
|
self,
|
|
211
|
-
|
|
280
|
+
object_state: ObjectState,
|
|
212
281
|
y: int,
|
|
213
282
|
x: int,
|
|
283
|
+
object_type: int,
|
|
214
284
|
timer_val: int,
|
|
215
|
-
|
|
285
|
+
random_respawn: bool,
|
|
216
286
|
rand_key: jax.Array,
|
|
217
|
-
) ->
|
|
218
|
-
"""Place a timer at
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
)
|
|
234
|
-
|
|
287
|
+
) -> ObjectState:
|
|
288
|
+
"""Place a timer at position or randomly within the same biome.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
object_state: Current object state
|
|
292
|
+
y, x: Original position
|
|
293
|
+
object_type: The object type ID that will respawn (0 for permanent removal)
|
|
294
|
+
timer_val: Timer countdown value (0 for permanent removal, positive for countdown)
|
|
295
|
+
random_respawn: If True, place at random location in same biome
|
|
296
|
+
rand_key: Random key for random placement
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Updated object_state with timer placed
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
# Handle permanent removal (timer_val == 0)
|
|
303
|
+
def place_empty():
|
|
304
|
+
return object_state.replace(
|
|
305
|
+
object_id=object_state.object_id.at[y, x].set(0),
|
|
306
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(0),
|
|
307
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
|
|
308
|
+
)
|
|
235
309
|
|
|
236
|
-
#
|
|
237
|
-
|
|
238
|
-
|
|
310
|
+
# Handle timer placement
|
|
311
|
+
def place_timer():
|
|
312
|
+
# Non-random: place at original position
|
|
313
|
+
def place_at_position():
|
|
314
|
+
return object_state.replace(
|
|
315
|
+
object_id=object_state.object_id.at[y, x].set(0),
|
|
316
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
|
|
317
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
318
|
+
object_type
|
|
319
|
+
),
|
|
320
|
+
)
|
|
239
321
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
322
|
+
# Random: place at random location in same biome
|
|
323
|
+
def place_randomly():
|
|
324
|
+
# Clear the collected object's position
|
|
325
|
+
new_object_id = object_state.object_id.at[y, x].set(0)
|
|
326
|
+
new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
|
|
327
|
+
new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
|
|
328
|
+
|
|
329
|
+
# Find valid spawn locations in the same biome
|
|
330
|
+
biome_id = object_state.biome_id[y, x]
|
|
331
|
+
biome_mask = object_state.biome_id == biome_id
|
|
332
|
+
empty_mask = new_object_id == 0
|
|
333
|
+
no_timer_mask = new_respawn_timer == 0
|
|
334
|
+
valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
|
|
335
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
336
|
+
|
|
337
|
+
y_indices, x_indices = jnp.nonzero(
|
|
338
|
+
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
339
|
+
)
|
|
340
|
+
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
341
|
+
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
342
|
+
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
343
|
+
|
|
344
|
+
# Place timer at the new random position
|
|
345
|
+
new_respawn_timer = new_respawn_timer.at[
|
|
346
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
347
|
+
].set(timer_val)
|
|
348
|
+
new_respawn_object_id = new_respawn_object_id.at[
|
|
349
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
350
|
+
].set(object_type)
|
|
351
|
+
|
|
352
|
+
return object_state.replace(
|
|
353
|
+
object_id=new_object_id,
|
|
354
|
+
respawn_timer=new_respawn_timer,
|
|
355
|
+
respawn_object_id=new_respawn_object_id,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
return jax.lax.cond(random_respawn, place_randomly, place_at_position)
|
|
359
|
+
|
|
360
|
+
return jax.lax.cond(timer_val == 0, place_empty, place_timer)
|
|
243
361
|
|
|
244
362
|
def step_env(
|
|
245
363
|
self,
|
|
@@ -249,9 +367,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
249
367
|
params: EnvParams,
|
|
250
368
|
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
251
369
|
"""Perform single timestep state transition."""
|
|
252
|
-
|
|
253
|
-
# Decode the object grid: positive values are objects, negative are timers (treat as empty)
|
|
254
|
-
current_objects = jnp.maximum(0, state.object_grid)
|
|
370
|
+
current_objects = state.object_state.object_id
|
|
255
371
|
|
|
256
372
|
# 1. UPDATE AGENT POSITION
|
|
257
373
|
direction = DIRECTIONS[action]
|
|
@@ -294,9 +410,12 @@ class ForagaxEnv(environment.Environment):
|
|
|
294
410
|
# Handle digestion: add reward to buffer if collected
|
|
295
411
|
digestion_buffer = state.digestion_buffer
|
|
296
412
|
key, reward_subkey = jax.random.split(key)
|
|
413
|
+
|
|
414
|
+
object_params = state.object_state.state_params[pos[1], pos[0]]
|
|
297
415
|
object_reward = jax.lax.switch(
|
|
298
|
-
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
|
416
|
+
obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
|
|
299
417
|
)
|
|
418
|
+
|
|
300
419
|
key, digestion_subkey = jax.random.split(key)
|
|
301
420
|
reward_delay = jax.lax.switch(
|
|
302
421
|
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
|
@@ -319,136 +438,120 @@ class ForagaxEnv(environment.Environment):
|
|
|
319
438
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
|
320
439
|
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
|
321
440
|
|
|
322
|
-
# Decrement timers
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
441
|
+
# Decrement respawn timers
|
|
442
|
+
has_timer = state.object_state.respawn_timer > 0
|
|
443
|
+
new_respawn_timer = jnp.where(
|
|
444
|
+
has_timer,
|
|
445
|
+
state.object_state.respawn_timer - 1,
|
|
446
|
+
state.object_state.respawn_timer,
|
|
326
447
|
)
|
|
327
448
|
|
|
328
|
-
# Track which
|
|
329
|
-
|
|
330
|
-
|
|
449
|
+
# Track which cells have timers that just reached 0
|
|
450
|
+
just_respawned = has_timer & (new_respawn_timer == 0)
|
|
451
|
+
|
|
452
|
+
# Respawn objects where timer reached 0
|
|
453
|
+
new_object_id = jnp.where(
|
|
454
|
+
just_respawned,
|
|
455
|
+
state.object_state.respawn_object_id,
|
|
456
|
+
state.object_state.object_id,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Clear respawn_object_id for cells that just respawned
|
|
460
|
+
new_respawn_object_id = jnp.where(
|
|
461
|
+
just_respawned,
|
|
462
|
+
0,
|
|
463
|
+
state.object_state.respawn_object_id,
|
|
464
|
+
)
|
|
331
465
|
|
|
332
466
|
# Update spawn times for objects that just respawned
|
|
333
|
-
|
|
334
|
-
just_respawned, state.time, state.
|
|
467
|
+
spawn_time = jnp.where(
|
|
468
|
+
just_respawned, state.time, state.object_state.spawn_time
|
|
335
469
|
)
|
|
336
470
|
|
|
337
471
|
# Collect object: set a timer
|
|
338
472
|
regen_delay = jax.lax.switch(
|
|
339
473
|
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
|
340
474
|
)
|
|
341
|
-
|
|
475
|
+
timer_countdown = jax.lax.cond(
|
|
476
|
+
regen_delay == jnp.iinfo(jnp.int32).max,
|
|
477
|
+
lambda: 0, # No timer (permanent removal)
|
|
478
|
+
lambda: regen_delay + 1, # Timer countdown
|
|
479
|
+
)
|
|
342
480
|
|
|
343
481
|
# If collected, replace object with timer; otherwise, keep it
|
|
344
|
-
val_at_pos =
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
#
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
pos[1],
|
|
355
|
-
pos[0],
|
|
356
|
-
encoded_timer,
|
|
357
|
-
state.biome_grid,
|
|
358
|
-
rand_key,
|
|
359
|
-
),
|
|
360
|
-
lambda: self._place_timer_at_position(
|
|
361
|
-
object_grid, pos[1], pos[0], encoded_timer
|
|
362
|
-
),
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
def no_collection():
|
|
366
|
-
return object_grid
|
|
367
|
-
|
|
368
|
-
object_grid = jax.lax.cond(
|
|
369
|
-
should_collect,
|
|
370
|
-
do_collection,
|
|
371
|
-
no_collection,
|
|
482
|
+
val_at_pos = current_objects[pos[1], pos[0]]
|
|
483
|
+
# Use original should_collect for consumption tracking
|
|
484
|
+
should_collect_now = is_collectable & (val_at_pos > 0)
|
|
485
|
+
|
|
486
|
+
# Create updated object state with new respawn_timer, object_id, and spawn_time
|
|
487
|
+
object_state = state.object_state.replace(
|
|
488
|
+
object_id=new_object_id,
|
|
489
|
+
respawn_timer=new_respawn_timer,
|
|
490
|
+
respawn_object_id=new_respawn_object_id,
|
|
491
|
+
spawn_time=spawn_time,
|
|
372
492
|
)
|
|
373
493
|
|
|
374
|
-
#
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
& (expiry_times >= 0)
|
|
388
|
-
& (current_objects_for_expiry > 0)
|
|
494
|
+
# Place timer on collection
|
|
495
|
+
object_state = jax.lax.cond(
|
|
496
|
+
should_collect_now,
|
|
497
|
+
lambda: self._place_timer(
|
|
498
|
+
object_state,
|
|
499
|
+
pos[1],
|
|
500
|
+
pos[0],
|
|
501
|
+
obj_at_pos, # object type
|
|
502
|
+
timer_countdown, # timer value
|
|
503
|
+
self.object_random_respawn[obj_at_pos],
|
|
504
|
+
rand_key,
|
|
505
|
+
),
|
|
506
|
+
lambda: object_state,
|
|
389
507
|
)
|
|
390
508
|
|
|
391
|
-
#
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
should_exp = should_expire[y, x]
|
|
398
|
-
|
|
399
|
-
def expire_object():
|
|
400
|
-
# Get expiry regen delay for this object
|
|
401
|
-
exp_key = jax.random.fold_in(key, y * self.size[0] + x)
|
|
402
|
-
exp_delay = jax.lax.switch(
|
|
403
|
-
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
509
|
+
# Clear color grid when object is collected
|
|
510
|
+
object_state = jax.lax.cond(
|
|
511
|
+
should_collect_now,
|
|
512
|
+
lambda: object_state.replace(
|
|
513
|
+
color=object_state.color.at[pos[1], pos[0]].set(
|
|
514
|
+
jnp.full((3,), 255, dtype=jnp.uint8)
|
|
404
515
|
)
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
should_random_respawn = self.object_random_respawn[obj_id]
|
|
409
|
-
|
|
410
|
-
# Use second split for randomness in random placement
|
|
411
|
-
rand_key = jax.random.split(exp_key)[1]
|
|
412
|
-
|
|
413
|
-
# Place timer either at current position or random position
|
|
414
|
-
new_grid = jax.lax.cond(
|
|
415
|
-
should_random_respawn,
|
|
416
|
-
lambda: self._place_timer_at_random_position(
|
|
417
|
-
grid, y, x, encoded_exp_timer, state.biome_grid, rand_key
|
|
418
|
-
),
|
|
419
|
-
lambda: self._place_timer_at_position(
|
|
420
|
-
grid, y, x, encoded_exp_timer
|
|
421
|
-
),
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
return new_grid, spawn_grid
|
|
425
|
-
|
|
426
|
-
def no_expire():
|
|
427
|
-
return grid, spawn_grid
|
|
428
|
-
|
|
429
|
-
return jax.lax.cond(should_exp, expire_object, no_expire)
|
|
430
|
-
|
|
431
|
-
# Apply expiry to all cells (vectorized)
|
|
432
|
-
def scan_expiry_row(carry, y):
|
|
433
|
-
grid, spawn_grid, key = carry
|
|
434
|
-
|
|
435
|
-
def scan_expiry_col(carry_col, x):
|
|
436
|
-
grid_col, spawn_grid_col, key_col = carry_col
|
|
437
|
-
grid_col, spawn_grid_col = process_expiry(
|
|
438
|
-
y, x, grid_col, spawn_grid_col, key_col
|
|
439
|
-
)
|
|
440
|
-
return (grid_col, spawn_grid_col, key_col), None
|
|
516
|
+
),
|
|
517
|
+
lambda: object_state,
|
|
518
|
+
)
|
|
441
519
|
|
|
442
|
-
|
|
443
|
-
|
|
520
|
+
# 3.5. HANDLE OBJECT EXPIRY
|
|
521
|
+
# Only process expiry if there are objects that can expire
|
|
522
|
+
key, object_state = self.expire_objects(key, state, object_state)
|
|
523
|
+
|
|
524
|
+
# 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
|
|
525
|
+
if self.dynamic_biomes:
|
|
526
|
+
# Update consumption count if an object was collected
|
|
527
|
+
# Only count if the object belongs to the current generation of its biome
|
|
528
|
+
collected_biome_id = object_state.biome_id[pos[1], pos[0]]
|
|
529
|
+
object_gen_at_pos = object_state.generation[pos[1], pos[0]]
|
|
530
|
+
current_biome_gen = state.biome_state.generation[collected_biome_id]
|
|
531
|
+
is_current_generation = object_gen_at_pos == current_biome_gen
|
|
532
|
+
|
|
533
|
+
biome_consumption_count = state.biome_state.consumption_count
|
|
534
|
+
biome_consumption_count = jax.lax.cond(
|
|
535
|
+
should_collect & is_current_generation,
|
|
536
|
+
lambda: biome_consumption_count.at[collected_biome_id].add(1),
|
|
537
|
+
lambda: biome_consumption_count,
|
|
444
538
|
)
|
|
445
|
-
return (grid, spawn_grid, key), None
|
|
446
539
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
540
|
+
# Check each biome for threshold crossing and respawn if needed
|
|
541
|
+
key, respawn_key = jax.random.split(key)
|
|
542
|
+
biome_state = BiomeState(
|
|
543
|
+
consumption_count=biome_consumption_count,
|
|
544
|
+
total_objects=state.biome_state.total_objects,
|
|
545
|
+
generation=state.biome_state.generation,
|
|
546
|
+
)
|
|
547
|
+
object_state, biome_state, respawn_key = self._check_and_respawn_biomes(
|
|
548
|
+
object_state,
|
|
549
|
+
biome_state,
|
|
550
|
+
state.time,
|
|
551
|
+
respawn_key,
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
biome_state = state.biome_state
|
|
452
555
|
|
|
453
556
|
info = {"discount": self.discount(state, params)}
|
|
454
557
|
temperatures = jnp.zeros(len(self.objects))
|
|
@@ -458,17 +561,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
458
561
|
get_temperature(obj.rewards, state.time, obj.repeat)
|
|
459
562
|
)
|
|
460
563
|
info["temperatures"] = temperatures
|
|
461
|
-
info["biome_id"] =
|
|
564
|
+
info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
|
|
462
565
|
info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
|
|
463
566
|
|
|
464
567
|
# 4. UPDATE STATE
|
|
465
568
|
state = EnvState(
|
|
466
569
|
pos=pos,
|
|
467
|
-
object_grid=object_grid,
|
|
468
|
-
biome_grid=state.biome_grid,
|
|
469
570
|
time=state.time + 1,
|
|
470
571
|
digestion_buffer=digestion_buffer,
|
|
471
|
-
|
|
572
|
+
object_state=object_state,
|
|
573
|
+
biome_state=biome_state,
|
|
472
574
|
)
|
|
473
575
|
|
|
474
576
|
done = self.is_terminal(state, params)
|
|
@@ -480,60 +582,401 @@ class ForagaxEnv(environment.Environment):
|
|
|
480
582
|
info,
|
|
481
583
|
)
|
|
482
584
|
|
|
585
|
+
def expire_objects(
|
|
586
|
+
self, key, state, object_state: ObjectState
|
|
587
|
+
) -> Tuple[jax.Array, ObjectState]:
|
|
588
|
+
if self.has_expiring_objects:
|
|
589
|
+
# Check each cell for objects that have exceeded their expiry time
|
|
590
|
+
current_objects_for_expiry = object_state.object_id
|
|
591
|
+
|
|
592
|
+
# Calculate age of each object (current_time - spawn_time)
|
|
593
|
+
object_ages = state.time - object_state.spawn_time
|
|
594
|
+
|
|
595
|
+
# Get expiry time for each object type in the grid
|
|
596
|
+
expiry_times = self.object_expiry_time[current_objects_for_expiry]
|
|
597
|
+
|
|
598
|
+
# Check if object should expire (age >= expiry_time and expiry_time >= 0)
|
|
599
|
+
should_expire = (
|
|
600
|
+
(object_ages >= expiry_times)
|
|
601
|
+
& (expiry_times >= 0)
|
|
602
|
+
& (current_objects_for_expiry > 0)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Count how many objects actually need to expire
|
|
606
|
+
num_expiring = jnp.sum(should_expire)
|
|
607
|
+
|
|
608
|
+
# Only process expiry if there are actually objects to expire
|
|
609
|
+
def process_expiries():
|
|
610
|
+
# Get positions of objects that should expire
|
|
611
|
+
# Use nonzero with fixed size to maintain JIT compatibility
|
|
612
|
+
max_objects = self.size[0] * self.size[1]
|
|
613
|
+
y_indices, x_indices = jnp.nonzero(
|
|
614
|
+
should_expire, size=max_objects, fill_value=-1
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
key_local, expiry_key = jax.random.split(key)
|
|
618
|
+
|
|
619
|
+
def process_one_expiry(carry, i):
|
|
620
|
+
obj_state = carry
|
|
621
|
+
y = y_indices[i]
|
|
622
|
+
x = x_indices[i]
|
|
623
|
+
|
|
624
|
+
# Skip if this is a padding index (from fill_value)
|
|
625
|
+
is_valid = (y >= 0) & (x >= 0)
|
|
626
|
+
|
|
627
|
+
def expire_one():
|
|
628
|
+
obj_id = current_objects_for_expiry[y, x]
|
|
629
|
+
exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
|
|
630
|
+
exp_delay = jax.lax.switch(
|
|
631
|
+
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
632
|
+
)
|
|
633
|
+
timer_countdown = jax.lax.cond(
|
|
634
|
+
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
635
|
+
lambda: 0, # No timer (permanent removal)
|
|
636
|
+
lambda: exp_delay + 1, # Timer countdown
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# Use unified timer placement method
|
|
640
|
+
rand_key = jax.random.split(exp_key)[1]
|
|
641
|
+
new_obj_state = self._place_timer(
|
|
642
|
+
obj_state,
|
|
643
|
+
y,
|
|
644
|
+
x,
|
|
645
|
+
obj_id,
|
|
646
|
+
timer_countdown,
|
|
647
|
+
self.object_random_respawn[obj_id],
|
|
648
|
+
rand_key,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Clear color grid when object expires
|
|
652
|
+
empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
|
|
653
|
+
new_obj_state = new_obj_state.replace(
|
|
654
|
+
color=new_obj_state.color.at[y, x].set(empty_color)
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
return new_obj_state
|
|
658
|
+
|
|
659
|
+
def no_op():
|
|
660
|
+
return obj_state
|
|
661
|
+
|
|
662
|
+
return jax.lax.cond(is_valid, expire_one, no_op), None
|
|
663
|
+
|
|
664
|
+
new_object_state, _ = jax.lax.scan(
|
|
665
|
+
process_one_expiry,
|
|
666
|
+
object_state,
|
|
667
|
+
jnp.arange(max_objects),
|
|
668
|
+
)
|
|
669
|
+
return key_local, new_object_state
|
|
670
|
+
|
|
671
|
+
def no_expiries():
|
|
672
|
+
return key, object_state
|
|
673
|
+
|
|
674
|
+
key, object_state = jax.lax.cond(
|
|
675
|
+
num_expiring > 0,
|
|
676
|
+
process_expiries,
|
|
677
|
+
no_expiries,
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
return key, object_state
|
|
681
|
+
|
|
682
|
+
def _check_and_respawn_biomes(
|
|
683
|
+
self,
|
|
684
|
+
object_state: ObjectState,
|
|
685
|
+
biome_state: BiomeState,
|
|
686
|
+
current_time: int,
|
|
687
|
+
key: jax.Array,
|
|
688
|
+
) -> Tuple[ObjectState, BiomeState, jax.Array]:
|
|
689
|
+
"""Check all biomes for consumption threshold and respawn if needed."""
|
|
690
|
+
|
|
691
|
+
num_biomes = self.biome_object_frequencies.shape[0]
|
|
692
|
+
|
|
693
|
+
# Compute consumption rates for all biomes
|
|
694
|
+
consumption_rates = biome_state.consumption_count / jnp.maximum(
|
|
695
|
+
1.0, biome_state.total_objects.astype(float)
|
|
696
|
+
)
|
|
697
|
+
should_respawn = consumption_rates >= self.biome_consumption_threshold
|
|
698
|
+
|
|
699
|
+
# Split key for all biomes in parallel
|
|
700
|
+
key, subkey = jax.random.split(key)
|
|
701
|
+
biome_keys = jax.random.split(subkey, num_biomes)
|
|
702
|
+
|
|
703
|
+
# Compute all new spawns in parallel using vmap for random, switch for deterministic
|
|
704
|
+
if self.deterministic_spawn:
|
|
705
|
+
# Use switch to dispatch to concrete biome spawns for deterministic
|
|
706
|
+
def make_spawn_fn(biome_idx):
|
|
707
|
+
def spawn_fn(key):
|
|
708
|
+
return self._spawn_biome_objects(biome_idx, key, deterministic=True)
|
|
709
|
+
|
|
710
|
+
return spawn_fn
|
|
711
|
+
|
|
712
|
+
spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
|
|
713
|
+
|
|
714
|
+
# Apply switch for each biome
|
|
715
|
+
all_new_objects_list = []
|
|
716
|
+
all_new_colors_list = []
|
|
717
|
+
all_new_params_list = []
|
|
718
|
+
for i in range(num_biomes):
|
|
719
|
+
obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
|
|
720
|
+
all_new_objects_list.append(obj)
|
|
721
|
+
all_new_colors_list.append(col)
|
|
722
|
+
all_new_params_list.append(par)
|
|
723
|
+
|
|
724
|
+
all_new_objects = jnp.stack(all_new_objects_list)
|
|
725
|
+
all_new_colors = jnp.stack(all_new_colors_list)
|
|
726
|
+
all_new_params = jnp.stack(all_new_params_list)
|
|
727
|
+
else:
|
|
728
|
+
# Random spawn works with vmap
|
|
729
|
+
all_new_objects, all_new_colors, all_new_params = jax.vmap(
|
|
730
|
+
lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
|
|
731
|
+
)(jnp.arange(num_biomes), biome_keys)
|
|
732
|
+
|
|
733
|
+
# Initialize updated grids
|
|
734
|
+
new_obj_id = object_state.object_id
|
|
735
|
+
new_color = object_state.color
|
|
736
|
+
new_params = object_state.state_params
|
|
737
|
+
new_spawn = object_state.spawn_time
|
|
738
|
+
new_gen = object_state.generation
|
|
739
|
+
|
|
740
|
+
# Update biome state
|
|
741
|
+
new_consumption_count = jnp.where(
|
|
742
|
+
should_respawn, 0, biome_state.consumption_count
|
|
743
|
+
)
|
|
744
|
+
new_generation = biome_state.generation + should_respawn.astype(int)
|
|
745
|
+
|
|
746
|
+
# Compute new total objects for respawning biomes
|
|
747
|
+
def count_objects(i):
|
|
748
|
+
return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
|
|
749
|
+
|
|
750
|
+
new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
|
|
751
|
+
new_total_objects = jnp.where(
|
|
752
|
+
should_respawn, new_object_counts, biome_state.total_objects
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
new_biome_state = BiomeState(
|
|
756
|
+
consumption_count=new_consumption_count,
|
|
757
|
+
total_objects=new_total_objects,
|
|
758
|
+
generation=new_generation,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# Update grids for respawning biomes
|
|
762
|
+
for i in range(num_biomes):
|
|
763
|
+
biome_mask = self.biome_masks_array[i]
|
|
764
|
+
new_gen_value = new_biome_state.generation[i]
|
|
765
|
+
|
|
766
|
+
# Only update where new spawn has objects and biome should respawn
|
|
767
|
+
is_new_object = (
|
|
768
|
+
(all_new_objects[i] > 0) & biome_mask & should_respawn[i][..., None]
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
new_obj_id = jnp.where(is_new_object, all_new_objects[i], new_obj_id)
|
|
772
|
+
new_color = jnp.where(
|
|
773
|
+
is_new_object[..., None], all_new_colors[i], new_color
|
|
774
|
+
)
|
|
775
|
+
new_params = jnp.where(
|
|
776
|
+
is_new_object[..., None], all_new_params[i], new_params
|
|
777
|
+
)
|
|
778
|
+
new_gen = jnp.where(is_new_object, new_gen_value, new_gen)
|
|
779
|
+
new_spawn = jnp.where(is_new_object, current_time, new_spawn)
|
|
780
|
+
|
|
781
|
+
# Clear timers in respawning biomes
|
|
782
|
+
new_respawn_timer = object_state.respawn_timer
|
|
783
|
+
new_respawn_object_id = object_state.respawn_object_id
|
|
784
|
+
for i in range(num_biomes):
|
|
785
|
+
biome_mask = self.biome_masks_array[i]
|
|
786
|
+
should_clear = biome_mask & should_respawn[i][..., None]
|
|
787
|
+
new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
|
|
788
|
+
new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
|
|
789
|
+
|
|
790
|
+
object_state = object_state.replace(
|
|
791
|
+
object_id=new_obj_id,
|
|
792
|
+
respawn_timer=new_respawn_timer,
|
|
793
|
+
respawn_object_id=new_respawn_object_id,
|
|
794
|
+
color=new_color,
|
|
795
|
+
state_params=new_params,
|
|
796
|
+
generation=new_gen,
|
|
797
|
+
spawn_time=new_spawn,
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
return object_state, new_biome_state, key
|
|
801
|
+
|
|
483
802
|
def reset_env(
|
|
484
803
|
self, key: jax.Array, params: EnvParams
|
|
485
804
|
) -> Tuple[jax.Array, EnvState]:
|
|
486
805
|
"""Reset environment state."""
|
|
487
|
-
|
|
488
|
-
|
|
806
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
807
|
+
object_state = ObjectState.create_empty(self.size, num_object_params)
|
|
808
|
+
|
|
489
809
|
key, iter_key = jax.random.split(key)
|
|
810
|
+
|
|
811
|
+
# Spawn objects in each biome using unified method
|
|
490
812
|
for i in range(self.biome_object_frequencies.shape[0]):
|
|
491
813
|
iter_key, biome_key = jax.random.split(iter_key)
|
|
492
814
|
mask = self.biome_masks[i]
|
|
493
|
-
biome_grid = jnp.where(mask, i, biome_grid)
|
|
494
815
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
816
|
+
# Set biome_id
|
|
817
|
+
object_state = object_state.replace(
|
|
818
|
+
biome_id=jnp.where(mask, i, object_state.biome_id)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# Use unified spawn method
|
|
822
|
+
biome_objects, biome_colors, biome_object_params = (
|
|
823
|
+
self._spawn_biome_objects(i, biome_key, self.deterministic_spawn)
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Merge biome objects/colors/params into object_state
|
|
827
|
+
object_state = object_state.replace(
|
|
828
|
+
object_id=jnp.where(mask, biome_objects, object_state.object_id),
|
|
829
|
+
color=jnp.where(mask[..., None], biome_colors, object_state.color),
|
|
830
|
+
state_params=jnp.where(
|
|
831
|
+
mask[..., None], biome_object_params, object_state.state_params
|
|
832
|
+
),
|
|
833
|
+
)
|
|
501
834
|
|
|
502
835
|
# Place agent in the center of the world
|
|
503
836
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
|
504
837
|
|
|
505
|
-
# Initialize
|
|
506
|
-
|
|
838
|
+
# Initialize biome consumption tracking
|
|
839
|
+
num_biomes = self.biome_object_frequencies.shape[0]
|
|
840
|
+
biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
|
|
841
|
+
biome_total_objects = jnp.zeros(num_biomes, dtype=int)
|
|
842
|
+
|
|
843
|
+
# Count objects in each biome
|
|
844
|
+
for i in range(num_biomes):
|
|
845
|
+
mask = self.biome_masks[i]
|
|
846
|
+
# Count non-empty objects (object_id > 0)
|
|
847
|
+
total = jnp.sum((object_state.object_id > 0) & mask)
|
|
848
|
+
biome_total_objects = biome_total_objects.at[i].set(total)
|
|
849
|
+
|
|
850
|
+
biome_generation = jnp.zeros(num_biomes, dtype=int)
|
|
507
851
|
|
|
508
852
|
state = EnvState(
|
|
509
853
|
pos=agent_pos,
|
|
510
|
-
object_grid=object_grid,
|
|
511
|
-
biome_grid=biome_grid,
|
|
512
854
|
time=0,
|
|
513
855
|
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
|
514
|
-
|
|
856
|
+
object_state=object_state,
|
|
857
|
+
biome_state=BiomeState(
|
|
858
|
+
consumption_count=biome_consumption_count,
|
|
859
|
+
total_objects=biome_total_objects,
|
|
860
|
+
generation=biome_generation,
|
|
861
|
+
),
|
|
515
862
|
)
|
|
516
863
|
|
|
517
864
|
return self.get_obs(state, params), state
|
|
518
865
|
|
|
519
|
-
def
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
866
|
+
def _spawn_biome_objects(
|
|
867
|
+
self,
|
|
868
|
+
biome_idx: int,
|
|
869
|
+
key: jax.Array,
|
|
870
|
+
deterministic: bool = False,
|
|
871
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
872
|
+
"""Spawn objects in a biome.
|
|
873
|
+
|
|
874
|
+
Returns:
|
|
875
|
+
object_grid: (H, W) array of object IDs
|
|
876
|
+
color_grid: (H, W, 3) array of RGB colors
|
|
877
|
+
state_grid: (H, W, num_state_params) array of object state parameters
|
|
878
|
+
"""
|
|
879
|
+
biome_freqs = self.biome_object_frequencies[biome_idx]
|
|
880
|
+
biome_mask = self.biome_masks_array[biome_idx]
|
|
881
|
+
|
|
882
|
+
key, spawn_key, color_key, params_key = jax.random.split(key, 4)
|
|
883
|
+
|
|
884
|
+
# Generate object IDs using deterministic or random spawn
|
|
885
|
+
if deterministic:
|
|
886
|
+
# Deterministic spawn: exact number of each object type
|
|
887
|
+
# NOTE: Requires concrete biome_idx to compute size at trace time
|
|
888
|
+
# Get static biome bounds
|
|
889
|
+
biome_start = self.biome_starts[biome_idx]
|
|
890
|
+
biome_stop = self.biome_stops[biome_idx]
|
|
891
|
+
biome_height = biome_stop[1] - biome_start[1]
|
|
892
|
+
biome_width = biome_stop[0] - biome_start[0]
|
|
893
|
+
biome_size = int(self.biome_sizes[biome_idx])
|
|
894
|
+
|
|
895
|
+
grid = jnp.linspace(0, 1, biome_size, endpoint=False)
|
|
896
|
+
biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
|
|
897
|
+
jnp.cumsum(biome_freqs[::-1]), grid, side="right"
|
|
898
|
+
)
|
|
899
|
+
biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
|
|
900
|
+
|
|
901
|
+
# Reshape to match biome dimensions (use concrete dimensions)
|
|
902
|
+
biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
|
|
903
|
+
|
|
904
|
+
# Place in full grid using slicing with static bounds
|
|
905
|
+
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
906
|
+
object_grid = object_grid.at[
|
|
907
|
+
biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
|
|
908
|
+
].set(biome_objects)
|
|
909
|
+
else:
|
|
910
|
+
# Random spawn: probabilistic placement (works with traced biome_idx)
|
|
911
|
+
grid_rand = jax.random.uniform(spawn_key, (self.size[1], self.size[0]))
|
|
912
|
+
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
|
913
|
+
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
|
914
|
+
cumulative_freqs = jnp.cumsum(
|
|
915
|
+
jnp.concatenate([jnp.array([0.0]), all_freqs])
|
|
916
|
+
)
|
|
917
|
+
object_grid = (
|
|
918
|
+
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
|
919
|
+
)
|
|
920
|
+
|
|
921
|
+
# Initialize color grid
|
|
922
|
+
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
|
|
923
|
+
|
|
924
|
+
# Sample ONE color per object type in this biome (not per instance)
|
|
925
|
+
# This gives objects of the same type the same color within a biome generation
|
|
926
|
+
# Skip index 0 (EMPTY object) - only sample colors for actual objects
|
|
927
|
+
num_object_types = len(self.objects)
|
|
928
|
+
num_actual_objects = num_object_types - 1 # Exclude EMPTY
|
|
929
|
+
|
|
930
|
+
if num_actual_objects > 0:
|
|
931
|
+
biome_object_colors = jax.random.randint(
|
|
932
|
+
color_key,
|
|
933
|
+
(num_actual_objects, 3),
|
|
934
|
+
minval=0,
|
|
935
|
+
maxval=256,
|
|
936
|
+
dtype=jnp.uint8,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# Assign colors based on object type (starting from index 1)
|
|
940
|
+
for obj_idx in range(1, num_object_types):
|
|
941
|
+
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
942
|
+
obj_color = biome_object_colors[
|
|
943
|
+
obj_idx - 1
|
|
944
|
+
] # Offset by 1 since we skip EMPTY
|
|
945
|
+
color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
|
|
946
|
+
|
|
947
|
+
# Initialize parameters grid
|
|
948
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
949
|
+
params_grid = jnp.zeros(
|
|
950
|
+
(self.size[1], self.size[0], num_object_params), dtype=jnp.float32
|
|
533
951
|
)
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
952
|
+
|
|
953
|
+
# Generate per-object parameters for each object type
|
|
954
|
+
for obj_idx in range(num_object_types):
|
|
955
|
+
# Get params for this object type - this happens at trace time
|
|
956
|
+
params_key, obj_key = jax.random.split(params_key)
|
|
957
|
+
obj_params = self.objects[obj_idx].get_state(obj_key)
|
|
958
|
+
|
|
959
|
+
# Skip if no params (e.g., for EMPTY or default objects)
|
|
960
|
+
if len(obj_params) == 0:
|
|
961
|
+
continue
|
|
962
|
+
|
|
963
|
+
# Ensure params match expected size
|
|
964
|
+
if len(obj_params) != num_object_params:
|
|
965
|
+
if len(obj_params) < num_object_params:
|
|
966
|
+
obj_params = jnp.pad(
|
|
967
|
+
obj_params,
|
|
968
|
+
(0, num_object_params - len(obj_params)),
|
|
969
|
+
constant_values=0.0,
|
|
970
|
+
)
|
|
971
|
+
else:
|
|
972
|
+
# Truncate if too long
|
|
973
|
+
obj_params = obj_params[:num_object_params]
|
|
974
|
+
|
|
975
|
+
# Assign to all objects of this type in this biome
|
|
976
|
+
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
977
|
+
params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
|
|
978
|
+
|
|
979
|
+
return object_grid, color_grid, params_grid
|
|
537
980
|
|
|
538
981
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
539
982
|
"""Foragax is a continuing environment."""
|
|
@@ -556,21 +999,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
556
999
|
|
|
557
1000
|
def state_space(self, params: EnvParams) -> spaces.Dict:
|
|
558
1001
|
"""State space of the environment."""
|
|
1002
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
559
1003
|
return spaces.Dict(
|
|
560
1004
|
{
|
|
561
1005
|
"pos": spaces.Box(0, max(self.size), (2,), int),
|
|
562
|
-
"object_grid": spaces.Box(
|
|
563
|
-
-1000 * len(self.object_ids),
|
|
564
|
-
len(self.object_ids),
|
|
565
|
-
(self.size[1], self.size[0]),
|
|
566
|
-
int,
|
|
567
|
-
),
|
|
568
|
-
"biome_grid": spaces.Box(
|
|
569
|
-
0,
|
|
570
|
-
self.biome_object_frequencies.shape[0],
|
|
571
|
-
(self.size[1], self.size[0]),
|
|
572
|
-
int,
|
|
573
|
-
),
|
|
574
1006
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
|
575
1007
|
"digestion_buffer": spaces.Box(
|
|
576
1008
|
-jnp.inf,
|
|
@@ -578,17 +1010,79 @@ class ForagaxEnv(environment.Environment):
|
|
|
578
1010
|
(self.max_reward_delay,),
|
|
579
1011
|
float,
|
|
580
1012
|
),
|
|
581
|
-
"
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
1013
|
+
"object_state": spaces.Dict(
|
|
1014
|
+
{
|
|
1015
|
+
"object_id": spaces.Box(
|
|
1016
|
+
-1000 * len(self.object_ids),
|
|
1017
|
+
len(self.object_ids),
|
|
1018
|
+
(self.size[1], self.size[0]),
|
|
1019
|
+
int,
|
|
1020
|
+
),
|
|
1021
|
+
"spawn_time": spaces.Box(
|
|
1022
|
+
0,
|
|
1023
|
+
jnp.inf,
|
|
1024
|
+
(self.size[1], self.size[0]),
|
|
1025
|
+
int,
|
|
1026
|
+
),
|
|
1027
|
+
"color": spaces.Box(
|
|
1028
|
+
0,
|
|
1029
|
+
255,
|
|
1030
|
+
(self.size[1], self.size[0], 3),
|
|
1031
|
+
int,
|
|
1032
|
+
),
|
|
1033
|
+
"generation": spaces.Box(
|
|
1034
|
+
0,
|
|
1035
|
+
jnp.inf,
|
|
1036
|
+
(self.size[1], self.size[0]),
|
|
1037
|
+
int,
|
|
1038
|
+
),
|
|
1039
|
+
"state_params": spaces.Box(
|
|
1040
|
+
-jnp.inf,
|
|
1041
|
+
jnp.inf,
|
|
1042
|
+
(self.size[1], self.size[0], num_object_params),
|
|
1043
|
+
float,
|
|
1044
|
+
),
|
|
1045
|
+
"biome_id": spaces.Box(
|
|
1046
|
+
-1,
|
|
1047
|
+
self.biome_object_frequencies.shape[0],
|
|
1048
|
+
(self.size[1], self.size[0]),
|
|
1049
|
+
int,
|
|
1050
|
+
),
|
|
1051
|
+
}
|
|
1052
|
+
),
|
|
1053
|
+
"biome_state": spaces.Dict(
|
|
1054
|
+
{
|
|
1055
|
+
"consumption_count": spaces.Box(
|
|
1056
|
+
0,
|
|
1057
|
+
jnp.inf,
|
|
1058
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1059
|
+
int,
|
|
1060
|
+
),
|
|
1061
|
+
"total_objects": spaces.Box(
|
|
1062
|
+
0,
|
|
1063
|
+
jnp.inf,
|
|
1064
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1065
|
+
int,
|
|
1066
|
+
),
|
|
1067
|
+
"generation": spaces.Box(
|
|
1068
|
+
0,
|
|
1069
|
+
jnp.inf,
|
|
1070
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1071
|
+
int,
|
|
1072
|
+
),
|
|
1073
|
+
}
|
|
586
1074
|
),
|
|
587
1075
|
}
|
|
588
1076
|
)
|
|
589
1077
|
|
|
590
|
-
def
|
|
591
|
-
|
|
1078
|
+
def _compute_aperture_coordinates(
|
|
1079
|
+
self, pos: jax.Array
|
|
1080
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
|
1081
|
+
"""Compute aperture coordinates for the given position.
|
|
1082
|
+
|
|
1083
|
+
Returns:
|
|
1084
|
+
(y_coords, x_coords, y_coords_clamped/mod, x_coords_clamped/mod)
|
|
1085
|
+
"""
|
|
592
1086
|
ap_h, ap_w = self.aperture_size
|
|
593
1087
|
start_y = pos[1] - ap_h // 2
|
|
594
1088
|
start_x = pos[0] - ap_w // 2
|
|
@@ -599,27 +1093,37 @@ class ForagaxEnv(environment.Environment):
|
|
|
599
1093
|
x_coords = start_x + x_offsets
|
|
600
1094
|
|
|
601
1095
|
if self.nowrap:
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
1096
|
+
y_coords_adj = jnp.clip(y_coords, 0, self.size[1] - 1)
|
|
1097
|
+
x_coords_adj = jnp.clip(x_coords, 0, self.size[0] - 1)
|
|
1098
|
+
else:
|
|
1099
|
+
y_coords_adj = jnp.mod(y_coords, self.size[1])
|
|
1100
|
+
x_coords_adj = jnp.mod(x_coords, self.size[0])
|
|
1101
|
+
|
|
1102
|
+
return y_coords, x_coords, y_coords_adj, x_coords_adj
|
|
1103
|
+
|
|
1104
|
+
def _get_aperture(self, object_id_grid: jax.Array, pos: jax.Array) -> jax.Array:
|
|
1105
|
+
"""Extract the aperture view from the object id grid."""
|
|
1106
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1107
|
+
self._compute_aperture_coordinates(pos)
|
|
1108
|
+
)
|
|
1109
|
+
|
|
1110
|
+
values = object_id_grid[y_coords_adj, x_coords_adj]
|
|
1111
|
+
|
|
1112
|
+
if self.nowrap:
|
|
1113
|
+
# Mark out-of-bounds positions with padding
|
|
607
1114
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
608
1115
|
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
609
1116
|
out_of_bounds = y_out | x_out
|
|
610
1117
|
padding_index = self.object_ids[-1]
|
|
611
1118
|
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
612
1119
|
else:
|
|
613
|
-
|
|
614
|
-
x_coords_mod = jnp.mod(x_coords, self.size[0])
|
|
615
|
-
aperture = object_grid[y_coords_mod, x_coords_mod]
|
|
1120
|
+
aperture = values
|
|
616
1121
|
|
|
617
1122
|
return aperture
|
|
618
1123
|
|
|
619
1124
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
|
620
1125
|
"""Get observation based on observation_type and full_world."""
|
|
621
|
-
|
|
622
|
-
obs_grid = jnp.maximum(0, state.object_grid)
|
|
1126
|
+
obs_grid = state.object_state.object_id
|
|
623
1127
|
|
|
624
1128
|
if self.full_world:
|
|
625
1129
|
return self._get_world_obs(obs_grid, state)
|
|
@@ -721,48 +1225,43 @@ class ForagaxEnv(environment.Environment):
|
|
|
721
1225
|
|
|
722
1226
|
if is_world_mode:
|
|
723
1227
|
# Create an RGB image from the object grid
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
1228
|
+
# Use stateful object colors if dynamic_biomes is enabled, else use default colors
|
|
1229
|
+
if self.dynamic_biomes:
|
|
1230
|
+
# Use per-instance colors from state
|
|
1231
|
+
img = state.object_state.color.copy()
|
|
1232
|
+
else:
|
|
1233
|
+
# Use default object colors
|
|
1234
|
+
img = jnp.zeros((self.size[1], self.size[0], 3))
|
|
1235
|
+
render_grid = state.object_state.object_id
|
|
727
1236
|
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
1237
|
+
def update_image(i, img):
|
|
1238
|
+
color = self.object_colors[i]
|
|
1239
|
+
mask = render_grid == i
|
|
1240
|
+
img = jnp.where(mask[..., None], color, img)
|
|
1241
|
+
return img
|
|
733
1242
|
|
|
734
|
-
|
|
1243
|
+
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
735
1244
|
|
|
736
1245
|
# Tint the agent's aperture
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
1246
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1247
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1248
|
+
)
|
|
740
1249
|
|
|
741
1250
|
alpha = 0.2
|
|
742
1251
|
agent_color = jnp.array(AGENT.color)
|
|
743
1252
|
|
|
744
|
-
# Create indices for the aperture
|
|
745
|
-
y_offsets = jnp.arange(ap_h)
|
|
746
|
-
x_offsets = jnp.arange(ap_w)
|
|
747
|
-
y_coords_original = start_y + y_offsets[:, None]
|
|
748
|
-
x_coords_original = start_x + x_offsets
|
|
749
|
-
|
|
750
1253
|
if self.nowrap:
|
|
751
|
-
y_coords = jnp.clip(y_coords_original, 0, self.size[1] - 1)
|
|
752
|
-
x_coords = jnp.clip(x_coords_original, 0, self.size[0] - 1)
|
|
753
1254
|
# Create tint mask: any in-bounds original position maps to a cell makes it tinted
|
|
754
1255
|
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
755
|
-
tint_mask = tint_mask.at[
|
|
1256
|
+
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
756
1257
|
# Apply tint to masked positions
|
|
757
1258
|
original_colors = img
|
|
758
1259
|
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
759
1260
|
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
760
1261
|
else:
|
|
761
|
-
|
|
762
|
-
x_coords = jnp.mod(x_coords_original, self.size[0])
|
|
763
|
-
original_colors = img[y_coords, x_coords]
|
|
1262
|
+
original_colors = img[y_coords_adj, x_coords_adj]
|
|
764
1263
|
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
765
|
-
img = img.at[
|
|
1264
|
+
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
766
1265
|
|
|
767
1266
|
# Agent color
|
|
768
1267
|
img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
|
|
@@ -775,6 +1274,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
775
1274
|
|
|
776
1275
|
if is_true_mode:
|
|
777
1276
|
# Apply true object borders by overlaying true colors on border pixels
|
|
1277
|
+
render_grid = state.object_state.object_id
|
|
778
1278
|
img = apply_true_borders(
|
|
779
1279
|
img, render_grid, self.size, len(self.object_ids)
|
|
780
1280
|
)
|
|
@@ -787,10 +1287,27 @@ class ForagaxEnv(environment.Environment):
|
|
|
787
1287
|
img = img.at[:, col_indices].set(grid_color)
|
|
788
1288
|
|
|
789
1289
|
elif is_aperture_mode:
|
|
790
|
-
obs_grid =
|
|
1290
|
+
obs_grid = state.object_state.object_id
|
|
791
1291
|
aperture = self._get_aperture(obs_grid, state.pos)
|
|
792
|
-
|
|
793
|
-
|
|
1292
|
+
|
|
1293
|
+
if self.dynamic_biomes:
|
|
1294
|
+
# Use per-instance colors from state - extract aperture view
|
|
1295
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1296
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1297
|
+
)
|
|
1298
|
+
img = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1299
|
+
|
|
1300
|
+
if self.nowrap:
|
|
1301
|
+
# For out-of-bounds, use padding object color
|
|
1302
|
+
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
1303
|
+
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
1304
|
+
out_of_bounds = y_out | x_out
|
|
1305
|
+
padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
|
|
1306
|
+
img = jnp.where(out_of_bounds[..., None], padding_color, img)
|
|
1307
|
+
else:
|
|
1308
|
+
# Use default object colors
|
|
1309
|
+
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
|
1310
|
+
img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
|
|
794
1311
|
|
|
795
1312
|
# Draw agent in the center
|
|
796
1313
|
center_y, center_x = self.aperture_size[1] // 2, self.aperture_size[0] // 2
|