continual-foragax 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.1.dist-info}/METADATA +1 -6
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.1.dist-info}/RECORD +8 -8
- foragax/env.py +786 -255
- foragax/objects.py +351 -4
- foragax/registry.py +59 -0
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.1.dist-info}/WHEEL +0 -0
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.1.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.31.0.dist-info → continual_foragax-0.32.1.dist-info}/top_level.txt +0 -0
foragax/env.py
CHANGED
|
@@ -19,6 +19,7 @@ from foragax.objects import (
|
|
|
19
19
|
EMPTY,
|
|
20
20
|
PADDING,
|
|
21
21
|
BaseForagaxObject,
|
|
22
|
+
FourierObject,
|
|
22
23
|
WeatherObject,
|
|
23
24
|
)
|
|
24
25
|
from foragax.rendering import apply_true_borders
|
|
@@ -42,6 +43,48 @@ DIRECTIONS = jnp.array(
|
|
|
42
43
|
)
|
|
43
44
|
|
|
44
45
|
|
|
46
|
+
@struct.dataclass
|
|
47
|
+
class ObjectState:
|
|
48
|
+
"""Per-cell object state information.
|
|
49
|
+
|
|
50
|
+
This struct encapsulates all state information about objects in the grid:
|
|
51
|
+
- object_id: (H, W) The object type (0 for empty, positive for object type)
|
|
52
|
+
- respawn_timer: (H, W) Countdown timer for respawning (0 = no timer, positive = countdown remaining)
|
|
53
|
+
- respawn_object_id: (H, W) What object type will spawn when timer reaches 0
|
|
54
|
+
- spawn_time: (H, W) When each object was spawned (for expiry tracking)
|
|
55
|
+
- color: (H, W, 3) RGB color for each object instance (for dynamic biomes)
|
|
56
|
+
- generation: (H, W) Which biome generation each object belongs to
|
|
57
|
+
- state_params: (H, W, N) Per-instance parameters (e.g., Fourier coefficients)
|
|
58
|
+
- biome_id: (H, W) Which biome each cell belongs to
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
object_id: jax.Array # (H, W) - Object type ID (0 = empty, >0 = object type)
|
|
62
|
+
respawn_timer: (
|
|
63
|
+
jax.Array
|
|
64
|
+
) # (H, W) - Respawn countdown (0 = no timer, >0 = countdown)
|
|
65
|
+
respawn_object_id: jax.Array # (H, W) - Object type to spawn when timer reaches 0
|
|
66
|
+
spawn_time: jax.Array # (H, W) - Timestep when object spawned
|
|
67
|
+
color: jax.Array # (H, W, 3) - RGB color per instance
|
|
68
|
+
generation: jax.Array # (H, W) - Biome generation number
|
|
69
|
+
state_params: jax.Array # (H, W, N) - Per-instance parameters
|
|
70
|
+
biome_id: jax.Array # (H, W) - Biome assignment for each cell
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def create_empty(cls, size: Tuple[int, int], num_params: int) -> "ObjectState":
|
|
74
|
+
"""Create an empty ObjectState for the given grid size."""
|
|
75
|
+
h, w = size[1], size[0]
|
|
76
|
+
return cls(
|
|
77
|
+
object_id=jnp.zeros((h, w), dtype=int),
|
|
78
|
+
respawn_timer=jnp.zeros((h, w), dtype=int),
|
|
79
|
+
respawn_object_id=jnp.zeros((h, w), dtype=int),
|
|
80
|
+
spawn_time=jnp.zeros((h, w), dtype=int),
|
|
81
|
+
color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
|
|
82
|
+
generation=jnp.zeros((h, w), dtype=int),
|
|
83
|
+
state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
|
|
84
|
+
biome_id=jnp.full((h, w), -1, dtype=int),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
45
88
|
@dataclass
|
|
46
89
|
class Biome:
|
|
47
90
|
# Object generation frequencies for this biome
|
|
@@ -55,14 +98,22 @@ class EnvParams(environment.EnvParams):
|
|
|
55
98
|
max_steps_in_episode: Union[int, None]
|
|
56
99
|
|
|
57
100
|
|
|
101
|
+
@struct.dataclass
|
|
102
|
+
class BiomeState:
|
|
103
|
+
"""Biome-level tracking state (num_biomes,)."""
|
|
104
|
+
|
|
105
|
+
consumption_count: jax.Array # objects consumed per biome
|
|
106
|
+
total_objects: jax.Array # total objects spawned per biome
|
|
107
|
+
generation: jax.Array # current generation per biome
|
|
108
|
+
|
|
109
|
+
|
|
58
110
|
@struct.dataclass
|
|
59
111
|
class EnvState(environment.EnvState):
|
|
60
112
|
pos: jax.Array
|
|
61
|
-
object_grid: jax.Array
|
|
62
|
-
biome_grid: jax.Array
|
|
63
113
|
time: int
|
|
64
114
|
digestion_buffer: jax.Array
|
|
65
|
-
|
|
115
|
+
object_state: ObjectState
|
|
116
|
+
biome_state: BiomeState
|
|
66
117
|
|
|
67
118
|
|
|
68
119
|
class ForagaxEnv(environment.Environment):
|
|
@@ -79,6 +130,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
79
130
|
deterministic_spawn: bool = False,
|
|
80
131
|
teleport_interval: Optional[int] = None,
|
|
81
132
|
observation_type: str = "object",
|
|
133
|
+
dynamic_biomes: bool = False,
|
|
134
|
+
biome_consumption_threshold: float = 0.9,
|
|
82
135
|
):
|
|
83
136
|
super().__init__()
|
|
84
137
|
self._name = name
|
|
@@ -100,11 +153,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
100
153
|
self.nowrap = nowrap
|
|
101
154
|
self.deterministic_spawn = deterministic_spawn
|
|
102
155
|
self.teleport_interval = teleport_interval
|
|
156
|
+
self.dynamic_biomes = dynamic_biomes
|
|
157
|
+
self.biome_consumption_threshold = biome_consumption_threshold
|
|
158
|
+
|
|
103
159
|
objects = (EMPTY,) + objects
|
|
104
160
|
if self.nowrap and not self.full_world:
|
|
105
161
|
objects = objects + (PADDING,)
|
|
106
162
|
self.objects = objects
|
|
107
163
|
|
|
164
|
+
# Infer num_fourier_terms from objects
|
|
165
|
+
self.num_fourier_terms = max(
|
|
166
|
+
(
|
|
167
|
+
obj.num_fourier_terms
|
|
168
|
+
for obj in self.objects
|
|
169
|
+
if isinstance(obj, FourierObject)
|
|
170
|
+
),
|
|
171
|
+
default=0,
|
|
172
|
+
)
|
|
173
|
+
|
|
108
174
|
# JIT-compatible versions of object and biome properties
|
|
109
175
|
self.object_ids = jnp.arange(len(objects))
|
|
110
176
|
self.object_blocking = jnp.array([o.blocking for o in objects])
|
|
@@ -122,6 +188,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
122
188
|
[o.expiry_time if o.expiry_time is not None else -1 for o in objects]
|
|
123
189
|
)
|
|
124
190
|
|
|
191
|
+
# Check if any objects can expire
|
|
192
|
+
self.has_expiring_objects = jnp.any(self.object_expiry_time >= 0)
|
|
193
|
+
|
|
125
194
|
# Compute reward steps per object (using max_reward_delay attribute)
|
|
126
195
|
object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
|
|
127
196
|
self.max_reward_delay = (
|
|
@@ -138,6 +207,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
138
207
|
[b.stop if b.stop is not None else (-1, -1) for b in biomes]
|
|
139
208
|
)
|
|
140
209
|
self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
|
|
210
|
+
self.biome_sizes_jax = jnp.array(self.biome_sizes) # JAX version for indexing
|
|
141
211
|
self.biome_starts_jax = jnp.array(self.biome_starts)
|
|
142
212
|
self.biome_stops_jax = jnp.array(self.biome_stops)
|
|
143
213
|
biome_centers = []
|
|
@@ -149,6 +219,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
149
219
|
biome_centers.append((center_x, center_y))
|
|
150
220
|
self.biome_centers_jax = jnp.array(biome_centers)
|
|
151
221
|
self.biome_masks = []
|
|
222
|
+
biome_masks_array = []
|
|
152
223
|
for i in range(self.biome_object_frequencies.shape[0]):
|
|
153
224
|
# Create mask for the biome
|
|
154
225
|
start = jax.lax.select(
|
|
@@ -170,6 +241,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
170
241
|
& (cols < stop[0])
|
|
171
242
|
)
|
|
172
243
|
self.biome_masks.append(mask)
|
|
244
|
+
biome_masks_array.append(mask)
|
|
245
|
+
|
|
246
|
+
# Convert to JAX array for indexing in JIT-compiled code
|
|
247
|
+
self.biome_masks_array = jnp.array(biome_masks_array)
|
|
173
248
|
|
|
174
249
|
# Compute unique colors and mapping for partial observability (for 'color' observation_type)
|
|
175
250
|
# Exclude EMPTY (index 0) from color channels
|
|
@@ -200,46 +275,89 @@ class ForagaxEnv(environment.Environment):
|
|
|
200
275
|
max_steps_in_episode=None,
|
|
201
276
|
)
|
|
202
277
|
|
|
203
|
-
def
|
|
204
|
-
self, grid: jax.Array, y: int, x: int, timer_val: int
|
|
205
|
-
) -> jax.Array:
|
|
206
|
-
"""Place a timer at a specific position."""
|
|
207
|
-
return grid.at[y, x].set(timer_val)
|
|
208
|
-
|
|
209
|
-
def _place_timer_at_random_position(
|
|
278
|
+
def _place_timer(
|
|
210
279
|
self,
|
|
211
|
-
|
|
280
|
+
object_state: ObjectState,
|
|
212
281
|
y: int,
|
|
213
282
|
x: int,
|
|
283
|
+
object_type: int,
|
|
214
284
|
timer_val: int,
|
|
215
|
-
|
|
285
|
+
random_respawn: bool,
|
|
216
286
|
rand_key: jax.Array,
|
|
217
|
-
) ->
|
|
218
|
-
"""Place a timer at
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
287
|
+
) -> ObjectState:
|
|
288
|
+
"""Place a timer at position or randomly within the same biome.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
object_state: Current object state
|
|
292
|
+
y, x: Original position
|
|
293
|
+
object_type: The object type ID that will respawn (0 for permanent removal)
|
|
294
|
+
timer_val: Timer countdown value (0 for permanent removal, positive for countdown)
|
|
295
|
+
random_respawn: If True, place at random location in same biome
|
|
296
|
+
rand_key: Random key for random placement
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Updated object_state with timer placed
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
# Handle permanent removal (timer_val == 0)
|
|
303
|
+
def place_empty():
|
|
304
|
+
return object_state.replace(
|
|
305
|
+
object_id=object_state.object_id.at[y, x].set(0),
|
|
306
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(0),
|
|
307
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
|
|
308
|
+
)
|
|
227
309
|
|
|
228
|
-
|
|
310
|
+
# Handle timer placement
|
|
311
|
+
def place_timer():
|
|
312
|
+
# Non-random: place at original position
|
|
313
|
+
def place_at_position():
|
|
314
|
+
return object_state.replace(
|
|
315
|
+
object_id=object_state.object_id.at[y, x].set(0),
|
|
316
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
|
|
317
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
318
|
+
object_type
|
|
319
|
+
),
|
|
320
|
+
)
|
|
229
321
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
322
|
+
# Random: place at random location in same biome
|
|
323
|
+
def place_randomly():
|
|
324
|
+
# Clear the collected object's position
|
|
325
|
+
new_object_id = object_state.object_id.at[y, x].set(0)
|
|
326
|
+
new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
|
|
327
|
+
new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
|
|
328
|
+
|
|
329
|
+
# Find valid spawn locations in the same biome
|
|
330
|
+
biome_id = object_state.biome_id[y, x]
|
|
331
|
+
biome_mask = object_state.biome_id == biome_id
|
|
332
|
+
empty_mask = new_object_id == 0
|
|
333
|
+
no_timer_mask = new_respawn_timer == 0
|
|
334
|
+
valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
|
|
335
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
336
|
+
|
|
337
|
+
y_indices, x_indices = jnp.nonzero(
|
|
338
|
+
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
339
|
+
)
|
|
340
|
+
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
341
|
+
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
342
|
+
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
343
|
+
|
|
344
|
+
# Place timer at the new random position
|
|
345
|
+
new_respawn_timer = new_respawn_timer.at[
|
|
346
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
347
|
+
].set(timer_val)
|
|
348
|
+
new_respawn_object_id = new_respawn_object_id.at[
|
|
349
|
+
new_spawn_pos[0], new_spawn_pos[1]
|
|
350
|
+
].set(object_type)
|
|
351
|
+
|
|
352
|
+
return object_state.replace(
|
|
353
|
+
object_id=new_object_id,
|
|
354
|
+
respawn_timer=new_respawn_timer,
|
|
355
|
+
respawn_object_id=new_respawn_object_id,
|
|
356
|
+
)
|
|
235
357
|
|
|
236
|
-
|
|
237
|
-
random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
|
|
238
|
-
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
358
|
+
return jax.lax.cond(random_respawn, place_randomly, place_at_position)
|
|
239
359
|
|
|
240
|
-
|
|
241
|
-
new_grid = grid_temp.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
|
|
242
|
-
return new_grid
|
|
360
|
+
return jax.lax.cond(timer_val == 0, place_empty, place_timer)
|
|
243
361
|
|
|
244
362
|
def step_env(
|
|
245
363
|
self,
|
|
@@ -249,9 +367,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
249
367
|
params: EnvParams,
|
|
250
368
|
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
251
369
|
"""Perform single timestep state transition."""
|
|
252
|
-
|
|
253
|
-
# Decode the object grid: positive values are objects, negative are timers (treat as empty)
|
|
254
|
-
current_objects = jnp.maximum(0, state.object_grid)
|
|
370
|
+
current_objects = state.object_state.object_id
|
|
255
371
|
|
|
256
372
|
# 1. UPDATE AGENT POSITION
|
|
257
373
|
direction = DIRECTIONS[action]
|
|
@@ -294,9 +410,12 @@ class ForagaxEnv(environment.Environment):
|
|
|
294
410
|
# Handle digestion: add reward to buffer if collected
|
|
295
411
|
digestion_buffer = state.digestion_buffer
|
|
296
412
|
key, reward_subkey = jax.random.split(key)
|
|
413
|
+
|
|
414
|
+
object_params = state.object_state.state_params[pos[1], pos[0]]
|
|
297
415
|
object_reward = jax.lax.switch(
|
|
298
|
-
obj_at_pos, self.reward_fns, state.time, reward_subkey
|
|
416
|
+
obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
|
|
299
417
|
)
|
|
418
|
+
|
|
300
419
|
key, digestion_subkey = jax.random.split(key)
|
|
301
420
|
reward_delay = jax.lax.switch(
|
|
302
421
|
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
|
@@ -319,136 +438,109 @@ class ForagaxEnv(environment.Environment):
|
|
|
319
438
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
|
320
439
|
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
|
321
440
|
|
|
322
|
-
# Decrement timers
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
441
|
+
# Decrement respawn timers
|
|
442
|
+
has_timer = state.object_state.respawn_timer > 0
|
|
443
|
+
new_respawn_timer = jnp.where(
|
|
444
|
+
has_timer,
|
|
445
|
+
state.object_state.respawn_timer - 1,
|
|
446
|
+
state.object_state.respawn_timer,
|
|
326
447
|
)
|
|
327
448
|
|
|
328
|
-
# Track which
|
|
329
|
-
|
|
330
|
-
|
|
449
|
+
# Track which cells have timers that just reached 0
|
|
450
|
+
just_respawned = has_timer & (new_respawn_timer == 0)
|
|
451
|
+
|
|
452
|
+
# Respawn objects where timer reached 0
|
|
453
|
+
new_object_id = jnp.where(
|
|
454
|
+
just_respawned,
|
|
455
|
+
state.object_state.respawn_object_id,
|
|
456
|
+
state.object_state.object_id,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Clear respawn_object_id for cells that just respawned
|
|
460
|
+
new_respawn_object_id = jnp.where(
|
|
461
|
+
just_respawned,
|
|
462
|
+
0,
|
|
463
|
+
state.object_state.respawn_object_id,
|
|
464
|
+
)
|
|
331
465
|
|
|
332
466
|
# Update spawn times for objects that just respawned
|
|
333
|
-
|
|
334
|
-
just_respawned, state.time, state.
|
|
467
|
+
spawn_time = jnp.where(
|
|
468
|
+
just_respawned, state.time, state.object_state.spawn_time
|
|
335
469
|
)
|
|
336
470
|
|
|
337
471
|
# Collect object: set a timer
|
|
338
472
|
regen_delay = jax.lax.switch(
|
|
339
473
|
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
|
340
474
|
)
|
|
341
|
-
|
|
475
|
+
timer_countdown = jax.lax.cond(
|
|
476
|
+
regen_delay == jnp.iinfo(jnp.int32).max,
|
|
477
|
+
lambda: 0, # No timer (permanent removal)
|
|
478
|
+
lambda: regen_delay + 1, # Timer countdown
|
|
479
|
+
)
|
|
342
480
|
|
|
343
481
|
# If collected, replace object with timer; otherwise, keep it
|
|
344
|
-
val_at_pos =
|
|
345
|
-
|
|
482
|
+
val_at_pos = current_objects[pos[1], pos[0]]
|
|
483
|
+
# Use original should_collect for consumption tracking
|
|
484
|
+
should_collect_now = is_collectable & (val_at_pos > 0)
|
|
485
|
+
|
|
486
|
+
# Create updated object state with new respawn_timer, object_id, and spawn_time
|
|
487
|
+
object_state = state.object_state.replace(
|
|
488
|
+
object_id=new_object_id,
|
|
489
|
+
respawn_timer=new_respawn_timer,
|
|
490
|
+
respawn_object_id=new_respawn_object_id,
|
|
491
|
+
spawn_time=spawn_time,
|
|
492
|
+
)
|
|
346
493
|
|
|
347
|
-
#
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
494
|
+
# Place timer on collection
|
|
495
|
+
object_state = jax.lax.cond(
|
|
496
|
+
should_collect_now,
|
|
497
|
+
lambda: self._place_timer(
|
|
498
|
+
object_state,
|
|
499
|
+
pos[1],
|
|
500
|
+
pos[0],
|
|
501
|
+
obj_at_pos, # object type
|
|
502
|
+
timer_countdown, # timer value
|
|
351
503
|
self.object_random_respawn[obj_at_pos],
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
pos[0],
|
|
356
|
-
encoded_timer,
|
|
357
|
-
state.biome_grid,
|
|
358
|
-
rand_key,
|
|
359
|
-
),
|
|
360
|
-
lambda: self._place_timer_at_position(
|
|
361
|
-
object_grid, pos[1], pos[0], encoded_timer
|
|
362
|
-
),
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
def no_collection():
|
|
366
|
-
return object_grid
|
|
367
|
-
|
|
368
|
-
object_grid = jax.lax.cond(
|
|
369
|
-
should_collect,
|
|
370
|
-
do_collection,
|
|
371
|
-
no_collection,
|
|
504
|
+
rand_key,
|
|
505
|
+
),
|
|
506
|
+
lambda: object_state,
|
|
372
507
|
)
|
|
373
508
|
|
|
374
509
|
# 3.5. HANDLE OBJECT EXPIRY
|
|
375
|
-
#
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
#
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
key, expiry_key = jax.random.split(key)
|
|
393
|
-
|
|
394
|
-
# Process expiry for all positions that need it
|
|
395
|
-
def process_expiry(y, x, grid, spawn_grid, key):
|
|
396
|
-
obj_id = current_objects_for_expiry[y, x]
|
|
397
|
-
should_exp = should_expire[y, x]
|
|
398
|
-
|
|
399
|
-
def expire_object():
|
|
400
|
-
# Get expiry regen delay for this object
|
|
401
|
-
exp_key = jax.random.fold_in(key, y * self.size[0] + x)
|
|
402
|
-
exp_delay = jax.lax.switch(
|
|
403
|
-
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
404
|
-
)
|
|
405
|
-
encoded_exp_timer = obj_id - ((exp_delay + 1) * num_obj_types)
|
|
406
|
-
|
|
407
|
-
# Check if this object should respawn randomly
|
|
408
|
-
should_random_respawn = self.object_random_respawn[obj_id]
|
|
409
|
-
|
|
410
|
-
# Use second split for randomness in random placement
|
|
411
|
-
rand_key = jax.random.split(exp_key)[1]
|
|
412
|
-
|
|
413
|
-
# Place timer either at current position or random position
|
|
414
|
-
new_grid = jax.lax.cond(
|
|
415
|
-
should_random_respawn,
|
|
416
|
-
lambda: self._place_timer_at_random_position(
|
|
417
|
-
grid, y, x, encoded_exp_timer, state.biome_grid, rand_key
|
|
418
|
-
),
|
|
419
|
-
lambda: self._place_timer_at_position(
|
|
420
|
-
grid, y, x, encoded_exp_timer
|
|
421
|
-
),
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
return new_grid, spawn_grid
|
|
425
|
-
|
|
426
|
-
def no_expire():
|
|
427
|
-
return grid, spawn_grid
|
|
428
|
-
|
|
429
|
-
return jax.lax.cond(should_exp, expire_object, no_expire)
|
|
430
|
-
|
|
431
|
-
# Apply expiry to all cells (vectorized)
|
|
432
|
-
def scan_expiry_row(carry, y):
|
|
433
|
-
grid, spawn_grid, key = carry
|
|
434
|
-
|
|
435
|
-
def scan_expiry_col(carry_col, x):
|
|
436
|
-
grid_col, spawn_grid_col, key_col = carry_col
|
|
437
|
-
grid_col, spawn_grid_col = process_expiry(
|
|
438
|
-
y, x, grid_col, spawn_grid_col, key_col
|
|
439
|
-
)
|
|
440
|
-
return (grid_col, spawn_grid_col, key_col), None
|
|
441
|
-
|
|
442
|
-
(grid, spawn_grid, key), _ = jax.lax.scan(
|
|
443
|
-
scan_expiry_col, (grid, spawn_grid, key), jnp.arange(self.size[0])
|
|
510
|
+
# Only process expiry if there are objects that can expire
|
|
511
|
+
key, object_state = self.expire_objects(key, state, object_state)
|
|
512
|
+
|
|
513
|
+
# 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
|
|
514
|
+
if self.dynamic_biomes:
|
|
515
|
+
# Update consumption count if an object was collected
|
|
516
|
+
# Only count if the object belongs to the current generation of its biome
|
|
517
|
+
collected_biome_id = object_state.biome_id[pos[1], pos[0]]
|
|
518
|
+
object_gen_at_pos = object_state.generation[pos[1], pos[0]]
|
|
519
|
+
current_biome_gen = state.biome_state.generation[collected_biome_id]
|
|
520
|
+
is_current_generation = object_gen_at_pos == current_biome_gen
|
|
521
|
+
|
|
522
|
+
biome_consumption_count = state.biome_state.consumption_count
|
|
523
|
+
biome_consumption_count = jax.lax.cond(
|
|
524
|
+
should_collect & is_current_generation,
|
|
525
|
+
lambda: biome_consumption_count.at[collected_biome_id].add(1),
|
|
526
|
+
lambda: biome_consumption_count,
|
|
444
527
|
)
|
|
445
|
-
return (grid, spawn_grid, key), None
|
|
446
528
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
529
|
+
# Check each biome for threshold crossing and respawn if needed
|
|
530
|
+
key, respawn_key = jax.random.split(key)
|
|
531
|
+
biome_state = BiomeState(
|
|
532
|
+
consumption_count=biome_consumption_count,
|
|
533
|
+
total_objects=state.biome_state.total_objects,
|
|
534
|
+
generation=state.biome_state.generation,
|
|
535
|
+
)
|
|
536
|
+
object_state, biome_state, respawn_key = self._check_and_respawn_biomes(
|
|
537
|
+
object_state,
|
|
538
|
+
biome_state,
|
|
539
|
+
state.time,
|
|
540
|
+
respawn_key,
|
|
541
|
+
)
|
|
542
|
+
else:
|
|
543
|
+
biome_state = state.biome_state
|
|
452
544
|
|
|
453
545
|
info = {"discount": self.discount(state, params)}
|
|
454
546
|
temperatures = jnp.zeros(len(self.objects))
|
|
@@ -458,17 +550,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
458
550
|
get_temperature(obj.rewards, state.time, obj.repeat)
|
|
459
551
|
)
|
|
460
552
|
info["temperatures"] = temperatures
|
|
461
|
-
info["biome_id"] =
|
|
553
|
+
info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
|
|
462
554
|
info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
|
|
463
555
|
|
|
464
556
|
# 4. UPDATE STATE
|
|
465
557
|
state = EnvState(
|
|
466
558
|
pos=pos,
|
|
467
|
-
object_grid=object_grid,
|
|
468
|
-
biome_grid=state.biome_grid,
|
|
469
559
|
time=state.time + 1,
|
|
470
560
|
digestion_buffer=digestion_buffer,
|
|
471
|
-
|
|
561
|
+
object_state=object_state,
|
|
562
|
+
biome_state=biome_state,
|
|
472
563
|
)
|
|
473
564
|
|
|
474
565
|
done = self.is_terminal(state, params)
|
|
@@ -480,60 +571,395 @@ class ForagaxEnv(environment.Environment):
|
|
|
480
571
|
info,
|
|
481
572
|
)
|
|
482
573
|
|
|
574
|
+
def expire_objects(
|
|
575
|
+
self, key, state, object_state: ObjectState
|
|
576
|
+
) -> Tuple[jax.Array, ObjectState]:
|
|
577
|
+
if self.has_expiring_objects:
|
|
578
|
+
# Check each cell for objects that have exceeded their expiry time
|
|
579
|
+
current_objects_for_expiry = object_state.object_id
|
|
580
|
+
|
|
581
|
+
# Calculate age of each object (current_time - spawn_time)
|
|
582
|
+
object_ages = state.time - object_state.spawn_time
|
|
583
|
+
|
|
584
|
+
# Get expiry time for each object type in the grid
|
|
585
|
+
expiry_times = self.object_expiry_time[current_objects_for_expiry]
|
|
586
|
+
|
|
587
|
+
# Check if object should expire (age >= expiry_time and expiry_time >= 0)
|
|
588
|
+
should_expire = (
|
|
589
|
+
(object_ages >= expiry_times)
|
|
590
|
+
& (expiry_times >= 0)
|
|
591
|
+
& (current_objects_for_expiry > 0)
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Count how many objects actually need to expire
|
|
595
|
+
num_expiring = jnp.sum(should_expire)
|
|
596
|
+
|
|
597
|
+
# Only process expiry if there are actually objects to expire
|
|
598
|
+
def process_expiries():
|
|
599
|
+
# Get positions of objects that should expire
|
|
600
|
+
# Use nonzero with fixed size to maintain JIT compatibility
|
|
601
|
+
max_objects = self.size[0] * self.size[1]
|
|
602
|
+
y_indices, x_indices = jnp.nonzero(
|
|
603
|
+
should_expire, size=max_objects, fill_value=-1
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
key_local, expiry_key = jax.random.split(key)
|
|
607
|
+
|
|
608
|
+
def process_one_expiry(carry, i):
|
|
609
|
+
obj_state = carry
|
|
610
|
+
y = y_indices[i]
|
|
611
|
+
x = x_indices[i]
|
|
612
|
+
|
|
613
|
+
# Skip if this is a padding index (from fill_value)
|
|
614
|
+
is_valid = (y >= 0) & (x >= 0)
|
|
615
|
+
|
|
616
|
+
def expire_one():
|
|
617
|
+
obj_id = current_objects_for_expiry[y, x]
|
|
618
|
+
exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
|
|
619
|
+
exp_delay = jax.lax.switch(
|
|
620
|
+
obj_id, self.expiry_regen_delay_fns, state.time, exp_key
|
|
621
|
+
)
|
|
622
|
+
timer_countdown = jax.lax.cond(
|
|
623
|
+
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
624
|
+
lambda: 0, # No timer (permanent removal)
|
|
625
|
+
lambda: exp_delay + 1, # Timer countdown
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
# Use unified timer placement method
|
|
629
|
+
rand_key = jax.random.split(exp_key)[1]
|
|
630
|
+
new_obj_state = self._place_timer(
|
|
631
|
+
obj_state,
|
|
632
|
+
y,
|
|
633
|
+
x,
|
|
634
|
+
obj_id,
|
|
635
|
+
timer_countdown,
|
|
636
|
+
self.object_random_respawn[obj_id],
|
|
637
|
+
rand_key,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
return new_obj_state
|
|
641
|
+
|
|
642
|
+
def no_op():
|
|
643
|
+
return obj_state
|
|
644
|
+
|
|
645
|
+
return jax.lax.cond(is_valid, expire_one, no_op), None
|
|
646
|
+
|
|
647
|
+
new_object_state, _ = jax.lax.scan(
|
|
648
|
+
process_one_expiry,
|
|
649
|
+
object_state,
|
|
650
|
+
jnp.arange(max_objects),
|
|
651
|
+
)
|
|
652
|
+
return key_local, new_object_state
|
|
653
|
+
|
|
654
|
+
def no_expiries():
|
|
655
|
+
return key, object_state
|
|
656
|
+
|
|
657
|
+
key, object_state = jax.lax.cond(
|
|
658
|
+
num_expiring > 0,
|
|
659
|
+
process_expiries,
|
|
660
|
+
no_expiries,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
return key, object_state
|
|
664
|
+
|
|
665
|
+
def _check_and_respawn_biomes(
|
|
666
|
+
self,
|
|
667
|
+
object_state: ObjectState,
|
|
668
|
+
biome_state: BiomeState,
|
|
669
|
+
current_time: int,
|
|
670
|
+
key: jax.Array,
|
|
671
|
+
) -> Tuple[ObjectState, BiomeState, jax.Array]:
|
|
672
|
+
"""Check all biomes for consumption threshold and respawn if needed."""
|
|
673
|
+
|
|
674
|
+
num_biomes = self.biome_object_frequencies.shape[0]
|
|
675
|
+
|
|
676
|
+
# Compute consumption rates for all biomes
|
|
677
|
+
consumption_rates = biome_state.consumption_count / jnp.maximum(
|
|
678
|
+
1.0, biome_state.total_objects.astype(float)
|
|
679
|
+
)
|
|
680
|
+
should_respawn = consumption_rates >= self.biome_consumption_threshold
|
|
681
|
+
|
|
682
|
+
# Split key for all biomes in parallel
|
|
683
|
+
key, subkey = jax.random.split(key)
|
|
684
|
+
biome_keys = jax.random.split(subkey, num_biomes)
|
|
685
|
+
|
|
686
|
+
# Compute all new spawns in parallel using vmap for random, switch for deterministic
|
|
687
|
+
if self.deterministic_spawn:
|
|
688
|
+
# Use switch to dispatch to concrete biome spawns for deterministic
|
|
689
|
+
def make_spawn_fn(biome_idx):
|
|
690
|
+
def spawn_fn(key):
|
|
691
|
+
return self._spawn_biome_objects(biome_idx, key, deterministic=True)
|
|
692
|
+
|
|
693
|
+
return spawn_fn
|
|
694
|
+
|
|
695
|
+
spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
|
|
696
|
+
|
|
697
|
+
# Apply switch for each biome
|
|
698
|
+
all_new_objects_list = []
|
|
699
|
+
all_new_colors_list = []
|
|
700
|
+
all_new_params_list = []
|
|
701
|
+
for i in range(num_biomes):
|
|
702
|
+
obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
|
|
703
|
+
all_new_objects_list.append(obj)
|
|
704
|
+
all_new_colors_list.append(col)
|
|
705
|
+
all_new_params_list.append(par)
|
|
706
|
+
|
|
707
|
+
all_new_objects = jnp.stack(all_new_objects_list)
|
|
708
|
+
all_new_colors = jnp.stack(all_new_colors_list)
|
|
709
|
+
all_new_params = jnp.stack(all_new_params_list)
|
|
710
|
+
else:
|
|
711
|
+
# Random spawn works with vmap
|
|
712
|
+
all_new_objects, all_new_colors, all_new_params = jax.vmap(
|
|
713
|
+
lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
|
|
714
|
+
)(jnp.arange(num_biomes), biome_keys)
|
|
715
|
+
|
|
716
|
+
# Initialize updated grids
|
|
717
|
+
new_obj_id = object_state.object_id
|
|
718
|
+
new_color = object_state.color
|
|
719
|
+
new_params = object_state.state_params
|
|
720
|
+
new_spawn = object_state.spawn_time
|
|
721
|
+
new_gen = object_state.generation
|
|
722
|
+
|
|
723
|
+
# Update biome state
|
|
724
|
+
new_consumption_count = jnp.where(
|
|
725
|
+
should_respawn, 0, biome_state.consumption_count
|
|
726
|
+
)
|
|
727
|
+
new_generation = biome_state.generation + should_respawn.astype(int)
|
|
728
|
+
|
|
729
|
+
# Compute new total objects for respawning biomes
|
|
730
|
+
def count_objects(i):
|
|
731
|
+
return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
|
|
732
|
+
|
|
733
|
+
new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
|
|
734
|
+
new_total_objects = jnp.where(
|
|
735
|
+
should_respawn, new_object_counts, biome_state.total_objects
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
new_biome_state = BiomeState(
|
|
739
|
+
consumption_count=new_consumption_count,
|
|
740
|
+
total_objects=new_total_objects,
|
|
741
|
+
generation=new_generation,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Update grids for respawning biomes
|
|
745
|
+
for i in range(num_biomes):
|
|
746
|
+
biome_mask = self.biome_masks_array[i]
|
|
747
|
+
new_gen_value = new_biome_state.generation[i]
|
|
748
|
+
|
|
749
|
+
# Only update where new spawn has objects and biome should respawn
|
|
750
|
+
is_new_object = (
|
|
751
|
+
(all_new_objects[i] > 0) & biome_mask & should_respawn[i][..., None]
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
new_obj_id = jnp.where(is_new_object, all_new_objects[i], new_obj_id)
|
|
755
|
+
new_color = jnp.where(
|
|
756
|
+
is_new_object[..., None], all_new_colors[i], new_color
|
|
757
|
+
)
|
|
758
|
+
new_params = jnp.where(
|
|
759
|
+
is_new_object[..., None], all_new_params[i], new_params
|
|
760
|
+
)
|
|
761
|
+
new_gen = jnp.where(is_new_object, new_gen_value, new_gen)
|
|
762
|
+
new_spawn = jnp.where(is_new_object, current_time, new_spawn)
|
|
763
|
+
|
|
764
|
+
# Clear timers in respawning biomes
|
|
765
|
+
new_respawn_timer = object_state.respawn_timer
|
|
766
|
+
new_respawn_object_id = object_state.respawn_object_id
|
|
767
|
+
for i in range(num_biomes):
|
|
768
|
+
biome_mask = self.biome_masks_array[i]
|
|
769
|
+
should_clear = biome_mask & should_respawn[i][..., None]
|
|
770
|
+
new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
|
|
771
|
+
new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
|
|
772
|
+
|
|
773
|
+
object_state = object_state.replace(
|
|
774
|
+
object_id=new_obj_id,
|
|
775
|
+
respawn_timer=new_respawn_timer,
|
|
776
|
+
respawn_object_id=new_respawn_object_id,
|
|
777
|
+
color=new_color,
|
|
778
|
+
state_params=new_params,
|
|
779
|
+
generation=new_gen,
|
|
780
|
+
spawn_time=new_spawn,
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
return object_state, new_biome_state, key
|
|
784
|
+
|
|
483
785
|
def reset_env(
|
|
484
786
|
self, key: jax.Array, params: EnvParams
|
|
485
787
|
) -> Tuple[jax.Array, EnvState]:
|
|
486
788
|
"""Reset environment state."""
|
|
487
|
-
|
|
488
|
-
|
|
789
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
790
|
+
object_state = ObjectState.create_empty(self.size, num_object_params)
|
|
791
|
+
|
|
489
792
|
key, iter_key = jax.random.split(key)
|
|
793
|
+
|
|
794
|
+
# Spawn objects in each biome using unified method
|
|
490
795
|
for i in range(self.biome_object_frequencies.shape[0]):
|
|
491
796
|
iter_key, biome_key = jax.random.split(iter_key)
|
|
492
797
|
mask = self.biome_masks[i]
|
|
493
|
-
biome_grid = jnp.where(mask, i, biome_grid)
|
|
494
798
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
799
|
+
# Set biome_id
|
|
800
|
+
object_state = object_state.replace(
|
|
801
|
+
biome_id=jnp.where(mask, i, object_state.biome_id)
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
# Use unified spawn method
|
|
805
|
+
biome_objects, biome_colors, biome_object_params = (
|
|
806
|
+
self._spawn_biome_objects(i, biome_key, self.deterministic_spawn)
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
# Merge biome objects/colors/params into object_state
|
|
810
|
+
object_state = object_state.replace(
|
|
811
|
+
object_id=jnp.where(mask, biome_objects, object_state.object_id),
|
|
812
|
+
color=jnp.where(mask[..., None], biome_colors, object_state.color),
|
|
813
|
+
state_params=jnp.where(
|
|
814
|
+
mask[..., None], biome_object_params, object_state.state_params
|
|
815
|
+
),
|
|
816
|
+
)
|
|
501
817
|
|
|
502
818
|
# Place agent in the center of the world
|
|
503
819
|
agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
|
|
504
820
|
|
|
505
|
-
# Initialize
|
|
506
|
-
|
|
821
|
+
# Initialize biome consumption tracking
|
|
822
|
+
num_biomes = self.biome_object_frequencies.shape[0]
|
|
823
|
+
biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
|
|
824
|
+
biome_total_objects = jnp.zeros(num_biomes, dtype=int)
|
|
825
|
+
|
|
826
|
+
# Count objects in each biome
|
|
827
|
+
for i in range(num_biomes):
|
|
828
|
+
mask = self.biome_masks[i]
|
|
829
|
+
# Count non-empty objects (object_id > 0)
|
|
830
|
+
total = jnp.sum((object_state.object_id > 0) & mask)
|
|
831
|
+
biome_total_objects = biome_total_objects.at[i].set(total)
|
|
832
|
+
|
|
833
|
+
biome_generation = jnp.zeros(num_biomes, dtype=int)
|
|
507
834
|
|
|
508
835
|
state = EnvState(
|
|
509
836
|
pos=agent_pos,
|
|
510
|
-
object_grid=object_grid,
|
|
511
|
-
biome_grid=biome_grid,
|
|
512
837
|
time=0,
|
|
513
838
|
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
|
514
|
-
|
|
839
|
+
object_state=object_state,
|
|
840
|
+
biome_state=BiomeState(
|
|
841
|
+
consumption_count=biome_consumption_count,
|
|
842
|
+
total_objects=biome_total_objects,
|
|
843
|
+
generation=biome_generation,
|
|
844
|
+
),
|
|
515
845
|
)
|
|
516
846
|
|
|
517
847
|
return self.get_obs(state, params), state
|
|
518
848
|
|
|
519
|
-
def
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
849
|
+
def _spawn_biome_objects(
|
|
850
|
+
self,
|
|
851
|
+
biome_idx: int,
|
|
852
|
+
key: jax.Array,
|
|
853
|
+
deterministic: bool = False,
|
|
854
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
855
|
+
"""Spawn objects in a biome.
|
|
856
|
+
|
|
857
|
+
Returns:
|
|
858
|
+
object_grid: (H, W) array of object IDs
|
|
859
|
+
color_grid: (H, W, 3) array of RGB colors
|
|
860
|
+
state_grid: (H, W, num_state_params) array of object state parameters
|
|
861
|
+
"""
|
|
862
|
+
biome_freqs = self.biome_object_frequencies[biome_idx]
|
|
863
|
+
biome_mask = self.biome_masks_array[biome_idx]
|
|
864
|
+
|
|
865
|
+
key, spawn_key, color_key, params_key = jax.random.split(key, 4)
|
|
866
|
+
|
|
867
|
+
# Generate object IDs using deterministic or random spawn
|
|
868
|
+
if deterministic:
|
|
869
|
+
# Deterministic spawn: exact number of each object type
|
|
870
|
+
# NOTE: Requires concrete biome_idx to compute size at trace time
|
|
871
|
+
# Get static biome bounds
|
|
872
|
+
biome_start = self.biome_starts[biome_idx]
|
|
873
|
+
biome_stop = self.biome_stops[biome_idx]
|
|
874
|
+
biome_height = biome_stop[1] - biome_start[1]
|
|
875
|
+
biome_width = biome_stop[0] - biome_start[0]
|
|
876
|
+
biome_size = int(self.biome_sizes[biome_idx])
|
|
877
|
+
|
|
878
|
+
grid = jnp.linspace(0, 1, biome_size, endpoint=False)
|
|
879
|
+
biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
|
|
880
|
+
jnp.cumsum(biome_freqs[::-1]), grid, side="right"
|
|
881
|
+
)
|
|
882
|
+
biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
|
|
883
|
+
|
|
884
|
+
# Reshape to match biome dimensions (use concrete dimensions)
|
|
885
|
+
biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
|
|
886
|
+
|
|
887
|
+
# Place in full grid using slicing with static bounds
|
|
888
|
+
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
889
|
+
object_grid = object_grid.at[
|
|
890
|
+
biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
|
|
891
|
+
].set(biome_objects)
|
|
892
|
+
else:
|
|
893
|
+
# Random spawn: probabilistic placement (works with traced biome_idx)
|
|
894
|
+
grid_rand = jax.random.uniform(spawn_key, (self.size[1], self.size[0]))
|
|
895
|
+
empty_freq = 1.0 - jnp.sum(biome_freqs)
|
|
896
|
+
all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
|
|
897
|
+
cumulative_freqs = jnp.cumsum(
|
|
898
|
+
jnp.concatenate([jnp.array([0.0]), all_freqs])
|
|
899
|
+
)
|
|
900
|
+
object_grid = (
|
|
901
|
+
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
|
902
|
+
)
|
|
903
|
+
|
|
904
|
+
# Initialize color grid
|
|
905
|
+
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
|
|
906
|
+
|
|
907
|
+
# Sample ONE color per object type in this biome (not per instance)
|
|
908
|
+
# This gives objects of the same type the same color within a biome generation
|
|
909
|
+
# Skip index 0 (EMPTY object) - only sample colors for actual objects
|
|
910
|
+
num_object_types = len(self.objects)
|
|
911
|
+
num_actual_objects = num_object_types - 1 # Exclude EMPTY
|
|
912
|
+
|
|
913
|
+
if num_actual_objects > 0:
|
|
914
|
+
biome_object_colors = jax.random.randint(
|
|
915
|
+
color_key,
|
|
916
|
+
(num_actual_objects, 3),
|
|
917
|
+
minval=0,
|
|
918
|
+
maxval=256,
|
|
919
|
+
dtype=jnp.uint8,
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
# Assign colors based on object type (starting from index 1)
|
|
923
|
+
for obj_idx in range(1, num_object_types):
|
|
924
|
+
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
925
|
+
obj_color = biome_object_colors[
|
|
926
|
+
obj_idx - 1
|
|
927
|
+
] # Offset by 1 since we skip EMPTY
|
|
928
|
+
color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
|
|
929
|
+
|
|
930
|
+
# Initialize parameters grid
|
|
931
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
932
|
+
params_grid = jnp.zeros(
|
|
933
|
+
(self.size[1], self.size[0], num_object_params), dtype=jnp.float32
|
|
533
934
|
)
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
935
|
+
|
|
936
|
+
# Generate per-object parameters for each object type
|
|
937
|
+
for obj_idx in range(num_object_types):
|
|
938
|
+
# Get params for this object type - this happens at trace time
|
|
939
|
+
params_key, obj_key = jax.random.split(params_key)
|
|
940
|
+
obj_params = self.objects[obj_idx].get_state(obj_key)
|
|
941
|
+
|
|
942
|
+
# Skip if no params (e.g., for EMPTY or default objects)
|
|
943
|
+
if len(obj_params) == 0:
|
|
944
|
+
continue
|
|
945
|
+
|
|
946
|
+
# Ensure params match expected size
|
|
947
|
+
if len(obj_params) != num_object_params:
|
|
948
|
+
if len(obj_params) < num_object_params:
|
|
949
|
+
obj_params = jnp.pad(
|
|
950
|
+
obj_params,
|
|
951
|
+
(0, num_object_params - len(obj_params)),
|
|
952
|
+
constant_values=0.0,
|
|
953
|
+
)
|
|
954
|
+
else:
|
|
955
|
+
# Truncate if too long
|
|
956
|
+
obj_params = obj_params[:num_object_params]
|
|
957
|
+
|
|
958
|
+
# Assign to all objects of this type in this biome
|
|
959
|
+
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
960
|
+
params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
|
|
961
|
+
|
|
962
|
+
return object_grid, color_grid, params_grid
|
|
537
963
|
|
|
538
964
|
def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
|
|
539
965
|
"""Foragax is a continuing environment."""
|
|
@@ -556,21 +982,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
556
982
|
|
|
557
983
|
def state_space(self, params: EnvParams) -> spaces.Dict:
|
|
558
984
|
"""State space of the environment."""
|
|
985
|
+
num_object_params = 2 + 2 * self.num_fourier_terms
|
|
559
986
|
return spaces.Dict(
|
|
560
987
|
{
|
|
561
988
|
"pos": spaces.Box(0, max(self.size), (2,), int),
|
|
562
|
-
"object_grid": spaces.Box(
|
|
563
|
-
-1000 * len(self.object_ids),
|
|
564
|
-
len(self.object_ids),
|
|
565
|
-
(self.size[1], self.size[0]),
|
|
566
|
-
int,
|
|
567
|
-
),
|
|
568
|
-
"biome_grid": spaces.Box(
|
|
569
|
-
0,
|
|
570
|
-
self.biome_object_frequencies.shape[0],
|
|
571
|
-
(self.size[1], self.size[0]),
|
|
572
|
-
int,
|
|
573
|
-
),
|
|
574
989
|
"time": spaces.Discrete(params.max_steps_in_episode),
|
|
575
990
|
"digestion_buffer": spaces.Box(
|
|
576
991
|
-jnp.inf,
|
|
@@ -578,17 +993,79 @@ class ForagaxEnv(environment.Environment):
|
|
|
578
993
|
(self.max_reward_delay,),
|
|
579
994
|
float,
|
|
580
995
|
),
|
|
581
|
-
"
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
996
|
+
"object_state": spaces.Dict(
|
|
997
|
+
{
|
|
998
|
+
"object_id": spaces.Box(
|
|
999
|
+
-1000 * len(self.object_ids),
|
|
1000
|
+
len(self.object_ids),
|
|
1001
|
+
(self.size[1], self.size[0]),
|
|
1002
|
+
int,
|
|
1003
|
+
),
|
|
1004
|
+
"spawn_time": spaces.Box(
|
|
1005
|
+
0,
|
|
1006
|
+
jnp.inf,
|
|
1007
|
+
(self.size[1], self.size[0]),
|
|
1008
|
+
int,
|
|
1009
|
+
),
|
|
1010
|
+
"color": spaces.Box(
|
|
1011
|
+
0,
|
|
1012
|
+
255,
|
|
1013
|
+
(self.size[1], self.size[0], 3),
|
|
1014
|
+
int,
|
|
1015
|
+
),
|
|
1016
|
+
"generation": spaces.Box(
|
|
1017
|
+
0,
|
|
1018
|
+
jnp.inf,
|
|
1019
|
+
(self.size[1], self.size[0]),
|
|
1020
|
+
int,
|
|
1021
|
+
),
|
|
1022
|
+
"state_params": spaces.Box(
|
|
1023
|
+
-jnp.inf,
|
|
1024
|
+
jnp.inf,
|
|
1025
|
+
(self.size[1], self.size[0], num_object_params),
|
|
1026
|
+
float,
|
|
1027
|
+
),
|
|
1028
|
+
"biome_id": spaces.Box(
|
|
1029
|
+
-1,
|
|
1030
|
+
self.biome_object_frequencies.shape[0],
|
|
1031
|
+
(self.size[1], self.size[0]),
|
|
1032
|
+
int,
|
|
1033
|
+
),
|
|
1034
|
+
}
|
|
1035
|
+
),
|
|
1036
|
+
"biome_state": spaces.Dict(
|
|
1037
|
+
{
|
|
1038
|
+
"consumption_count": spaces.Box(
|
|
1039
|
+
0,
|
|
1040
|
+
jnp.inf,
|
|
1041
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1042
|
+
int,
|
|
1043
|
+
),
|
|
1044
|
+
"total_objects": spaces.Box(
|
|
1045
|
+
0,
|
|
1046
|
+
jnp.inf,
|
|
1047
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1048
|
+
int,
|
|
1049
|
+
),
|
|
1050
|
+
"generation": spaces.Box(
|
|
1051
|
+
0,
|
|
1052
|
+
jnp.inf,
|
|
1053
|
+
(self.biome_object_frequencies.shape[0],),
|
|
1054
|
+
int,
|
|
1055
|
+
),
|
|
1056
|
+
}
|
|
586
1057
|
),
|
|
587
1058
|
}
|
|
588
1059
|
)
|
|
589
1060
|
|
|
590
|
-
def
|
|
591
|
-
|
|
1061
|
+
def _compute_aperture_coordinates(
|
|
1062
|
+
self, pos: jax.Array
|
|
1063
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
|
1064
|
+
"""Compute aperture coordinates for the given position.
|
|
1065
|
+
|
|
1066
|
+
Returns:
|
|
1067
|
+
(y_coords, x_coords, y_coords_clamped/mod, x_coords_clamped/mod)
|
|
1068
|
+
"""
|
|
592
1069
|
ap_h, ap_w = self.aperture_size
|
|
593
1070
|
start_y = pos[1] - ap_h // 2
|
|
594
1071
|
start_x = pos[0] - ap_w // 2
|
|
@@ -599,33 +1076,53 @@ class ForagaxEnv(environment.Environment):
|
|
|
599
1076
|
x_coords = start_x + x_offsets
|
|
600
1077
|
|
|
601
1078
|
if self.nowrap:
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
1079
|
+
y_coords_adj = jnp.clip(y_coords, 0, self.size[1] - 1)
|
|
1080
|
+
x_coords_adj = jnp.clip(x_coords, 0, self.size[0] - 1)
|
|
1081
|
+
else:
|
|
1082
|
+
y_coords_adj = jnp.mod(y_coords, self.size[1])
|
|
1083
|
+
x_coords_adj = jnp.mod(x_coords, self.size[0])
|
|
1084
|
+
|
|
1085
|
+
return y_coords, x_coords, y_coords_adj, x_coords_adj
|
|
1086
|
+
|
|
1087
|
+
def _get_aperture(self, object_id_grid: jax.Array, pos: jax.Array) -> jax.Array:
|
|
1088
|
+
"""Extract the aperture view from the object id grid."""
|
|
1089
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1090
|
+
self._compute_aperture_coordinates(pos)
|
|
1091
|
+
)
|
|
1092
|
+
|
|
1093
|
+
values = object_id_grid[y_coords_adj, x_coords_adj]
|
|
1094
|
+
|
|
1095
|
+
if self.nowrap:
|
|
1096
|
+
# Mark out-of-bounds positions with padding
|
|
607
1097
|
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
608
1098
|
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
609
1099
|
out_of_bounds = y_out | x_out
|
|
610
|
-
|
|
611
|
-
|
|
1100
|
+
|
|
1101
|
+
# Handle both object_id grids (2D) and color grids (3D)
|
|
1102
|
+
if len(values.shape) == 3:
|
|
1103
|
+
# Color grid: use PADDING color (0, 0, 0)
|
|
1104
|
+
padding_value = jnp.array([0, 0, 0], dtype=values.dtype)
|
|
1105
|
+
aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
|
|
1106
|
+
else:
|
|
1107
|
+
# Object ID grid: use PADDING index
|
|
1108
|
+
padding_index = self.object_ids[-1]
|
|
1109
|
+
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
612
1110
|
else:
|
|
613
|
-
|
|
614
|
-
x_coords_mod = jnp.mod(x_coords, self.size[0])
|
|
615
|
-
aperture = object_grid[y_coords_mod, x_coords_mod]
|
|
1111
|
+
aperture = values
|
|
616
1112
|
|
|
617
1113
|
return aperture
|
|
618
1114
|
|
|
619
1115
|
def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
|
|
620
1116
|
"""Get observation based on observation_type and full_world."""
|
|
621
|
-
|
|
622
|
-
|
|
1117
|
+
obs_grid = state.object_state.object_id
|
|
1118
|
+
color_grid = state.object_state.color
|
|
623
1119
|
|
|
624
1120
|
if self.full_world:
|
|
625
1121
|
return self._get_world_obs(obs_grid, state)
|
|
626
1122
|
else:
|
|
627
1123
|
grid = self._get_aperture(obs_grid, state.pos)
|
|
628
|
-
|
|
1124
|
+
color_grid = self._get_aperture(color_grid, state.pos)
|
|
1125
|
+
return self._get_aperture_obs(grid, color_grid, state)
|
|
629
1126
|
|
|
630
1127
|
def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
|
|
631
1128
|
"""Get world observation."""
|
|
@@ -642,12 +1139,14 @@ class ForagaxEnv(environment.Environment):
|
|
|
642
1139
|
obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
|
|
643
1140
|
return obs
|
|
644
1141
|
elif self.observation_type == "rgb":
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
1142
|
+
# Use state colors directly (supports dynamic biomes)
|
|
1143
|
+
colors = state.object_state.color / 255.0
|
|
1144
|
+
|
|
1145
|
+
# Mask empty cells (object_id == 0) to white
|
|
1146
|
+
empty_mask = obs_grid == 0
|
|
1147
|
+
white_color = jnp.ones((self.size[1], self.size[0], 3), dtype=jnp.float32)
|
|
1148
|
+
obs = jnp.where(empty_mask[..., None], white_color, colors)
|
|
1149
|
+
|
|
651
1150
|
return obs
|
|
652
1151
|
elif self.observation_type == "color":
|
|
653
1152
|
# Handle case with no objects (only EMPTY)
|
|
@@ -664,17 +1163,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
664
1163
|
else:
|
|
665
1164
|
raise ValueError(f"Unknown observation_type: {self.observation_type}")
|
|
666
1165
|
|
|
667
|
-
def _get_aperture_obs(
|
|
1166
|
+
def _get_aperture_obs(
|
|
1167
|
+
self, aperture: jax.Array, color_aperture: jax.Array, state: EnvState
|
|
1168
|
+
) -> jax.Array:
|
|
668
1169
|
"""Get aperture observation."""
|
|
669
1170
|
if self.observation_type == "object":
|
|
670
1171
|
num_obj_types = len(self.object_ids)
|
|
671
1172
|
obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
|
|
672
1173
|
return obs
|
|
673
1174
|
elif self.observation_type == "rgb":
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
1175
|
+
# Use the color aperture that was passed in
|
|
1176
|
+
aperture_colors = color_aperture / 255.0
|
|
1177
|
+
|
|
1178
|
+
# Mask empty cells (object_id == 0) to white
|
|
1179
|
+
empty_mask = aperture == 0
|
|
1180
|
+
white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
|
|
1181
|
+
|
|
1182
|
+
obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
|
|
1183
|
+
|
|
678
1184
|
return obs
|
|
679
1185
|
elif self.observation_type == "color":
|
|
680
1186
|
# Handle case with no objects (only EMPTY)
|
|
@@ -721,48 +1227,47 @@ class ForagaxEnv(environment.Environment):
|
|
|
721
1227
|
|
|
722
1228
|
if is_world_mode:
|
|
723
1229
|
# Create an RGB image from the object grid
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
1230
|
+
# Use stateful object colors if dynamic_biomes is enabled, else use default colors
|
|
1231
|
+
if self.dynamic_biomes:
|
|
1232
|
+
# Use per-instance colors from state
|
|
1233
|
+
img = state.object_state.color.copy()
|
|
1234
|
+
# Mask empty cells (object_id == 0) to white
|
|
1235
|
+
empty_mask = state.object_state.object_id == 0
|
|
1236
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1237
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1238
|
+
else:
|
|
1239
|
+
# Use default object colors
|
|
1240
|
+
img = jnp.zeros((self.size[1], self.size[0], 3))
|
|
1241
|
+
render_grid = state.object_state.object_id
|
|
727
1242
|
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
1243
|
+
def update_image(i, img):
|
|
1244
|
+
color = self.object_colors[i]
|
|
1245
|
+
mask = render_grid == i
|
|
1246
|
+
img = jnp.where(mask[..., None], color, img)
|
|
1247
|
+
return img
|
|
733
1248
|
|
|
734
|
-
|
|
1249
|
+
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
735
1250
|
|
|
736
1251
|
# Tint the agent's aperture
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
1252
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1253
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1254
|
+
)
|
|
740
1255
|
|
|
741
1256
|
alpha = 0.2
|
|
742
1257
|
agent_color = jnp.array(AGENT.color)
|
|
743
1258
|
|
|
744
|
-
# Create indices for the aperture
|
|
745
|
-
y_offsets = jnp.arange(ap_h)
|
|
746
|
-
x_offsets = jnp.arange(ap_w)
|
|
747
|
-
y_coords_original = start_y + y_offsets[:, None]
|
|
748
|
-
x_coords_original = start_x + x_offsets
|
|
749
|
-
|
|
750
1259
|
if self.nowrap:
|
|
751
|
-
y_coords = jnp.clip(y_coords_original, 0, self.size[1] - 1)
|
|
752
|
-
x_coords = jnp.clip(x_coords_original, 0, self.size[0] - 1)
|
|
753
1260
|
# Create tint mask: any in-bounds original position maps to a cell makes it tinted
|
|
754
1261
|
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
755
|
-
tint_mask = tint_mask.at[
|
|
1262
|
+
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
756
1263
|
# Apply tint to masked positions
|
|
757
1264
|
original_colors = img
|
|
758
1265
|
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
759
1266
|
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
760
1267
|
else:
|
|
761
|
-
|
|
762
|
-
x_coords = jnp.mod(x_coords_original, self.size[0])
|
|
763
|
-
original_colors = img[y_coords, x_coords]
|
|
1268
|
+
original_colors = img[y_coords_adj, x_coords_adj]
|
|
764
1269
|
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
765
|
-
img = img.at[
|
|
1270
|
+
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
766
1271
|
|
|
767
1272
|
# Agent color
|
|
768
1273
|
img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
|
|
@@ -775,6 +1280,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
775
1280
|
|
|
776
1281
|
if is_true_mode:
|
|
777
1282
|
# Apply true object borders by overlaying true colors on border pixels
|
|
1283
|
+
render_grid = state.object_state.object_id
|
|
778
1284
|
img = apply_true_borders(
|
|
779
1285
|
img, render_grid, self.size, len(self.object_ids)
|
|
780
1286
|
)
|
|
@@ -787,10 +1293,35 @@ class ForagaxEnv(environment.Environment):
|
|
|
787
1293
|
img = img.at[:, col_indices].set(grid_color)
|
|
788
1294
|
|
|
789
1295
|
elif is_aperture_mode:
|
|
790
|
-
obs_grid =
|
|
1296
|
+
obs_grid = state.object_state.object_id
|
|
791
1297
|
aperture = self._get_aperture(obs_grid, state.pos)
|
|
792
|
-
|
|
793
|
-
|
|
1298
|
+
|
|
1299
|
+
if self.dynamic_biomes:
|
|
1300
|
+
# Use per-instance colors from state - extract aperture view
|
|
1301
|
+
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1302
|
+
self._compute_aperture_coordinates(state.pos)
|
|
1303
|
+
)
|
|
1304
|
+
img = state.object_state.color[y_coords_adj, x_coords_adj]
|
|
1305
|
+
|
|
1306
|
+
# Mask empty cells (object_id == 0) to white
|
|
1307
|
+
aperture_object_ids = state.object_state.object_id[
|
|
1308
|
+
y_coords_adj, x_coords_adj
|
|
1309
|
+
]
|
|
1310
|
+
empty_mask = aperture_object_ids == 0
|
|
1311
|
+
white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
1312
|
+
img = jnp.where(empty_mask[..., None], white_color, img)
|
|
1313
|
+
|
|
1314
|
+
if self.nowrap:
|
|
1315
|
+
# For out-of-bounds, use padding object color
|
|
1316
|
+
y_out = (y_coords < 0) | (y_coords >= self.size[1])
|
|
1317
|
+
x_out = (x_coords < 0) | (x_coords >= self.size[0])
|
|
1318
|
+
out_of_bounds = y_out | x_out
|
|
1319
|
+
padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
|
|
1320
|
+
img = jnp.where(out_of_bounds[..., None], padding_color, img)
|
|
1321
|
+
else:
|
|
1322
|
+
# Use default object colors
|
|
1323
|
+
aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
|
|
1324
|
+
img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
|
|
794
1325
|
|
|
795
1326
|
# Draw agent in the center
|
|
796
1327
|
center_y, center_x = self.aperture_size[1] // 2, self.aperture_size[0] // 2
|