continual-foragax 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/env.py CHANGED
@@ -19,6 +19,7 @@ from foragax.objects import (
19
19
  EMPTY,
20
20
  PADDING,
21
21
  BaseForagaxObject,
22
+ FourierObject,
22
23
  WeatherObject,
23
24
  )
24
25
  from foragax.rendering import apply_true_borders
@@ -42,6 +43,48 @@ DIRECTIONS = jnp.array(
42
43
  )
43
44
 
44
45
 
46
+ @struct.dataclass
47
+ class ObjectState:
48
+ """Per-cell object state information.
49
+
50
+ This struct encapsulates all state information about objects in the grid:
51
+ - object_id: (H, W) The object type (0 for empty, positive for object type)
52
+ - respawn_timer: (H, W) Countdown timer for respawning (0 = no timer, positive = countdown remaining)
53
+ - respawn_object_id: (H, W) What object type will spawn when timer reaches 0
54
+ - spawn_time: (H, W) When each object was spawned (for expiry tracking)
55
+ - color: (H, W, 3) RGB color for each object instance (for dynamic biomes)
56
+ - generation: (H, W) Which biome generation each object belongs to
57
+ - state_params: (H, W, N) Per-instance parameters (e.g., Fourier coefficients)
58
+ - biome_id: (H, W) Which biome each cell belongs to
59
+ """
60
+
61
+ object_id: jax.Array # (H, W) - Object type ID (0 = empty, >0 = object type)
62
+ respawn_timer: (
63
+ jax.Array
64
+ ) # (H, W) - Respawn countdown (0 = no timer, >0 = countdown)
65
+ respawn_object_id: jax.Array # (H, W) - Object type to spawn when timer reaches 0
66
+ spawn_time: jax.Array # (H, W) - Timestep when object spawned
67
+ color: jax.Array # (H, W, 3) - RGB color per instance
68
+ generation: jax.Array # (H, W) - Biome generation number
69
+ state_params: jax.Array # (H, W, N) - Per-instance parameters
70
+ biome_id: jax.Array # (H, W) - Biome assignment for each cell
71
+
72
+ @classmethod
73
+ def create_empty(cls, size: Tuple[int, int], num_params: int) -> "ObjectState":
74
+ """Create an empty ObjectState for the given grid size."""
75
+ h, w = size[1], size[0]
76
+ return cls(
77
+ object_id=jnp.zeros((h, w), dtype=int),
78
+ respawn_timer=jnp.zeros((h, w), dtype=int),
79
+ respawn_object_id=jnp.zeros((h, w), dtype=int),
80
+ spawn_time=jnp.zeros((h, w), dtype=int),
81
+ color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
82
+ generation=jnp.zeros((h, w), dtype=int),
83
+ state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
84
+ biome_id=jnp.full((h, w), -1, dtype=int),
85
+ )
86
+
87
+
45
88
  @dataclass
46
89
  class Biome:
47
90
  # Object generation frequencies for this biome
@@ -55,14 +98,22 @@ class EnvParams(environment.EnvParams):
55
98
  max_steps_in_episode: Union[int, None]
56
99
 
57
100
 
101
+ @struct.dataclass
102
+ class BiomeState:
103
+ """Biome-level tracking state (num_biomes,)."""
104
+
105
+ consumption_count: jax.Array # objects consumed per biome
106
+ total_objects: jax.Array # total objects spawned per biome
107
+ generation: jax.Array # current generation per biome
108
+
109
+
58
110
  @struct.dataclass
59
111
  class EnvState(environment.EnvState):
60
112
  pos: jax.Array
61
- object_grid: jax.Array
62
- biome_grid: jax.Array
63
113
  time: int
64
114
  digestion_buffer: jax.Array
65
- object_spawn_time_grid: jax.Array
115
+ object_state: ObjectState
116
+ biome_state: BiomeState
66
117
 
67
118
 
68
119
  class ForagaxEnv(environment.Environment):
@@ -79,6 +130,8 @@ class ForagaxEnv(environment.Environment):
79
130
  deterministic_spawn: bool = False,
80
131
  teleport_interval: Optional[int] = None,
81
132
  observation_type: str = "object",
133
+ dynamic_biomes: bool = False,
134
+ biome_consumption_threshold: float = 0.9,
82
135
  ):
83
136
  super().__init__()
84
137
  self._name = name
@@ -100,11 +153,24 @@ class ForagaxEnv(environment.Environment):
100
153
  self.nowrap = nowrap
101
154
  self.deterministic_spawn = deterministic_spawn
102
155
  self.teleport_interval = teleport_interval
156
+ self.dynamic_biomes = dynamic_biomes
157
+ self.biome_consumption_threshold = biome_consumption_threshold
158
+
103
159
  objects = (EMPTY,) + objects
104
160
  if self.nowrap and not self.full_world:
105
161
  objects = objects + (PADDING,)
106
162
  self.objects = objects
107
163
 
164
+ # Infer num_fourier_terms from objects
165
+ self.num_fourier_terms = max(
166
+ (
167
+ obj.num_fourier_terms
168
+ for obj in self.objects
169
+ if isinstance(obj, FourierObject)
170
+ ),
171
+ default=0,
172
+ )
173
+
108
174
  # JIT-compatible versions of object and biome properties
109
175
  self.object_ids = jnp.arange(len(objects))
110
176
  self.object_blocking = jnp.array([o.blocking for o in objects])
@@ -122,6 +188,9 @@ class ForagaxEnv(environment.Environment):
122
188
  [o.expiry_time if o.expiry_time is not None else -1 for o in objects]
123
189
  )
124
190
 
191
+ # Check if any objects can expire
192
+ self.has_expiring_objects = jnp.any(self.object_expiry_time >= 0)
193
+
125
194
  # Compute reward steps per object (using max_reward_delay attribute)
126
195
  object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
127
196
  self.max_reward_delay = (
@@ -138,6 +207,7 @@ class ForagaxEnv(environment.Environment):
138
207
  [b.stop if b.stop is not None else (-1, -1) for b in biomes]
139
208
  )
140
209
  self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
210
+ self.biome_sizes_jax = jnp.array(self.biome_sizes) # JAX version for indexing
141
211
  self.biome_starts_jax = jnp.array(self.biome_starts)
142
212
  self.biome_stops_jax = jnp.array(self.biome_stops)
143
213
  biome_centers = []
@@ -149,6 +219,7 @@ class ForagaxEnv(environment.Environment):
149
219
  biome_centers.append((center_x, center_y))
150
220
  self.biome_centers_jax = jnp.array(biome_centers)
151
221
  self.biome_masks = []
222
+ biome_masks_array = []
152
223
  for i in range(self.biome_object_frequencies.shape[0]):
153
224
  # Create mask for the biome
154
225
  start = jax.lax.select(
@@ -170,6 +241,10 @@ class ForagaxEnv(environment.Environment):
170
241
  & (cols < stop[0])
171
242
  )
172
243
  self.biome_masks.append(mask)
244
+ biome_masks_array.append(mask)
245
+
246
+ # Convert to JAX array for indexing in JIT-compiled code
247
+ self.biome_masks_array = jnp.array(biome_masks_array)
173
248
 
174
249
  # Compute unique colors and mapping for partial observability (for 'color' observation_type)
175
250
  # Exclude EMPTY (index 0) from color channels
@@ -200,46 +275,89 @@ class ForagaxEnv(environment.Environment):
200
275
  max_steps_in_episode=None,
201
276
  )
202
277
 
203
- def _place_timer_at_position(
204
- self, grid: jax.Array, y: int, x: int, timer_val: int
205
- ) -> jax.Array:
206
- """Place a timer at a specific position."""
207
- return grid.at[y, x].set(timer_val)
208
-
209
- def _place_timer_at_random_position(
278
+ def _place_timer(
210
279
  self,
211
- grid: jax.Array,
280
+ object_state: ObjectState,
212
281
  y: int,
213
282
  x: int,
283
+ object_type: int,
214
284
  timer_val: int,
215
- biome_grid: jax.Array,
285
+ random_respawn: bool,
216
286
  rand_key: jax.Array,
217
- ) -> jax.Array:
218
- """Place a timer at a random valid position within the same biome."""
219
- # Set the original position to empty temporarily
220
- grid_temp = grid.at[y, x].set(0)
221
-
222
- # Find all valid spawn locations (empty cells within the same biome)
223
- biome_id = biome_grid[y, x]
224
- biome_mask = biome_grid == biome_id
225
- empty_mask = grid_temp == 0
226
- valid_spawn_mask = biome_mask & empty_mask
287
+ ) -> ObjectState:
288
+ """Place a timer at position or randomly within the same biome.
289
+
290
+ Args:
291
+ object_state: Current object state
292
+ y, x: Original position
293
+ object_type: The object type ID that will respawn (0 for permanent removal)
294
+ timer_val: Timer countdown value (0 for permanent removal, positive for countdown)
295
+ random_respawn: If True, place at random location in same biome
296
+ rand_key: Random key for random placement
297
+
298
+ Returns:
299
+ Updated object_state with timer placed
300
+ """
301
+
302
+ # Handle permanent removal (timer_val == 0)
303
+ def place_empty():
304
+ return object_state.replace(
305
+ object_id=object_state.object_id.at[y, x].set(0),
306
+ respawn_timer=object_state.respawn_timer.at[y, x].set(0),
307
+ respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
308
+ )
227
309
 
228
- num_valid_spawns = jnp.sum(valid_spawn_mask)
310
+ # Handle timer placement
311
+ def place_timer():
312
+ # Non-random: place at original position
313
+ def place_at_position():
314
+ return object_state.replace(
315
+ object_id=object_state.object_id.at[y, x].set(0),
316
+ respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
317
+ respawn_object_id=object_state.respawn_object_id.at[y, x].set(
318
+ object_type
319
+ ),
320
+ )
229
321
 
230
- # Get indices of valid spawn locations, padded to a static size
231
- y_indices, x_indices = jnp.nonzero(
232
- valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
233
- )
234
- valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
322
+ # Random: place at random location in same biome
323
+ def place_randomly():
324
+ # Clear the collected object's position
325
+ new_object_id = object_state.object_id.at[y, x].set(0)
326
+ new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
327
+ new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
328
+
329
+ # Find valid spawn locations in the same biome
330
+ biome_id = object_state.biome_id[y, x]
331
+ biome_mask = object_state.biome_id == biome_id
332
+ empty_mask = new_object_id == 0
333
+ no_timer_mask = new_respawn_timer == 0
334
+ valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
335
+ num_valid_spawns = jnp.sum(valid_spawn_mask)
336
+
337
+ y_indices, x_indices = jnp.nonzero(
338
+ valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
339
+ )
340
+ valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
341
+ random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
342
+ new_spawn_pos = valid_spawn_indices[random_idx]
343
+
344
+ # Place timer at the new random position
345
+ new_respawn_timer = new_respawn_timer.at[
346
+ new_spawn_pos[0], new_spawn_pos[1]
347
+ ].set(timer_val)
348
+ new_respawn_object_id = new_respawn_object_id.at[
349
+ new_spawn_pos[0], new_spawn_pos[1]
350
+ ].set(object_type)
351
+
352
+ return object_state.replace(
353
+ object_id=new_object_id,
354
+ respawn_timer=new_respawn_timer,
355
+ respawn_object_id=new_respawn_object_id,
356
+ )
235
357
 
236
- # Select a random valid location
237
- random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
238
- new_spawn_pos = valid_spawn_indices[random_idx]
358
+ return jax.lax.cond(random_respawn, place_randomly, place_at_position)
239
359
 
240
- # Place the timer at the new random position
241
- new_grid = grid_temp.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
242
- return new_grid
360
+ return jax.lax.cond(timer_val == 0, place_empty, place_timer)
243
361
 
244
362
  def step_env(
245
363
  self,
@@ -249,9 +367,7 @@ class ForagaxEnv(environment.Environment):
249
367
  params: EnvParams,
250
368
  ) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
251
369
  """Perform single timestep state transition."""
252
- num_obj_types = len(self.object_ids)
253
- # Decode the object grid: positive values are objects, negative are timers (treat as empty)
254
- current_objects = jnp.maximum(0, state.object_grid)
370
+ current_objects = state.object_state.object_id
255
371
 
256
372
  # 1. UPDATE AGENT POSITION
257
373
  direction = DIRECTIONS[action]
@@ -294,9 +410,12 @@ class ForagaxEnv(environment.Environment):
294
410
  # Handle digestion: add reward to buffer if collected
295
411
  digestion_buffer = state.digestion_buffer
296
412
  key, reward_subkey = jax.random.split(key)
413
+
414
+ object_params = state.object_state.state_params[pos[1], pos[0]]
297
415
  object_reward = jax.lax.switch(
298
- obj_at_pos, self.reward_fns, state.time, reward_subkey
416
+ obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
299
417
  )
418
+
300
419
  key, digestion_subkey = jax.random.split(key)
301
420
  reward_delay = jax.lax.switch(
302
421
  obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
@@ -319,136 +438,109 @@ class ForagaxEnv(environment.Environment):
319
438
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
320
439
  key, regen_subkey, rand_key = jax.random.split(key, 3)
321
440
 
322
- # Decrement timers (stored as negative values)
323
- is_timer = state.object_grid < 0
324
- object_grid = jnp.where(
325
- is_timer, state.object_grid + num_obj_types, state.object_grid
441
+ # Decrement respawn timers
442
+ has_timer = state.object_state.respawn_timer > 0
443
+ new_respawn_timer = jnp.where(
444
+ has_timer,
445
+ state.object_state.respawn_timer - 1,
446
+ state.object_state.respawn_timer,
326
447
  )
327
448
 
328
- # Track which objects just respawned (timer -> object transition)
329
- # An object respawned if it was negative and is now positive
330
- just_respawned = is_timer & (object_grid > 0)
449
+ # Track which cells have timers that just reached 0
450
+ just_respawned = has_timer & (new_respawn_timer == 0)
451
+
452
+ # Respawn objects where timer reached 0
453
+ new_object_id = jnp.where(
454
+ just_respawned,
455
+ state.object_state.respawn_object_id,
456
+ state.object_state.object_id,
457
+ )
458
+
459
+ # Clear respawn_object_id for cells that just respawned
460
+ new_respawn_object_id = jnp.where(
461
+ just_respawned,
462
+ 0,
463
+ state.object_state.respawn_object_id,
464
+ )
331
465
 
332
466
  # Update spawn times for objects that just respawned
333
- object_spawn_time_grid = jnp.where(
334
- just_respawned, state.time, state.object_spawn_time_grid
467
+ spawn_time = jnp.where(
468
+ just_respawned, state.time, state.object_state.spawn_time
335
469
  )
336
470
 
337
471
  # Collect object: set a timer
338
472
  regen_delay = jax.lax.switch(
339
473
  obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
340
474
  )
341
- encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
475
+ timer_countdown = jax.lax.cond(
476
+ regen_delay == jnp.iinfo(jnp.int32).max,
477
+ lambda: 0, # No timer (permanent removal)
478
+ lambda: regen_delay + 1, # Timer countdown
479
+ )
342
480
 
343
481
  # If collected, replace object with timer; otherwise, keep it
344
- val_at_pos = object_grid[pos[1], pos[0]]
345
- should_collect = is_collectable & (val_at_pos > 0)
482
+ val_at_pos = current_objects[pos[1], pos[0]]
483
+ # Use original should_collect for consumption tracking
484
+ should_collect_now = is_collectable & (val_at_pos > 0)
485
+
486
+ # Create updated object state with new respawn_timer, object_id, and spawn_time
487
+ object_state = state.object_state.replace(
488
+ object_id=new_object_id,
489
+ respawn_timer=new_respawn_timer,
490
+ respawn_object_id=new_respawn_object_id,
491
+ spawn_time=spawn_time,
492
+ )
346
493
 
347
- # When not collecting, the value at the position remains unchanged.
348
- # When collecting, we either place the timer at the current position or a random one.
349
- def do_collection():
350
- return jax.lax.cond(
494
+ # Place timer on collection
495
+ object_state = jax.lax.cond(
496
+ should_collect_now,
497
+ lambda: self._place_timer(
498
+ object_state,
499
+ pos[1],
500
+ pos[0],
501
+ obj_at_pos, # object type
502
+ timer_countdown, # timer value
351
503
  self.object_random_respawn[obj_at_pos],
352
- lambda: self._place_timer_at_random_position(
353
- object_grid,
354
- pos[1],
355
- pos[0],
356
- encoded_timer,
357
- state.biome_grid,
358
- rand_key,
359
- ),
360
- lambda: self._place_timer_at_position(
361
- object_grid, pos[1], pos[0], encoded_timer
362
- ),
363
- )
364
-
365
- def no_collection():
366
- return object_grid
367
-
368
- object_grid = jax.lax.cond(
369
- should_collect,
370
- do_collection,
371
- no_collection,
504
+ rand_key,
505
+ ),
506
+ lambda: object_state,
372
507
  )
373
508
 
374
509
  # 3.5. HANDLE OBJECT EXPIRY
375
- # Check each cell for objects that have exceeded their expiry time
376
- current_objects_for_expiry = jnp.maximum(0, object_grid)
377
-
378
- # Calculate age of each object (current_time - spawn_time)
379
- object_ages = state.time - object_spawn_time_grid
380
-
381
- # Get expiry time for each object type in the grid
382
- expiry_times = self.object_expiry_time[current_objects_for_expiry]
383
-
384
- # Check if object should expire (age >= expiry_time and expiry_time >= 0)
385
- should_expire = (
386
- (object_ages >= expiry_times)
387
- & (expiry_times >= 0)
388
- & (current_objects_for_expiry > 0)
389
- )
390
-
391
- # For expired objects, calculate expiry regen delay
392
- key, expiry_key = jax.random.split(key)
393
-
394
- # Process expiry for all positions that need it
395
- def process_expiry(y, x, grid, spawn_grid, key):
396
- obj_id = current_objects_for_expiry[y, x]
397
- should_exp = should_expire[y, x]
398
-
399
- def expire_object():
400
- # Get expiry regen delay for this object
401
- exp_key = jax.random.fold_in(key, y * self.size[0] + x)
402
- exp_delay = jax.lax.switch(
403
- obj_id, self.expiry_regen_delay_fns, state.time, exp_key
404
- )
405
- encoded_exp_timer = obj_id - ((exp_delay + 1) * num_obj_types)
406
-
407
- # Check if this object should respawn randomly
408
- should_random_respawn = self.object_random_respawn[obj_id]
409
-
410
- # Use second split for randomness in random placement
411
- rand_key = jax.random.split(exp_key)[1]
412
-
413
- # Place timer either at current position or random position
414
- new_grid = jax.lax.cond(
415
- should_random_respawn,
416
- lambda: self._place_timer_at_random_position(
417
- grid, y, x, encoded_exp_timer, state.biome_grid, rand_key
418
- ),
419
- lambda: self._place_timer_at_position(
420
- grid, y, x, encoded_exp_timer
421
- ),
422
- )
423
-
424
- return new_grid, spawn_grid
425
-
426
- def no_expire():
427
- return grid, spawn_grid
428
-
429
- return jax.lax.cond(should_exp, expire_object, no_expire)
430
-
431
- # Apply expiry to all cells (vectorized)
432
- def scan_expiry_row(carry, y):
433
- grid, spawn_grid, key = carry
434
-
435
- def scan_expiry_col(carry_col, x):
436
- grid_col, spawn_grid_col, key_col = carry_col
437
- grid_col, spawn_grid_col = process_expiry(
438
- y, x, grid_col, spawn_grid_col, key_col
439
- )
440
- return (grid_col, spawn_grid_col, key_col), None
441
-
442
- (grid, spawn_grid, key), _ = jax.lax.scan(
443
- scan_expiry_col, (grid, spawn_grid, key), jnp.arange(self.size[0])
510
+ # Only process expiry if there are objects that can expire
511
+ key, object_state = self.expire_objects(key, state, object_state)
512
+
513
+ # 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
514
+ if self.dynamic_biomes:
515
+ # Update consumption count if an object was collected
516
+ # Only count if the object belongs to the current generation of its biome
517
+ collected_biome_id = object_state.biome_id[pos[1], pos[0]]
518
+ object_gen_at_pos = object_state.generation[pos[1], pos[0]]
519
+ current_biome_gen = state.biome_state.generation[collected_biome_id]
520
+ is_current_generation = object_gen_at_pos == current_biome_gen
521
+
522
+ biome_consumption_count = state.biome_state.consumption_count
523
+ biome_consumption_count = jax.lax.cond(
524
+ should_collect & is_current_generation,
525
+ lambda: biome_consumption_count.at[collected_biome_id].add(1),
526
+ lambda: biome_consumption_count,
444
527
  )
445
- return (grid, spawn_grid, key), None
446
528
 
447
- (object_grid, object_spawn_time_grid, _), _ = jax.lax.scan(
448
- scan_expiry_row,
449
- (object_grid, object_spawn_time_grid, expiry_key),
450
- jnp.arange(self.size[1]),
451
- )
529
+ # Check each biome for threshold crossing and respawn if needed
530
+ key, respawn_key = jax.random.split(key)
531
+ biome_state = BiomeState(
532
+ consumption_count=biome_consumption_count,
533
+ total_objects=state.biome_state.total_objects,
534
+ generation=state.biome_state.generation,
535
+ )
536
+ object_state, biome_state, respawn_key = self._check_and_respawn_biomes(
537
+ object_state,
538
+ biome_state,
539
+ state.time,
540
+ respawn_key,
541
+ )
542
+ else:
543
+ biome_state = state.biome_state
452
544
 
453
545
  info = {"discount": self.discount(state, params)}
454
546
  temperatures = jnp.zeros(len(self.objects))
@@ -458,17 +550,16 @@ class ForagaxEnv(environment.Environment):
458
550
  get_temperature(obj.rewards, state.time, obj.repeat)
459
551
  )
460
552
  info["temperatures"] = temperatures
461
- info["biome_id"] = state.biome_grid[pos[1], pos[0]]
553
+ info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
462
554
  info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
463
555
 
464
556
  # 4. UPDATE STATE
465
557
  state = EnvState(
466
558
  pos=pos,
467
- object_grid=object_grid,
468
- biome_grid=state.biome_grid,
469
559
  time=state.time + 1,
470
560
  digestion_buffer=digestion_buffer,
471
- object_spawn_time_grid=object_spawn_time_grid,
561
+ object_state=object_state,
562
+ biome_state=biome_state,
472
563
  )
473
564
 
474
565
  done = self.is_terminal(state, params)
@@ -480,60 +571,395 @@ class ForagaxEnv(environment.Environment):
480
571
  info,
481
572
  )
482
573
 
574
+ def expire_objects(
575
+ self, key, state, object_state: ObjectState
576
+ ) -> Tuple[jax.Array, ObjectState]:
577
+ if self.has_expiring_objects:
578
+ # Check each cell for objects that have exceeded their expiry time
579
+ current_objects_for_expiry = object_state.object_id
580
+
581
+ # Calculate age of each object (current_time - spawn_time)
582
+ object_ages = state.time - object_state.spawn_time
583
+
584
+ # Get expiry time for each object type in the grid
585
+ expiry_times = self.object_expiry_time[current_objects_for_expiry]
586
+
587
+ # Check if object should expire (age >= expiry_time and expiry_time >= 0)
588
+ should_expire = (
589
+ (object_ages >= expiry_times)
590
+ & (expiry_times >= 0)
591
+ & (current_objects_for_expiry > 0)
592
+ )
593
+
594
+ # Count how many objects actually need to expire
595
+ num_expiring = jnp.sum(should_expire)
596
+
597
+ # Only process expiry if there are actually objects to expire
598
+ def process_expiries():
599
+ # Get positions of objects that should expire
600
+ # Use nonzero with fixed size to maintain JIT compatibility
601
+ max_objects = self.size[0] * self.size[1]
602
+ y_indices, x_indices = jnp.nonzero(
603
+ should_expire, size=max_objects, fill_value=-1
604
+ )
605
+
606
+ key_local, expiry_key = jax.random.split(key)
607
+
608
+ def process_one_expiry(carry, i):
609
+ obj_state = carry
610
+ y = y_indices[i]
611
+ x = x_indices[i]
612
+
613
+ # Skip if this is a padding index (from fill_value)
614
+ is_valid = (y >= 0) & (x >= 0)
615
+
616
+ def expire_one():
617
+ obj_id = current_objects_for_expiry[y, x]
618
+ exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
619
+ exp_delay = jax.lax.switch(
620
+ obj_id, self.expiry_regen_delay_fns, state.time, exp_key
621
+ )
622
+ timer_countdown = jax.lax.cond(
623
+ exp_delay == jnp.iinfo(jnp.int32).max,
624
+ lambda: 0, # No timer (permanent removal)
625
+ lambda: exp_delay + 1, # Timer countdown
626
+ )
627
+
628
+ # Use unified timer placement method
629
+ rand_key = jax.random.split(exp_key)[1]
630
+ new_obj_state = self._place_timer(
631
+ obj_state,
632
+ y,
633
+ x,
634
+ obj_id,
635
+ timer_countdown,
636
+ self.object_random_respawn[obj_id],
637
+ rand_key,
638
+ )
639
+
640
+ return new_obj_state
641
+
642
+ def no_op():
643
+ return obj_state
644
+
645
+ return jax.lax.cond(is_valid, expire_one, no_op), None
646
+
647
+ new_object_state, _ = jax.lax.scan(
648
+ process_one_expiry,
649
+ object_state,
650
+ jnp.arange(max_objects),
651
+ )
652
+ return key_local, new_object_state
653
+
654
+ def no_expiries():
655
+ return key, object_state
656
+
657
+ key, object_state = jax.lax.cond(
658
+ num_expiring > 0,
659
+ process_expiries,
660
+ no_expiries,
661
+ )
662
+
663
+ return key, object_state
664
+
665
+ def _check_and_respawn_biomes(
666
+ self,
667
+ object_state: ObjectState,
668
+ biome_state: BiomeState,
669
+ current_time: int,
670
+ key: jax.Array,
671
+ ) -> Tuple[ObjectState, BiomeState, jax.Array]:
672
+ """Check all biomes for consumption threshold and respawn if needed."""
673
+
674
+ num_biomes = self.biome_object_frequencies.shape[0]
675
+
676
+ # Compute consumption rates for all biomes
677
+ consumption_rates = biome_state.consumption_count / jnp.maximum(
678
+ 1.0, biome_state.total_objects.astype(float)
679
+ )
680
+ should_respawn = consumption_rates >= self.biome_consumption_threshold
681
+
682
+ # Split key for all biomes in parallel
683
+ key, subkey = jax.random.split(key)
684
+ biome_keys = jax.random.split(subkey, num_biomes)
685
+
686
+ # Compute all new spawns in parallel using vmap for random, switch for deterministic
687
+ if self.deterministic_spawn:
688
+ # Use switch to dispatch to concrete biome spawns for deterministic
689
+ def make_spawn_fn(biome_idx):
690
+ def spawn_fn(key):
691
+ return self._spawn_biome_objects(biome_idx, key, deterministic=True)
692
+
693
+ return spawn_fn
694
+
695
+ spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
696
+
697
+ # Apply switch for each biome
698
+ all_new_objects_list = []
699
+ all_new_colors_list = []
700
+ all_new_params_list = []
701
+ for i in range(num_biomes):
702
+ obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
703
+ all_new_objects_list.append(obj)
704
+ all_new_colors_list.append(col)
705
+ all_new_params_list.append(par)
706
+
707
+ all_new_objects = jnp.stack(all_new_objects_list)
708
+ all_new_colors = jnp.stack(all_new_colors_list)
709
+ all_new_params = jnp.stack(all_new_params_list)
710
+ else:
711
+ # Random spawn works with vmap
712
+ all_new_objects, all_new_colors, all_new_params = jax.vmap(
713
+ lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
714
+ )(jnp.arange(num_biomes), biome_keys)
715
+
716
+ # Initialize updated grids
717
+ new_obj_id = object_state.object_id
718
+ new_color = object_state.color
719
+ new_params = object_state.state_params
720
+ new_spawn = object_state.spawn_time
721
+ new_gen = object_state.generation
722
+
723
+ # Update biome state
724
+ new_consumption_count = jnp.where(
725
+ should_respawn, 0, biome_state.consumption_count
726
+ )
727
+ new_generation = biome_state.generation + should_respawn.astype(int)
728
+
729
+ # Compute new total objects for respawning biomes
730
+ def count_objects(i):
731
+ return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
732
+
733
+ new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
734
+ new_total_objects = jnp.where(
735
+ should_respawn, new_object_counts, biome_state.total_objects
736
+ )
737
+
738
+ new_biome_state = BiomeState(
739
+ consumption_count=new_consumption_count,
740
+ total_objects=new_total_objects,
741
+ generation=new_generation,
742
+ )
743
+
744
+ # Update grids for respawning biomes
745
+ for i in range(num_biomes):
746
+ biome_mask = self.biome_masks_array[i]
747
+ new_gen_value = new_biome_state.generation[i]
748
+
749
+ # Only update where new spawn has objects and biome should respawn
750
+ is_new_object = (
751
+ (all_new_objects[i] > 0) & biome_mask & should_respawn[i][..., None]
752
+ )
753
+
754
+ new_obj_id = jnp.where(is_new_object, all_new_objects[i], new_obj_id)
755
+ new_color = jnp.where(
756
+ is_new_object[..., None], all_new_colors[i], new_color
757
+ )
758
+ new_params = jnp.where(
759
+ is_new_object[..., None], all_new_params[i], new_params
760
+ )
761
+ new_gen = jnp.where(is_new_object, new_gen_value, new_gen)
762
+ new_spawn = jnp.where(is_new_object, current_time, new_spawn)
763
+
764
+ # Clear timers in respawning biomes
765
+ new_respawn_timer = object_state.respawn_timer
766
+ new_respawn_object_id = object_state.respawn_object_id
767
+ for i in range(num_biomes):
768
+ biome_mask = self.biome_masks_array[i]
769
+ should_clear = biome_mask & should_respawn[i][..., None]
770
+ new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
771
+ new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
772
+
773
+ object_state = object_state.replace(
774
+ object_id=new_obj_id,
775
+ respawn_timer=new_respawn_timer,
776
+ respawn_object_id=new_respawn_object_id,
777
+ color=new_color,
778
+ state_params=new_params,
779
+ generation=new_gen,
780
+ spawn_time=new_spawn,
781
+ )
782
+
783
+ return object_state, new_biome_state, key
784
+
483
785
  def reset_env(
484
786
  self, key: jax.Array, params: EnvParams
485
787
  ) -> Tuple[jax.Array, EnvState]:
486
788
  """Reset environment state."""
487
- object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
488
- biome_grid = jnp.full((self.size[1], self.size[0]), -1, dtype=int)
789
+ num_object_params = 2 + 2 * self.num_fourier_terms
790
+ object_state = ObjectState.create_empty(self.size, num_object_params)
791
+
489
792
  key, iter_key = jax.random.split(key)
793
+
794
+ # Spawn objects in each biome using unified method
490
795
  for i in range(self.biome_object_frequencies.shape[0]):
491
796
  iter_key, biome_key = jax.random.split(iter_key)
492
797
  mask = self.biome_masks[i]
493
- biome_grid = jnp.where(mask, i, biome_grid)
494
798
 
495
- if self.deterministic_spawn:
496
- biome_objects = self.generate_biome_new(i, biome_key)
497
- object_grid = object_grid.at[mask].set(biome_objects)
498
- else:
499
- biome_objects = self.generate_biome_old(i, biome_key)
500
- object_grid = jnp.where(mask, biome_objects, object_grid)
799
+ # Set biome_id
800
+ object_state = object_state.replace(
801
+ biome_id=jnp.where(mask, i, object_state.biome_id)
802
+ )
803
+
804
+ # Use unified spawn method
805
+ biome_objects, biome_colors, biome_object_params = (
806
+ self._spawn_biome_objects(i, biome_key, self.deterministic_spawn)
807
+ )
808
+
809
+ # Merge biome objects/colors/params into object_state
810
+ object_state = object_state.replace(
811
+ object_id=jnp.where(mask, biome_objects, object_state.object_id),
812
+ color=jnp.where(mask[..., None], biome_colors, object_state.color),
813
+ state_params=jnp.where(
814
+ mask[..., None], biome_object_params, object_state.state_params
815
+ ),
816
+ )
501
817
 
502
818
  # Place agent in the center of the world
503
819
  agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
504
820
 
505
- # Initialize spawn times to 0 (all objects spawn at time 0)
506
- object_spawn_time_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
821
+ # Initialize biome consumption tracking
822
+ num_biomes = self.biome_object_frequencies.shape[0]
823
+ biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
824
+ biome_total_objects = jnp.zeros(num_biomes, dtype=int)
825
+
826
+ # Count objects in each biome
827
+ for i in range(num_biomes):
828
+ mask = self.biome_masks[i]
829
+ # Count non-empty objects (object_id > 0)
830
+ total = jnp.sum((object_state.object_id > 0) & mask)
831
+ biome_total_objects = biome_total_objects.at[i].set(total)
832
+
833
+ biome_generation = jnp.zeros(num_biomes, dtype=int)
507
834
 
508
835
  state = EnvState(
509
836
  pos=agent_pos,
510
- object_grid=object_grid,
511
- biome_grid=biome_grid,
512
837
  time=0,
513
838
  digestion_buffer=jnp.zeros((self.max_reward_delay,)),
514
- object_spawn_time_grid=object_spawn_time_grid,
839
+ object_state=object_state,
840
+ biome_state=BiomeState(
841
+ consumption_count=biome_consumption_count,
842
+ total_objects=biome_total_objects,
843
+ generation=biome_generation,
844
+ ),
515
845
  )
516
846
 
517
847
  return self.get_obs(state, params), state
518
848
 
519
- def generate_biome_old(self, i: int, biome_key: jax.Array):
520
- biome_freqs = self.biome_object_frequencies[i]
521
- grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
522
- empty_freq = 1.0 - jnp.sum(biome_freqs)
523
- all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
524
- cumulative_freqs = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), all_freqs]))
525
- biome_objects = jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
526
- return biome_objects
527
-
528
- def generate_biome_new(self, i: int, biome_key: jax.Array):
529
- biome_freqs = self.biome_object_frequencies[i]
530
- grid = jnp.linspace(0, 1, self.biome_sizes[i], endpoint=False)
531
- biome_objects = len(biome_freqs) - jnp.searchsorted(
532
- jnp.cumsum(biome_freqs[::-1]), grid, side="right"
849
+ def _spawn_biome_objects(
850
+ self,
851
+ biome_idx: int,
852
+ key: jax.Array,
853
+ deterministic: bool = False,
854
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
855
+ """Spawn objects in a biome.
856
+
857
+ Returns:
858
+ object_grid: (H, W) array of object IDs
859
+ color_grid: (H, W, 3) array of RGB colors
860
+ state_grid: (H, W, num_state_params) array of object state parameters
861
+ """
862
+ biome_freqs = self.biome_object_frequencies[biome_idx]
863
+ biome_mask = self.biome_masks_array[biome_idx]
864
+
865
+ key, spawn_key, color_key, params_key = jax.random.split(key, 4)
866
+
867
+ # Generate object IDs using deterministic or random spawn
868
+ if deterministic:
869
+ # Deterministic spawn: exact number of each object type
870
+ # NOTE: Requires concrete biome_idx to compute size at trace time
871
+ # Get static biome bounds
872
+ biome_start = self.biome_starts[biome_idx]
873
+ biome_stop = self.biome_stops[biome_idx]
874
+ biome_height = biome_stop[1] - biome_start[1]
875
+ biome_width = biome_stop[0] - biome_start[0]
876
+ biome_size = int(self.biome_sizes[biome_idx])
877
+
878
+ grid = jnp.linspace(0, 1, biome_size, endpoint=False)
879
+ biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
880
+ jnp.cumsum(biome_freqs[::-1]), grid, side="right"
881
+ )
882
+ biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
883
+
884
+ # Reshape to match biome dimensions (use concrete dimensions)
885
+ biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
886
+
887
+ # Place in full grid using slicing with static bounds
888
+ object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
889
+ object_grid = object_grid.at[
890
+ biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
891
+ ].set(biome_objects)
892
+ else:
893
+ # Random spawn: probabilistic placement (works with traced biome_idx)
894
+ grid_rand = jax.random.uniform(spawn_key, (self.size[1], self.size[0]))
895
+ empty_freq = 1.0 - jnp.sum(biome_freqs)
896
+ all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
897
+ cumulative_freqs = jnp.cumsum(
898
+ jnp.concatenate([jnp.array([0.0]), all_freqs])
899
+ )
900
+ object_grid = (
901
+ jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
902
+ )
903
+
904
+ # Initialize color grid
905
+ color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
906
+
907
+ # Sample ONE color per object type in this biome (not per instance)
908
+ # This gives objects of the same type the same color within a biome generation
909
+ # Skip index 0 (EMPTY object) - only sample colors for actual objects
910
+ num_object_types = len(self.objects)
911
+ num_actual_objects = num_object_types - 1 # Exclude EMPTY
912
+
913
+ if num_actual_objects > 0:
914
+ biome_object_colors = jax.random.randint(
915
+ color_key,
916
+ (num_actual_objects, 3),
917
+ minval=0,
918
+ maxval=256,
919
+ dtype=jnp.uint8,
920
+ )
921
+
922
+ # Assign colors based on object type (starting from index 1)
923
+ for obj_idx in range(1, num_object_types):
924
+ obj_mask = (object_grid == obj_idx) & biome_mask
925
+ obj_color = biome_object_colors[
926
+ obj_idx - 1
927
+ ] # Offset by 1 since we skip EMPTY
928
+ color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
929
+
930
+ # Initialize parameters grid
931
+ num_object_params = 2 + 2 * self.num_fourier_terms
932
+ params_grid = jnp.zeros(
933
+ (self.size[1], self.size[0], num_object_params), dtype=jnp.float32
533
934
  )
534
- flat_biome_objects = biome_objects.flatten()
535
- shuffled_objects = jax.random.permutation(biome_key, flat_biome_objects)
536
- return shuffled_objects
935
+
936
+ # Generate per-object parameters for each object type
937
+ for obj_idx in range(num_object_types):
938
+ # Get params for this object type - this happens at trace time
939
+ params_key, obj_key = jax.random.split(params_key)
940
+ obj_params = self.objects[obj_idx].get_state(obj_key)
941
+
942
+ # Skip if no params (e.g., for EMPTY or default objects)
943
+ if len(obj_params) == 0:
944
+ continue
945
+
946
+ # Ensure params match expected size
947
+ if len(obj_params) != num_object_params:
948
+ if len(obj_params) < num_object_params:
949
+ obj_params = jnp.pad(
950
+ obj_params,
951
+ (0, num_object_params - len(obj_params)),
952
+ constant_values=0.0,
953
+ )
954
+ else:
955
+ # Truncate if too long
956
+ obj_params = obj_params[:num_object_params]
957
+
958
+ # Assign to all objects of this type in this biome
959
+ obj_mask = (object_grid == obj_idx) & biome_mask
960
+ params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
961
+
962
+ return object_grid, color_grid, params_grid
537
963
 
538
964
  def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
539
965
  """Foragax is a continuing environment."""
@@ -556,21 +982,10 @@ class ForagaxEnv(environment.Environment):
556
982
 
557
983
  def state_space(self, params: EnvParams) -> spaces.Dict:
558
984
  """State space of the environment."""
985
+ num_object_params = 2 + 2 * self.num_fourier_terms
559
986
  return spaces.Dict(
560
987
  {
561
988
  "pos": spaces.Box(0, max(self.size), (2,), int),
562
- "object_grid": spaces.Box(
563
- -1000 * len(self.object_ids),
564
- len(self.object_ids),
565
- (self.size[1], self.size[0]),
566
- int,
567
- ),
568
- "biome_grid": spaces.Box(
569
- 0,
570
- self.biome_object_frequencies.shape[0],
571
- (self.size[1], self.size[0]),
572
- int,
573
- ),
574
989
  "time": spaces.Discrete(params.max_steps_in_episode),
575
990
  "digestion_buffer": spaces.Box(
576
991
  -jnp.inf,
@@ -578,17 +993,79 @@ class ForagaxEnv(environment.Environment):
578
993
  (self.max_reward_delay,),
579
994
  float,
580
995
  ),
581
- "object_spawn_time_grid": spaces.Box(
582
- 0,
583
- jnp.inf,
584
- (self.size[1], self.size[0]),
585
- int,
996
+ "object_state": spaces.Dict(
997
+ {
998
+ "object_id": spaces.Box(
999
+ -1000 * len(self.object_ids),
1000
+ len(self.object_ids),
1001
+ (self.size[1], self.size[0]),
1002
+ int,
1003
+ ),
1004
+ "spawn_time": spaces.Box(
1005
+ 0,
1006
+ jnp.inf,
1007
+ (self.size[1], self.size[0]),
1008
+ int,
1009
+ ),
1010
+ "color": spaces.Box(
1011
+ 0,
1012
+ 255,
1013
+ (self.size[1], self.size[0], 3),
1014
+ int,
1015
+ ),
1016
+ "generation": spaces.Box(
1017
+ 0,
1018
+ jnp.inf,
1019
+ (self.size[1], self.size[0]),
1020
+ int,
1021
+ ),
1022
+ "state_params": spaces.Box(
1023
+ -jnp.inf,
1024
+ jnp.inf,
1025
+ (self.size[1], self.size[0], num_object_params),
1026
+ float,
1027
+ ),
1028
+ "biome_id": spaces.Box(
1029
+ -1,
1030
+ self.biome_object_frequencies.shape[0],
1031
+ (self.size[1], self.size[0]),
1032
+ int,
1033
+ ),
1034
+ }
1035
+ ),
1036
+ "biome_state": spaces.Dict(
1037
+ {
1038
+ "consumption_count": spaces.Box(
1039
+ 0,
1040
+ jnp.inf,
1041
+ (self.biome_object_frequencies.shape[0],),
1042
+ int,
1043
+ ),
1044
+ "total_objects": spaces.Box(
1045
+ 0,
1046
+ jnp.inf,
1047
+ (self.biome_object_frequencies.shape[0],),
1048
+ int,
1049
+ ),
1050
+ "generation": spaces.Box(
1051
+ 0,
1052
+ jnp.inf,
1053
+ (self.biome_object_frequencies.shape[0],),
1054
+ int,
1055
+ ),
1056
+ }
586
1057
  ),
587
1058
  }
588
1059
  )
589
1060
 
590
- def _get_aperture(self, object_grid: jax.Array, pos: jax.Array) -> jax.Array:
591
- """Extract the aperture view from the object grid."""
1061
+ def _compute_aperture_coordinates(
1062
+ self, pos: jax.Array
1063
+ ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
1064
+ """Compute aperture coordinates for the given position.
1065
+
1066
+ Returns:
1067
+ (y_coords, x_coords, y_coords_clamped/mod, x_coords_clamped/mod)
1068
+ """
592
1069
  ap_h, ap_w = self.aperture_size
593
1070
  start_y = pos[1] - ap_h // 2
594
1071
  start_x = pos[0] - ap_w // 2
@@ -599,33 +1076,53 @@ class ForagaxEnv(environment.Environment):
599
1076
  x_coords = start_x + x_offsets
600
1077
 
601
1078
  if self.nowrap:
602
- # Clamp coordinates to bounds
603
- y_coords_clamped = jnp.clip(y_coords, 0, self.size[1] - 1)
604
- x_coords_clamped = jnp.clip(x_coords, 0, self.size[0] - 1)
605
- values = object_grid[y_coords_clamped, x_coords_clamped]
606
- # Mark out-of-bounds positions with -1
1079
+ y_coords_adj = jnp.clip(y_coords, 0, self.size[1] - 1)
1080
+ x_coords_adj = jnp.clip(x_coords, 0, self.size[0] - 1)
1081
+ else:
1082
+ y_coords_adj = jnp.mod(y_coords, self.size[1])
1083
+ x_coords_adj = jnp.mod(x_coords, self.size[0])
1084
+
1085
+ return y_coords, x_coords, y_coords_adj, x_coords_adj
1086
+
1087
+ def _get_aperture(self, object_id_grid: jax.Array, pos: jax.Array) -> jax.Array:
1088
+ """Extract the aperture view from the object id grid."""
1089
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1090
+ self._compute_aperture_coordinates(pos)
1091
+ )
1092
+
1093
+ values = object_id_grid[y_coords_adj, x_coords_adj]
1094
+
1095
+ if self.nowrap:
1096
+ # Mark out-of-bounds positions with padding
607
1097
  y_out = (y_coords < 0) | (y_coords >= self.size[1])
608
1098
  x_out = (x_coords < 0) | (x_coords >= self.size[0])
609
1099
  out_of_bounds = y_out | x_out
610
- padding_index = self.object_ids[-1]
611
- aperture = jnp.where(out_of_bounds, padding_index, values)
1100
+
1101
+ # Handle both object_id grids (2D) and color grids (3D)
1102
+ if len(values.shape) == 3:
1103
+ # Color grid: use PADDING color (0, 0, 0)
1104
+ padding_value = jnp.array([0, 0, 0], dtype=values.dtype)
1105
+ aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
1106
+ else:
1107
+ # Object ID grid: use PADDING index
1108
+ padding_index = self.object_ids[-1]
1109
+ aperture = jnp.where(out_of_bounds, padding_index, values)
612
1110
  else:
613
- y_coords_mod = jnp.mod(y_coords, self.size[1])
614
- x_coords_mod = jnp.mod(x_coords, self.size[0])
615
- aperture = object_grid[y_coords_mod, x_coords_mod]
1111
+ aperture = values
616
1112
 
617
1113
  return aperture
618
1114
 
619
1115
  def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
620
1116
  """Get observation based on observation_type and full_world."""
621
- # Decode grid for observation
622
- obs_grid = jnp.maximum(0, state.object_grid)
1117
+ obs_grid = state.object_state.object_id
1118
+ color_grid = state.object_state.color
623
1119
 
624
1120
  if self.full_world:
625
1121
  return self._get_world_obs(obs_grid, state)
626
1122
  else:
627
1123
  grid = self._get_aperture(obs_grid, state.pos)
628
- return self._get_aperture_obs(grid, state)
1124
+ color_grid = self._get_aperture(color_grid, state.pos)
1125
+ return self._get_aperture_obs(grid, color_grid, state)
629
1126
 
630
1127
  def _get_world_obs(self, obs_grid: jax.Array, state: EnvState) -> jax.Array:
631
1128
  """Get world observation."""
@@ -642,12 +1139,14 @@ class ForagaxEnv(environment.Environment):
642
1139
  obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
643
1140
  return obs
644
1141
  elif self.observation_type == "rgb":
645
- obs = jax.nn.one_hot(obs_grid, num_obj_types)
646
- # Agent position
647
- obs = obs.at[state.pos[1], state.pos[0], :].set(0)
648
- obs = obs.at[state.pos[1], state.pos[0], -1].set(1)
649
- colors = self.object_colors / 255.0
650
- obs = jnp.tensordot(obs, colors, axes=1)
1142
+ # Use state colors directly (supports dynamic biomes)
1143
+ colors = state.object_state.color / 255.0
1144
+
1145
+ # Mask empty cells (object_id == 0) to white
1146
+ empty_mask = obs_grid == 0
1147
+ white_color = jnp.ones((self.size[1], self.size[0], 3), dtype=jnp.float32)
1148
+ obs = jnp.where(empty_mask[..., None], white_color, colors)
1149
+
651
1150
  return obs
652
1151
  elif self.observation_type == "color":
653
1152
  # Handle case with no objects (only EMPTY)
@@ -664,17 +1163,24 @@ class ForagaxEnv(environment.Environment):
664
1163
  else:
665
1164
  raise ValueError(f"Unknown observation_type: {self.observation_type}")
666
1165
 
667
- def _get_aperture_obs(self, aperture: jax.Array, state: EnvState) -> jax.Array:
1166
+ def _get_aperture_obs(
1167
+ self, aperture: jax.Array, color_aperture: jax.Array, state: EnvState
1168
+ ) -> jax.Array:
668
1169
  """Get aperture observation."""
669
1170
  if self.observation_type == "object":
670
1171
  num_obj_types = len(self.object_ids)
671
1172
  obs = jax.nn.one_hot(aperture, num_obj_types, axis=-1)
672
1173
  return obs
673
1174
  elif self.observation_type == "rgb":
674
- num_obj_types = len(self.object_ids)
675
- aperture_one_hot = jax.nn.one_hot(aperture, num_obj_types)
676
- colors = self.object_colors / 255.0
677
- obs = jnp.tensordot(aperture_one_hot, colors, axes=1)
1175
+ # Use the color aperture that was passed in
1176
+ aperture_colors = color_aperture / 255.0
1177
+
1178
+ # Mask empty cells (object_id == 0) to white
1179
+ empty_mask = aperture == 0
1180
+ white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
1181
+
1182
+ obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
1183
+
678
1184
  return obs
679
1185
  elif self.observation_type == "color":
680
1186
  # Handle case with no objects (only EMPTY)
@@ -721,48 +1227,47 @@ class ForagaxEnv(environment.Environment):
721
1227
 
722
1228
  if is_world_mode:
723
1229
  # Create an RGB image from the object grid
724
- img = jnp.zeros((self.size[1], self.size[0], 3))
725
- # Decode grid for rendering: non-negative are objects, negative are empty
726
- render_grid = jnp.maximum(0, state.object_grid)
1230
+ # Use stateful object colors if dynamic_biomes is enabled, else use default colors
1231
+ if self.dynamic_biomes:
1232
+ # Use per-instance colors from state
1233
+ img = state.object_state.color.copy()
1234
+ # Mask empty cells (object_id == 0) to white
1235
+ empty_mask = state.object_state.object_id == 0
1236
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1237
+ img = jnp.where(empty_mask[..., None], white_color, img)
1238
+ else:
1239
+ # Use default object colors
1240
+ img = jnp.zeros((self.size[1], self.size[0], 3))
1241
+ render_grid = state.object_state.object_id
727
1242
 
728
- def update_image(i, img):
729
- color = self.object_colors[i]
730
- mask = render_grid == i
731
- img = jnp.where(mask[..., None], color, img)
732
- return img
1243
+ def update_image(i, img):
1244
+ color = self.object_colors[i]
1245
+ mask = render_grid == i
1246
+ img = jnp.where(mask[..., None], color, img)
1247
+ return img
733
1248
 
734
- img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
1249
+ img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
735
1250
 
736
1251
  # Tint the agent's aperture
737
- ap_h, ap_w = self.aperture_size
738
- start_y = state.pos[1] - ap_h // 2
739
- start_x = state.pos[0] - ap_w // 2
1252
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1253
+ self._compute_aperture_coordinates(state.pos)
1254
+ )
740
1255
 
741
1256
  alpha = 0.2
742
1257
  agent_color = jnp.array(AGENT.color)
743
1258
 
744
- # Create indices for the aperture
745
- y_offsets = jnp.arange(ap_h)
746
- x_offsets = jnp.arange(ap_w)
747
- y_coords_original = start_y + y_offsets[:, None]
748
- x_coords_original = start_x + x_offsets
749
-
750
1259
  if self.nowrap:
751
- y_coords = jnp.clip(y_coords_original, 0, self.size[1] - 1)
752
- x_coords = jnp.clip(x_coords_original, 0, self.size[0] - 1)
753
1260
  # Create tint mask: any in-bounds original position maps to a cell makes it tinted
754
1261
  tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
755
- tint_mask = tint_mask.at[y_coords, x_coords].set(1)
1262
+ tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
756
1263
  # Apply tint to masked positions
757
1264
  original_colors = img
758
1265
  tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
759
1266
  img = jnp.where(tint_mask[..., None], tinted_colors, img)
760
1267
  else:
761
- y_coords = jnp.mod(y_coords_original, self.size[1])
762
- x_coords = jnp.mod(x_coords_original, self.size[0])
763
- original_colors = img[y_coords, x_coords]
1268
+ original_colors = img[y_coords_adj, x_coords_adj]
764
1269
  tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
765
- img = img.at[y_coords, x_coords].set(tinted_colors)
1270
+ img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
766
1271
 
767
1272
  # Agent color
768
1273
  img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
@@ -775,6 +1280,7 @@ class ForagaxEnv(environment.Environment):
775
1280
 
776
1281
  if is_true_mode:
777
1282
  # Apply true object borders by overlaying true colors on border pixels
1283
+ render_grid = state.object_state.object_id
778
1284
  img = apply_true_borders(
779
1285
  img, render_grid, self.size, len(self.object_ids)
780
1286
  )
@@ -787,10 +1293,35 @@ class ForagaxEnv(environment.Environment):
787
1293
  img = img.at[:, col_indices].set(grid_color)
788
1294
 
789
1295
  elif is_aperture_mode:
790
- obs_grid = jnp.maximum(0, state.object_grid)
1296
+ obs_grid = state.object_state.object_id
791
1297
  aperture = self._get_aperture(obs_grid, state.pos)
792
- aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
793
- img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
1298
+
1299
+ if self.dynamic_biomes:
1300
+ # Use per-instance colors from state - extract aperture view
1301
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1302
+ self._compute_aperture_coordinates(state.pos)
1303
+ )
1304
+ img = state.object_state.color[y_coords_adj, x_coords_adj]
1305
+
1306
+ # Mask empty cells (object_id == 0) to white
1307
+ aperture_object_ids = state.object_state.object_id[
1308
+ y_coords_adj, x_coords_adj
1309
+ ]
1310
+ empty_mask = aperture_object_ids == 0
1311
+ white_color = jnp.array([255, 255, 255], dtype=jnp.uint8)
1312
+ img = jnp.where(empty_mask[..., None], white_color, img)
1313
+
1314
+ if self.nowrap:
1315
+ # For out-of-bounds, use padding object color
1316
+ y_out = (y_coords < 0) | (y_coords >= self.size[1])
1317
+ x_out = (x_coords < 0) | (x_coords >= self.size[0])
1318
+ out_of_bounds = y_out | x_out
1319
+ padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
1320
+ img = jnp.where(out_of_bounds[..., None], padding_color, img)
1321
+ else:
1322
+ # Use default object colors
1323
+ aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
1324
+ img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
794
1325
 
795
1326
  # Draw agent in the center
796
1327
  center_y, center_x = self.aperture_size[1] // 2, self.aperture_size[0] // 2