continual-foragax 0.39.0__py3-none-any.whl → 0.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foragax/env.py CHANGED
@@ -6,7 +6,7 @@ Source: https://github.com/andnp/Forager
6
6
  from dataclasses import dataclass
7
7
  from enum import IntEnum
8
8
  from functools import partial
9
- from typing import Any, Dict, Optional, Tuple, Union
9
+ from typing import Any, Dict, Tuple, Union
10
10
 
11
11
  import jax
12
12
  import jax.numpy as jnp
@@ -26,6 +26,15 @@ from foragax.rendering import apply_true_borders
26
26
  from foragax.weather import get_temperature
27
27
 
28
28
 
29
+ ID_DTYPE = jnp.int32 # Object type ID (0 = empty, >0 = object type)
30
+ TIMER_DTYPE = jnp.int32 # Respawn countdown (0 = no timer, >0 = countdown)
31
+ TIME_DTYPE = jnp.int32 # Timesteps (spawn time, current time)
32
+ PARAM_DTYPE = jnp.float16 # Per-instance object parameters
33
+ COLOR_DTYPE = jnp.uint8 # RGB color channels (0-255)
34
+ BIOME_ID_DTYPE = jnp.int16 # Biome assignment for each cell
35
+ REWARD_DTYPE = jnp.float32 # Reward values
36
+
37
+
29
38
  class Actions(IntEnum):
30
39
  DOWN = 0
31
40
  RIGHT = 1
@@ -74,14 +83,14 @@ class ObjectState:
74
83
  """Create an empty ObjectState for the given grid size."""
75
84
  h, w = size[1], size[0]
76
85
  return cls(
77
- object_id=jnp.zeros((h, w), dtype=int),
78
- respawn_timer=jnp.zeros((h, w), dtype=int),
79
- respawn_object_id=jnp.zeros((h, w), dtype=int),
80
- spawn_time=jnp.zeros((h, w), dtype=int),
81
- color=jnp.full((h, w, 3), 255, dtype=jnp.uint8),
82
- generation=jnp.zeros((h, w), dtype=int),
83
- state_params=jnp.zeros((h, w, num_params), dtype=jnp.float32),
84
- biome_id=jnp.full((h, w), -1, dtype=int),
86
+ object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
87
+ respawn_timer=jnp.zeros((h, w), dtype=TIMER_DTYPE),
88
+ respawn_object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
89
+ spawn_time=jnp.zeros((h, w), dtype=TIME_DTYPE),
90
+ color=jnp.full((h, w, 3), 255, dtype=COLOR_DTYPE),
91
+ generation=jnp.zeros((h, w), dtype=ID_DTYPE),
92
+ state_params=jnp.zeros((h, w, num_params), dtype=PARAM_DTYPE),
93
+ biome_id=jnp.full((h, w), -1, dtype=BIOME_ID_DTYPE),
85
94
  )
86
95
 
87
96
 
@@ -128,7 +137,6 @@ class ForagaxEnv(environment.Environment):
128
137
  biomes: Tuple[Biome, ...] = (Biome(object_frequencies=()),),
129
138
  nowrap: bool = False,
130
139
  deterministic_spawn: bool = False,
131
- teleport_interval: Optional[int] = None,
132
140
  observation_type: str = "object",
133
141
  dynamic_biomes: bool = False,
134
142
  biome_consumption_threshold: float = 0.9,
@@ -154,7 +162,6 @@ class ForagaxEnv(environment.Environment):
154
162
  self.observation_type = observation_type
155
163
  self.nowrap = nowrap
156
164
  self.deterministic_spawn = deterministic_spawn
157
- self.teleport_interval = teleport_interval
158
165
  self.dynamic_biomes = dynamic_biomes
159
166
  self.biome_consumption_threshold = biome_consumption_threshold
160
167
  self.dynamic_biome_spawn_empty = dynamic_biome_spawn_empty
@@ -230,13 +237,13 @@ class ForagaxEnv(environment.Environment):
230
237
  # Create mask for the biome
231
238
  start = jax.lax.select(
232
239
  self.biome_starts[i, 0] == -1,
233
- jnp.array([0, 0]),
234
- self.biome_starts[i],
240
+ jnp.array([0, 0], dtype=jnp.int32),
241
+ self.biome_starts[i].astype(jnp.int32),
235
242
  )
236
243
  stop = jax.lax.select(
237
244
  self.biome_stops[i, 0] == -1,
238
- jnp.array(self.size),
239
- self.biome_stops[i],
245
+ jnp.array(self.size, dtype=jnp.int32),
246
+ self.biome_stops[i].astype(jnp.int32),
240
247
  )
241
248
  rows = jnp.arange(self.size[1])[:, None]
242
249
  cols = jnp.arange(self.size[0])
@@ -275,12 +282,18 @@ class ForagaxEnv(environment.Environment):
275
282
  # color_indices maps from object_id-1 to color_channel_index
276
283
  self.object_to_color_map = color_indices
277
284
 
285
+ # Rendering constants
286
+ self.agent_color_jax = jnp.array(AGENT.color, dtype=jnp.uint8)
287
+ self.white_color_jax = jnp.array([255, 255, 255], dtype=jnp.uint8)
288
+ self.grid_color_jax = jnp.zeros(3, dtype=jnp.uint8)
289
+
278
290
  @property
279
291
  def default_params(self) -> EnvParams:
280
292
  return EnvParams(
281
293
  max_steps_in_episode=None,
282
294
  )
283
295
 
296
+ @partial(jax.named_call, name="_place_timer")
284
297
  def _place_timer(
285
298
  self,
286
299
  object_state: ObjectState,
@@ -304,13 +317,24 @@ class ForagaxEnv(environment.Environment):
304
317
  Returns:
305
318
  Updated object_state with timer placed
306
319
  """
320
+ # Ensure inputs match ObjectState dtypes
321
+ object_type = jnp.array(object_type, dtype=ID_DTYPE)
322
+ timer_val = jnp.array(timer_val, dtype=TIMER_DTYPE)
323
+ y = jnp.array(y, dtype=jnp.int32)
324
+ x = jnp.array(x, dtype=jnp.int32)
307
325
 
308
326
  # Handle permanent removal (timer_val == 0)
309
327
  def place_empty():
310
328
  return object_state.replace(
311
- object_id=object_state.object_id.at[y, x].set(0),
312
- respawn_timer=object_state.respawn_timer.at[y, x].set(0),
313
- respawn_object_id=object_state.respawn_object_id.at[y, x].set(0),
329
+ object_id=object_state.object_id.at[y, x].set(
330
+ jnp.array(0, dtype=ID_DTYPE)
331
+ ),
332
+ respawn_timer=object_state.respawn_timer.at[y, x].set(
333
+ jnp.array(0, dtype=TIMER_DTYPE)
334
+ ),
335
+ respawn_object_id=object_state.respawn_object_id.at[y, x].set(
336
+ jnp.array(0, dtype=ID_DTYPE)
337
+ ),
314
338
  )
315
339
 
316
340
  # Handle timer placement
@@ -318,7 +342,9 @@ class ForagaxEnv(environment.Environment):
318
342
  # Non-random: place at original position
319
343
  def place_at_position():
320
344
  return object_state.replace(
321
- object_id=object_state.object_id.at[y, x].set(0),
345
+ object_id=object_state.object_id.at[y, x].set(
346
+ jnp.array(0, dtype=ID_DTYPE)
347
+ ),
322
348
  respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
323
349
  respawn_object_id=object_state.respawn_object_id.at[y, x].set(
324
350
  object_type
@@ -328,9 +354,15 @@ class ForagaxEnv(environment.Environment):
328
354
  # Random: place at random location in same biome
329
355
  def place_randomly():
330
356
  # Clear the collected object's position
331
- new_object_id = object_state.object_id.at[y, x].set(0)
332
- new_respawn_timer = object_state.respawn_timer.at[y, x].set(0)
333
- new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(0)
357
+ new_object_id = object_state.object_id.at[y, x].set(
358
+ jnp.array(0, dtype=ID_DTYPE)
359
+ )
360
+ new_respawn_timer = object_state.respawn_timer.at[y, x].set(
361
+ jnp.array(0, dtype=TIMER_DTYPE)
362
+ )
363
+ new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(
364
+ jnp.array(0, dtype=ID_DTYPE)
365
+ )
334
366
 
335
367
  # Find valid spawn locations in the same biome
336
368
  biome_id = object_state.biome_id[y, x]
@@ -338,13 +370,15 @@ class ForagaxEnv(environment.Environment):
338
370
  empty_mask = new_object_id == 0
339
371
  no_timer_mask = new_respawn_timer == 0
340
372
  valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
341
- num_valid_spawns = jnp.sum(valid_spawn_mask)
373
+ num_valid_spawns = jnp.sum(valid_spawn_mask, dtype=jnp.int32)
342
374
 
343
375
  y_indices, x_indices = jnp.nonzero(
344
376
  valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
345
377
  )
346
378
  valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
347
- random_idx = jax.random.randint(rand_key, (), 0, num_valid_spawns)
379
+ random_idx = jax.random.randint(
380
+ rand_key, (), jnp.array(0, dtype=jnp.int32), num_valid_spawns
381
+ )
348
382
  new_spawn_pos = valid_spawn_indices[random_idx]
349
383
 
350
384
  # Place timer at the new random position
@@ -365,17 +399,11 @@ class ForagaxEnv(environment.Environment):
365
399
 
366
400
  return jax.lax.cond(timer_val == 0, place_empty, place_timer)
367
401
 
368
- def step_env(
369
- self,
370
- key: jax.Array,
371
- state: EnvState,
372
- action: Union[int, float, jax.Array],
373
- params: EnvParams,
374
- ) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
375
- """Perform single timestep state transition."""
402
+ @partial(jax.named_call, name="move_agent")
403
+ def _move_agent(
404
+ self, state: EnvState, action: Union[int, float, jax.Array]
405
+ ) -> Tuple[jax.Array, jax.Array]:
376
406
  current_objects = state.object_state.object_id
377
-
378
- # 1. UPDATE AGENT POSITION
379
407
  direction = DIRECTIONS[action]
380
408
  new_pos = state.pos + direction
381
409
 
@@ -388,45 +416,44 @@ class ForagaxEnv(environment.Environment):
388
416
 
389
417
  # Check for blocking objects
390
418
  obj_at_new_pos = current_objects[new_pos[1], new_pos[0]]
391
- is_blocking = self.object_blocking[obj_at_new_pos]
392
- pos = jax.lax.select(is_blocking, state.pos, new_pos)
419
+ is_blocking = self.object_blocking[obj_at_new_pos.astype(jnp.int32)]
420
+ pos = jnp.where(
421
+ is_blocking[..., None],
422
+ state.pos.astype(jnp.int32),
423
+ new_pos.astype(jnp.int32),
424
+ )
425
+ return pos, new_pos
393
426
 
394
- # Check for automatic teleport
395
- if self.teleport_interval is not None:
396
- should_teleport = jnp.mod(state.time + 1, self.teleport_interval) == 0
397
- else:
398
- should_teleport = False
399
-
400
- def teleport_fn():
401
- # Calculate squared distances from current position to each biome center
402
- diffs = self.biome_centers_jax - pos
403
- distances = jnp.sum(diffs**2, axis=1)
404
- # Find the index of the furthest biome center
405
- furthest_idx = jnp.argmax(distances)
406
- new_pos = self.biome_centers_jax[furthest_idx]
407
- return new_pos
408
-
409
- pos = jax.lax.cond(should_teleport, teleport_fn, lambda: pos)
410
-
411
- # 2. HANDLE COLLISIONS AND REWARDS
412
- obj_at_pos = current_objects[pos[1], pos[0]]
413
- is_collectable = self.object_collectable[obj_at_pos]
414
- should_collect = is_collectable & (obj_at_pos > 0)
415
-
416
- # Handle digestion: add reward to buffer if collected
417
- digestion_buffer = state.digestion_buffer
427
+ @partial(jax.named_call, name="compute_reward")
428
+ def _compute_reward(
429
+ self,
430
+ state: EnvState,
431
+ pos: jax.Array,
432
+ key: jax.Array,
433
+ should_collect: jax.Array,
434
+ digestion_buffer: jax.Array,
435
+ obj_at_pos: jax.Array,
436
+ ) -> Tuple[jax.Array, jax.Array]:
418
437
  key, reward_subkey = jax.random.split(key)
419
438
 
420
439
  object_params = state.object_state.state_params[pos[1], pos[0]]
421
440
  object_reward = jax.lax.switch(
422
- obj_at_pos, self.reward_fns, state.time, reward_subkey, object_params
441
+ obj_at_pos.astype(jnp.int32),
442
+ self.reward_fns,
443
+ state.time,
444
+ reward_subkey,
445
+ object_params.astype(jnp.float32),
423
446
  )
424
447
 
425
448
  key, digestion_subkey = jax.random.split(key)
426
449
  reward_delay = jax.lax.switch(
427
450
  obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
428
451
  )
429
- reward = jnp.where(should_collect & (reward_delay == 0), object_reward, 0.0)
452
+ reward = jnp.where(
453
+ should_collect & (reward_delay == jnp.array(0, dtype=jnp.int32)),
454
+ object_reward,
455
+ 0.0,
456
+ )
430
457
  if self.max_reward_delay > 0:
431
458
  # Add delayed rewards to buffer
432
459
  digestion_buffer = jax.lax.cond(
@@ -440,20 +467,33 @@ class ForagaxEnv(environment.Environment):
440
467
  current_index = state.time % self.max_reward_delay
441
468
  reward += digestion_buffer[current_index]
442
469
  digestion_buffer = digestion_buffer.at[current_index].set(0.0)
470
+ return reward, digestion_buffer
443
471
 
472
+ @partial(jax.named_call, name="respawn_logic")
473
+ def _respawn_logic(
474
+ self,
475
+ state: EnvState,
476
+ pos: jax.Array,
477
+ key: jax.Array,
478
+ current_objects: jax.Array,
479
+ is_collectable: jax.Array,
480
+ obj_at_pos: jax.Array,
481
+ ) -> ObjectState:
444
482
  # 3. HANDLE OBJECT COLLECTION AND RESPAWNING
445
483
  key, regen_subkey, rand_key = jax.random.split(key, 3)
446
484
 
447
485
  # Decrement respawn timers
448
- has_timer = state.object_state.respawn_timer > 0
486
+ has_timer = state.object_state.respawn_timer > jnp.array(0, dtype=TIMER_DTYPE)
449
487
  new_respawn_timer = jnp.where(
450
488
  has_timer,
451
- state.object_state.respawn_timer - 1,
489
+ state.object_state.respawn_timer - jnp.array(1, dtype=TIMER_DTYPE),
452
490
  state.object_state.respawn_timer,
453
491
  )
454
492
 
455
493
  # Track which cells have timers that just reached 0
456
- just_respawned = has_timer & (new_respawn_timer == 0)
494
+ just_respawned = has_timer & (
495
+ new_respawn_timer == jnp.array(0, dtype=TIMER_DTYPE)
496
+ )
457
497
 
458
498
  # Respawn objects where timer reached 0
459
499
  new_object_id = jnp.where(
@@ -465,7 +505,7 @@ class ForagaxEnv(environment.Environment):
465
505
  # Clear respawn_object_id for cells that just respawned
466
506
  new_respawn_object_id = jnp.where(
467
507
  just_respawned,
468
- 0,
508
+ jnp.array(0, dtype=ID_DTYPE),
469
509
  state.object_state.respawn_object_id,
470
510
  )
471
511
 
@@ -478,16 +518,19 @@ class ForagaxEnv(environment.Environment):
478
518
  regen_delay = jax.lax.switch(
479
519
  obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
480
520
  )
521
+ # Cast timer_countdown to match ObjectState.respawn_timer dtype
481
522
  timer_countdown = jax.lax.cond(
482
523
  regen_delay == jnp.iinfo(jnp.int32).max,
483
- lambda: 0, # No timer (permanent removal)
484
- lambda: regen_delay + 1, # Timer countdown
524
+ lambda: jnp.array(0, dtype=TIMER_DTYPE), # No timer (permanent removal)
525
+ lambda: (regen_delay + 1).astype(TIMER_DTYPE), # Timer countdown
485
526
  )
486
527
 
487
528
  # If collected, replace object with timer; otherwise, keep it
488
529
  val_at_pos = current_objects[pos[1], pos[0]]
489
530
  # Use original should_collect for consumption tracking
490
- should_collect_now = is_collectable & (val_at_pos > 0)
531
+ should_collect_now = is_collectable & (
532
+ val_at_pos > jnp.array(0, dtype=ID_DTYPE)
533
+ )
491
534
 
492
535
  # Create updated object state with new respawn_timer, object_id, and spawn_time
493
536
  object_state = state.object_state.replace(
@@ -511,12 +554,17 @@ class ForagaxEnv(environment.Environment):
511
554
  ),
512
555
  lambda: object_state,
513
556
  )
557
+ return object_state
514
558
 
515
- # 3.5. HANDLE OBJECT EXPIRY
516
- # Only process expiry if there are objects that can expire
517
- key, object_state = self.expire_objects(key, state, object_state)
518
-
519
- # 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
559
+ @partial(jax.named_call, name="dynamic_biomes")
560
+ def _dynamic_biomes(
561
+ self,
562
+ state: EnvState,
563
+ pos: jax.Array,
564
+ key: jax.Array,
565
+ object_state: ObjectState,
566
+ should_collect: jax.Array,
567
+ ) -> Tuple[ObjectState, BiomeState]:
520
568
  if self.dynamic_biomes:
521
569
  # Update consumption count if an object was collected
522
570
  # Only count if the object belongs to the current generation of its biome
@@ -528,7 +576,9 @@ class ForagaxEnv(environment.Environment):
528
576
  biome_consumption_count = state.biome_state.consumption_count
529
577
  biome_consumption_count = jax.lax.cond(
530
578
  should_collect & is_current_generation,
531
- lambda: biome_consumption_count.at[collected_biome_id].add(1),
579
+ lambda: biome_consumption_count.at[collected_biome_id].add(
580
+ jnp.array(1, dtype=ID_DTYPE)
581
+ ),
532
582
  lambda: biome_consumption_count,
533
583
  )
534
584
 
@@ -545,8 +595,73 @@ class ForagaxEnv(environment.Environment):
545
595
  state.time,
546
596
  respawn_key,
547
597
  )
598
+ return object_state, biome_state
548
599
  else:
549
- biome_state = state.biome_state
600
+ return object_state, state.biome_state
601
+
602
+ @partial(jax.named_call, name="reward_grid")
603
+ def _reward_grid(self, state: EnvState, object_state: ObjectState) -> jax.Array:
604
+ # Compute reward at each grid position
605
+ fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
606
+
607
+ def compute_reward(obj_id, params):
608
+ return jax.lax.cond(
609
+ obj_id > jnp.array(0, dtype=ID_DTYPE),
610
+ lambda: jax.lax.switch(
611
+ obj_id.astype(jnp.int32),
612
+ self.reward_fns,
613
+ state.time,
614
+ fixed_key,
615
+ params.astype(jnp.float32),
616
+ ),
617
+ lambda: 0.0,
618
+ )
619
+
620
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(
621
+ object_state.object_id.astype(ID_DTYPE),
622
+ object_state.state_params.astype(PARAM_DTYPE),
623
+ )
624
+ return reward_grid
625
+
626
+ def step_env(
627
+ self,
628
+ key: jax.Array,
629
+ state: EnvState,
630
+ action: Union[int, float, jax.Array],
631
+ params: EnvParams,
632
+ ) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
633
+ """Perform single timestep state transition."""
634
+ current_objects = state.object_state.object_id
635
+ pos, new_pos = self._move_agent(state, action)
636
+
637
+ with jax.named_scope("compute_reward"):
638
+ # 2. HANDLE COLLISIONS AND REWARDS
639
+ obj_at_pos = current_objects[pos[1], pos[0]]
640
+ is_collectable = self.object_collectable[obj_at_pos]
641
+ should_collect = is_collectable & (
642
+ obj_at_pos > jnp.array(0, dtype=ID_DTYPE)
643
+ )
644
+
645
+ # Handle digestion: add reward to buffer if collected
646
+ digestion_buffer = state.digestion_buffer
647
+ key, reward_subkey = jax.random.split(key)
648
+
649
+ reward, digestion_buffer = self._compute_reward(
650
+ state, pos, key, should_collect, digestion_buffer, obj_at_pos
651
+ )
652
+
653
+ object_state = self._respawn_logic(
654
+ state, pos, key, current_objects, is_collectable, obj_at_pos
655
+ )
656
+
657
+ # 3.5. HANDLE OBJECT EXPIRY
658
+ # Only process expiry if there are objects that can expire
659
+ key, object_state = self.expire_objects(key, state, object_state)
660
+
661
+ # 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
662
+ object_state, biome_state = self._dynamic_biomes(
663
+ state, pos, key, object_state, should_collect
664
+ )
550
665
 
551
666
  info = {"discount": self.discount(state, params)}
552
667
  temperatures = jnp.zeros(len(self.objects))
@@ -557,33 +672,40 @@ class ForagaxEnv(environment.Environment):
557
672
  )
558
673
  info["temperatures"] = temperatures
559
674
  info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
560
- info["object_collected_id"] = jax.lax.select(should_collect, obj_at_pos, -1)
675
+ info["object_collected_id"] = jnp.where(
676
+ should_collect,
677
+ obj_at_pos.astype(ID_DTYPE),
678
+ jnp.array(-1, dtype=ID_DTYPE),
679
+ )
561
680
 
562
681
  # 4. UPDATE STATE
682
+ # Ensure all fields have canonical dtypes for consistency (e.g., for gymnax step selection)
683
+ object_state = object_state.replace(
684
+ object_id=object_state.object_id.astype(ID_DTYPE),
685
+ respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
686
+ respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
687
+ spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
688
+ color=object_state.color.astype(COLOR_DTYPE),
689
+ generation=object_state.generation.astype(ID_DTYPE),
690
+ state_params=object_state.state_params.astype(PARAM_DTYPE),
691
+ biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
692
+ )
693
+ biome_state = biome_state.replace(
694
+ consumption_count=biome_state.consumption_count.astype(ID_DTYPE),
695
+ total_objects=biome_state.total_objects.astype(ID_DTYPE),
696
+ generation=biome_state.generation.astype(ID_DTYPE),
697
+ )
698
+
563
699
  state = EnvState(
564
- pos=pos,
565
- time=state.time + 1,
566
- digestion_buffer=digestion_buffer,
700
+ pos=pos.astype(jnp.int32),
701
+ time=jnp.array(state.time + 1, dtype=jnp.int32),
702
+ digestion_buffer=digestion_buffer.astype(REWARD_DTYPE),
567
703
  object_state=object_state,
568
704
  biome_state=biome_state,
569
705
  )
570
706
 
571
- # Compute reward at each grid position
572
- fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
573
-
574
- def compute_reward(obj_id, params):
575
- return jax.lax.cond(
576
- obj_id > 0,
577
- lambda: jax.lax.switch(
578
- obj_id, self.reward_fns, state.time, fixed_key, params
579
- ),
580
- lambda: 0.0,
581
- )
582
-
583
- reward_grid = jax.vmap(jax.vmap(compute_reward))(
584
- object_state.object_id, object_state.state_params
585
- )
586
- info["rewards"] = reward_grid
707
+ reward_grid = self._reward_grid(state, object_state)
708
+ info["rewards"] = reward_grid.astype(jnp.float16)
587
709
 
588
710
  done = self.is_terminal(state, params)
589
711
  return (
@@ -594,6 +716,7 @@ class ForagaxEnv(environment.Environment):
594
716
  info,
595
717
  )
596
718
 
719
+ @partial(jax.named_call, name="expire_objects")
597
720
  def expire_objects(
598
721
  self, key, state, object_state: ObjectState
599
722
  ) -> Tuple[jax.Array, ObjectState]:
@@ -610,8 +733,8 @@ class ForagaxEnv(environment.Environment):
610
733
  # Check if object should expire (age >= expiry_time and expiry_time >= 0)
611
734
  should_expire = (
612
735
  (object_ages >= expiry_times)
613
- & (expiry_times >= 0)
614
- & (current_objects_for_expiry > 0)
736
+ & (expiry_times >= jnp.array(0, dtype=TIME_DTYPE))
737
+ & (current_objects_for_expiry > jnp.array(0, dtype=ID_DTYPE))
615
738
  )
616
739
 
617
740
  # Only process expiry if there are actually objects to expire
@@ -646,8 +769,8 @@ class ForagaxEnv(environment.Environment):
646
769
  )
647
770
  timer_countdown = jax.lax.cond(
648
771
  exp_delay == jnp.iinfo(jnp.int32).max,
649
- lambda: 0,
650
- lambda: exp_delay + 1,
772
+ lambda: jnp.array(0, dtype=TIMER_DTYPE),
773
+ lambda: (exp_delay + 1).astype(TIMER_DTYPE),
651
774
  )
652
775
 
653
776
  respawn_random = self.object_random_respawn[obj_id]
@@ -707,147 +830,189 @@ class ForagaxEnv(environment.Environment):
707
830
  biome_state.consumption_count >= self.biome_consumption_threshold
708
831
  )
709
832
 
710
- # Split key for all biomes in parallel
711
- key, subkey = jax.random.split(key)
712
- biome_keys = jax.random.split(subkey, num_biomes)
833
+ any_respawn = jnp.any(should_respawn)
713
834
 
714
- # Compute all new spawns in parallel using vmap for random, switch for deterministic
715
- if self.deterministic_spawn:
716
- # Use switch to dispatch to concrete biome spawns for deterministic
717
- def make_spawn_fn(biome_idx):
718
- def spawn_fn(key):
719
- return self._spawn_biome_objects(biome_idx, key, deterministic=True)
835
+ def do_respawn(args):
836
+ (
837
+ object_state,
838
+ biome_state,
839
+ should_respawn,
840
+ key,
841
+ ) = args
842
+ # Split key for all biomes in parallel
843
+ key, subkey = jax.random.split(key)
844
+ biome_keys = jax.random.split(subkey, num_biomes)
845
+
846
+ # Compute all new spawns in parallel using vmap for random, switch for deterministic
847
+ if self.deterministic_spawn:
848
+ # Use switch to dispatch to concrete biome spawns for deterministic
849
+ def make_spawn_fn(biome_idx):
850
+ def spawn_fn(key):
851
+ return self._spawn_biome_objects(
852
+ biome_idx, key, deterministic=True
853
+ )
720
854
 
721
- return spawn_fn
855
+ return spawn_fn
722
856
 
723
- spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
857
+ spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
724
858
 
725
- # Apply switch for each biome
726
- all_new_objects_list = []
727
- all_new_colors_list = []
728
- all_new_params_list = []
729
- for i in range(num_biomes):
730
- obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
731
- all_new_objects_list.append(obj)
732
- all_new_colors_list.append(col)
733
- all_new_params_list.append(par)
734
-
735
- all_new_objects = jnp.stack(all_new_objects_list)
736
- all_new_colors = jnp.stack(all_new_colors_list)
737
- all_new_params = jnp.stack(all_new_params_list)
738
- else:
739
- # Random spawn works with vmap
740
- all_new_objects, all_new_colors, all_new_params = jax.vmap(
741
- lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
742
- )(jnp.arange(num_biomes), biome_keys)
743
-
744
- # Initialize updated grids
745
- new_obj_id = object_state.object_id
746
- new_color = object_state.color
747
- new_params = object_state.state_params
748
- new_spawn = object_state.spawn_time
749
- new_gen = object_state.generation
750
-
751
- # Update biome state
752
- new_consumption_count = jnp.where(
753
- should_respawn, 0, biome_state.consumption_count
754
- )
755
- new_generation = biome_state.generation + should_respawn.astype(int)
859
+ # Apply switch for each biome
860
+ all_new_objects_list = []
861
+ all_new_colors_list = []
862
+ all_new_params_list = []
863
+ for i in range(num_biomes):
864
+ obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
865
+ all_new_objects_list.append(obj)
866
+ all_new_colors_list.append(col)
867
+ all_new_params_list.append(par)
868
+
869
+ all_new_objects = jnp.stack(all_new_objects_list)
870
+ all_new_colors = jnp.stack(all_new_colors_list)
871
+ all_new_params = jnp.stack(all_new_params_list)
872
+ else:
873
+ # Random spawn works with vmap
874
+ all_new_objects, all_new_colors, all_new_params = jax.vmap(
875
+ lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
876
+ )(jnp.arange(num_biomes), biome_keys)
877
+
878
+ # Initialize updated grids
879
+ new_obj_id = object_state.object_id
880
+ new_color = object_state.color
881
+ new_params = object_state.state_params
882
+ new_spawn = object_state.spawn_time
883
+ new_gen = object_state.generation
884
+
885
+ # Update biome state
886
+ new_consumption_count = jnp.where(
887
+ should_respawn,
888
+ jnp.array(0, dtype=ID_DTYPE),
889
+ biome_state.consumption_count,
890
+ )
891
+ new_generation = biome_state.generation + should_respawn.astype(ID_DTYPE)
756
892
 
757
- # Compute new total objects for respawning biomes
758
- def count_objects(i):
759
- return jnp.sum((all_new_objects[i] > 0) & self.biome_masks_array[i])
893
+ # Compute new total objects for respawning biomes
894
+ def count_objects(i):
895
+ return jnp.sum(
896
+ (all_new_objects[i] > 0) & self.biome_masks_array[i], dtype=ID_DTYPE
897
+ )
760
898
 
761
- new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
762
- new_total_objects = jnp.where(
763
- should_respawn, new_object_counts, biome_state.total_objects
764
- )
899
+ new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
900
+ new_total_objects = jnp.where(
901
+ should_respawn, new_object_counts, biome_state.total_objects
902
+ )
765
903
 
766
- new_biome_state = BiomeState(
767
- consumption_count=new_consumption_count,
768
- total_objects=new_total_objects,
769
- generation=new_generation,
770
- )
904
+ new_biome_state = BiomeState(
905
+ consumption_count=new_consumption_count,
906
+ total_objects=new_total_objects,
907
+ generation=new_generation,
908
+ )
771
909
 
772
- # Update grids for respawning biomes
773
- for i in range(num_biomes):
774
- biome_mask = self.biome_masks_array[i]
775
- new_gen_value = new_biome_state.generation[i]
910
+ # Update grids for respawning biomes
911
+ for i in range(num_biomes):
912
+ biome_mask = self.biome_masks_array[i]
913
+ new_gen_value = new_biome_state.generation[i]
776
914
 
777
- # Update mask: biome area AND needs respawn
778
- should_update = biome_mask & should_respawn[i][..., None]
915
+ # Update mask: biome area AND needs respawn
916
+ should_update = biome_mask & should_respawn[i][..., None]
779
917
 
780
- # 1. Merge: Overwrite with new objects if present, otherwise keep existing
781
- # If the new spawn has an object, we take it. If it's empty, we keep whatever was there.
782
- new_spawn_valid = all_new_objects[i] > 0
918
+ # 1. Merge: Overwrite with new objects if present, otherwise keep existing
919
+ new_spawn_valid = all_new_objects[i] > 0
783
920
 
784
- merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
785
- merged_colors = jnp.where(
786
- new_spawn_valid[..., None], all_new_colors[i], new_color
787
- )
788
- merged_params = jnp.where(
789
- new_spawn_valid[..., None], all_new_params[i], new_params
790
- )
921
+ merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
922
+ merged_colors = jnp.where(
923
+ new_spawn_valid[..., None], all_new_colors[i], new_color
924
+ )
925
+ merged_params = jnp.where(
926
+ new_spawn_valid[..., None], all_new_params[i], new_params
927
+ )
791
928
 
792
- # For generation/spawn time, update only if we took the NEW object
793
- merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
794
- merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
929
+ merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
930
+ merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
795
931
 
796
- # 2. Apply dropout to the MERGED result (only where we are allowed to update)
797
- if self.dynamic_biome_spawn_empty > 0:
798
- key, dropout_key = jax.random.split(key)
799
- # Dropout applies only to cells that are both in the biome AND need updating
800
- keep_mask = jax.random.bernoulli(
801
- dropout_key, 1.0 - self.dynamic_biome_spawn_empty, merged_objs.shape
932
+ if self.dynamic_biome_spawn_empty > 0:
933
+ key, dropout_key = jax.random.split(key)
934
+ keep_mask = jax.random.bernoulli(
935
+ dropout_key,
936
+ 1.0 - self.dynamic_biome_spawn_empty,
937
+ merged_objs.shape,
938
+ )
939
+ dropout_mask = should_update & keep_mask
940
+ final_objs = jnp.where(
941
+ dropout_mask, merged_objs, jnp.array(0, dtype=ID_DTYPE)
942
+ )
943
+ final_colors = jnp.where(
944
+ dropout_mask[..., None],
945
+ merged_colors,
946
+ jnp.array(0, dtype=COLOR_DTYPE),
947
+ )
948
+ final_params = jnp.where(
949
+ dropout_mask[..., None],
950
+ merged_params,
951
+ jnp.array(0, dtype=PARAM_DTYPE),
952
+ )
953
+ final_gen = jnp.where(
954
+ dropout_mask, merged_gen, jnp.array(0, dtype=ID_DTYPE)
955
+ )
956
+ final_spawn = jnp.where(
957
+ dropout_mask, merged_spawn, jnp.array(0, dtype=TIME_DTYPE)
958
+ )
959
+ else:
960
+ final_objs = merged_objs
961
+ final_colors = merged_colors
962
+ final_params = merged_params
963
+ final_gen = merged_gen
964
+ final_spawn = merged_spawn
965
+
966
+ new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
967
+ new_color = jnp.where(should_update[..., None], final_colors, new_color)
968
+ new_params = jnp.where(
969
+ should_update[..., None], final_params, new_params
802
970
  )
803
- # Only apply dropout where should_update is true; elsewhere, keep merged_objs
804
- dropout_mask = should_update & keep_mask
805
- # Apply dropout only to the merged result and associated metadata
806
- final_objs = jnp.where(dropout_mask, merged_objs, 0)
807
- final_colors = jnp.where(dropout_mask[..., None], merged_colors, 0)
808
- final_params = jnp.where(dropout_mask[..., None], merged_params, 0)
809
- final_gen = jnp.where(dropout_mask, merged_gen, 0)
810
- final_spawn = jnp.where(dropout_mask, merged_spawn, 0)
811
- else:
812
- final_objs = merged_objs
813
- final_colors = merged_colors
814
- final_params = merged_params
815
- final_gen = merged_gen
816
- final_spawn = merged_spawn
817
-
818
- # 3. Write back: Only update where should_update is true
819
- new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
820
- new_color = jnp.where(should_update[..., None], final_colors, new_color)
821
- new_params = jnp.where(should_update[..., None], final_params, new_params)
822
- new_gen = jnp.where(should_update, final_gen, new_gen)
823
- new_spawn = jnp.where(should_update, final_spawn, new_spawn)
824
-
825
- # Clear timers in respawning biomes
826
- new_respawn_timer = object_state.respawn_timer
827
- new_respawn_object_id = object_state.respawn_object_id
828
- for i in range(num_biomes):
829
- biome_mask = self.biome_masks_array[i]
830
- should_clear = biome_mask & should_respawn[i][..., None]
831
- new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
832
- new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
971
+ new_gen = jnp.where(should_update, final_gen, new_gen)
972
+ new_spawn = jnp.where(should_update, final_spawn, new_spawn)
833
973
 
834
- object_state = object_state.replace(
835
- object_id=new_obj_id,
836
- respawn_timer=new_respawn_timer,
837
- respawn_object_id=new_respawn_object_id,
838
- color=new_color,
839
- state_params=new_params,
840
- generation=new_gen,
841
- spawn_time=new_spawn,
974
+ # Clear timers in respawning biomes
975
+ new_respawn_timer = object_state.respawn_timer
976
+ new_respawn_object_id = object_state.respawn_object_id
977
+ for i in range(num_biomes):
978
+ biome_mask = self.biome_masks_array[i]
979
+ should_clear = biome_mask & should_respawn[i][..., None]
980
+ new_respawn_timer = jnp.where(
981
+ should_clear, jnp.array(0, dtype=TIMER_DTYPE), new_respawn_timer
982
+ )
983
+ new_respawn_object_id = jnp.where(
984
+ should_clear, jnp.array(0, dtype=ID_DTYPE), new_respawn_object_id
985
+ )
986
+
987
+ new_object_state = object_state.replace(
988
+ object_id=new_obj_id,
989
+ respawn_timer=new_respawn_timer,
990
+ respawn_object_id=new_respawn_object_id,
991
+ color=new_color,
992
+ state_params=new_params,
993
+ generation=new_gen,
994
+ spawn_time=new_spawn,
995
+ )
996
+ return new_object_state, new_biome_state, key
997
+
998
+ def no_respawn(args):
999
+ object_state, biome_state, _, key = args
1000
+ return object_state, biome_state, key
1001
+
1002
+ object_state, biome_state, key = jax.lax.cond(
1003
+ any_respawn,
1004
+ do_respawn,
1005
+ no_respawn,
1006
+ (object_state, biome_state, should_respawn, key),
842
1007
  )
843
1008
 
844
- return object_state, new_biome_state, key
1009
+ return object_state, biome_state, key
845
1010
 
846
1011
  def reset_env(
847
1012
  self, key: jax.Array, params: EnvParams
848
1013
  ) -> Tuple[jax.Array, EnvState]:
849
1014
  """Reset environment state."""
850
- num_object_params = 2 + 2 * self.num_fourier_terms
1015
+ num_object_params = 3 + 2 * self.num_fourier_terms
851
1016
  object_state = ObjectState.create_empty(self.size, num_object_params)
852
1017
 
853
1018
  key, iter_key = jax.random.split(key)
@@ -859,7 +1024,9 @@ class ForagaxEnv(environment.Environment):
859
1024
 
860
1025
  # Set biome_id
861
1026
  object_state = object_state.replace(
862
- biome_id=jnp.where(mask, i, object_state.biome_id)
1027
+ biome_id=jnp.where(
1028
+ mask, jnp.array(i, dtype=BIOME_ID_DTYPE), object_state.biome_id
1029
+ )
863
1030
  )
864
1031
 
865
1032
  # Use unified spawn method
@@ -881,31 +1048,45 @@ class ForagaxEnv(environment.Environment):
881
1048
 
882
1049
  # Initialize biome consumption tracking
883
1050
  num_biomes = self.biome_object_frequencies.shape[0]
884
- biome_consumption_count = jnp.zeros(num_biomes, dtype=int)
885
- biome_total_objects = jnp.zeros(num_biomes, dtype=int)
1051
+ biome_consumption_count = jnp.zeros(num_biomes, dtype=ID_DTYPE)
1052
+ biome_total_objects = jnp.zeros(num_biomes, dtype=ID_DTYPE)
886
1053
 
887
1054
  # Count objects in each biome
888
1055
  for i in range(num_biomes):
889
1056
  mask = self.biome_masks[i]
890
1057
  # Count non-empty objects (object_id > 0)
891
- total = jnp.sum((object_state.object_id > 0) & mask)
1058
+ total = jnp.sum((object_state.object_id > 0) & mask, dtype=ID_DTYPE)
892
1059
  biome_total_objects = biome_total_objects.at[i].set(total)
893
1060
 
894
- biome_generation = jnp.zeros(num_biomes, dtype=int)
1061
+ biome_generation = jnp.zeros(num_biomes, dtype=ID_DTYPE)
1062
+
1063
+ # Final state cleanup to ensure type consistency
1064
+ object_state = object_state.replace(
1065
+ object_id=object_state.object_id.astype(ID_DTYPE),
1066
+ respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
1067
+ respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
1068
+ spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
1069
+ color=object_state.color.astype(COLOR_DTYPE),
1070
+ generation=object_state.generation.astype(ID_DTYPE),
1071
+ state_params=object_state.state_params.astype(PARAM_DTYPE),
1072
+ biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
1073
+ )
895
1074
 
896
1075
  state = EnvState(
897
- pos=agent_pos,
898
- time=0,
899
- digestion_buffer=jnp.zeros((self.max_reward_delay,)),
1076
+ pos=agent_pos.astype(jnp.int32),
1077
+ time=jnp.array(0, dtype=jnp.int32),
1078
+ digestion_buffer=jnp.zeros((self.max_reward_delay,), dtype=REWARD_DTYPE),
900
1079
  object_state=object_state,
901
1080
  biome_state=BiomeState(
902
- consumption_count=biome_consumption_count,
903
- total_objects=biome_total_objects,
904
- generation=biome_generation,
1081
+ consumption_count=biome_consumption_count.astype(ID_DTYPE),
1082
+ total_objects=biome_total_objects.astype(ID_DTYPE),
1083
+ generation=biome_generation.astype(ID_DTYPE),
905
1084
  ),
906
1085
  )
907
1086
 
908
- return self.get_obs(state, params), state
1087
+ return jax.lax.stop_gradient(
1088
+ self.get_obs(state, params)
1089
+ ), jax.lax.stop_gradient(state)
909
1090
 
910
1091
  def _spawn_biome_objects(
911
1092
  self,
@@ -937,16 +1118,17 @@ class ForagaxEnv(environment.Environment):
937
1118
  biome_size = int(self.biome_sizes[biome_idx])
938
1119
 
939
1120
  grid = jnp.linspace(0, 1, biome_size, endpoint=False)
940
- biome_objects_flat = len(biome_freqs) - jnp.searchsorted(
941
- jnp.cumsum(biome_freqs[::-1]), grid, side="right"
942
- )
1121
+ biome_objects_flat = (
1122
+ len(biome_freqs)
1123
+ - jnp.searchsorted(jnp.cumsum(biome_freqs[::-1]), grid, side="right")
1124
+ ).astype(ID_DTYPE)
943
1125
  biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
944
1126
 
945
1127
  # Reshape to match biome dimensions (use concrete dimensions)
946
1128
  biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
947
1129
 
948
1130
  # Place in full grid using slicing with static bounds
949
- object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=int)
1131
+ object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=ID_DTYPE)
950
1132
  object_grid = object_grid.at[
951
1133
  biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
952
1134
  ].set(biome_objects)
@@ -960,10 +1142,10 @@ class ForagaxEnv(environment.Environment):
960
1142
  )
961
1143
  object_grid = (
962
1144
  jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
963
- )
1145
+ ).astype(ID_DTYPE)
964
1146
 
965
1147
  # Initialize color grid
966
- color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=jnp.uint8)
1148
+ color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=COLOR_DTYPE)
967
1149
 
968
1150
  # Sample ONE color per object type in this biome (not per instance)
969
1151
  # This gives objects of the same type the same color within a biome generation
@@ -975,9 +1157,9 @@ class ForagaxEnv(environment.Environment):
975
1157
  biome_object_colors = jax.random.randint(
976
1158
  color_key,
977
1159
  (num_actual_objects, 3),
978
- minval=0,
979
- maxval=256,
980
- dtype=jnp.uint8,
1160
+ minval=jnp.array(0, dtype=COLOR_DTYPE),
1161
+ maxval=jnp.array(256, dtype=jnp.int32), # randint range is often int32
1162
+ dtype=COLOR_DTYPE,
981
1163
  )
982
1164
 
983
1165
  # Assign colors based on object type (starting from index 1)
@@ -989,9 +1171,9 @@ class ForagaxEnv(environment.Environment):
989
1171
  color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
990
1172
 
991
1173
  # Initialize parameters grid
992
- num_object_params = 2 + 2 * self.num_fourier_terms
1174
+ num_object_params = 3 + 2 * self.num_fourier_terms
993
1175
  params_grid = jnp.zeros(
994
- (self.size[1], self.size[0], num_object_params), dtype=jnp.float32
1176
+ (self.size[1], self.size[0], num_object_params), dtype=PARAM_DTYPE
995
1177
  )
996
1178
 
997
1179
  # Generate per-object parameters for each object type
@@ -1018,7 +1200,9 @@ class ForagaxEnv(environment.Environment):
1018
1200
 
1019
1201
  # Assign to all objects of this type in this biome
1020
1202
  obj_mask = (object_grid == obj_idx) & biome_mask
1021
- params_grid = jnp.where(obj_mask[..., None], obj_params, params_grid)
1203
+ params_grid = jnp.where(
1204
+ obj_mask[..., None], obj_params.astype(PARAM_DTYPE), params_grid
1205
+ )
1022
1206
 
1023
1207
  return object_grid, color_grid, params_grid
1024
1208
 
@@ -1043,7 +1227,7 @@ class ForagaxEnv(environment.Environment):
1043
1227
 
1044
1228
  def state_space(self, params: EnvParams) -> spaces.Dict:
1045
1229
  """State space of the environment."""
1046
- num_object_params = 2 + 2 * self.num_fourier_terms
1230
+ num_object_params = 3 + 2 * self.num_fourier_terms
1047
1231
  return spaces.Dict(
1048
1232
  {
1049
1233
  "pos": spaces.Box(0, max(self.size), (2,), int),
@@ -1166,7 +1350,7 @@ class ForagaxEnv(environment.Environment):
1166
1350
  aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
1167
1351
  else:
1168
1352
  # Object ID grid: use PADDING index
1169
- padding_index = self.object_ids[-1]
1353
+ padding_index = self.object_ids[-1].astype(values.dtype)
1170
1354
  aperture = jnp.where(out_of_bounds, padding_index, values)
1171
1355
  else:
1172
1356
  aperture = values
@@ -1215,9 +1399,11 @@ class ForagaxEnv(environment.Environment):
1215
1399
  return jnp.zeros(obs_grid.shape + (0,), dtype=jnp.float32)
1216
1400
  # Map object IDs to color channel indices
1217
1401
  color_channels = jnp.where(
1218
- obs_grid == 0,
1219
- -1,
1220
- jnp.take(self.object_to_color_map, obs_grid - 1, axis=0),
1402
+ obs_grid == jnp.array(0, dtype=obs_grid.dtype),
1403
+ jnp.array(-1, dtype=jnp.int32),
1404
+ jnp.take(
1405
+ self.object_to_color_map, obs_grid.astype(jnp.int32) - 1, axis=0
1406
+ ),
1221
1407
  )
1222
1408
  obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
1223
1409
  return obs
@@ -1237,9 +1423,8 @@ class ForagaxEnv(environment.Environment):
1237
1423
  aperture_colors = color_aperture / 255.0
1238
1424
 
1239
1425
  # Mask empty cells (object_id == 0) to white
1240
- empty_mask = aperture == 0
1426
+ empty_mask = aperture == jnp.array(0, dtype=aperture.dtype)
1241
1427
  white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
1242
-
1243
1428
  obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
1244
1429
 
1245
1430
  return obs
@@ -1279,12 +1464,15 @@ class ForagaxEnv(environment.Environment):
1279
1464
 
1280
1465
  return spaces.Box(0, 1, obs_shape, float)
1281
1466
 
1282
- def _compute_reward_grid(self, state: EnvState) -> jax.Array:
1283
- """Compute rewards for all grid positions.
1467
+ def _compute_reward_grid(
1468
+ self, state: EnvState, object_id=None, state_params=None
1469
+ ) -> jax.Array:
1470
+ """Compute rewards for given positions. If no grid provided, uses full world."""
1471
+ if object_id is None:
1472
+ object_id = state.object_state.object_id
1473
+ if state_params is None:
1474
+ state_params = state.object_state.state_params
1284
1475
 
1285
- Returns:
1286
- Array of shape (H, W) with reward values for each cell
1287
- """
1288
1476
  fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
1289
1477
 
1290
1478
  def compute_reward(obj_id, params):
@@ -1296,9 +1484,7 @@ class ForagaxEnv(environment.Environment):
1296
1484
  lambda: 0.0,
1297
1485
  )
1298
1486
 
1299
- reward_grid = jax.vmap(jax.vmap(compute_reward))(
1300
- state.object_state.object_id, state.object_state.state_params
1301
- )
1487
+ reward_grid = jax.vmap(jax.vmap(compute_reward))(object_id, state_params)
1302
1488
  return reward_grid
1303
1489
 
1304
1490
  def _reward_to_color(self, reward: jax.Array) -> jax.Array:
@@ -1388,172 +1574,87 @@ class ForagaxEnv(environment.Environment):
1388
1574
 
1389
1575
  img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
1390
1576
 
1391
- if is_reward_mode:
1392
- # Scale image by 3 to create space for reward visualization
1393
- img = jax.image.resize(
1394
- img,
1395
- (self.size[1] * 3, self.size[0] * 3, 3),
1396
- jax.image.ResizeMethod.NEAREST,
1397
- )
1398
-
1399
- # Compute rewards for all cells
1400
- reward_grid = self._compute_reward_grid(state)
1401
-
1402
- # Convert rewards to colors
1403
- reward_colors = self._reward_to_color(reward_grid)
1404
-
1405
- # Resize reward colors to match 3x scale and place in middle cells
1406
- # We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
1407
- # Create index arrays for middle cells
1408
- i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
1409
- j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
1410
-
1411
- # Broadcast and set middle cells
1412
- img = img.at[i_indices, j_indices].set(reward_colors)
1413
-
1414
- # Tint the agent's aperture
1577
+ # Define constants for all world modes
1578
+ alpha = 0.2
1415
1579
  y_coords, x_coords, y_coords_adj, x_coords_adj = (
1416
1580
  self._compute_aperture_coordinates(state.pos)
1417
1581
  )
1418
1582
 
1419
- alpha = 0.2
1420
- agent_color = jnp.array(AGENT.color)
1421
-
1422
1583
  if is_reward_mode:
1423
- # For reward mode, we need to adjust coordinates for 3x scaled image
1424
- if self.nowrap:
1425
- # Create tint mask for 3x scaled image
1426
- tint_mask = jnp.zeros(
1427
- (self.size[1] * 3, self.size[0] * 3), dtype=bool
1428
- )
1429
-
1430
- # For each aperture cell, tint all 9 cells in its 3x3 block
1431
- # Create meshgrid to get all aperture cell coordinates
1432
- y_grid, x_grid = jnp.meshgrid(
1433
- y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
1434
- )
1435
- y_flat = y_grid.flatten()
1436
- x_flat = x_grid.flatten()
1437
-
1438
- # Create offset arrays for 3x3 blocks
1439
- offsets = jnp.array(
1440
- [[dy, dx] for dy in range(3) for dx in range(3)]
1441
- )
1442
-
1443
- # For each aperture cell, expand to 9 cells
1444
- # We need to repeat each cell coordinate 9 times, then add offsets
1445
- num_aperture_cells = y_flat.size
1446
- y_base = jnp.repeat(
1447
- y_flat * 3, 9
1448
- ) # Repeat each y coord 9 times and scale by 3
1449
- x_base = jnp.repeat(
1450
- x_flat * 3, 9
1451
- ) # Repeat each x coord 9 times and scale by 3
1452
- y_offsets = jnp.tile(
1453
- offsets[:, 0], num_aperture_cells
1454
- ) # Tile all 9 offsets
1455
- x_offsets = jnp.tile(
1456
- offsets[:, 1], num_aperture_cells
1457
- ) # Tile all 9 offsets
1458
- y_expanded = y_base + y_offsets
1459
- x_expanded = x_base + x_offsets
1460
-
1461
- tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
1462
-
1463
- original_colors = img
1464
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1465
- img = jnp.where(tint_mask[..., None], tinted_colors, img)
1466
- else:
1467
- # Tint all 9 cells in each 3x3 block for aperture cells
1468
- # Create meshgrid to get all aperture cell coordinates
1469
- y_grid, x_grid = jnp.meshgrid(
1470
- y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
1471
- )
1472
- y_flat = y_grid.flatten()
1473
- x_flat = x_grid.flatten()
1584
+ # Construct 3x intermediate image
1585
+ # Each cell is 3x3, with reward color in center
1586
+ reward_grid = self._compute_reward_grid(state)
1587
+ reward_colors = self._reward_to_color(reward_grid)
1474
1588
 
1475
- # Create offset arrays for 3x3 blocks
1476
- offsets = jnp.array(
1477
- [[dy, dx] for dy in range(3) for dx in range(3)]
1478
- )
1589
+ # Each cell has its base color in 8 pixels and reward color in 1 (center)
1590
+ # Create a 3x3 pattern mask for center pixels
1591
+ cell_mask = jnp.array(
1592
+ [[False, False, False], [False, True, False], [False, False, False]]
1593
+ )
1594
+ grid_reward_mask = jnp.tile(cell_mask, (self.size[1], self.size[0]))
1479
1595
 
1480
- # For each aperture cell, expand to 9 cells
1481
- # We need to repeat each cell coordinate 9 times, then add offsets
1482
- num_aperture_cells = y_flat.size
1483
- y_base = jnp.repeat(
1484
- y_flat * 3, 9
1485
- ) # Repeat each y coord 9 times and scale by 3
1486
- x_base = jnp.repeat(
1487
- x_flat * 3, 9
1488
- ) # Repeat each x coord 9 times and scale by 3
1489
- y_offsets = jnp.tile(
1490
- offsets[:, 0], num_aperture_cells
1491
- ) # Tile all 9 offsets
1492
- x_offsets = jnp.tile(
1493
- offsets[:, 1], num_aperture_cells
1494
- ) # Tile all 9 offsets
1495
- y_expanded = y_base + y_offsets
1496
- x_expanded = x_base + x_offsets
1497
-
1498
- # Get original colors and tint them
1499
- original_colors = img[y_expanded, x_expanded]
1500
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1501
- img = img.at[y_expanded, x_expanded].set(tinted_colors)
1502
-
1503
- # Agent color - set all 9 cells of the agent's 3x3 block
1504
- agent_y, agent_x = state.pos[1], state.pos[0]
1505
- agent_offsets = jnp.array(
1506
- [[dy, dx] for dy in range(3) for dx in range(3)]
1596
+ # Repeat base colors and rewards to 3x3
1597
+ base_img_x3 = jnp.repeat(jnp.repeat(img, 3, axis=0), 3, axis=1)
1598
+ reward_colors_x3 = jnp.repeat(
1599
+ jnp.repeat(reward_colors, 3, axis=0), 3, axis=1
1507
1600
  )
1508
- agent_y_cells = agent_y * 3 + agent_offsets[:, 0]
1509
- agent_x_cells = agent_x * 3 + agent_offsets[:, 1]
1510
- img = img.at[agent_y_cells, agent_x_cells].set(
1511
- jnp.array(AGENT.color, dtype=jnp.uint8)
1601
+
1602
+ # Composite base and reward colors
1603
+ img = jnp.where(
1604
+ grid_reward_mask[..., None], reward_colors_x3, base_img_x3
1512
1605
  )
1513
1606
 
1514
- # Scale by 8 to final size
1515
- img = jax.image.resize(
1516
- img,
1517
- (self.size[1] * 24, self.size[0] * 24, 3),
1518
- jax.image.ResizeMethod.NEAREST,
1607
+ # Tint the aperture region at 3x scale
1608
+ aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
1609
+ aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
1610
+ aperture_mask_x3 = jnp.repeat(
1611
+ jnp.repeat(aperture_mask, 3, axis=0), 3, axis=1
1519
1612
  )
1613
+
1614
+ tinted_img = (
1615
+ (1.0 - alpha) * img.astype(jnp.float32)
1616
+ + alpha * self.agent_color_jax.astype(jnp.float32)
1617
+ ).astype(jnp.uint8)
1618
+ img = jnp.where(aperture_mask_x3[..., None], tinted_img, img)
1619
+
1620
+ # Set agent center block
1621
+ agent_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
1622
+ agent_mask = agent_mask.at[state.pos[1], state.pos[0]].set(True)
1623
+ agent_mask_x3 = jnp.repeat(jnp.repeat(agent_mask, 3, axis=0), 3, axis=1)
1624
+ img = jnp.where(agent_mask_x3[..., None], self.agent_color_jax, img)
1625
+
1626
+ # Final scale by 8 to get 24x
1627
+ img = jnp.repeat(jnp.repeat(img, 8, axis=0), 8, axis=1)
1520
1628
  else:
1521
1629
  # Standard rendering without reward visualization
1522
- if self.nowrap:
1523
- # Create tint mask: any in-bounds original position maps to a cell makes it tinted
1524
- tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
1525
- tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
1526
- # Apply tint to masked positions
1527
- original_colors = img
1528
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1529
- img = jnp.where(tint_mask[..., None], tinted_colors, img)
1530
- else:
1531
- original_colors = img[y_coords_adj, x_coords_adj]
1532
- tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
1533
- img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
1630
+ aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
1631
+ aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
1534
1632
 
1535
- # Agent color
1536
- img = img.at[state.pos[1], state.pos[0]].set(jnp.array(AGENT.color))
1633
+ tinted_img = (
1634
+ (1.0 - alpha) * img.astype(jnp.float32)
1635
+ + alpha * self.agent_color_jax.astype(jnp.float32)
1636
+ ).astype(jnp.uint8)
1637
+ img = jnp.where(aperture_mask[..., None], tinted_img, img)
1537
1638
 
1538
- img = jax.image.resize(
1539
- img,
1540
- (self.size[1] * 24, self.size[0] * 24, 3),
1541
- jax.image.ResizeMethod.NEAREST,
1542
- )
1639
+ # Set agent
1640
+ img = img.at[state.pos[1], state.pos[0]].set(self.agent_color_jax)
1641
+ # Scale by 24
1642
+ img = jnp.repeat(jnp.repeat(img, 24, axis=0), 24, axis=1)
1543
1643
 
1544
1644
  if is_true_mode:
1545
- # Apply true object borders by overlaying true colors on border pixels
1546
- render_grid = state.object_state.object_id
1645
+ # Apply true object borders
1547
1646
  img = apply_true_borders(
1548
- img, render_grid, self.size, len(self.object_ids)
1647
+ img, state.object_state.object_id, self.size, len(self.object_ids)
1549
1648
  )
1550
1649
 
1551
- # Add grid lines for world mode
1552
- grid_color = jnp.zeros(3, dtype=jnp.uint8)
1553
- row_indices = jnp.arange(1, self.size[1]) * 24
1554
- col_indices = jnp.arange(1, self.size[0]) * 24
1555
- img = img.at[row_indices, :].set(grid_color)
1556
- img = img.at[:, col_indices].set(grid_color)
1650
+ # Add grid lines using masking instead of slice-setting
1651
+ row_grid = (jnp.arange(self.size[1] * 24) % 24) == 0
1652
+ col_grid = (jnp.arange(self.size[0] * 24) % 24) == 0
1653
+ # skip first rows/cols as they are borders or managed by caller
1654
+ row_grid = row_grid.at[0].set(False)
1655
+ col_grid = col_grid.at[0].set(False)
1656
+ grid_mask = row_grid[:, None] | col_grid[None, :]
1657
+ img = jnp.where(grid_mask[..., None], self.grid_color_jax, img)
1557
1658
 
1558
1659
  elif is_aperture_mode:
1559
1660
  obs_grid = state.object_state.object_id
@@ -1600,11 +1701,16 @@ class ForagaxEnv(environment.Environment):
1600
1701
  self._compute_aperture_coordinates(state.pos)
1601
1702
  )
1602
1703
 
1603
- # Get reward grid for the full world
1604
- full_reward_grid = self._compute_reward_grid(state)
1605
-
1606
- # Extract aperture rewards
1607
- aperture_rewards = full_reward_grid[y_coords_adj, x_coords_adj]
1704
+ # Get reward grid only for aperture region
1705
+ aperture_object_ids = state.object_state.object_id[
1706
+ y_coords_adj, x_coords_adj
1707
+ ]
1708
+ aperture_params = state.object_state.state_params[
1709
+ y_coords_adj, x_coords_adj
1710
+ ]
1711
+ aperture_rewards = self._compute_reward_grid(
1712
+ state, aperture_object_ids, aperture_params
1713
+ )
1608
1714
 
1609
1715
  # Convert rewards to colors
1610
1716
  reward_colors = self._reward_to_color(aperture_rewards)