continual-foragax 0.40.0__py3-none-any.whl → 0.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/env.py CHANGED
@@ -26,6 +26,15 @@ from foragax.rendering import apply_true_borders
26
26
  from foragax.weather import get_temperature
27
27
 
28
28
 
29
+ ID_DTYPE = jnp.int32 # Object type ID (0 = empty, >0 = object type)
30
+ TIMER_DTYPE = jnp.int32 # Respawn countdown (0 = no timer, >0 = countdown)
31
+ TIME_DTYPE = jnp.int32 # Timesteps (spawn time, current time)
32
+ PARAM_DTYPE = jnp.float16 # Per-instance object parameters
33
+ COLOR_DTYPE = jnp.uint8 # RGB color channels (0-255)
34
+ BIOME_ID_DTYPE = jnp.int16 # Biome assignment for each cell
35
+ REWARD_DTYPE = jnp.float32 # Reward values
36
+
37
+
29
38
  class Actions(IntEnum):
30
39
  DOWN = 0
31
40
  RIGHT = 1
@@ -74,14 +83,14 @@ class ObjectState:
74
83
  """Create an empty ObjectState for the given grid size."""
75
84
  h, w = size[1], size[0]
76
85
  return cls(
77
- object_id=jnp.zeros((h, w), dtype=int),
78
- respawn_timer=jnp.zeros((h, w), dtype=int),
79
- respawn_object_id=jnp.zeros((h, w), dtype=int),
80
- spawn_time=jnp.zeros((h, w), dtype=int),
81
- color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
82
- generation=jnp.zeros((h, w), dtype=int),
83
- state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
84
- biome_id=jnp.full((h, w), -1, dtype=int),
86
+ object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
87
+ respawn_timer=jnp.zeros((h, w), dtype=TIMER_DTYPE),
88
+ respawn_object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
89
+ spawn_time=jnp.zeros((h, w), dtype=TIME_DTYPE),
90
+ color=jnp.full((h, w, 3), 255, dtype=COLOR_DTYPE),
91
+ generation=jnp.zeros((h, w), dtype=ID_DTYPE),
92
+ state_params=jnp.zeros((h, w, num_params), dtype=PARAM_DTYPE),
93
+ biome_id=jnp.full((h, w), -1, dtype=BIOME_ID_DTYPE),
85
94
  )
86
95
 
87
96
 
@@ -228,13 +237,13 @@ class ForagaxEnv(environment.Environment):
228
237
  # Create mask for the biome
229
238
  start = jax.lax.select(
230
239
  self.biome_starts[i, 0] == -1,
231
- jnp.array([0, 0]),
232
- self.biome_starts[i],
240
+ jnp.array([0, 0], dtype=jnp.int32),
241
+ self.biome_starts[i].astype(jnp.int32),
233
242
  )
234
243
  stop = jax.lax.select(
235
244
  self.biome_stops[i, 0] == -1,
236
- jnp.array(self.size),
237
- self.biome_stops[i],
245
+ jnp.array(self.size, dtype=jnp.int32),
246
+ self.biome_stops[i].astype(jnp.int32),
238
247
  )
239
248
  rows = jnp.arange(self.size[1])[:, None]
240
249
  cols = jnp.arange(self.size[0])
@@ -273,12 +282,18 @@ class ForagaxEnv(environment.Environment):
273
282
  # color_indices maps from object_id-1 to color_channel_index
274
283
  self.object_to_color_map = color_indices
275
284
 
285
+ # Rendering constants
286
+ self.agent_color_jax = jnp.array(AGENT.color, dtype=jnp.uint8)
287
+ self.white_color_jax = jnp.array([255, 255, 255], dtype=jnp.uint8)
288
+ self.grid_color_jax = jnp.zeros(3, dtype=jnp.uint8)
289
+
276
290
  @property
277
291
  def default_params(self) -> EnvParams:
278
292
  return EnvParams(
279
293
  max_steps_in_episode=None,
280
294
  )
281
295
 
296
+ @partial(jax.named_call, name="_place_timer")
282
297
  def _place_timer(
283
298
  self,
284
299
  object_state: ObjectState,
@@ -302,13 +317,24 @@ class ForagaxEnv(environment.Environment):
302
317
  Returns:
303
318
  Updated object_state with timer placed
304
319
  """
320
+ # Ensure inputs match ObjectState dtypes
321
+ object_type = jnp.array(object_type, dtype=ID_DTYPE)
322
+ timer_val = jnp.array(timer_val, dtype=TIMER_DTYPE)
323
+ y = jnp.array(y, dtype=jnp.int32)
324
+ x = jnp.array(x, dtype=jnp.int32)
305
325
 
306
326
  # Handle permanent removal (timer_val == 0)
307
327
  def place_empty():
308
328
  return object_state.replace(
309
- object_id=object_state.object_id.at[y, x].set(0),
310
- respawn_timer=object_state.respawn_timer.at[y, x].set(0),
311
- respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
329
+ object_id=object_state.object_id.at[y, x].set(
330
+ jnp.array(0, dtype=ID_DTYPE)
331
+ ),
332
+ respawn_timer=object_state.respawn_timer.at[y, x].set(
333
+ jnp.array(0, dtype=TIMER_DTYPE)
334
+ ),
335
+ respawn_object_id=object_state.respawn_object_id.at[y, x].set(
336
+ jnp.array(0, dtype=ID_DTYPE)
337
+ ),
312
338
  )
313
339
 
314
340
  # Handle timer placement
@@ -316,7 +342,9 @@ class ForagaxEnv(environment.Environment):
316
342
  # Non-random: place at original position
317
343
  def place_at_position():
318
344
  return object_state.replace(
319
- object_id=object_state.object_id.at[y, x].set(0),
345
+ object_id=object_state.object_id.at[y, x].set(
346
+ jnp.array(0, dtype=ID_DTYPE)
347
+ ),
320
348
  respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
321
349
  respawn_object_id=object_state.respawn_object_id.at[y, x].set(
322
350
  object_type
@@ -326,9 +354,15 @@ class ForagaxEnv(environment.Environment):
326
354
  # Random: place at random location in same biome
327
355
  def place_randomly():
328
356
  # Clear the collected object's position
329
- new_object_id = object_state.object_id.at[y, x].set(0)
330
- new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
331
- new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
357
+ new_object_id = object_state.object_id.at[y, x].set(
358
+ jnp.array(0, dtype=ID_DTYPE)
359
+ )
360
+ new_respawn_timer = object_state.respawn_timer.at[y, x].set(
361
+ jnp.array(0, dtype=TIMER_DTYPE)
362
+ )
363
+ new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(
364
+ jnp.array(0, dtype=ID_DTYPE)
365
+ )
332
366
 
333
367
  # Find valid spawn locations in the same biome
334
368
  biome_id = object_state.biome_id[y, x]
@@ -336,13 +370,15 @@ class ForagaxEnv(environment.Environment):
336
370
  empty_mask = new_object_id == 0
337
371
  no_timer_mask = new_respawn_timer == 0
338
372
  valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
339
- num_valid_spawns = jnp.sum(valid_spawn_mask)
373
+ num_valid_spawns = jnp.sum(valid_spawn_mask, dtype=jnp.int32)
340
374
 
341
375
  y_indices, x_indices = jnp.nonzero(
342
376
  valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
343
377
  )
344
378
  valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
345
- random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
379
+ random_idx = jax.random.randint(
380
+ rand_key, (), jnp.array(0, dtype=jnp.int32), num_valid_spawns
381
+ )
346
382
  new_spawn_pos = valid_spawn_indices[random_idx]
347
383
 
348
384
  # Place timer at the new random position
@@ -363,17 +399,11 @@ class ForagaxEnv(environment.Environment):
363
399
 
364
400
  return jax.lax.cond(timer_val == 0, place_empty, place_timer)
365
401
 
366
- def step_env(
367
- self,
368
- key: jax.Array,
369
- state: EnvState,
370
- action: Union[int, float, jax.Array],
371
- params: EnvParams,
372
- ) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
373
- """Perform single timestep state transition."""
402
+ @partial(jax.named_call, name="move_agent")
403
+ def _move_agent(
404
+ self, state: EnvState, action: Union[int, float, jax.Array]
405
+ ) -> Tuple[jax.Array, jax.Array]:
374
406
  current_objects = state.object_state.object_id
375
-
376
- # 1. UPDATE AGENT POSITION
377
407
  direction = DIRECTIONS[action]
378
408
  new_pos = state.pos + direction
379
409
 
@@ -386,28 +416,44 @@ class ForagaxEnv(environment.Environment):
386
416
 
387
417
  # Check for blocking objects
388
418
  obj_at_new_pos = current_objects[new_pos[1], new_pos[0]]
389
- is_blocking = self.object_blocking[obj_at_new_pos]
390
- pos = jax.lax.select(is_blocking, state.pos, new_pos)
391
-
392
- # 2. HANDLE COLLISIONS AND REWARDS
393
- obj_at_pos = current_objects[pos[1], pos[0]]
394
- is_collectable = self.object_collectable[obj_at_pos]
395
- should_collect = is_collectable & (obj_at_pos > 0)
419
+ is_blocking = self.object_blocking[obj_at_new_pos.astype(jnp.int32)]
420
+ pos = jnp.where(
421
+ is_blocking[..., None],
422
+ state.pos.astype(jnp.int32),
423
+ new_pos.astype(jnp.int32),
424
+ )
425
+ return pos, new_pos
396
426
 
397
- # Handle digestion: add reward to buffer if collected
398
- digestion_buffer = state.digestion_buffer
427
+ @partial(jax.named_call, name="compute_reward")
428
+ def _compute_reward(
429
+ self,
430
+ state: EnvState,
431
+ pos: jax.Array,
432
+ key: jax.Array,
433
+ should_collect: jax.Array,
434
+ digestion_buffer: jax.Array,
435
+ obj_at_pos: jax.Array,
436
+ ) -> Tuple[jax.Array, jax.Array]:
399
437
  key, reward_subkey = jax.random.split(key)
400
438
 
401
439
  object_params = state.object_state.state_params[pos[1], pos[0]]
402
440
  object_reward = jax.lax.switch(
403
- obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
441
+ obj_at_pos.astype(jnp.int32),
442
+ self.reward_fns,
443
+ state.time,
444
+ reward_subkey,
445
+ object_params.astype(jnp.float32),
404
446
  )
405
447
 
406
448
  key, digestion_subkey = jax.random.split(key)
407
449
  reward_delay = jax.lax.switch(
408
450
  obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
409
451
  )
410
- reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
452
+ reward = jnp.where(
453
+ should_collect & (reward_delay == jnp.array(0, dtype=jnp.int32)),
454
+ object_reward,
455
+ 0.0,
456
+ )
411
457
  if self.max_reward_delay > 0:
412
458
  # Add delayed rewards to buffer
413
459
  digestion_buffer = jax.lax.cond(
@@ -421,20 +467,33 @@ class ForagaxEnv(environment.Environment):
421
467
  current_index = state.time % self.max_reward_delay
422
468
  reward += digestion_buffer[current_index]
423
469
  digestion_buffer = digestion_buffer.at[current_index].set(0.0)
470
+ return reward, digestion_buffer
424
471
 
472
+ @partial(jax.named_call, name="respawn_logic")
473
+ def _respawn_logic(
474
+ self,
475
+ state: EnvState,
476
+ pos: jax.Array,
477
+ key: jax.Array,
478
+ current_objects: jax.Array,
479
+ is_collectable: jax.Array,
480
+ obj_at_pos: jax.Array,
481
+ ) -> ObjectState:
425
482
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
426
483
  key, regen_subkey, rand_key = jax.random.split(key, 3)
427
484
 
428
485
  # Decrement respawn timers
429
- has_timer = state.object_state.respawn_timer > 0
486
+ has_timer = state.object_state.respawn_timer > jnp.array(0, dtype=TIMER_DTYPE)
430
487
  new_respawn_timer = jnp.where(
431
488
  has_timer,
432
- state.object_state.respawn_timer - 1,
489
+ state.object_state.respawn_timer - jnp.array(1, dtype=TIMER_DTYPE),
433
490
  state.object_state.respawn_timer,
434
491
  )
435
492
 
436
493
  # Track which cells have timers that just reached 0
437
- just_respawned = has_timer & (new_respawn_timer == 0)
494
+ just_respawned = has_timer & (
495
+ new_respawn_timer == jnp.array(0, dtype=TIMER_DTYPE)
496
+ )
438
497
 
439
498
  # Respawn objects where timer reached 0
440
499
  new_object_id = jnp.where(
@@ -446,7 +505,7 @@ class ForagaxEnv(environment.Environment):
446
505
  # Clear respawn_object_id for cells that just respawned
447
506
  new_respawn_object_id = jnp.where(
448
507
  just_respawned,
449
- 0,
508
+ jnp.array(0, dtype=ID_DTYPE),
450
509
  state.object_state.respawn_object_id,
451
510
  )
452
511
 
@@ -459,16 +518,19 @@ class ForagaxEnv(environment.Environment):
459
518
  regen_delay = jax.lax.switch(
460
519
  obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
461
520
  )
521
+ # Cast timer_countdown to match ObjectState.respawn_timer dtype
462
522
  timer_countdown = jax.lax.cond(
463
523
  regen_delay == jnp.iinfo(jnp.int32).max,
464
- lambda: 0, # No timer (permanent removal)
465
- lambda: regen_delay + 1, # Timer countdown
524
+ lambda: jnp.array(0, dtype=TIMER_DTYPE), # No timer (permanent removal)
525
+ lambda: (regen_delay + 1).astype(TIMER_DTYPE), # Timer countdown
466
526
  )
467
527
 
468
528
  # If collected, replace object with timer; otherwise, keep it
469
529
  val_at_pos = current_objects[pos[1], pos[0]]
470
530
  # Use original should_collect for consumption tracking
471
- should_collect_now = is_collectable & (val_at_pos > 0)
531
+ should_collect_now = is_collectable & (
532
+ val_at_pos > jnp.array(0, dtype=ID_DTYPE)
533
+ )
472
534
 
473
535
  # Create updated object state with new respawn_timer, object_id, and spawn_time
474
536
  object_state = state.object_state.replace(
@@ -492,12 +554,17 @@ class ForagaxEnv(environment.Environment):
492
554
  ),
493
555
  lambda: object_state,
494
556
  )
557
+ return object_state
495
558
 
496
- # 3.5. HANDLE OBJECT EXPIRY
497
- # Only process expiry if there are objects that can expire
498
- key, object_state = self.expire_objects(key, state, object_state)
499
-
500
- # 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
559
+ @partial(jax.named_call, name="dynamic_biomes")
560
+ def _dynamic_biomes(
561
+ self,
562
+ state: EnvState,
563
+ pos: jax.Array,
564
+ key: jax.Array,
565
+ object_state: ObjectState,
566
+ should_collect: jax.Array,
567
+ ) -> Tuple[ObjectState, BiomeState]:
501
568
  if self.dynamic_biomes:
502
569
  # Update consumption count if an object was collected
503
570
  # Only count if the object belongs to the current generation of its biome
@@ -509,7 +576,9 @@ class ForagaxEnv(environment.Environment):
509
576
  biome_consumption_count = state.biome_state.consumption_count
510
577
  biome_consumption_count = jax.lax.cond(
511
578
  should_collect & is_current_generation,
512
- lambda: biome_consumption_count.at[collected_biome_id].add(1),
579
+ lambda: biome_consumption_count.at[collected_biome_id].add(
580
+ jnp.array(1, dtype=ID_DTYPE)
581
+ ),
513
582
  lambda: biome_consumption_count,
514
583
  )
515
584
 
@@ -526,8 +595,73 @@ class ForagaxEnv(environment.Environment):
526
595
  state.time,
527
596
  respawn_key,
528
597
  )
598
+ return object_state, biome_state
529
599
  else:
530
- biome_state = state.biome_state
600
+ return object_state, state.biome_state
601
+
602
+ @partial(jax.named_call, name="reward_grid")
603
+ def _reward_grid(self, state: EnvState, object_state: ObjectState) -> jax.Array:
604
+ # Compute reward at each grid position
605
+ fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
606
+
607
+ def compute_reward(obj_id, params):
608
+ return jax.lax.cond(
609
+ obj_id > jnp.array(0, dtype=ID_DTYPE),
610
+ lambda: jax.lax.switch(
611
+ obj_id.astype(jnp.int32),
612
+ self.reward_fns,
613
+ state.time,
614
+ fixed_key,
615
+ params.astype(jnp.float32),
616
+ ),
617
+ lambda: 0.0,
618
+ )
619
+
620
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(
621
+ object_state.object_id.astype(ID_DTYPE),
622
+ object_state.state_params.astype(PARAM_DTYPE),
623
+ )
624
+ return reward_grid
625
+
626
+ def step_env(
627
+ self,
628
+ key: jax.Array,
629
+ state: EnvState,
630
+ action: Union[int, float, jax.Array],
631
+ params: EnvParams,
632
+ ) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
633
+ """Perform single timestep state transition."""
634
+ current_objects = state.object_state.object_id
635
+ pos, new_pos = self._move_agent(state, action)
636
+
637
+ with jax.named_scope("compute_reward"):
638
+ # 2. HANDLE COLLISIONS AND REWARDS
639
+ obj_at_pos = current_objects[pos[1], pos[0]]
640
+ is_collectable = self.object_collectable[obj_at_pos]
641
+ should_collect = is_collectable & (
642
+ obj_at_pos > jnp.array(0, dtype=ID_DTYPE)
643
+ )
644
+
645
+ # Handle digestion: add reward to buffer if collected
646
+ digestion_buffer = state.digestion_buffer
647
+ key, reward_subkey = jax.random.split(key)
648
+
649
+ reward, digestion_buffer = self._compute_reward(
650
+ state, pos, key, should_collect, digestion_buffer, obj_at_pos
651
+ )
652
+
653
+ object_state = self._respawn_logic(
654
+ state, pos, key, current_objects, is_collectable, obj_at_pos
655
+ )
656
+
657
+ # 3.5. HANDLE OBJECT EXPIRY
658
+ # Only process expiry if there are objects that can expire
659
+ key, object_state = self.expire_objects(key, state, object_state)
660
+
661
+ # 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
662
+ object_state, biome_state = self._dynamic_biomes(
663
+ state, pos, key, object_state, should_collect
664
+ )
531
665
 
532
666
  info = {"discount": self.discount(state, params)}
533
667
  temperatures = jnp.zeros(len(self.objects))
@@ -538,33 +672,40 @@ class ForagaxEnv(environment.Environment):
538
672
  )
539
673
  info["temperatures"] = temperatures
540
674
  info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
541
- info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
675
+ info["object_collected_id"] = jnp.where(
676
+ should_collect,
677
+ obj_at_pos.astype(ID_DTYPE),
678
+ jnp.array(-1, dtype=ID_DTYPE),
679
+ )
542
680
 
543
681
  # 4. UPDATE STATE
682
+ # Ensure all fields have canonical dtypes for consistency (e.g., for gymnax step selection)
683
+ object_state = object_state.replace(
684
+ object_id=object_state.object_id.astype(ID_DTYPE),
685
+ respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
686
+ respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
687
+ spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
688
+ color=object_state.color.astype(COLOR_DTYPE),
689
+ generation=object_state.generation.astype(ID_DTYPE),
690
+ state_params=object_state.state_params.astype(PARAM_DTYPE),
691
+ biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
692
+ )
693
+ biome_state = biome_state.replace(
694
+ consumption_count=biome_state.consumption_count.astype(ID_DTYPE),
695
+ total_objects=biome_state.total_objects.astype(ID_DTYPE),
696
+ generation=biome_state.generation.astype(ID_DTYPE),
697
+ )
698
+
544
699
  state = EnvState(
545
- pos=pos,
546
- time=state.time + 1,
547
- digestion_buffer=digestion_buffer,
700
+ pos=pos.astype(jnp.int32),
701
+ time=jnp.array(state.time + 1, dtype=jnp.int32),
702
+ digestion_buffer=digestion_buffer.astype(REWARD_DTYPE),
548
703
  object_state=object_state,
549
704
  biome_state=biome_state,
550
705
  )
551
706
 
552
- # Compute reward at each grid position
553
- fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
554
-
555
- def compute_reward(obj_id, params):
556
- return jax.lax.cond(
557
- obj_id > 0,
558
- lambda: jax.lax.switch(
559
- obj_id, self.reward_fns, state.time, fixed_key, params
560
- ),
561
- lambda: 0.0,
562
- )
563
-
564
- reward_grid = jax.vmap(jax.vmap(compute_reward))(
565
- object_state.object_id, object_state.state_params
566
- )
567
- info["rewards"] = reward_grid
707
+ reward_grid = self._reward_grid(state, object_state)
708
+ info["rewards"] = reward_grid.astype(jnp.float16)
568
709
 
569
710
  done = self.is_terminal(state, params)
570
711
  return (
@@ -575,6 +716,7 @@ class ForagaxEnv(environment.Environment):
575
716
  info,
576
717
  )
577
718
 
719
+ @partial(jax.named_call, name="expire_objects")
578
720
  def expire_objects(
579
721
  self, key, state, object_state: ObjectState
580
722
  ) -> Tuple[jax.Array, ObjectState]:
@@ -591,8 +733,8 @@ class ForagaxEnv(environment.Environment):
591
733
  # Check if object should expire (age >= expiry_time and expiry_time >= 0)
592
734
  should_expire = (
593
735
  (object_ages >= expiry_times)
594
- & (expiry_times >= 0)
595
- & (current_objects_for_expiry > 0)
736
+ & (expiry_times >= jnp.array(0, dtype=TIME_DTYPE))
737
+ & (current_objects_for_expiry > jnp.array(0, dtype=ID_DTYPE))
596
738
  )
597
739
 
598
740
  # Only process expiry if there are actually objects to expire
@@ -627,8 +769,8 @@ class ForagaxEnv(environment.Environment):
627
769
  )
628
770
  timer_countdown = jax.lax.cond(
629
771
  exp_delay == jnp.iinfo(jnp.int32).max,
630
- lambda: 0,
631
- lambda: exp_delay + 1,
772
+ lambda: jnp.array(0, dtype=TIMER_DTYPE),
773
+ lambda: (exp_delay + 1).astype(TIMER_DTYPE),
632
774
  )
633
775
 
634
776
  respawn_random = self.object_random_respawn[obj_id]
@@ -688,147 +830,189 @@ class ForagaxEnv(environment.Environment):
688
830
  biome_state.consumption_count >= self.biome_consumption_threshold
689
831
  )
690
832
 
691
- # Split key for all biomes in parallel
692
- key, subkey = jax.random.split(key)
693
- biome_keys = jax.random.split(subkey, num_biomes)
833
+ any_respawn = jnp.any(should_respawn)
694
834
 
695
- # Compute all new spawns in parallel using vmap for random, switch for deterministic
696
- if self.deterministic_spawn:
697
- # Use switch to dispatch to concrete biome spawns for deterministic
698
- def make_spawn_fn(biome_idx):
699
- def spawn_fn(key):
700
- return self._spawn_biome_objects(biome_idx, key, deterministic=True)
835
+ def do_respawn(args):
836
+ (
837
+ object_state,
838
+ biome_state,
839
+ should_respawn,
840
+ key,
841
+ ) = args
842
+ # Split key for all biomes in parallel
843
+ key, subkey = jax.random.split(key)
844
+ biome_keys = jax.random.split(subkey, num_biomes)
845
+
846
+ # Compute all new spawns in parallel using vmap for random, switch for deterministic
847
+ if self.deterministic_spawn:
848
+ # Use switch to dispatch to concrete biome spawns for deterministic
849
+ def make_spawn_fn(biome_idx):
850
+ def spawn_fn(key):
851
+ return self._spawn_biome_objects(
852
+ biome_idx, key, deterministic=True
853
+ )
701
854
 
702
- return spawn_fn
855
+ return spawn_fn
703
856
 
704
- spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
857
+ spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
705
858
 
706
- # Apply switch for each biome
707
- all_new_objects_list = []
708
- all_new_colors_list = []
709
- all_new_params_list = []
710
- for i in range(num_biomes):
711
- obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
712
- all_new_objects_list.append(obj)
713
- all_new_colors_list.append(col)
714
- all_new_params_list.append(par)
715
-
716
- all_new_objects = jnp.stack(all_new_objects_list)
717
- all_new_colors = jnp.stack(all_new_colors_list)
718
- all_new_params = jnp.stack(all_new_params_list)
719
- else:
720
- # Random spawn works with vmap
721
- all_new_objects, all_new_colors, all_new_params = jax.vmap(
722
- lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
723
- )(jnp.arange(num_biomes), biome_keys)
724
-
725
- # Initialize updated grids
726
- new_obj_id = object_state.object_id
727
- new_color = object_state.color
728
- new_params = object_state.state_params
729
- new_spawn = object_state.spawn_time
730
- new_gen = object_state.generation
731
-
732
- # Update biome state
733
- new_consumption_count = jnp.where(
734
- should_respawn, 0, biome_state.consumption_count
735
- )
736
- new_generation = biome_state.generation + should_respawn.astype(int)
859
+ # Apply switch for each biome
860
+ all_new_objects_list = []
861
+ all_new_colors_list = []
862
+ all_new_params_list = []
863
+ for i in range(num_biomes):
864
+ obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
865
+ all_new_objects_list.append(obj)
866
+ all_new_colors_list.append(col)
867
+ all_new_params_list.append(par)
868
+
869
+ all_new_objects = jnp.stack(all_new_objects_list)
870
+ all_new_colors = jnp.stack(all_new_colors_list)
871
+ all_new_params = jnp.stack(all_new_params_list)
872
+ else:
873
+ # Random spawn works with vmap
874
+ all_new_objects, all_new_colors, all_new_params = jax.vmap(
875
+ lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
876
+ )(jnp.arange(num_biomes), biome_keys)
877
+
878
+ # Initialize updated grids
879
+ new_obj_id = object_state.object_id
880
+ new_color = object_state.color
881
+ new_params = object_state.state_params
882
+ new_spawn = object_state.spawn_time
883
+ new_gen = object_state.generation
884
+
885
+ # Update biome state
886
+ new_consumption_count = jnp.where(
887
+ should_respawn,
888
+ jnp.array(0, dtype=ID_DTYPE),
889
+ biome_state.consumption_count,
890
+ )
891
+ new_generation = biome_state.generation + should_respawn.astype(ID_DTYPE)
737
892
 
738
- # Compute new total objects for respawning biomes
739
- def count_objects(i):
740
- return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
893
+ # Compute new total objects for respawning biomes
894
+ def count_objects(i):
895
+ return jnp.sum(
896
+ (all_new_objects[i] > 0) & self.biome_masks_array[i], dtype=ID_DTYPE
897
+ )
741
898
 
742
- new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
743
- new_total_objects = jnp.where(
744
- should_respawn, new_object_counts, biome_state.total_objects
745
- )
899
+ new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
900
+ new_total_objects = jnp.where(
901
+ should_respawn, new_object_counts, biome_state.total_objects
902
+ )
746
903
 
747
- new_biome_state = BiomeState(
748
- consumption_count=new_consumption_count,
749
- total_objects=new_total_objects,
750
- generation=new_generation,
751
- )
904
+ new_biome_state = BiomeState(
905
+ consumption_count=new_consumption_count,
906
+ total_objects=new_total_objects,
907
+ generation=new_generation,
908
+ )
752
909
 
753
- # Update grids for respawning biomes
754
- for i in range(num_biomes):
755
- biome_mask = self.biome_masks_array[i]
756
- new_gen_value = new_biome_state.generation[i]
910
+ # Update grids for respawning biomes
911
+ for i in range(num_biomes):
912
+ biome_mask = self.biome_masks_array[i]
913
+ new_gen_value = new_biome_state.generation[i]
757
914
 
758
- # Update mask: biome area AND needs respawn
759
- should_update = biome_mask & should_respawn[i][..., None]
915
+ # Update mask: biome area AND needs respawn
916
+ should_update = biome_mask & should_respawn[i][..., None]
760
917
 
761
- # 1. Merge: Overwrite with new objects if present, otherwise keep existing
762
- # If the new spawn has an object, we take it. If it's empty, we keep whatever was there.
763
- new_spawn_valid = all_new_objects[i] > 0
918
+ # 1. Merge: Overwrite with new objects if present, otherwise keep existing
919
+ new_spawn_valid = all_new_objects[i] > 0
764
920
 
765
- merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
766
- merged_colors = jnp.where(
767
- new_spawn_valid[..., None], all_new_colors[i], new_color
768
- )
769
- merged_params = jnp.where(
770
- new_spawn_valid[..., None], all_new_params[i], new_params
771
- )
921
+ merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
922
+ merged_colors = jnp.where(
923
+ new_spawn_valid[..., None], all_new_colors[i], new_color
924
+ )
925
+ merged_params = jnp.where(
926
+ new_spawn_valid[..., None], all_new_params[i], new_params
927
+ )
772
928
 
773
- # For generation/spawn time, update only if we took the NEW object
774
- merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
775
- merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
929
+ merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
930
+ merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
776
931
 
777
- # 2. Apply dropout to the MERGED result (only where we are allowed to update)
778
- if self.dynamic_biome_spawn_empty > 0:
779
- key, dropout_key = jax.random.split(key)
780
- # Dropout applies only to cells that are both in the biome AND need updating
781
- keep_mask = jax.random.bernoulli(
782
- dropout_key, 1.0 - self.dynamic_biome_spawn_empty, merged_objs.shape
932
+ if self.dynamic_biome_spawn_empty > 0:
933
+ key, dropout_key = jax.random.split(key)
934
+ keep_mask = jax.random.bernoulli(
935
+ dropout_key,
936
+ 1.0 - self.dynamic_biome_spawn_empty,
937
+ merged_objs.shape,
938
+ )
939
+ dropout_mask = should_update & keep_mask
940
+ final_objs = jnp.where(
941
+ dropout_mask, merged_objs, jnp.array(0, dtype=ID_DTYPE)
942
+ )
943
+ final_colors = jnp.where(
944
+ dropout_mask[..., None],
945
+ merged_colors,
946
+ jnp.array(0, dtype=COLOR_DTYPE),
947
+ )
948
+ final_params = jnp.where(
949
+ dropout_mask[..., None],
950
+ merged_params,
951
+ jnp.array(0, dtype=PARAM_DTYPE),
952
+ )
953
+ final_gen = jnp.where(
954
+ dropout_mask, merged_gen, jnp.array(0, dtype=ID_DTYPE)
955
+ )
956
+ final_spawn = jnp.where(
957
+ dropout_mask, merged_spawn, jnp.array(0, dtype=TIME_DTYPE)
958
+ )
959
+ else:
960
+ final_objs = merged_objs
961
+ final_colors = merged_colors
962
+ final_params = merged_params
963
+ final_gen = merged_gen
964
+ final_spawn = merged_spawn
965
+
966
+ new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
967
+ new_color = jnp.where(should_update[..., None], final_colors, new_color)
968
+ new_params = jnp.where(
969
+ should_update[..., None], final_params, new_params
783
970
  )
784
- # Only apply dropout where should_update is true; elsewhere, keep merged_objs
785
- dropout_mask = should_update & keep_mask
786
- # Apply dropout only to the merged result and associated metadata
787
- final_objs = jnp.where(dropout_mask, merged_objs, 0)
788
- final_colors = jnp.where(dropout_mask[..., None], merged_colors, 0)
789
- final_params = jnp.where(dropout_mask[..., None], merged_params, 0)
790
- final_gen = jnp.where(dropout_mask, merged_gen, 0)
791
- final_spawn = jnp.where(dropout_mask, merged_spawn, 0)
792
- else:
793
- final_objs = merged_objs
794
- final_colors = merged_colors
795
- final_params = merged_params
796
- final_gen = merged_gen
797
- final_spawn = merged_spawn
798
-
799
- # 3. Write back: Only update where should_update is true
800
- new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
801
- new_color = jnp.where(should_update[..., None], final_colors, new_color)
802
- new_params = jnp.where(should_update[..., None], final_params, new_params)
803
- new_gen = jnp.where(should_update, final_gen, new_gen)
804
- new_spawn = jnp.where(should_update, final_spawn, new_spawn)
805
-
806
- # Clear timers in respawning biomes
807
- new_respawn_timer = object_state.respawn_timer
808
- new_respawn_object_id = object_state.respawn_object_id
809
- for i in range(num_biomes):
810
- biome_mask = self.biome_masks_array[i]
811
- should_clear = biome_mask & should_respawn[i][..., None]
812
- new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
813
- new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
971
+ new_gen = jnp.where(should_update, final_gen, new_gen)
972
+ new_spawn = jnp.where(should_update, final_spawn, new_spawn)
814
973
 
815
- object_state = object_state.replace(
816
- object_id=new_obj_id,
817
- respawn_timer=new_respawn_timer,
818
- respawn_object_id=new_respawn_object_id,
819
- color=new_color,
820
- state_params=new_params,
821
- generation=new_gen,
822
- spawn_time=new_spawn,
974
+ # Clear timers in respawning biomes
975
+ new_respawn_timer = object_state.respawn_timer
976
+ new_respawn_object_id = object_state.respawn_object_id
977
+ for i in range(num_biomes):
978
+ biome_mask = self.biome_masks_array[i]
979
+ should_clear = biome_mask & should_respawn[i][..., None]
980
+ new_respawn_timer = jnp.where(
981
+ should_clear, jnp.array(0, dtype=TIMER_DTYPE), new_respawn_timer
982
+ )
983
+ new_respawn_object_id = jnp.where(
984
+ should_clear, jnp.array(0, dtype=ID_DTYPE), new_respawn_object_id
985
+ )
986
+
987
+ new_object_state = object_state.replace(
988
+ object_id=new_obj_id,
989
+ respawn_timer=new_respawn_timer,
990
+ respawn_object_id=new_respawn_object_id,
991
+ color=new_color,
992
+ state_params=new_params,
993
+ generation=new_gen,
994
+ spawn_time=new_spawn,
995
+ )
996
+ return new_object_state, new_biome_state, key
997
+
998
+ def no_respawn(args):
999
+ object_state, biome_state, _, key = args
1000
+ return object_state, biome_state, key
1001
+
1002
+ object_state, biome_state, key = jax.lax.cond(
1003
+ any_respawn,
1004
+ do_respawn,
1005
+ no_respawn,
1006
+ (object_state, biome_state, should_respawn, key),
823
1007
  )
824
1008
 
825
- return object_state, new_biome_state, key
1009
+ return object_state, biome_state, key
826
1010
 
827
1011
  def reset_env(
828
1012
  self, key: jax.Array, params: EnvParams
829
1013
  ) -> Tuple[jax.Array, EnvState]:
830
1014
  """Reset environment state."""
831
- num_object_params = 2 + 2 * self.num_fourier_terms
1015
+ num_object_params = 3 + 2 * self.num_fourier_terms
832
1016
  object_state = ObjectState.create_empty(self.size, num_object_params)
833
1017
 
834
1018
  key, iter_key = jax.random.split(key)
@@ -840,7 +1024,9 @@ class ForagaxEnv(environment.Environment):
840
1024
 
841
1025
  # Set biome_id
842
1026
  object_state = object_state.replace(
843
- biome_id=jnp.where(mask, i, object_state.biome_id)
1027
+ biome_id=jnp.where(
1028
+ mask, jnp.array(i, dtype=BIOME_ID_DTYPE), object_state.biome_id
1029
+ )
844
1030
  )
845
1031
 
846
1032
  # Use unified spawn method
@@ -862,31 +1048,45 @@ class ForagaxEnv(environment.Environment):
862
1048
 
863
1049
  # Initialize biome consumption tracking
864
1050
  num_biomes = self.biome_object_frequencies.shape[0]
865
- biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
866
- biome_total_objects = jnp.zeros(num_biomes, dtype=int)
1051
+ biome_consumption_count = jnp.zeros(num_biomes, dtype=ID_DTYPE)
1052
+ biome_total_objects = jnp.zeros(num_biomes, dtype=ID_DTYPE)
867
1053
 
868
1054
  # Count objects in each biome
869
1055
  for i in range(num_biomes):
870
1056
  mask = self.biome_masks[i]
871
1057
  # Count non-empty objects (object_id > 0)
872
- total = jnp.sum((object_state.object_id > 0) & mask)
1058
+ total = jnp.sum((object_state.object_id > 0) & mask, dtype=ID_DTYPE)
873
1059
  biome_total_objects = biome_total_objects.at[i].set(total)
874
1060
 
875
- biome_generation = jnp.zeros(num_biomes, dtype=int)
1061
+ biome_generation = jnp.zeros(num_biomes, dtype=ID_DTYPE)
1062
+
1063
+ # Final state cleanup to ensure type consistency
1064
+ object_state = object_state.replace(
1065
+ object_id=object_state.object_id.astype(ID_DTYPE),
1066
+ respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
1067
+ respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
1068
+ spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
1069
+ color=object_state.color.astype(COLOR_DTYPE),
1070
+ generation=object_state.generation.astype(ID_DTYPE),
1071
+ state_params=object_state.state_params.astype(PARAM_DTYPE),
1072
+ biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
1073
+ )
876
1074
 
877
1075
  state = EnvState(
878
- pos=agent_pos,
879
- time=0,
880
- digestion_buffer=jnp.zeros((self.max_reward_delay,)),
1076
+ pos=agent_pos.astype(jnp.int32),
1077
+ time=jnp.array(0, dtype=jnp.int32),
1078
+ digestion_buffer=jnp.zeros((self.max_reward_delay,), dtype=REWARD_DTYPE),
881
1079
  object_state=object_state,
882
1080
  biome_state=BiomeState(
883
- consumption_count=biome_consumption_count,
884
- total_objects=biome_total_objects,
885
- generation=biome_generation,
1081
+ consumption_count=biome_consumption_count.astype(ID_DTYPE),
1082
+ total_objects=biome_total_objects.astype(ID_DTYPE),
1083
+ generation=biome_generation.astype(ID_DTYPE),
886
1084
  ),
887
1085
  )
888
1086
 
889
- return self.get_obs(state, params), state
1087
+ return jax.lax.stop_gradient(
1088
+ self.get_obs(state, params)
1089
+ ), jax.lax.stop_gradient(state)
890
1090
 
891
1091
  def _spawn_biome_objects(
892
1092
  self,
@@ -918,16 +1118,17 @@ class ForagaxEnv(environment.Environment):
918
1118
  biome_size = int(self.biome_sizes[biome_idx])
919
1119
 
920
1120
  grid = jnp.linspace(0, 1, biome_size, endpoint=False)
921
- biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
922
- jnp.cumsum(biome_freqs[::-1]), grid, side="right"
923
- )
1121
+ biome_objects_flat = (
1122
+ len(biome_freqs)
1123
+ - jnp.searchsorted(jnp.cumsum(biome_freqs[::-1]), grid, side="right")
1124
+ ).astype(ID_DTYPE)
924
1125
  biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
925
1126
 
926
1127
  # Reshape to match biome dimensions (use concrete dimensions)
927
1128
  biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
928
1129
 
929
1130
  # Place in full grid using slicing with static bounds
930
- object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
1131
+ object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=ID_DTYPE)
931
1132
  object_grid = object_grid.at[
932
1133
  biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
933
1134
  ].set(biome_objects)
@@ -941,10 +1142,10 @@ class ForagaxEnv(environment.Environment):
941
1142
  )
942
1143
  object_grid = (
943
1144
  jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
944
- )
1145
+ ).astype(ID_DTYPE)
945
1146
 
946
1147
  # Initialize color grid
947
- color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
1148
+ color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=COLOR_DTYPE)
948
1149
 
949
1150
  # Sample ONE color per object type in this biome (not per instance)
950
1151
  # This gives objects of the same type the same color within a biome generation
@@ -956,9 +1157,9 @@ class ForagaxEnv(environment.Environment):
956
1157
  biome_object_colors = jax.random.randint(
957
1158
  color_key,
958
1159
  (num_actual_objects, 3),
959
- minval=0,
960
- maxval=256,
961
- dtype=jnp.uint8,
1160
+ minval=jnp.array(0, dtype=COLOR_DTYPE),
1161
+ maxval=jnp.array(256, dtype=jnp.int32), # randint range is often int32
1162
+ dtype=COLOR_DTYPE,
962
1163
  )
963
1164
 
964
1165
  # Assign colors based on object type (starting from index 1)
@@ -970,9 +1171,9 @@ class ForagaxEnv(environment.Environment):
970
1171
  color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
971
1172
 
972
1173
  # Initialize parameters grid
973
- num_object_params = 2 + 2 * self.num_fourier_terms
1174
+ num_object_params = 3 + 2 * self.num_fourier_terms
974
1175
  params_grid = jnp.zeros(
975
- (self.size[1], self.size[0], num_object_params), dtype=jnp.float32
1176
+ (self.size[1], self.size[0], num_object_params), dtype=PARAM_DTYPE
976
1177
  )
977
1178
 
978
1179
  # Generate per-object parameters for each object type
@@ -999,7 +1200,9 @@ class ForagaxEnv(environment.Environment):
999
1200
 
1000
1201
  # Assign to all objects of this type in this biome
1001
1202
  obj_mask = (object_grid == obj_idx) & biome_mask
1002
- params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
1203
+ params_grid = jnp.where(
1204
+ obj_mask[..., None], obj_params.astype(PARAM_DTYPE), params_grid
1205
+ )
1003
1206
 
1004
1207
  return object_grid, color_grid, params_grid
1005
1208
 
@@ -1024,7 +1227,7 @@ class ForagaxEnv(environment.Environment):
1024
1227
 
1025
1228
  def state_space(self, params: EnvParams) -> spaces.Dict:
1026
1229
  """State space of the environment."""
1027
- num_object_params = 2 + 2 * self.num_fourier_terms
1230
+ num_object_params = 3 + 2 * self.num_fourier_terms
1028
1231
  return spaces.Dict(
1029
1232
  {
1030
1233
  "pos": spaces.Box(0, max(self.size), (2,), int),
@@ -1147,7 +1350,7 @@ class ForagaxEnv(environment.Environment):
1147
1350
  aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
1148
1351
  else:
1149
1352
  # Object ID grid: use PADDING index
1150
- padding_index = self.object_ids[-1]
1353
+ padding_index = self.object_ids[-1].astype(values.dtype)
1151
1354
  aperture = jnp.where(out_of_bounds, padding_index, values)
1152
1355
  else:
1153
1356
  aperture = values
@@ -1196,9 +1399,11 @@ class ForagaxEnv(environment.Environment):
1196
1399
  return jnp.zeros(obs_grid.shape + (0,), dtype=jnp.float32)
1197
1400
  # Map object IDs to color channel indices
1198
1401
  color_channels = jnp.where(
1199
- obs_grid == 0,
1200
- -1,
1201
- jnp.take(self.object_to_color_map, obs_grid - 1, axis=0),
1402
+ obs_grid == jnp.array(0, dtype=obs_grid.dtype),
1403
+ jnp.array(-1, dtype=jnp.int32),
1404
+ jnp.take(
1405
+ self.object_to_color_map, obs_grid.astype(jnp.int32) - 1, axis=0
1406
+ ),
1202
1407
  )
1203
1408
  obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
1204
1409
  return obs
@@ -1218,9 +1423,8 @@ class ForagaxEnv(environment.Environment):
1218
1423
  aperture_colors = color_aperture / 255.0
1219
1424
 
1220
1425
  # Mask empty cells (object_id == 0) to white
1221
- empty_mask = aperture == 0
1426
+ empty_mask = aperture == jnp.array(0, dtype=aperture.dtype)
1222
1427
  white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
1223
-
1224
1428
  obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
1225
1429
 
1226
1430
  return obs
@@ -1260,12 +1464,15 @@ class ForagaxEnv(environment.Environment):
1260
1464
 
1261
1465
  return spaces.Box(0, 1, obs_shape, float)
1262
1466
 
1263
- def _compute_reward_grid(self, state: EnvState) -> jax.Array:
1264
- """Compute rewards for all grid positions.
1467
+ def _compute_reward_grid(
1468
+ self, state: EnvState, object_id=None, state_params=None
1469
+ ) -> jax.Array:
1470
+ """Compute rewards for given positions. If no grid provided, uses full world."""
1471
+ if object_id is None:
1472
+ object_id = state.object_state.object_id
1473
+ if state_params is None:
1474
+ state_params = state.object_state.state_params
1265
1475
 
1266
- Returns:
1267
- Array of shape (H, W) with reward values for each cell
1268
- """
1269
1476
  fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
1270
1477
 
1271
1478
  def compute_reward(obj_id, params):
@@ -1277,9 +1484,7 @@ class ForagaxEnv(environment.Environment):
1277
1484
  lambda: 0.0,
1278
1485
  )
1279
1486
 
1280
- reward_grid = jax.vmap(jax.vmap(compute_reward))(
1281
- state.object_state.object_id, state.object_state.state_params
1282
- )
1487
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(object_id, state_params)
1283
1488
  return reward_grid
1284
1489
 
1285
1490
  def _reward_to_color(self, reward: jax.Array) -> jax.Array:
@@ -1369,172 +1574,87 @@ class ForagaxEnv(environment.Environment):
1369
1574
 
1370
1575
  img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
1371
1576
 
1372
- if is_reward_mode:
1373
- # Scale image by 3 to create space for reward visualization
1374
- img = jax.image.resize(
1375
- img,
1376
- (self.size[1] * 3, self.size[0] * 3, 3),
1377
- jax.image.ResizeMethod.NEAREST,
1378
- )
1379
-
1380
- # Compute rewards for all cells
1381
- reward_grid = self._compute_reward_grid(state)
1382
-
1383
- # Convert rewards to colors
1384
- reward_colors = self._reward_to_color(reward_grid)
1385
-
1386
- # Resize reward colors to match 3x scale and place in middle cells
1387
- # We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
1388
- # Create index arrays for middle cells
1389
- i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
1390
- j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
1391
-
1392
- # Broadcast and set middle cells
1393
- img = img.at[i_indices, j_indices].set(reward_colors)
1394
-
1395
- # Tint the agent's aperture
1577
+ # Define constants for all world modes
1578
+ alpha = 0.2
1396
1579
  y_coords, x_coords, y_coords_adj, x_coords_adj = (
1397
1580
  self._compute_aperture_coordinates(state.pos)
1398
1581
  )
1399
1582
 
1400
- alpha = 0.2
1401
- agent_color = jnp.array(AGENT.color)
1402
-
1403
1583
  if is_reward_mode:
1404
- # For reward mode, we need to adjust coordinates for 3x scaled image
1405
- if self.nowrap:
1406
- # Create tint mask for 3x scaled image
1407
- tint_mask = jnp.zeros(
1408
- (self.size[1] * 3, self.size[0] * 3), dtype=bool
1409
- )
1410
-
1411
- # For each aperture cell, tint all 9 cells in its 3x3 block
1412
- # Create meshgrid to get all aperture cell coordinates
1413
- y_grid, x_grid = jnp.meshgrid(
1414
- y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
1415
- )
1416
- y_flat = y_grid.flatten()
1417
- x_flat = x_grid.flatten()
1418
-
1419
- # Create offset arrays for 3x3 blocks
1420
- offsets = jnp.array(
1421
- [[dy, dx] for dy in range(3) for dx in range(3)]
1422
- )
1423
-
1424
- # For each aperture cell, expand to 9 cells
1425
- # We need to repeat each cell coordinate 9 times, then add offsets
1426
- num_aperture_cells = y_flat.size
1427
- y_base = jnp.repeat(
1428
- y_flat * 3, 9
1429
- ) # Repeat each y coord 9 times and scale by 3
1430
- x_base = jnp.repeat(
1431
- x_flat * 3, 9
1432
- ) # Repeat each x coord 9 times and scale by 3
1433
- y_offsets = jnp.tile(
1434
- offsets[:, 0], num_aperture_cells
1435
- ) # Tile all 9 offsets
1436
- x_offsets = jnp.tile(
1437
- offsets[:, 1], num_aperture_cells
1438
- ) # Tile all 9 offsets
1439
- y_expanded = y_base + y_offsets
1440
- x_expanded = x_base + x_offsets
1441
-
1442
- tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
1443
-
1444
- original_colors = img
1445
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1446
- img = jnp.where(tint_mask[..., None], tinted_colors, img)
1447
- else:
1448
- # Tint all 9 cells in each 3x3 block for aperture cells
1449
- # Create meshgrid to get all aperture cell coordinates
1450
- y_grid, x_grid = jnp.meshgrid(
1451
- y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
1452
- )
1453
- y_flat = y_grid.flatten()
1454
- x_flat = x_grid.flatten()
1584
+ # Construct 3x intermediate image
1585
+ # Each cell is 3x3, with reward color in center
1586
+ reward_grid = self._compute_reward_grid(state)
1587
+ reward_colors = self._reward_to_color(reward_grid)
1455
1588
 
1456
- # Create offset arrays for 3x3 blocks
1457
- offsets = jnp.array(
1458
- [[dy, dx] for dy in range(3) for dx in range(3)]
1459
- )
1589
+ # Each cell has its base color in 8 pixels and reward color in 1 (center)
1590
+ # Create a 3x3 pattern mask for center pixels
1591
+ cell_mask = jnp.array(
1592
+ [[False, False, False], [False, True, False], [False, False, False]]
1593
+ )
1594
+ grid_reward_mask = jnp.tile(cell_mask, (self.size[1], self.size[0]))
1460
1595
 
1461
- # For each aperture cell, expand to 9 cells
1462
- # We need to repeat each cell coordinate 9 times, then add offsets
1463
- num_aperture_cells = y_flat.size
1464
- y_base = jnp.repeat(
1465
- y_flat * 3, 9
1466
- ) # Repeat each y coord 9 times and scale by 3
1467
- x_base = jnp.repeat(
1468
- x_flat * 3, 9
1469
- ) # Repeat each x coord 9 times and scale by 3
1470
- y_offsets = jnp.tile(
1471
- offsets[:, 0], num_aperture_cells
1472
- ) # Tile all 9 offsets
1473
- x_offsets = jnp.tile(
1474
- offsets[:, 1], num_aperture_cells
1475
- ) # Tile all 9 offsets
1476
- y_expanded = y_base + y_offsets
1477
- x_expanded = x_base + x_offsets
1478
-
1479
- # Get original colors and tint them
1480
- original_colors = img[y_expanded, x_expanded]
1481
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1482
- img = img.at[y_expanded, x_expanded].set(tinted_colors)
1483
-
1484
- # Agent color - set all 9 cells of the agent's 3x3 block
1485
- agent_y, agent_x = state.pos[1], state.pos[0]
1486
- agent_offsets = jnp.array(
1487
- [[dy, dx] for dy in range(3) for dx in range(3)]
1596
+ # Repeat base colors and rewards to 3x3
1597
+ base_img_x3 = jnp.repeat(jnp.repeat(img, 3, axis=0), 3, axis=1)
1598
+ reward_colors_x3 = jnp.repeat(
1599
+ jnp.repeat(reward_colors, 3, axis=0), 3, axis=1
1488
1600
  )
1489
- agent_y_cells = agent_y * 3 + agent_offsets[:, 0]
1490
- agent_x_cells = agent_x * 3 + agent_offsets[:, 1]
1491
- img = img.at[agent_y_cells, agent_x_cells].set(
1492
- jnp.array(AGENT.color, dtype=jnp.uint8)
1601
+
1602
+ # Composite base and reward colors
1603
+ img = jnp.where(
1604
+ grid_reward_mask[..., None], reward_colors_x3, base_img_x3
1493
1605
  )
1494
1606
 
1495
- # Scale by 8 to final size
1496
- img = jax.image.resize(
1497
- img,
1498
- (self.size[1] * 24, self.size[0] * 24, 3),
1499
- jax.image.ResizeMethod.NEAREST,
1607
+ # Tint the aperture region at 3x scale
1608
+ aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
1609
+ aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
1610
+ aperture_mask_x3 = jnp.repeat(
1611
+ jnp.repeat(aperture_mask, 3, axis=0), 3, axis=1
1500
1612
  )
1613
+
1614
+ tinted_img = (
1615
+ (1.0 - alpha) * img.astype(jnp.float32)
1616
+ + alpha * self.agent_color_jax.astype(jnp.float32)
1617
+ ).astype(jnp.uint8)
1618
+ img = jnp.where(aperture_mask_x3[..., None], tinted_img, img)
1619
+
1620
+ # Set agent center block
1621
+ agent_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
1622
+ agent_mask = agent_mask.at[state.pos[1], state.pos[0]].set(True)
1623
+ agent_mask_x3 = jnp.repeat(jnp.repeat(agent_mask, 3, axis=0), 3, axis=1)
1624
+ img = jnp.where(agent_mask_x3[..., None], self.agent_color_jax, img)
1625
+
1626
+ # Final scale by 8 to get 24x
1627
+ img = jnp.repeat(jnp.repeat(img, 8, axis=0), 8, axis=1)
1501
1628
  else:
1502
1629
  # Standard rendering without reward visualization
1503
- if self.nowrap:
1504
- # Create tint mask: any in-bounds original position maps to a cell makes it tinted
1505
- tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
1506
- tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
1507
- # Apply tint to masked positions
1508
- original_colors = img
1509
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1510
- img = jnp.where(tint_mask[..., None], tinted_colors, img)
1511
- else:
1512
- original_colors = img[y_coords_adj, x_coords_adj]
1513
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1514
- img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
1630
+ aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
1631
+ aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
1515
1632
 
1516
- # Agent color
1517
- img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
1633
+ tinted_img = (
1634
+ (1.0 - alpha) * img.astype(jnp.float32)
1635
+ + alpha * self.agent_color_jax.astype(jnp.float32)
1636
+ ).astype(jnp.uint8)
1637
+ img = jnp.where(aperture_mask[..., None], tinted_img, img)
1518
1638
 
1519
- img = jax.image.resize(
1520
- img,
1521
- (self.size[1] * 24, self.size[0] * 24, 3),
1522
- jax.image.ResizeMethod.NEAREST,
1523
- )
1639
+ # Set agent
1640
+ img = img.at[state.pos[1], state.pos[0]].set(self.agent_color_jax)
1641
+ # Scale by 24
1642
+ img = jnp.repeat(jnp.repeat(img, 24, axis=0), 24, axis=1)
1524
1643
 
1525
1644
  if is_true_mode:
1526
- # Apply true object borders by overlaying true colors on border pixels
1527
- render_grid = state.object_state.object_id
1645
+ # Apply true object borders
1528
1646
  img = apply_true_borders(
1529
- img, render_grid, self.size, len(self.object_ids)
1647
+ img, state.object_state.object_id, self.size, len(self.object_ids)
1530
1648
  )
1531
1649
 
1532
- # Add grid lines for world mode
1533
- grid_color = jnp.zeros(3, dtype=jnp.uint8)
1534
- row_indices = jnp.arange(1, self.size[1]) * 24
1535
- col_indices = jnp.arange(1, self.size[0]) * 24
1536
- img = img.at[row_indices, :].set(grid_color)
1537
- img = img.at[:, col_indices].set(grid_color)
1650
+ # Add grid lines using masking instead of slice-setting
1651
+ row_grid = (jnp.arange(self.size[1] * 24) % 24) == 0
1652
+ col_grid = (jnp.arange(self.size[0] * 24) % 24) == 0
1653
+ # skip first rows/cols as they are borders or managed by caller
1654
+ row_grid = row_grid.at[0].set(False)
1655
+ col_grid = col_grid.at[0].set(False)
1656
+ grid_mask = row_grid[:, None] | col_grid[None, :]
1657
+ img = jnp.where(grid_mask[..., None], self.grid_color_jax, img)
1538
1658
 
1539
1659
  elif is_aperture_mode:
1540
1660
  obs_grid = state.object_state.object_id
@@ -1581,11 +1701,16 @@ class ForagaxEnv(environment.Environment):
1581
1701
  self._compute_aperture_coordinates(state.pos)
1582
1702
  )
1583
1703
 
1584
- # Get reward grid for the full world
1585
- full_reward_grid = self._compute_reward_grid(state)
1586
-
1587
- # Extract aperture rewards
1588
- aperture_rewards = full_reward_grid[y_coords_adj, x_coords_adj]
1704
+ # Get reward grid only for aperture region
1705
+ aperture_object_ids = state.object_state.object_id[
1706
+ y_coords_adj, x_coords_adj
1707
+ ]
1708
+ aperture_params = state.object_state.state_params[
1709
+ y_coords_adj, x_coords_adj
1710
+ ]
1711
+ aperture_rewards = self._compute_reward_grid(
1712
+ state, aperture_object_ids, aperture_params
1713
+ )
1589
1714
 
1590
1715
  # Convert rewards to colors
1591
1716
  reward_colors = self._reward_to_color(aperture_rewards)