continual-foragax 0.30.1__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/env.py CHANGED
@@ -19,6 +19,7 @@ from foragax.objects import (
19
19
  EMPTY,
20
20
  PADDING,
21
21
  BaseForagaxObject,
22
+ FourierObject,
22
23
  WeatherObject,
23
24
  )
24
25
  from foragax.rendering import apply_true_borders
@@ -42,6 +43,48 @@ DIRECTIONS = jnp.array(
42
43
  )
43
44
 
44
45
 
46
+ @struct.dataclass
47
+ class ObjectState:
48
+ """Per-cell object state information.
49
+
50
+ This struct encapsulates all state information about objects in the grid:
51
+ - object_id: (H, W) The object type (0 for empty, positive for object type)
52
+ - respawn_timer: (H, W) Countdown timer for respawning (0 = no timer, positive = countdown remaining)
53
+ - respawn_object_id: (H, W) What object type will spawn when timer reaches 0
54
+ - spawn_time: (H, W) When each object was spawned (for expiry tracking)
55
+ - color: (H, W, 3) RGB color for each object instance (for dynamic biomes)
56
+ - generation: (H, W) Which biome generation each object belongs to
57
+ - state_params: (H, W, N) Per-instance parameters (e.g., Fourier coefficients)
58
+ - biome_id: (H, W) Which biome each cell belongs to
59
+ """
60
+
61
+ object_id: jax.Array # (H, W) - Object type ID (0 = empty, >0 = object type)
62
+ respawn_timer: (
63
+ jax.Array
64
+ ) # (H, W) - Respawn countdown (0 = no timer, >0 = countdown)
65
+ respawn_object_id: jax.Array # (H, W) - Object type to spawn when timer reaches 0
66
+ spawn_time: jax.Array # (H, W) - Timestep when object spawned
67
+ color: jax.Array # (H, W, 3) - RGB color per instance
68
+ generation: jax.Array # (H, W) - Biome generation number
69
+ state_params: jax.Array # (H, W, N) - Per-instance parameters
70
+ biome_id: jax.Array # (H, W) - Biome assignment for each cell
71
+
72
+ @classmethod
73
+ def create_empty(cls, size: Tuple[int, int], num_params: int) -> "ObjectState":
74
+ """Create an empty ObjectState for the given grid size."""
75
+ h, w = size[1], size[0]
76
+ return cls(
77
+ object_id=jnp.zeros((h, w), dtype=int),
78
+ respawn_timer=jnp.zeros((h, w), dtype=int),
79
+ respawn_object_id=jnp.zeros((h, w), dtype=int),
80
+ spawn_time=jnp.zeros((h, w), dtype=int),
81
+ color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
82
+ generation=jnp.zeros((h, w), dtype=int),
83
+ state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
84
+ biome_id=jnp.full((h, w), -1, dtype=int),
85
+ )
86
+
87
+
45
88
  @dataclass
46
89
  class Biome:
47
90
  # Object generation frequencies for this biome
@@ -55,13 +98,22 @@ class EnvParams(environment.EnvParams):
55
98
  max_steps_in_episode: Union[int, None]
56
99
 
57
100
 
101
+ @struct.dataclass
102
+ class BiomeState:
103
+ """Biome-level tracking state (num_biomes,)."""
104
+
105
+ consumption_count: jax.Array # objects consumed per biome
106
+ total_objects: jax.Array # total objects spawned per biome
107
+ generation: jax.Array # current generation per biome
108
+
109
+
58
110
  @struct.dataclass
59
111
  class EnvState(environment.EnvState):
60
112
  pos: jax.Array
61
- object_grid: jax.Array
62
- biome_grid: jax.Array
63
113
  time: int
64
114
  digestion_buffer: jax.Array
115
+ object_state: ObjectState
116
+ biome_state: BiomeState
65
117
 
66
118
 
67
119
  class ForagaxEnv(environment.Environment):
@@ -78,6 +130,8 @@ class ForagaxEnv(environment.Environment):
78
130
  deterministic_spawn: bool = False,
79
131
  teleport_interval: Optional[int] = None,
80
132
  observation_type: str = "object",
133
+ dynamic_biomes: bool = False,
134
+ biome_consumption_threshold: float = 0.9,
81
135
  ):
82
136
  super().__init__()
83
137
  self._name = name
@@ -99,11 +153,24 @@ class ForagaxEnv(environment.Environment):
99
153
  self.nowrap = nowrap
100
154
  self.deterministic_spawn = deterministic_spawn
101
155
  self.teleport_interval = teleport_interval
156
+ self.dynamic_biomes = dynamic_biomes
157
+ self.biome_consumption_threshold = biome_consumption_threshold
158
+
102
159
  objects = (EMPTY,) + objects
103
160
  if self.nowrap and not self.full_world:
104
161
  objects = objects + (PADDING,)
105
162
  self.objects = objects
106
163
 
164
+ # Infer num_fourier_terms from objects
165
+ self.num_fourier_terms = max(
166
+ (
167
+ obj.num_fourier_terms
168
+ for obj in self.objects
169
+ if isinstance(obj, FourierObject)
170
+ ),
171
+ default=0,
172
+ )
173
+
107
174
  # JIT-compatible versions of object and biome properties
108
175
  self.object_ids = jnp.arange(len(objects))
109
176
  self.object_blocking = jnp.array([o.blocking for o in objects])
@@ -114,6 +181,15 @@ class ForagaxEnv(environment.Environment):
114
181
  self.reward_fns = [o.reward for o in objects]
115
182
  self.regen_delay_fns = [o.regen_delay for o in objects]
116
183
  self.reward_delay_fns = [o.reward_delay for o in objects]
184
+ self.expiry_regen_delay_fns = [o.expiry_regen_delay for o in objects]
185
+
186
+ # Expiry times per object (None becomes -1 for no expiry)
187
+ self.object_expiry_time = jnp.array(
188
+ [o.expiry_time if o.expiry_time is not None else -1 for o in objects]
189
+ )
190
+
191
+ # Check if any objects can expire
192
+ self.has_expiring_objects = jnp.any(self.object_expiry_time >= 0)
117
193
 
118
194
  # Compute reward steps per object (using max_reward_delay attribute)
119
195
  object_max_reward_delay = jnp.array([o.max_reward_delay for o in objects])
@@ -131,6 +207,7 @@ class ForagaxEnv(environment.Environment):
131
207
  [b.stop if b.stop is not None else (-1, -1) for b in biomes]
132
208
  )
133
209
  self.biome_sizes = np.prod(self.biome_stops - self.biome_starts, axis=1)
210
+ self.biome_sizes_jax = jnp.array(self.biome_sizes) # JAX version for indexing
134
211
  self.biome_starts_jax = jnp.array(self.biome_starts)
135
212
  self.biome_stops_jax = jnp.array(self.biome_stops)
136
213
  biome_centers = []
@@ -142,6 +219,7 @@ class ForagaxEnv(environment.Environment):
142
219
  biome_centers.append((center_x, center_y))
143
220
  self.biome_centers_jax = jnp.array(biome_centers)
144
221
  self.biome_masks = []
222
+ biome_masks_array = []
145
223
  for i in range(self.biome_object_frequencies.shape[0]):
146
224
  # Create mask for the biome
147
225
  start = jax.lax.select(
@@ -163,6 +241,10 @@ class ForagaxEnv(environment.Environment):
163
241
  & (cols < stop[0])
164
242
  )
165
243
  self.biome_masks.append(mask)
244
+ biome_masks_array.append(mask)
245
+
246
+ # Convert to JAX array for indexing in JIT-compiled code
247
+ self.biome_masks_array = jnp.array(biome_masks_array)
166
248
 
167
249
  # Compute unique colors and mapping for partial observability (for 'color' observation_type)
168
250
  # Exclude EMPTY (index 0) from color channels
@@ -193,6 +275,90 @@ class ForagaxEnv(environment.Environment):
193
275
  max_steps_in_episode=None,
194
276
  )
195
277
 
278
+ def _place_timer(
279
+ self,
280
+ object_state: ObjectState,
281
+ y: int,
282
+ x: int,
283
+ object_type: int,
284
+ timer_val: int,
285
+ random_respawn: bool,
286
+ rand_key: jax.Array,
287
+ ) -> ObjectState:
288
+ """Place a timer at position or randomly within the same biome.
289
+
290
+ Args:
291
+ object_state: Current object state
292
+ y, x: Original position
293
+ object_type: The object type ID that will respawn (0 for permanent removal)
294
+ timer_val: Timer countdown value (0 for permanent removal, positive for countdown)
295
+ random_respawn: If True, place at random location in same biome
296
+ rand_key: Random key for random placement
297
+
298
+ Returns:
299
+ Updated object_state with timer placed
300
+ """
301
+
302
+ # Handle permanent removal (timer_val == 0)
303
+ def place_empty():
304
+ return object_state.replace(
305
+ object_id=object_state.object_id.at[y, x].set(0),
306
+ respawn_timer=object_state.respawn_timer.at[y, x].set(0),
307
+ respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
308
+ )
309
+
310
+ # Handle timer placement
311
+ def place_timer():
312
+ # Non-random: place at original position
313
+ def place_at_position():
314
+ return object_state.replace(
315
+ object_id=object_state.object_id.at[y, x].set(0),
316
+ respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
317
+ respawn_object_id=object_state.respawn_object_id.at[y, x].set(
318
+ object_type
319
+ ),
320
+ )
321
+
322
+ # Random: place at random location in same biome
323
+ def place_randomly():
324
+ # Clear the collected object's position
325
+ new_object_id = object_state.object_id.at[y, x].set(0)
326
+ new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
327
+ new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
328
+
329
+ # Find valid spawn locations in the same biome
330
+ biome_id = object_state.biome_id[y, x]
331
+ biome_mask = object_state.biome_id == biome_id
332
+ empty_mask = new_object_id == 0
333
+ no_timer_mask = new_respawn_timer == 0
334
+ valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
335
+ num_valid_spawns = jnp.sum(valid_spawn_mask)
336
+
337
+ y_indices, x_indices = jnp.nonzero(
338
+ valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
339
+ )
340
+ valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
341
+ random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
342
+ new_spawn_pos = valid_spawn_indices[random_idx]
343
+
344
+ # Place timer at the new random position
345
+ new_respawn_timer = new_respawn_timer.at[
346
+ new_spawn_pos[0], new_spawn_pos[1]
347
+ ].set(timer_val)
348
+ new_respawn_object_id = new_respawn_object_id.at[
349
+ new_spawn_pos[0], new_spawn_pos[1]
350
+ ].set(object_type)
351
+
352
+ return object_state.replace(
353
+ object_id=new_object_id,
354
+ respawn_timer=new_respawn_timer,
355
+ respawn_object_id=new_respawn_object_id,
356
+ )
357
+
358
+ return jax.lax.cond(random_respawn, place_randomly, place_at_position)
359
+
360
+ return jax.lax.cond(timer_val == 0, place_empty, place_timer)
361
+
196
362
  def step_env(
197
363
  self,
198
364
  key: jax.Array,
@@ -201,9 +367,7 @@ class ForagaxEnv(environment.Environment):
201
367
  params: EnvParams,
202
368
  ) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
203
369
  """Perform single timestep state transition."""
204
- num_obj_types = len(self.object_ids)
205
- # Decode the object grid: positive values are objects, negative are timers (treat as empty)
206
- current_objects = jnp.maximum(0, state.object_grid)
370
+ current_objects = state.object_state.object_id
207
371
 
208
372
  # 1. UPDATE AGENT POSITION
209
373
  direction = DIRECTIONS[action]
@@ -246,9 +410,12 @@ class ForagaxEnv(environment.Environment):
246
410
  # Handle digestion: add reward to buffer if collected
247
411
  digestion_buffer = state.digestion_buffer
248
412
  key, reward_subkey = jax.random.split(key)
413
+
414
+ object_params = state.object_state.state_params[pos[1], pos[0]]
249
415
  object_reward = jax.lax.switch(
250
- obj_at_pos, self.reward_fns, state.time, reward_subkey
416
+ obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
251
417
  )
418
+
252
419
  key, digestion_subkey = jax.random.split(key)
253
420
  reward_delay = jax.lax.switch(
254
421
  obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
@@ -271,61 +438,120 @@ class ForagaxEnv(environment.Environment):
271
438
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
272
439
  key, regen_subkey, rand_key = jax.random.split(key, 3)
273
440
 
274
- # Decrement timers (stored as negative values)
275
- is_timer = state.object_grid < 0
276
- object_grid = jnp.where(
277
- is_timer, state.object_grid + num_obj_types, state.object_grid
441
+ # Decrement respawn timers
442
+ has_timer = state.object_state.respawn_timer > 0
443
+ new_respawn_timer = jnp.where(
444
+ has_timer,
445
+ state.object_state.respawn_timer - 1,
446
+ state.object_state.respawn_timer,
447
+ )
448
+
449
+ # Track which cells have timers that just reached 0
450
+ just_respawned = has_timer & (new_respawn_timer == 0)
451
+
452
+ # Respawn objects where timer reached 0
453
+ new_object_id = jnp.where(
454
+ just_respawned,
455
+ state.object_state.respawn_object_id,
456
+ state.object_state.object_id,
457
+ )
458
+
459
+ # Clear respawn_object_id for cells that just respawned
460
+ new_respawn_object_id = jnp.where(
461
+ just_respawned,
462
+ 0,
463
+ state.object_state.respawn_object_id,
464
+ )
465
+
466
+ # Update spawn times for objects that just respawned
467
+ spawn_time = jnp.where(
468
+ just_respawned, state.time, state.object_state.spawn_time
278
469
  )
279
470
 
280
471
  # Collect object: set a timer
281
472
  regen_delay = jax.lax.switch(
282
473
  obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
283
474
  )
284
- encoded_timer = obj_at_pos - ((regen_delay + 1) * num_obj_types)
285
-
286
- def place_at_current_pos(current_grid, timer_val):
287
- return current_grid.at[pos[1], pos[0]].set(timer_val)
475
+ timer_countdown = jax.lax.cond(
476
+ regen_delay == jnp.iinfo(jnp.int32).max,
477
+ lambda: 0, # No timer (permanent removal)
478
+ lambda: regen_delay + 1, # Timer countdown
479
+ )
288
480
 
289
- def place_at_random_pos(current_grid, timer_val):
290
- # Set the collected position to empty temporarily
291
- grid = current_grid.at[pos[1], pos[0]].set(0)
481
+ # If collected, replace object with timer; otherwise, keep it
482
+ val_at_pos = current_objects[pos[1], pos[0]]
483
+ # Use original should_collect for consumption tracking
484
+ should_collect_now = is_collectable & (val_at_pos > 0)
485
+
486
+ # Create updated object state with new respawn_timer, object_id, and spawn_time
487
+ object_state = state.object_state.replace(
488
+ object_id=new_object_id,
489
+ respawn_timer=new_respawn_timer,
490
+ respawn_object_id=new_respawn_object_id,
491
+ spawn_time=spawn_time,
492
+ )
292
493
 
293
- # Find all valid spawn locations (empty cells within the same biome)
294
- biome_id = state.biome_grid[pos[1], pos[0]]
295
- biome_mask = state.biome_grid == biome_id
296
- empty_mask = grid == 0
297
- valid_spawn_mask = biome_mask & empty_mask
494
+ # Place timer on collection
495
+ object_state = jax.lax.cond(
496
+ should_collect_now,
497
+ lambda: self._place_timer(
498
+ object_state,
499
+ pos[1],
500
+ pos[0],
501
+ obj_at_pos, # object type
502
+ timer_countdown, # timer value
503
+ self.object_random_respawn[obj_at_pos],
504
+ rand_key,
505
+ ),
506
+ lambda: object_state,
507
+ )
298
508
 
299
- num_valid_spawns = jnp.sum(valid_spawn_mask)
509
+ # Clear color grid when object is collected
510
+ object_state = jax.lax.cond(
511
+ should_collect_now,
512
+ lambda: object_state.replace(
513
+ color=object_state.color.at[pos[1], pos[0]].set(
514
+ jnp.full((3,), 255, dtype=jnp.uint8)
515
+ )
516
+ ),
517
+ lambda: object_state,
518
+ )
300
519
 
301
- # Get indices of valid spawn locations, padded to a static size
302
- y_indices, x_indices = jnp.nonzero(
303
- valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
520
+ # 3.5. HANDLE OBJECT EXPIRY
521
+ # Only process expiry if there are objects that can expire
522
+ key, object_state = self.expire_objects(key, state, object_state)
523
+
524
+ # 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
525
+ if self.dynamic_biomes:
526
+ # Update consumption count if an object was collected
527
+ # Only count if the object belongs to the current generation of its biome
528
+ collected_biome_id = object_state.biome_id[pos[1], pos[0]]
529
+ object_gen_at_pos = object_state.generation[pos[1], pos[0]]
530
+ current_biome_gen = state.biome_state.generation[collected_biome_id]
531
+ is_current_generation = object_gen_at_pos == current_biome_gen
532
+
533
+ biome_consumption_count = state.biome_state.consumption_count
534
+ biome_consumption_count = jax.lax.cond(
535
+ should_collect & is_current_generation,
536
+ lambda: biome_consumption_count.at[collected_biome_id].add(1),
537
+ lambda: biome_consumption_count,
304
538
  )
305
- valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
306
-
307
- # Select a random valid location
308
- random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
309
- new_spawn_pos = valid_spawn_indices[random_idx]
310
539
 
311
- # Place the timer at the new random position
312
- return grid.at[new_spawn_pos[0], new_spawn_pos[1]].set(timer_val)
313
-
314
- # If collected, replace object with timer; otherwise, keep it
315
- val_at_pos = object_grid[pos[1], pos[0]]
316
- should_collect = is_collectable & (val_at_pos > 0)
317
-
318
- # When not collecting, the value at the position remains unchanged.
319
- # When collecting, we either place the timer at the current position or a random one.
320
- object_grid = jax.lax.cond(
321
- should_collect,
322
- lambda: jax.lax.cond(
323
- self.object_random_respawn[obj_at_pos],
324
- lambda: place_at_random_pos(object_grid, encoded_timer),
325
- lambda: place_at_current_pos(object_grid, encoded_timer),
326
- ),
327
- lambda: object_grid,
328
- )
540
+ # Check each biome for threshold crossing and respawn if needed
541
+ key, respawn_key = jax.random.split(key)
542
+ biome_state = BiomeState(
543
+ consumption_count=biome_consumption_count,
544
+ total_objects=state.biome_state.total_objects,
545
+ generation=state.biome_state.generation,
546
+ )
547
+ object_state, biome_state, respawn_key = self._check_and_respawn_biomes(
548
+ object_state,
549
+ biome_state,
550
+ state.time,
551
+ respawn_key,
552
+ )
553
+ else:
554
+ biome_state = state.biome_state
329
555
 
330
556
  info = {"discount": self.discount(state, params)}
331
557
  temperatures = jnp.zeros(len(self.objects))
@@ -335,16 +561,16 @@ class ForagaxEnv(environment.Environment):
335
561
  get_temperature(obj.rewards, state.time, obj.repeat)
336
562
  )
337
563
  info["temperatures"] = temperatures
338
- info["biome_id"] = state.biome_grid[pos[1], pos[0]]
564
+ info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
339
565
  info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
340
566
 
341
567
  # 4. UPDATE STATE
342
568
  state = EnvState(
343
569
  pos=pos,
344
- object_grid=object_grid,
345
- biome_grid=state.biome_grid,
346
570
  time=state.time + 1,
347
571
  digestion_buffer=digestion_buffer,
572
+ object_state=object_state,
573
+ biome_state=biome_state,
348
574
  )
349
575
 
350
576
  done = self.is_terminal(state, params)
@@ -356,56 +582,401 @@ class ForagaxEnv(environment.Environment):
356
582
  info,
357
583
  )
358
584
 
585
+ def expire_objects(
586
+ self, key, state, object_state: ObjectState
587
+ ) -> Tuple[jax.Array, ObjectState]:
588
+ if self.has_expiring_objects:
589
+ # Check each cell for objects that have exceeded their expiry time
590
+ current_objects_for_expiry = object_state.object_id
591
+
592
+ # Calculate age of each object (current_time - spawn_time)
593
+ object_ages = state.time - object_state.spawn_time
594
+
595
+ # Get expiry time for each object type in the grid
596
+ expiry_times = self.object_expiry_time[current_objects_for_expiry]
597
+
598
+ # Check if object should expire (age >= expiry_time and expiry_time >= 0)
599
+ should_expire = (
600
+ (object_ages >= expiry_times)
601
+ & (expiry_times >= 0)
602
+ & (current_objects_for_expiry > 0)
603
+ )
604
+
605
+ # Count how many objects actually need to expire
606
+ num_expiring = jnp.sum(should_expire)
607
+
608
+ # Only process expiry if there are actually objects to expire
609
+ def process_expiries():
610
+ # Get positions of objects that should expire
611
+ # Use nonzero with fixed size to maintain JIT compatibility
612
+ max_objects = self.size[0] * self.size[1]
613
+ y_indices, x_indices = jnp.nonzero(
614
+ should_expire, size=max_objects, fill_value=-1
615
+ )
616
+
617
+ key_local, expiry_key = jax.random.split(key)
618
+
619
+ def process_one_expiry(carry, i):
620
+ obj_state = carry
621
+ y = y_indices[i]
622
+ x = x_indices[i]
623
+
624
+ # Skip if this is a padding index (from fill_value)
625
+ is_valid = (y >= 0) & (x >= 0)
626
+
627
+ def expire_one():
628
+ obj_id = current_objects_for_expiry[y, x]
629
+ exp_key = jax.random.fold_in(expiry_key, y * self.size[0] + x)
630
+ exp_delay = jax.lax.switch(
631
+ obj_id, self.expiry_regen_delay_fns, state.time, exp_key
632
+ )
633
+ timer_countdown = jax.lax.cond(
634
+ exp_delay == jnp.iinfo(jnp.int32).max,
635
+ lambda: 0, # No timer (permanent removal)
636
+ lambda: exp_delay + 1, # Timer countdown
637
+ )
638
+
639
+ # Use unified timer placement method
640
+ rand_key = jax.random.split(exp_key)[1]
641
+ new_obj_state = self._place_timer(
642
+ obj_state,
643
+ y,
644
+ x,
645
+ obj_id,
646
+ timer_countdown,
647
+ self.object_random_respawn[obj_id],
648
+ rand_key,
649
+ )
650
+
651
+ # Clear color grid when object expires
652
+ empty_color = jnp.full((3,), 255, dtype=jnp.uint8)
653
+ new_obj_state = new_obj_state.replace(
654
+ color=new_obj_state.color.at[y, x].set(empty_color)
655
+ )
656
+
657
+ return new_obj_state
658
+
659
+ def no_op():
660
+ return obj_state
661
+
662
+ return jax.lax.cond(is_valid, expire_one, no_op), None
663
+
664
+ new_object_state, _ = jax.lax.scan(
665
+ process_one_expiry,
666
+ object_state,
667
+ jnp.arange(max_objects),
668
+ )
669
+ return key_local, new_object_state
670
+
671
+ def no_expiries():
672
+ return key, object_state
673
+
674
+ key, object_state = jax.lax.cond(
675
+ num_expiring > 0,
676
+ process_expiries,
677
+ no_expiries,
678
+ )
679
+
680
+ return key, object_state
681
+
682
+ def _check_and_respawn_biomes(
683
+ self,
684
+ object_state: ObjectState,
685
+ biome_state: BiomeState,
686
+ current_time: int,
687
+ key: jax.Array,
688
+ ) -> Tuple[ObjectState, BiomeState, jax.Array]:
689
+ """Check all biomes for consumption threshold and respawn if needed."""
690
+
691
+ num_biomes = self.biome_object_frequencies.shape[0]
692
+
693
+ # Compute consumption rates for all biomes
694
+ consumption_rates = biome_state.consumption_count / jnp.maximum(
695
+ 1.0, biome_state.total_objects.astype(float)
696
+ )
697
+ should_respawn = consumption_rates >= self.biome_consumption_threshold
698
+
699
+ # Split key for all biomes in parallel
700
+ key, subkey = jax.random.split(key)
701
+ biome_keys = jax.random.split(subkey, num_biomes)
702
+
703
+ # Compute all new spawns in parallel using vmap for random, switch for deterministic
704
+ if self.deterministic_spawn:
705
+ # Use switch to dispatch to concrete biome spawns for deterministic
706
+ def make_spawn_fn(biome_idx):
707
+ def spawn_fn(key):
708
+ return self._spawn_biome_objects(biome_idx, key, deterministic=True)
709
+
710
+ return spawn_fn
711
+
712
+ spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
713
+
714
+ # Apply switch for each biome
715
+ all_new_objects_list = []
716
+ all_new_colors_list = []
717
+ all_new_params_list = []
718
+ for i in range(num_biomes):
719
+ obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
720
+ all_new_objects_list.append(obj)
721
+ all_new_colors_list.append(col)
722
+ all_new_params_list.append(par)
723
+
724
+ all_new_objects = jnp.stack(all_new_objects_list)
725
+ all_new_colors = jnp.stack(all_new_colors_list)
726
+ all_new_params = jnp.stack(all_new_params_list)
727
+ else:
728
+ # Random spawn works with vmap
729
+ all_new_objects, all_new_colors, all_new_params = jax.vmap(
730
+ lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
731
+ )(jnp.arange(num_biomes), biome_keys)
732
+
733
+ # Initialize updated grids
734
+ new_obj_id = object_state.object_id
735
+ new_color = object_state.color
736
+ new_params = object_state.state_params
737
+ new_spawn = object_state.spawn_time
738
+ new_gen = object_state.generation
739
+
740
+ # Update biome state
741
+ new_consumption_count = jnp.where(
742
+ should_respawn, 0, biome_state.consumption_count
743
+ )
744
+ new_generation = biome_state.generation + should_respawn.astype(int)
745
+
746
+ # Compute new total objects for respawning biomes
747
+ def count_objects(i):
748
+ return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
749
+
750
+ new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
751
+ new_total_objects = jnp.where(
752
+ should_respawn, new_object_counts, biome_state.total_objects
753
+ )
754
+
755
+ new_biome_state = BiomeState(
756
+ consumption_count=new_consumption_count,
757
+ total_objects=new_total_objects,
758
+ generation=new_generation,
759
+ )
760
+
761
+ # Update grids for respawning biomes
762
+ for i in range(num_biomes):
763
+ biome_mask = self.biome_masks_array[i]
764
+ new_gen_value = new_biome_state.generation[i]
765
+
766
+ # Only update where new spawn has objects and biome should respawn
767
+ is_new_object = (
768
+ (all_new_objects[i] > 0) & biome_mask & should_respawn[i][..., None]
769
+ )
770
+
771
+ new_obj_id = jnp.where(is_new_object, all_new_objects[i], new_obj_id)
772
+ new_color = jnp.where(
773
+ is_new_object[..., None], all_new_colors[i], new_color
774
+ )
775
+ new_params = jnp.where(
776
+ is_new_object[..., None], all_new_params[i], new_params
777
+ )
778
+ new_gen = jnp.where(is_new_object, new_gen_value, new_gen)
779
+ new_spawn = jnp.where(is_new_object, current_time, new_spawn)
780
+
781
+ # Clear timers in respawning biomes
782
+ new_respawn_timer = object_state.respawn_timer
783
+ new_respawn_object_id = object_state.respawn_object_id
784
+ for i in range(num_biomes):
785
+ biome_mask = self.biome_masks_array[i]
786
+ should_clear = biome_mask & should_respawn[i][..., None]
787
+ new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
788
+ new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
789
+
790
+ object_state = object_state.replace(
791
+ object_id=new_obj_id,
792
+ respawn_timer=new_respawn_timer,
793
+ respawn_object_id=new_respawn_object_id,
794
+ color=new_color,
795
+ state_params=new_params,
796
+ generation=new_gen,
797
+ spawn_time=new_spawn,
798
+ )
799
+
800
+ return object_state, new_biome_state, key
801
+
359
802
  def reset_env(
360
803
  self, key: jax.Array, params: EnvParams
361
804
  ) -> Tuple[jax.Array, EnvState]:
362
805
  """Reset environment state."""
363
- object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
364
- biome_grid = jnp.full((self.size[1], self.size[0]), -1, dtype=int)
806
+ num_object_params = 2 + 2 * self.num_fourier_terms
807
+ object_state = ObjectState.create_empty(self.size, num_object_params)
808
+
365
809
  key, iter_key = jax.random.split(key)
810
+
811
+ # Spawn objects in each biome using unified method
366
812
  for i in range(self.biome_object_frequencies.shape[0]):
367
813
  iter_key, biome_key = jax.random.split(iter_key)
368
814
  mask = self.biome_masks[i]
369
- biome_grid = jnp.where(mask, i, biome_grid)
370
815
 
371
- if self.deterministic_spawn:
372
- biome_objects = self.generate_biome_new(i, biome_key)
373
- object_grid = object_grid.at[mask].set(biome_objects)
374
- else:
375
- biome_objects = self.generate_biome_old(i, biome_key)
376
- object_grid = jnp.where(mask, biome_objects, object_grid)
816
+ # Set biome_id
817
+ object_state = object_state.replace(
818
+ biome_id=jnp.where(mask, i, object_state.biome_id)
819
+ )
820
+
821
+ # Use unified spawn method
822
+ biome_objects, biome_colors, biome_object_params = (
823
+ self._spawn_biome_objects(i, biome_key, self.deterministic_spawn)
824
+ )
825
+
826
+ # Merge biome objects/colors/params into object_state
827
+ object_state = object_state.replace(
828
+ object_id=jnp.where(mask, biome_objects, object_state.object_id),
829
+ color=jnp.where(mask[..., None], biome_colors, object_state.color),
830
+ state_params=jnp.where(
831
+ mask[..., None], biome_object_params, object_state.state_params
832
+ ),
833
+ )
377
834
 
378
835
  # Place agent in the center of the world
379
836
  agent_pos = jnp.array([self.size[0] // 2, self.size[1] // 2])
380
837
 
838
+ # Initialize biome consumption tracking
839
+ num_biomes = self.biome_object_frequencies.shape[0]
840
+ biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
841
+ biome_total_objects = jnp.zeros(num_biomes, dtype=int)
842
+
843
+ # Count objects in each biome
844
+ for i in range(num_biomes):
845
+ mask = self.biome_masks[i]
846
+ # Count non-empty objects (object_id > 0)
847
+ total = jnp.sum((object_state.object_id > 0) & mask)
848
+ biome_total_objects = biome_total_objects.at[i].set(total)
849
+
850
+ biome_generation = jnp.zeros(num_biomes, dtype=int)
851
+
381
852
  state = EnvState(
382
853
  pos=agent_pos,
383
- object_grid=object_grid,
384
- biome_grid=biome_grid,
385
854
  time=0,
386
855
  digestion_buffer=jnp.zeros((self.max_reward_delay,)),
856
+ object_state=object_state,
857
+ biome_state=BiomeState(
858
+ consumption_count=biome_consumption_count,
859
+ total_objects=biome_total_objects,
860
+ generation=biome_generation,
861
+ ),
387
862
  )
388
863
 
389
864
  return self.get_obs(state, params), state
390
865
 
391
- def generate_biome_old(self, i: int, biome_key: jax.Array):
392
- biome_freqs = self.biome_object_frequencies[i]
393
- grid_rand = jax.random.uniform(biome_key, (self.size[1], self.size[0]))
394
- empty_freq = 1.0 - jnp.sum(biome_freqs)
395
- all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
396
- cumulative_freqs = jnp.cumsum(jnp.concatenate([jnp.array([0.0]), all_freqs]))
397
- biome_objects = jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
398
- return biome_objects
399
-
400
- def generate_biome_new(self, i: int, biome_key: jax.Array):
401
- biome_freqs = self.biome_object_frequencies[i]
402
- grid = jnp.linspace(0, 1, self.biome_sizes[i], endpoint=False)
403
- biome_objects = len(biome_freqs) - jnp.searchsorted(
404
- jnp.cumsum(biome_freqs[::-1]), grid, side="right"
866
+ def _spawn_biome_objects(
867
+ self,
868
+ biome_idx: int,
869
+ key: jax.Array,
870
+ deterministic: bool = False,
871
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
872
+ """Spawn objects in a biome.
873
+
874
+ Returns:
875
+ object_grid: (H, W) array of object IDs
876
+ color_grid: (H, W, 3) array of RGB colors
877
+ state_grid: (H, W, num_state_params) array of object state parameters
878
+ """
879
+ biome_freqs = self.biome_object_frequencies[biome_idx]
880
+ biome_mask = self.biome_masks_array[biome_idx]
881
+
882
+ key, spawn_key, color_key, params_key = jax.random.split(key, 4)
883
+
884
+ # Generate object IDs using deterministic or random spawn
885
+ if deterministic:
886
+ # Deterministic spawn: exact number of each object type
887
+ # NOTE: Requires concrete biome_idx to compute size at trace time
888
+ # Get static biome bounds
889
+ biome_start = self.biome_starts[biome_idx]
890
+ biome_stop = self.biome_stops[biome_idx]
891
+ biome_height = biome_stop[1] - biome_start[1]
892
+ biome_width = biome_stop[0] - biome_start[0]
893
+ biome_size = int(self.biome_sizes[biome_idx])
894
+
895
+ grid = jnp.linspace(0, 1, biome_size, endpoint=False)
896
+ biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
897
+ jnp.cumsum(biome_freqs[::-1]), grid, side="right"
898
+ )
899
+ biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
900
+
901
+ # Reshape to match biome dimensions (use concrete dimensions)
902
+ biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
903
+
904
+ # Place in full grid using slicing with static bounds
905
+ object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
906
+ object_grid = object_grid.at[
907
+ biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
908
+ ].set(biome_objects)
909
+ else:
910
+ # Random spawn: probabilistic placement (works with traced biome_idx)
911
+ grid_rand = jax.random.uniform(spawn_key, (self.size[1], self.size[0]))
912
+ empty_freq = 1.0 - jnp.sum(biome_freqs)
913
+ all_freqs = jnp.concatenate([jnp.array([empty_freq]), biome_freqs])
914
+ cumulative_freqs = jnp.cumsum(
915
+ jnp.concatenate([jnp.array([0.0]), all_freqs])
916
+ )
917
+ object_grid = (
918
+ jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
919
+ )
920
+
921
+ # Initialize color grid
922
+ color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
923
+
924
+ # Sample ONE color per object type in this biome (not per instance)
925
+ # This gives objects of the same type the same color within a biome generation
926
+ # Skip index 0 (EMPTY object) - only sample colors for actual objects
927
+ num_object_types = len(self.objects)
928
+ num_actual_objects = num_object_types - 1 # Exclude EMPTY
929
+
930
+ if num_actual_objects > 0:
931
+ biome_object_colors = jax.random.randint(
932
+ color_key,
933
+ (num_actual_objects, 3),
934
+ minval=0,
935
+ maxval=256,
936
+ dtype=jnp.uint8,
937
+ )
938
+
939
+ # Assign colors based on object type (starting from index 1)
940
+ for obj_idx in range(1, num_object_types):
941
+ obj_mask = (object_grid == obj_idx) & biome_mask
942
+ obj_color = biome_object_colors[
943
+ obj_idx - 1
944
+ ] # Offset by 1 since we skip EMPTY
945
+ color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
946
+
947
+ # Initialize parameters grid
948
+ num_object_params = 2 + 2 * self.num_fourier_terms
949
+ params_grid = jnp.zeros(
950
+ (self.size[1], self.size[0], num_object_params), dtype=jnp.float32
405
951
  )
406
- flat_biome_objects = biome_objects.flatten()
407
- shuffled_objects = jax.random.permutation(biome_key, flat_biome_objects)
408
- return shuffled_objects
952
+
953
+ # Generate per-object parameters for each object type
954
+ for obj_idx in range(num_object_types):
955
+ # Get params for this object type - this happens at trace time
956
+ params_key, obj_key = jax.random.split(params_key)
957
+ obj_params = self.objects[obj_idx].get_state(obj_key)
958
+
959
+ # Skip if no params (e.g., for EMPTY or default objects)
960
+ if len(obj_params) == 0:
961
+ continue
962
+
963
+ # Ensure params match expected size
964
+ if len(obj_params) != num_object_params:
965
+ if len(obj_params) < num_object_params:
966
+ obj_params = jnp.pad(
967
+ obj_params,
968
+ (0, num_object_params - len(obj_params)),
969
+ constant_values=0.0,
970
+ )
971
+ else:
972
+ # Truncate if too long
973
+ obj_params = obj_params[:num_object_params]
974
+
975
+ # Assign to all objects of this type in this biome
976
+ obj_mask = (object_grid == obj_idx) & biome_mask
977
+ params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
978
+
979
+ return object_grid, color_grid, params_grid
409
980
 
410
981
  def is_terminal(self, state: EnvState, params: EnvParams) -> jax.Array:
411
982
  """Foragax is a continuing environment."""
@@ -428,21 +999,10 @@ class ForagaxEnv(environment.Environment):
428
999
 
429
1000
  def state_space(self, params: EnvParams) -> spaces.Dict:
430
1001
  """State space of the environment."""
1002
+ num_object_params = 2 + 2 * self.num_fourier_terms
431
1003
  return spaces.Dict(
432
1004
  {
433
1005
  "pos": spaces.Box(0, max(self.size), (2,), int),
434
- "object_grid": spaces.Box(
435
- -1000 * len(self.object_ids),
436
- len(self.object_ids),
437
- (self.size[1], self.size[0]),
438
- int,
439
- ),
440
- "biome_grid": spaces.Box(
441
- 0,
442
- self.biome_object_frequencies.shape[0],
443
- (self.size[1], self.size[0]),
444
- int,
445
- ),
446
1006
  "time": spaces.Discrete(params.max_steps_in_episode),
447
1007
  "digestion_buffer": spaces.Box(
448
1008
  -jnp.inf,
@@ -450,11 +1010,79 @@ class ForagaxEnv(environment.Environment):
450
1010
  (self.max_reward_delay,),
451
1011
  float,
452
1012
  ),
1013
+ "object_state": spaces.Dict(
1014
+ {
1015
+ "object_id": spaces.Box(
1016
+ -1000 * len(self.object_ids),
1017
+ len(self.object_ids),
1018
+ (self.size[1], self.size[0]),
1019
+ int,
1020
+ ),
1021
+ "spawn_time": spaces.Box(
1022
+ 0,
1023
+ jnp.inf,
1024
+ (self.size[1], self.size[0]),
1025
+ int,
1026
+ ),
1027
+ "color": spaces.Box(
1028
+ 0,
1029
+ 255,
1030
+ (self.size[1], self.size[0], 3),
1031
+ int,
1032
+ ),
1033
+ "generation": spaces.Box(
1034
+ 0,
1035
+ jnp.inf,
1036
+ (self.size[1], self.size[0]),
1037
+ int,
1038
+ ),
1039
+ "state_params": spaces.Box(
1040
+ -jnp.inf,
1041
+ jnp.inf,
1042
+ (self.size[1], self.size[0], num_object_params),
1043
+ float,
1044
+ ),
1045
+ "biome_id": spaces.Box(
1046
+ -1,
1047
+ self.biome_object_frequencies.shape[0],
1048
+ (self.size[1], self.size[0]),
1049
+ int,
1050
+ ),
1051
+ }
1052
+ ),
1053
+ "biome_state": spaces.Dict(
1054
+ {
1055
+ "consumption_count": spaces.Box(
1056
+ 0,
1057
+ jnp.inf,
1058
+ (self.biome_object_frequencies.shape[0],),
1059
+ int,
1060
+ ),
1061
+ "total_objects": spaces.Box(
1062
+ 0,
1063
+ jnp.inf,
1064
+ (self.biome_object_frequencies.shape[0],),
1065
+ int,
1066
+ ),
1067
+ "generation": spaces.Box(
1068
+ 0,
1069
+ jnp.inf,
1070
+ (self.biome_object_frequencies.shape[0],),
1071
+ int,
1072
+ ),
1073
+ }
1074
+ ),
453
1075
  }
454
1076
  )
455
1077
 
456
- def _get_aperture(self, object_grid: jax.Array, pos: jax.Array) -> jax.Array:
457
- """Extract the aperture view from the object grid."""
1078
+ def _compute_aperture_coordinates(
1079
+ self, pos: jax.Array
1080
+ ) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
1081
+ """Compute aperture coordinates for the given position.
1082
+
1083
+ Returns:
1084
+ (y_coords, x_coords, y_coords_clamped/mod, x_coords_clamped/mod)
1085
+ """
458
1086
  ap_h, ap_w = self.aperture_size
459
1087
  start_y = pos[1] - ap_h // 2
460
1088
  start_x = pos[0] - ap_w // 2
@@ -465,27 +1093,37 @@ class ForagaxEnv(environment.Environment):
465
1093
  x_coords = start_x + x_offsets
466
1094
 
467
1095
  if self.nowrap:
468
- # Clamp coordinates to bounds
469
- y_coords_clamped = jnp.clip(y_coords, 0, self.size[1] - 1)
470
- x_coords_clamped = jnp.clip(x_coords, 0, self.size[0] - 1)
471
- values = object_grid[y_coords_clamped, x_coords_clamped]
472
- # Mark out-of-bounds positions with -1
1096
+ y_coords_adj = jnp.clip(y_coords, 0, self.size[1] - 1)
1097
+ x_coords_adj = jnp.clip(x_coords, 0, self.size[0] - 1)
1098
+ else:
1099
+ y_coords_adj = jnp.mod(y_coords, self.size[1])
1100
+ x_coords_adj = jnp.mod(x_coords, self.size[0])
1101
+
1102
+ return y_coords, x_coords, y_coords_adj, x_coords_adj
1103
+
1104
+ def _get_aperture(self, object_id_grid: jax.Array, pos: jax.Array) -> jax.Array:
1105
+ """Extract the aperture view from the object id grid."""
1106
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1107
+ self._compute_aperture_coordinates(pos)
1108
+ )
1109
+
1110
+ values = object_id_grid[y_coords_adj, x_coords_adj]
1111
+
1112
+ if self.nowrap:
1113
+ # Mark out-of-bounds positions with padding
473
1114
  y_out = (y_coords < 0) | (y_coords >= self.size[1])
474
1115
  x_out = (x_coords < 0) | (x_coords >= self.size[0])
475
1116
  out_of_bounds = y_out | x_out
476
1117
  padding_index = self.object_ids[-1]
477
1118
  aperture = jnp.where(out_of_bounds, padding_index, values)
478
1119
  else:
479
- y_coords_mod = jnp.mod(y_coords, self.size[1])
480
- x_coords_mod = jnp.mod(x_coords, self.size[0])
481
- aperture = object_grid[y_coords_mod, x_coords_mod]
1120
+ aperture = values
482
1121
 
483
1122
  return aperture
484
1123
 
485
1124
  def get_obs(self, state: EnvState, params: EnvParams, key=None) -> jax.Array:
486
1125
  """Get observation based on observation_type and full_world."""
487
- # Decode grid for observation
488
- obs_grid = jnp.maximum(0, state.object_grid)
1126
+ obs_grid = state.object_state.object_id
489
1127
 
490
1128
  if self.full_world:
491
1129
  return self._get_world_obs(obs_grid, state)
@@ -587,48 +1225,43 @@ class ForagaxEnv(environment.Environment):
587
1225
 
588
1226
  if is_world_mode:
589
1227
  # Create an RGB image from the object grid
590
- img = jnp.zeros((self.size[1], self.size[0], 3))
591
- # Decode grid for rendering: non-negative are objects, negative are empty
592
- render_grid = jnp.maximum(0, state.object_grid)
1228
+ # Use stateful object colors if dynamic_biomes is enabled, else use default colors
1229
+ if self.dynamic_biomes:
1230
+ # Use per-instance colors from state
1231
+ img = state.object_state.color.copy()
1232
+ else:
1233
+ # Use default object colors
1234
+ img = jnp.zeros((self.size[1], self.size[0], 3))
1235
+ render_grid = state.object_state.object_id
593
1236
 
594
- def update_image(i, img):
595
- color = self.object_colors[i]
596
- mask = render_grid == i
597
- img = jnp.where(mask[..., None], color, img)
598
- return img
1237
+ def update_image(i, img):
1238
+ color = self.object_colors[i]
1239
+ mask = render_grid == i
1240
+ img = jnp.where(mask[..., None], color, img)
1241
+ return img
599
1242
 
600
- img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
1243
+ img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
601
1244
 
602
1245
  # Tint the agent's aperture
603
- ap_h, ap_w = self.aperture_size
604
- start_y = state.pos[1] - ap_h // 2
605
- start_x = state.pos[0] - ap_w // 2
1246
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1247
+ self._compute_aperture_coordinates(state.pos)
1248
+ )
606
1249
 
607
1250
  alpha = 0.2
608
1251
  agent_color = jnp.array(AGENT.color)
609
1252
 
610
- # Create indices for the aperture
611
- y_offsets = jnp.arange(ap_h)
612
- x_offsets = jnp.arange(ap_w)
613
- y_coords_original = start_y + y_offsets[:, None]
614
- x_coords_original = start_x + x_offsets
615
-
616
1253
  if self.nowrap:
617
- y_coords = jnp.clip(y_coords_original, 0, self.size[1] - 1)
618
- x_coords = jnp.clip(x_coords_original, 0, self.size[0] - 1)
619
1254
  # Create tint mask: any in-bounds original position maps to a cell makes it tinted
620
1255
  tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
621
- tint_mask = tint_mask.at[y_coords, x_coords].set(1)
1256
+ tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
622
1257
  # Apply tint to masked positions
623
1258
  original_colors = img
624
1259
  tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
625
1260
  img = jnp.where(tint_mask[..., None], tinted_colors, img)
626
1261
  else:
627
- y_coords = jnp.mod(y_coords_original, self.size[1])
628
- x_coords = jnp.mod(x_coords_original, self.size[0])
629
- original_colors = img[y_coords, x_coords]
1262
+ original_colors = img[y_coords_adj, x_coords_adj]
630
1263
  tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
631
- img = img.at[y_coords, x_coords].set(tinted_colors)
1264
+ img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
632
1265
 
633
1266
  # Agent color
634
1267
  img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
@@ -641,6 +1274,7 @@ class ForagaxEnv(environment.Environment):
641
1274
 
642
1275
  if is_true_mode:
643
1276
  # Apply true object borders by overlaying true colors on border pixels
1277
+ render_grid = state.object_state.object_id
644
1278
  img = apply_true_borders(
645
1279
  img, render_grid, self.size, len(self.object_ids)
646
1280
  )
@@ -653,10 +1287,27 @@ class ForagaxEnv(environment.Environment):
653
1287
  img = img.at[:, col_indices].set(grid_color)
654
1288
 
655
1289
  elif is_aperture_mode:
656
- obs_grid = jnp.maximum(0, state.object_grid)
1290
+ obs_grid = state.object_state.object_id
657
1291
  aperture = self._get_aperture(obs_grid, state.pos)
658
- aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
659
- img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
1292
+
1293
+ if self.dynamic_biomes:
1294
+ # Use per-instance colors from state - extract aperture view
1295
+ y_coords, x_coords, y_coords_adj, x_coords_adj = (
1296
+ self._compute_aperture_coordinates(state.pos)
1297
+ )
1298
+ img = state.object_state.color[y_coords_adj, x_coords_adj]
1299
+
1300
+ if self.nowrap:
1301
+ # For out-of-bounds, use padding object color
1302
+ y_out = (y_coords < 0) | (y_coords >= self.size[1])
1303
+ x_out = (x_coords < 0) | (x_coords >= self.size[0])
1304
+ out_of_bounds = y_out | x_out
1305
+ padding_color = jnp.array(self.objects[-1].color, dtype=jnp.float32)
1306
+ img = jnp.where(out_of_bounds[..., None], padding_color, img)
1307
+ else:
1308
+ # Use default object colors
1309
+ aperture_one_hot = jax.nn.one_hot(aperture, len(self.object_ids))
1310
+ img = jnp.tensordot(aperture_one_hot, self.object_colors, axes=1)
660
1311
 
661
1312
  # Draw agent in the center
662
1313
  center_y, center_x = self.aperture_size[1] // 2, self.aperture_size[0] // 2