continual-foragax 0.40.0__py3-none-any.whl → 0.41.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/METADATA +1 -1
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/RECORD +8 -8
- foragax/env.py +514 -389
- foragax/objects.py +10 -7
- foragax/rendering.py +9 -33
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/WHEEL +0 -0
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/entry_points.txt +0 -0
- {continual_foragax-0.40.0.dist-info → continual_foragax-0.41.0.dist-info}/top_level.txt +0 -0
foragax/env.py
CHANGED
|
@@ -26,6 +26,15 @@ from foragax.rendering import apply_true_borders
|
|
|
26
26
|
from foragax.weather import get_temperature
|
|
27
27
|
|
|
28
28
|
|
|
29
|
+
ID_DTYPE = jnp.int32 # Object type ID (0 = empty, >0 = object type)
|
|
30
|
+
TIMER_DTYPE = jnp.int32 # Respawn countdown (0 = no timer, >0 = countdown)
|
|
31
|
+
TIME_DTYPE = jnp.int32 # Timesteps (spawn time, current time)
|
|
32
|
+
PARAM_DTYPE = jnp.float16 # Per-instance object parameters
|
|
33
|
+
COLOR_DTYPE = jnp.uint8 # RGB color channels (0-255)
|
|
34
|
+
BIOME_ID_DTYPE = jnp.int16 # Biome assignment for each cell
|
|
35
|
+
REWARD_DTYPE = jnp.float32 # Reward values
|
|
36
|
+
|
|
37
|
+
|
|
29
38
|
class Actions(IntEnum):
|
|
30
39
|
DOWN = 0
|
|
31
40
|
RIGHT = 1
|
|
@@ -74,14 +83,14 @@ class ObjectState:
|
|
|
74
83
|
"""Create an empty ObjectState for the given grid size."""
|
|
75
84
|
h, w = size[1], size[0]
|
|
76
85
|
return cls(
|
|
77
|
-
object_id=jnp.zeros((h, w), dtype=
|
|
78
|
-
respawn_timer=jnp.zeros((h, w), dtype=
|
|
79
|
-
respawn_object_id=jnp.zeros((h, w), dtype=
|
|
80
|
-
spawn_time=jnp.zeros((h, w), dtype=
|
|
81
|
-
color=jnp.full((h, w, 3), 255, dtype=
|
|
82
|
-
generation=jnp.zeros((h, w), dtype=
|
|
83
|
-
state_params=jnp.zeros((h, w, num_params), dtype=
|
|
84
|
-
biome_id=jnp.full((h, w), -1, dtype=
|
|
86
|
+
object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
|
|
87
|
+
respawn_timer=jnp.zeros((h, w), dtype=TIMER_DTYPE),
|
|
88
|
+
respawn_object_id=jnp.zeros((h, w), dtype=ID_DTYPE),
|
|
89
|
+
spawn_time=jnp.zeros((h, w), dtype=TIME_DTYPE),
|
|
90
|
+
color=jnp.full((h, w, 3), 255, dtype=COLOR_DTYPE),
|
|
91
|
+
generation=jnp.zeros((h, w), dtype=ID_DTYPE),
|
|
92
|
+
state_params=jnp.zeros((h, w, num_params), dtype=PARAM_DTYPE),
|
|
93
|
+
biome_id=jnp.full((h, w), -1, dtype=BIOME_ID_DTYPE),
|
|
85
94
|
)
|
|
86
95
|
|
|
87
96
|
|
|
@@ -228,13 +237,13 @@ class ForagaxEnv(environment.Environment):
|
|
|
228
237
|
# Create mask for the biome
|
|
229
238
|
start = jax.lax.select(
|
|
230
239
|
self.biome_starts[i, 0] == -1,
|
|
231
|
-
jnp.array([0, 0]),
|
|
232
|
-
self.biome_starts[i],
|
|
240
|
+
jnp.array([0, 0], dtype=jnp.int32),
|
|
241
|
+
self.biome_starts[i].astype(jnp.int32),
|
|
233
242
|
)
|
|
234
243
|
stop = jax.lax.select(
|
|
235
244
|
self.biome_stops[i, 0] == -1,
|
|
236
|
-
jnp.array(self.size),
|
|
237
|
-
self.biome_stops[i],
|
|
245
|
+
jnp.array(self.size, dtype=jnp.int32),
|
|
246
|
+
self.biome_stops[i].astype(jnp.int32),
|
|
238
247
|
)
|
|
239
248
|
rows = jnp.arange(self.size[1])[:, None]
|
|
240
249
|
cols = jnp.arange(self.size[0])
|
|
@@ -273,12 +282,18 @@ class ForagaxEnv(environment.Environment):
|
|
|
273
282
|
# color_indices maps from object_id-1 to color_channel_index
|
|
274
283
|
self.object_to_color_map = color_indices
|
|
275
284
|
|
|
285
|
+
# Rendering constants
|
|
286
|
+
self.agent_color_jax = jnp.array(AGENT.color, dtype=jnp.uint8)
|
|
287
|
+
self.white_color_jax = jnp.array([255, 255, 255], dtype=jnp.uint8)
|
|
288
|
+
self.grid_color_jax = jnp.zeros(3, dtype=jnp.uint8)
|
|
289
|
+
|
|
276
290
|
@property
|
|
277
291
|
def default_params(self) -> EnvParams:
|
|
278
292
|
return EnvParams(
|
|
279
293
|
max_steps_in_episode=None,
|
|
280
294
|
)
|
|
281
295
|
|
|
296
|
+
@partial(jax.named_call, name="_place_timer")
|
|
282
297
|
def _place_timer(
|
|
283
298
|
self,
|
|
284
299
|
object_state: ObjectState,
|
|
@@ -302,13 +317,24 @@ class ForagaxEnv(environment.Environment):
|
|
|
302
317
|
Returns:
|
|
303
318
|
Updated object_state with timer placed
|
|
304
319
|
"""
|
|
320
|
+
# Ensure inputs match ObjectState dtypes
|
|
321
|
+
object_type = jnp.array(object_type, dtype=ID_DTYPE)
|
|
322
|
+
timer_val = jnp.array(timer_val, dtype=TIMER_DTYPE)
|
|
323
|
+
y = jnp.array(y, dtype=jnp.int32)
|
|
324
|
+
x = jnp.array(x, dtype=jnp.int32)
|
|
305
325
|
|
|
306
326
|
# Handle permanent removal (timer_val == 0)
|
|
307
327
|
def place_empty():
|
|
308
328
|
return object_state.replace(
|
|
309
|
-
object_id=object_state.object_id.at[y, x].set(
|
|
310
|
-
|
|
311
|
-
|
|
329
|
+
object_id=object_state.object_id.at[y, x].set(
|
|
330
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
331
|
+
),
|
|
332
|
+
respawn_timer=object_state.respawn_timer.at[y, x].set(
|
|
333
|
+
jnp.array(0, dtype=TIMER_DTYPE)
|
|
334
|
+
),
|
|
335
|
+
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
336
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
337
|
+
),
|
|
312
338
|
)
|
|
313
339
|
|
|
314
340
|
# Handle timer placement
|
|
@@ -316,7 +342,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
316
342
|
# Non-random: place at original position
|
|
317
343
|
def place_at_position():
|
|
318
344
|
return object_state.replace(
|
|
319
|
-
object_id=object_state.object_id.at[y, x].set(
|
|
345
|
+
object_id=object_state.object_id.at[y, x].set(
|
|
346
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
347
|
+
),
|
|
320
348
|
respawn_timer=object_state.respawn_timer.at[y, x].set(timer_val),
|
|
321
349
|
respawn_object_id=object_state.respawn_object_id.at[y, x].set(
|
|
322
350
|
object_type
|
|
@@ -326,9 +354,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
326
354
|
# Random: place at random location in same biome
|
|
327
355
|
def place_randomly():
|
|
328
356
|
# Clear the collected object's position
|
|
329
|
-
new_object_id = object_state.object_id.at[y, x].set(
|
|
330
|
-
|
|
331
|
-
|
|
357
|
+
new_object_id = object_state.object_id.at[y, x].set(
|
|
358
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
359
|
+
)
|
|
360
|
+
new_respawn_timer = object_state.respawn_timer.at[y, x].set(
|
|
361
|
+
jnp.array(0, dtype=TIMER_DTYPE)
|
|
362
|
+
)
|
|
363
|
+
new_respawn_object_id = object_state.respawn_object_id.at[y, x].set(
|
|
364
|
+
jnp.array(0, dtype=ID_DTYPE)
|
|
365
|
+
)
|
|
332
366
|
|
|
333
367
|
# Find valid spawn locations in the same biome
|
|
334
368
|
biome_id = object_state.biome_id[y, x]
|
|
@@ -336,13 +370,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
336
370
|
empty_mask = new_object_id == 0
|
|
337
371
|
no_timer_mask = new_respawn_timer == 0
|
|
338
372
|
valid_spawn_mask = biome_mask & empty_mask & no_timer_mask
|
|
339
|
-
num_valid_spawns = jnp.sum(valid_spawn_mask)
|
|
373
|
+
num_valid_spawns = jnp.sum(valid_spawn_mask, dtype=jnp.int32)
|
|
340
374
|
|
|
341
375
|
y_indices, x_indices = jnp.nonzero(
|
|
342
376
|
valid_spawn_mask, size=self.size[0] * self.size[1], fill_value=-1
|
|
343
377
|
)
|
|
344
378
|
valid_spawn_indices = jnp.stack([y_indices, x_indices], axis=1)
|
|
345
|
-
random_idx = jax.random.randint(
|
|
379
|
+
random_idx = jax.random.randint(
|
|
380
|
+
rand_key, (), jnp.array(0, dtype=jnp.int32), num_valid_spawns
|
|
381
|
+
)
|
|
346
382
|
new_spawn_pos = valid_spawn_indices[random_idx]
|
|
347
383
|
|
|
348
384
|
# Place timer at the new random position
|
|
@@ -363,17 +399,11 @@ class ForagaxEnv(environment.Environment):
|
|
|
363
399
|
|
|
364
400
|
return jax.lax.cond(timer_val == 0, place_empty, place_timer)
|
|
365
401
|
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
action: Union[int, float, jax.Array],
|
|
371
|
-
params: EnvParams,
|
|
372
|
-
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
373
|
-
"""Perform single timestep state transition."""
|
|
402
|
+
@partial(jax.named_call, name="move_agent")
|
|
403
|
+
def _move_agent(
|
|
404
|
+
self, state: EnvState, action: Union[int, float, jax.Array]
|
|
405
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
374
406
|
current_objects = state.object_state.object_id
|
|
375
|
-
|
|
376
|
-
# 1. UPDATE AGENT POSITION
|
|
377
407
|
direction = DIRECTIONS[action]
|
|
378
408
|
new_pos = state.pos + direction
|
|
379
409
|
|
|
@@ -386,28 +416,44 @@ class ForagaxEnv(environment.Environment):
|
|
|
386
416
|
|
|
387
417
|
# Check for blocking objects
|
|
388
418
|
obj_at_new_pos = current_objects[new_pos[1], new_pos[0]]
|
|
389
|
-
is_blocking = self.object_blocking[obj_at_new_pos]
|
|
390
|
-
pos =
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
419
|
+
is_blocking = self.object_blocking[obj_at_new_pos.astype(jnp.int32)]
|
|
420
|
+
pos = jnp.where(
|
|
421
|
+
is_blocking[..., None],
|
|
422
|
+
state.pos.astype(jnp.int32),
|
|
423
|
+
new_pos.astype(jnp.int32),
|
|
424
|
+
)
|
|
425
|
+
return pos, new_pos
|
|
396
426
|
|
|
397
|
-
|
|
398
|
-
|
|
427
|
+
@partial(jax.named_call, name="compute_reward")
|
|
428
|
+
def _compute_reward(
|
|
429
|
+
self,
|
|
430
|
+
state: EnvState,
|
|
431
|
+
pos: jax.Array,
|
|
432
|
+
key: jax.Array,
|
|
433
|
+
should_collect: jax.Array,
|
|
434
|
+
digestion_buffer: jax.Array,
|
|
435
|
+
obj_at_pos: jax.Array,
|
|
436
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
399
437
|
key, reward_subkey = jax.random.split(key)
|
|
400
438
|
|
|
401
439
|
object_params = state.object_state.state_params[pos[1], pos[0]]
|
|
402
440
|
object_reward = jax.lax.switch(
|
|
403
|
-
obj_at_pos
|
|
441
|
+
obj_at_pos.astype(jnp.int32),
|
|
442
|
+
self.reward_fns,
|
|
443
|
+
state.time,
|
|
444
|
+
reward_subkey,
|
|
445
|
+
object_params.astype(jnp.float32),
|
|
404
446
|
)
|
|
405
447
|
|
|
406
448
|
key, digestion_subkey = jax.random.split(key)
|
|
407
449
|
reward_delay = jax.lax.switch(
|
|
408
450
|
obj_at_pos, self.reward_delay_fns, state.time, digestion_subkey
|
|
409
451
|
)
|
|
410
|
-
reward = jnp.where(
|
|
452
|
+
reward = jnp.where(
|
|
453
|
+
should_collect & (reward_delay == jnp.array(0, dtype=jnp.int32)),
|
|
454
|
+
object_reward,
|
|
455
|
+
0.0,
|
|
456
|
+
)
|
|
411
457
|
if self.max_reward_delay > 0:
|
|
412
458
|
# Add delayed rewards to buffer
|
|
413
459
|
digestion_buffer = jax.lax.cond(
|
|
@@ -421,20 +467,33 @@ class ForagaxEnv(environment.Environment):
|
|
|
421
467
|
current_index = state.time % self.max_reward_delay
|
|
422
468
|
reward += digestion_buffer[current_index]
|
|
423
469
|
digestion_buffer = digestion_buffer.at[current_index].set(0.0)
|
|
470
|
+
return reward, digestion_buffer
|
|
424
471
|
|
|
472
|
+
@partial(jax.named_call, name="respawn_logic")
|
|
473
|
+
def _respawn_logic(
|
|
474
|
+
self,
|
|
475
|
+
state: EnvState,
|
|
476
|
+
pos: jax.Array,
|
|
477
|
+
key: jax.Array,
|
|
478
|
+
current_objects: jax.Array,
|
|
479
|
+
is_collectable: jax.Array,
|
|
480
|
+
obj_at_pos: jax.Array,
|
|
481
|
+
) -> ObjectState:
|
|
425
482
|
# 3. HANDLE OBJECT COLLECTION AND RESPAWNING
|
|
426
483
|
key, regen_subkey, rand_key = jax.random.split(key, 3)
|
|
427
484
|
|
|
428
485
|
# Decrement respawn timers
|
|
429
|
-
has_timer = state.object_state.respawn_timer > 0
|
|
486
|
+
has_timer = state.object_state.respawn_timer > jnp.array(0, dtype=TIMER_DTYPE)
|
|
430
487
|
new_respawn_timer = jnp.where(
|
|
431
488
|
has_timer,
|
|
432
|
-
state.object_state.respawn_timer - 1,
|
|
489
|
+
state.object_state.respawn_timer - jnp.array(1, dtype=TIMER_DTYPE),
|
|
433
490
|
state.object_state.respawn_timer,
|
|
434
491
|
)
|
|
435
492
|
|
|
436
493
|
# Track which cells have timers that just reached 0
|
|
437
|
-
just_respawned = has_timer & (
|
|
494
|
+
just_respawned = has_timer & (
|
|
495
|
+
new_respawn_timer == jnp.array(0, dtype=TIMER_DTYPE)
|
|
496
|
+
)
|
|
438
497
|
|
|
439
498
|
# Respawn objects where timer reached 0
|
|
440
499
|
new_object_id = jnp.where(
|
|
@@ -446,7 +505,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
446
505
|
# Clear respawn_object_id for cells that just respawned
|
|
447
506
|
new_respawn_object_id = jnp.where(
|
|
448
507
|
just_respawned,
|
|
449
|
-
0,
|
|
508
|
+
jnp.array(0, dtype=ID_DTYPE),
|
|
450
509
|
state.object_state.respawn_object_id,
|
|
451
510
|
)
|
|
452
511
|
|
|
@@ -459,16 +518,19 @@ class ForagaxEnv(environment.Environment):
|
|
|
459
518
|
regen_delay = jax.lax.switch(
|
|
460
519
|
obj_at_pos, self.regen_delay_fns, state.time, regen_subkey
|
|
461
520
|
)
|
|
521
|
+
# Cast timer_countdown to match ObjectState.respawn_timer dtype
|
|
462
522
|
timer_countdown = jax.lax.cond(
|
|
463
523
|
regen_delay == jnp.iinfo(jnp.int32).max,
|
|
464
|
-
lambda: 0, # No timer (permanent removal)
|
|
465
|
-
lambda: regen_delay + 1, # Timer countdown
|
|
524
|
+
lambda: jnp.array(0, dtype=TIMER_DTYPE), # No timer (permanent removal)
|
|
525
|
+
lambda: (regen_delay + 1).astype(TIMER_DTYPE), # Timer countdown
|
|
466
526
|
)
|
|
467
527
|
|
|
468
528
|
# If collected, replace object with timer; otherwise, keep it
|
|
469
529
|
val_at_pos = current_objects[pos[1], pos[0]]
|
|
470
530
|
# Use original should_collect for consumption tracking
|
|
471
|
-
should_collect_now = is_collectable & (
|
|
531
|
+
should_collect_now = is_collectable & (
|
|
532
|
+
val_at_pos > jnp.array(0, dtype=ID_DTYPE)
|
|
533
|
+
)
|
|
472
534
|
|
|
473
535
|
# Create updated object state with new respawn_timer, object_id, and spawn_time
|
|
474
536
|
object_state = state.object_state.replace(
|
|
@@ -492,12 +554,17 @@ class ForagaxEnv(environment.Environment):
|
|
|
492
554
|
),
|
|
493
555
|
lambda: object_state,
|
|
494
556
|
)
|
|
557
|
+
return object_state
|
|
495
558
|
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
559
|
+
@partial(jax.named_call, name="dynamic_biomes")
|
|
560
|
+
def _dynamic_biomes(
|
|
561
|
+
self,
|
|
562
|
+
state: EnvState,
|
|
563
|
+
pos: jax.Array,
|
|
564
|
+
key: jax.Array,
|
|
565
|
+
object_state: ObjectState,
|
|
566
|
+
should_collect: jax.Array,
|
|
567
|
+
) -> Tuple[ObjectState, BiomeState]:
|
|
501
568
|
if self.dynamic_biomes:
|
|
502
569
|
# Update consumption count if an object was collected
|
|
503
570
|
# Only count if the object belongs to the current generation of its biome
|
|
@@ -509,7 +576,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
509
576
|
biome_consumption_count = state.biome_state.consumption_count
|
|
510
577
|
biome_consumption_count = jax.lax.cond(
|
|
511
578
|
should_collect & is_current_generation,
|
|
512
|
-
lambda: biome_consumption_count.at[collected_biome_id].add(
|
|
579
|
+
lambda: biome_consumption_count.at[collected_biome_id].add(
|
|
580
|
+
jnp.array(1, dtype=ID_DTYPE)
|
|
581
|
+
),
|
|
513
582
|
lambda: biome_consumption_count,
|
|
514
583
|
)
|
|
515
584
|
|
|
@@ -526,8 +595,73 @@ class ForagaxEnv(environment.Environment):
|
|
|
526
595
|
state.time,
|
|
527
596
|
respawn_key,
|
|
528
597
|
)
|
|
598
|
+
return object_state, biome_state
|
|
529
599
|
else:
|
|
530
|
-
|
|
600
|
+
return object_state, state.biome_state
|
|
601
|
+
|
|
602
|
+
@partial(jax.named_call, name="reward_grid")
|
|
603
|
+
def _reward_grid(self, state: EnvState, object_state: ObjectState) -> jax.Array:
|
|
604
|
+
# Compute reward at each grid position
|
|
605
|
+
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
606
|
+
|
|
607
|
+
def compute_reward(obj_id, params):
|
|
608
|
+
return jax.lax.cond(
|
|
609
|
+
obj_id > jnp.array(0, dtype=ID_DTYPE),
|
|
610
|
+
lambda: jax.lax.switch(
|
|
611
|
+
obj_id.astype(jnp.int32),
|
|
612
|
+
self.reward_fns,
|
|
613
|
+
state.time,
|
|
614
|
+
fixed_key,
|
|
615
|
+
params.astype(jnp.float32),
|
|
616
|
+
),
|
|
617
|
+
lambda: 0.0,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
621
|
+
object_state.object_id.astype(ID_DTYPE),
|
|
622
|
+
object_state.state_params.astype(PARAM_DTYPE),
|
|
623
|
+
)
|
|
624
|
+
return reward_grid
|
|
625
|
+
|
|
626
|
+
def step_env(
|
|
627
|
+
self,
|
|
628
|
+
key: jax.Array,
|
|
629
|
+
state: EnvState,
|
|
630
|
+
action: Union[int, float, jax.Array],
|
|
631
|
+
params: EnvParams,
|
|
632
|
+
) -> Tuple[jax.Array, EnvState, jax.Array, jax.Array, Dict[Any, Any]]:
|
|
633
|
+
"""Perform single timestep state transition."""
|
|
634
|
+
current_objects = state.object_state.object_id
|
|
635
|
+
pos, new_pos = self._move_agent(state, action)
|
|
636
|
+
|
|
637
|
+
with jax.named_scope("compute_reward"):
|
|
638
|
+
# 2. HANDLE COLLISIONS AND REWARDS
|
|
639
|
+
obj_at_pos = current_objects[pos[1], pos[0]]
|
|
640
|
+
is_collectable = self.object_collectable[obj_at_pos]
|
|
641
|
+
should_collect = is_collectable & (
|
|
642
|
+
obj_at_pos > jnp.array(0, dtype=ID_DTYPE)
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
# Handle digestion: add reward to buffer if collected
|
|
646
|
+
digestion_buffer = state.digestion_buffer
|
|
647
|
+
key, reward_subkey = jax.random.split(key)
|
|
648
|
+
|
|
649
|
+
reward, digestion_buffer = self._compute_reward(
|
|
650
|
+
state, pos, key, should_collect, digestion_buffer, obj_at_pos
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
object_state = self._respawn_logic(
|
|
654
|
+
state, pos, key, current_objects, is_collectable, obj_at_pos
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
# 3.5. HANDLE OBJECT EXPIRY
|
|
658
|
+
# Only process expiry if there are objects that can expire
|
|
659
|
+
key, object_state = self.expire_objects(key, state, object_state)
|
|
660
|
+
|
|
661
|
+
# 3.6. HANDLE DYNAMIC BIOME CONSUMPTION AND RESPAWNING
|
|
662
|
+
object_state, biome_state = self._dynamic_biomes(
|
|
663
|
+
state, pos, key, object_state, should_collect
|
|
664
|
+
)
|
|
531
665
|
|
|
532
666
|
info = {"discount": self.discount(state, params)}
|
|
533
667
|
temperatures = jnp.zeros(len(self.objects))
|
|
@@ -538,33 +672,40 @@ class ForagaxEnv(environment.Environment):
|
|
|
538
672
|
)
|
|
539
673
|
info["temperatures"] = temperatures
|
|
540
674
|
info["biome_id"] = object_state.biome_id[pos[1], pos[0]]
|
|
541
|
-
info["object_collected_id"] =
|
|
675
|
+
info["object_collected_id"] = jnp.where(
|
|
676
|
+
should_collect,
|
|
677
|
+
obj_at_pos.astype(ID_DTYPE),
|
|
678
|
+
jnp.array(-1, dtype=ID_DTYPE),
|
|
679
|
+
)
|
|
542
680
|
|
|
543
681
|
# 4. UPDATE STATE
|
|
682
|
+
# Ensure all fields have canonical dtypes for consistency (e.g., for gymnax step selection)
|
|
683
|
+
object_state = object_state.replace(
|
|
684
|
+
object_id=object_state.object_id.astype(ID_DTYPE),
|
|
685
|
+
respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
|
|
686
|
+
respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
|
|
687
|
+
spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
|
|
688
|
+
color=object_state.color.astype(COLOR_DTYPE),
|
|
689
|
+
generation=object_state.generation.astype(ID_DTYPE),
|
|
690
|
+
state_params=object_state.state_params.astype(PARAM_DTYPE),
|
|
691
|
+
biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
|
|
692
|
+
)
|
|
693
|
+
biome_state = biome_state.replace(
|
|
694
|
+
consumption_count=biome_state.consumption_count.astype(ID_DTYPE),
|
|
695
|
+
total_objects=biome_state.total_objects.astype(ID_DTYPE),
|
|
696
|
+
generation=biome_state.generation.astype(ID_DTYPE),
|
|
697
|
+
)
|
|
698
|
+
|
|
544
699
|
state = EnvState(
|
|
545
|
-
pos=pos,
|
|
546
|
-
time=state.time + 1,
|
|
547
|
-
digestion_buffer=digestion_buffer,
|
|
700
|
+
pos=pos.astype(jnp.int32),
|
|
701
|
+
time=jnp.array(state.time + 1, dtype=jnp.int32),
|
|
702
|
+
digestion_buffer=digestion_buffer.astype(REWARD_DTYPE),
|
|
548
703
|
object_state=object_state,
|
|
549
704
|
biome_state=biome_state,
|
|
550
705
|
)
|
|
551
706
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
def compute_reward(obj_id, params):
|
|
556
|
-
return jax.lax.cond(
|
|
557
|
-
obj_id > 0,
|
|
558
|
-
lambda: jax.lax.switch(
|
|
559
|
-
obj_id, self.reward_fns, state.time, fixed_key, params
|
|
560
|
-
),
|
|
561
|
-
lambda: 0.0,
|
|
562
|
-
)
|
|
563
|
-
|
|
564
|
-
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
565
|
-
object_state.object_id, object_state.state_params
|
|
566
|
-
)
|
|
567
|
-
info["rewards"] = reward_grid
|
|
707
|
+
reward_grid = self._reward_grid(state, object_state)
|
|
708
|
+
info["rewards"] = reward_grid.astype(jnp.float16)
|
|
568
709
|
|
|
569
710
|
done = self.is_terminal(state, params)
|
|
570
711
|
return (
|
|
@@ -575,6 +716,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
575
716
|
info,
|
|
576
717
|
)
|
|
577
718
|
|
|
719
|
+
@partial(jax.named_call, name="expire_objects")
|
|
578
720
|
def expire_objects(
|
|
579
721
|
self, key, state, object_state: ObjectState
|
|
580
722
|
) -> Tuple[jax.Array, ObjectState]:
|
|
@@ -591,8 +733,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
591
733
|
# Check if object should expire (age >= expiry_time and expiry_time >= 0)
|
|
592
734
|
should_expire = (
|
|
593
735
|
(object_ages >= expiry_times)
|
|
594
|
-
& (expiry_times >= 0)
|
|
595
|
-
& (current_objects_for_expiry > 0)
|
|
736
|
+
& (expiry_times >= jnp.array(0, dtype=TIME_DTYPE))
|
|
737
|
+
& (current_objects_for_expiry > jnp.array(0, dtype=ID_DTYPE))
|
|
596
738
|
)
|
|
597
739
|
|
|
598
740
|
# Only process expiry if there are actually objects to expire
|
|
@@ -627,8 +769,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
627
769
|
)
|
|
628
770
|
timer_countdown = jax.lax.cond(
|
|
629
771
|
exp_delay == jnp.iinfo(jnp.int32).max,
|
|
630
|
-
lambda: 0,
|
|
631
|
-
lambda: exp_delay + 1,
|
|
772
|
+
lambda: jnp.array(0, dtype=TIMER_DTYPE),
|
|
773
|
+
lambda: (exp_delay + 1).astype(TIMER_DTYPE),
|
|
632
774
|
)
|
|
633
775
|
|
|
634
776
|
respawn_random = self.object_random_respawn[obj_id]
|
|
@@ -688,147 +830,189 @@ class ForagaxEnv(environment.Environment):
|
|
|
688
830
|
biome_state.consumption_count >= self.biome_consumption_threshold
|
|
689
831
|
)
|
|
690
832
|
|
|
691
|
-
|
|
692
|
-
key, subkey = jax.random.split(key)
|
|
693
|
-
biome_keys = jax.random.split(subkey, num_biomes)
|
|
833
|
+
any_respawn = jnp.any(should_respawn)
|
|
694
834
|
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
835
|
+
def do_respawn(args):
|
|
836
|
+
(
|
|
837
|
+
object_state,
|
|
838
|
+
biome_state,
|
|
839
|
+
should_respawn,
|
|
840
|
+
key,
|
|
841
|
+
) = args
|
|
842
|
+
# Split key for all biomes in parallel
|
|
843
|
+
key, subkey = jax.random.split(key)
|
|
844
|
+
biome_keys = jax.random.split(subkey, num_biomes)
|
|
845
|
+
|
|
846
|
+
# Compute all new spawns in parallel using vmap for random, switch for deterministic
|
|
847
|
+
if self.deterministic_spawn:
|
|
848
|
+
# Use switch to dispatch to concrete biome spawns for deterministic
|
|
849
|
+
def make_spawn_fn(biome_idx):
|
|
850
|
+
def spawn_fn(key):
|
|
851
|
+
return self._spawn_biome_objects(
|
|
852
|
+
biome_idx, key, deterministic=True
|
|
853
|
+
)
|
|
701
854
|
|
|
702
|
-
|
|
855
|
+
return spawn_fn
|
|
703
856
|
|
|
704
|
-
|
|
857
|
+
spawn_fns = [make_spawn_fn(idx) for idx in range(num_biomes)]
|
|
705
858
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
859
|
+
# Apply switch for each biome
|
|
860
|
+
all_new_objects_list = []
|
|
861
|
+
all_new_colors_list = []
|
|
862
|
+
all_new_params_list = []
|
|
863
|
+
for i in range(num_biomes):
|
|
864
|
+
obj, col, par = jax.lax.switch(i, spawn_fns, biome_keys[i])
|
|
865
|
+
all_new_objects_list.append(obj)
|
|
866
|
+
all_new_colors_list.append(col)
|
|
867
|
+
all_new_params_list.append(par)
|
|
868
|
+
|
|
869
|
+
all_new_objects = jnp.stack(all_new_objects_list)
|
|
870
|
+
all_new_colors = jnp.stack(all_new_colors_list)
|
|
871
|
+
all_new_params = jnp.stack(all_new_params_list)
|
|
872
|
+
else:
|
|
873
|
+
# Random spawn works with vmap
|
|
874
|
+
all_new_objects, all_new_colors, all_new_params = jax.vmap(
|
|
875
|
+
lambda i, k: self._spawn_biome_objects(i, k, deterministic=False)
|
|
876
|
+
)(jnp.arange(num_biomes), biome_keys)
|
|
877
|
+
|
|
878
|
+
# Initialize updated grids
|
|
879
|
+
new_obj_id = object_state.object_id
|
|
880
|
+
new_color = object_state.color
|
|
881
|
+
new_params = object_state.state_params
|
|
882
|
+
new_spawn = object_state.spawn_time
|
|
883
|
+
new_gen = object_state.generation
|
|
884
|
+
|
|
885
|
+
# Update biome state
|
|
886
|
+
new_consumption_count = jnp.where(
|
|
887
|
+
should_respawn,
|
|
888
|
+
jnp.array(0, dtype=ID_DTYPE),
|
|
889
|
+
biome_state.consumption_count,
|
|
890
|
+
)
|
|
891
|
+
new_generation = biome_state.generation + should_respawn.astype(ID_DTYPE)
|
|
737
892
|
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
893
|
+
# Compute new total objects for respawning biomes
|
|
894
|
+
def count_objects(i):
|
|
895
|
+
return jnp.sum(
|
|
896
|
+
(all_new_objects[i] > 0) & self.biome_masks_array[i], dtype=ID_DTYPE
|
|
897
|
+
)
|
|
741
898
|
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
899
|
+
new_object_counts = jax.vmap(count_objects)(jnp.arange(num_biomes))
|
|
900
|
+
new_total_objects = jnp.where(
|
|
901
|
+
should_respawn, new_object_counts, biome_state.total_objects
|
|
902
|
+
)
|
|
746
903
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
904
|
+
new_biome_state = BiomeState(
|
|
905
|
+
consumption_count=new_consumption_count,
|
|
906
|
+
total_objects=new_total_objects,
|
|
907
|
+
generation=new_generation,
|
|
908
|
+
)
|
|
752
909
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
910
|
+
# Update grids for respawning biomes
|
|
911
|
+
for i in range(num_biomes):
|
|
912
|
+
biome_mask = self.biome_masks_array[i]
|
|
913
|
+
new_gen_value = new_biome_state.generation[i]
|
|
757
914
|
|
|
758
|
-
|
|
759
|
-
|
|
915
|
+
# Update mask: biome area AND needs respawn
|
|
916
|
+
should_update = biome_mask & should_respawn[i][..., None]
|
|
760
917
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
new_spawn_valid = all_new_objects[i] > 0
|
|
918
|
+
# 1. Merge: Overwrite with new objects if present, otherwise keep existing
|
|
919
|
+
new_spawn_valid = all_new_objects[i] > 0
|
|
764
920
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
921
|
+
merged_objs = jnp.where(new_spawn_valid, all_new_objects[i], new_obj_id)
|
|
922
|
+
merged_colors = jnp.where(
|
|
923
|
+
new_spawn_valid[..., None], all_new_colors[i], new_color
|
|
924
|
+
)
|
|
925
|
+
merged_params = jnp.where(
|
|
926
|
+
new_spawn_valid[..., None], all_new_params[i], new_params
|
|
927
|
+
)
|
|
772
928
|
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
|
|
929
|
+
merged_gen = jnp.where(new_spawn_valid, new_gen_value, new_gen)
|
|
930
|
+
merged_spawn = jnp.where(new_spawn_valid, current_time, new_spawn)
|
|
776
931
|
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
932
|
+
if self.dynamic_biome_spawn_empty > 0:
|
|
933
|
+
key, dropout_key = jax.random.split(key)
|
|
934
|
+
keep_mask = jax.random.bernoulli(
|
|
935
|
+
dropout_key,
|
|
936
|
+
1.0 - self.dynamic_biome_spawn_empty,
|
|
937
|
+
merged_objs.shape,
|
|
938
|
+
)
|
|
939
|
+
dropout_mask = should_update & keep_mask
|
|
940
|
+
final_objs = jnp.where(
|
|
941
|
+
dropout_mask, merged_objs, jnp.array(0, dtype=ID_DTYPE)
|
|
942
|
+
)
|
|
943
|
+
final_colors = jnp.where(
|
|
944
|
+
dropout_mask[..., None],
|
|
945
|
+
merged_colors,
|
|
946
|
+
jnp.array(0, dtype=COLOR_DTYPE),
|
|
947
|
+
)
|
|
948
|
+
final_params = jnp.where(
|
|
949
|
+
dropout_mask[..., None],
|
|
950
|
+
merged_params,
|
|
951
|
+
jnp.array(0, dtype=PARAM_DTYPE),
|
|
952
|
+
)
|
|
953
|
+
final_gen = jnp.where(
|
|
954
|
+
dropout_mask, merged_gen, jnp.array(0, dtype=ID_DTYPE)
|
|
955
|
+
)
|
|
956
|
+
final_spawn = jnp.where(
|
|
957
|
+
dropout_mask, merged_spawn, jnp.array(0, dtype=TIME_DTYPE)
|
|
958
|
+
)
|
|
959
|
+
else:
|
|
960
|
+
final_objs = merged_objs
|
|
961
|
+
final_colors = merged_colors
|
|
962
|
+
final_params = merged_params
|
|
963
|
+
final_gen = merged_gen
|
|
964
|
+
final_spawn = merged_spawn
|
|
965
|
+
|
|
966
|
+
new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
|
|
967
|
+
new_color = jnp.where(should_update[..., None], final_colors, new_color)
|
|
968
|
+
new_params = jnp.where(
|
|
969
|
+
should_update[..., None], final_params, new_params
|
|
783
970
|
)
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
# Apply dropout only to the merged result and associated metadata
|
|
787
|
-
final_objs = jnp.where(dropout_mask, merged_objs, 0)
|
|
788
|
-
final_colors = jnp.where(dropout_mask[..., None], merged_colors, 0)
|
|
789
|
-
final_params = jnp.where(dropout_mask[..., None], merged_params, 0)
|
|
790
|
-
final_gen = jnp.where(dropout_mask, merged_gen, 0)
|
|
791
|
-
final_spawn = jnp.where(dropout_mask, merged_spawn, 0)
|
|
792
|
-
else:
|
|
793
|
-
final_objs = merged_objs
|
|
794
|
-
final_colors = merged_colors
|
|
795
|
-
final_params = merged_params
|
|
796
|
-
final_gen = merged_gen
|
|
797
|
-
final_spawn = merged_spawn
|
|
798
|
-
|
|
799
|
-
# 3. Write back: Only update where should_update is true
|
|
800
|
-
new_obj_id = jnp.where(should_update, final_objs, new_obj_id)
|
|
801
|
-
new_color = jnp.where(should_update[..., None], final_colors, new_color)
|
|
802
|
-
new_params = jnp.where(should_update[..., None], final_params, new_params)
|
|
803
|
-
new_gen = jnp.where(should_update, final_gen, new_gen)
|
|
804
|
-
new_spawn = jnp.where(should_update, final_spawn, new_spawn)
|
|
805
|
-
|
|
806
|
-
# Clear timers in respawning biomes
|
|
807
|
-
new_respawn_timer = object_state.respawn_timer
|
|
808
|
-
new_respawn_object_id = object_state.respawn_object_id
|
|
809
|
-
for i in range(num_biomes):
|
|
810
|
-
biome_mask = self.biome_masks_array[i]
|
|
811
|
-
should_clear = biome_mask & should_respawn[i][..., None]
|
|
812
|
-
new_respawn_timer = jnp.where(should_clear, 0, new_respawn_timer)
|
|
813
|
-
new_respawn_object_id = jnp.where(should_clear, 0, new_respawn_object_id)
|
|
971
|
+
new_gen = jnp.where(should_update, final_gen, new_gen)
|
|
972
|
+
new_spawn = jnp.where(should_update, final_spawn, new_spawn)
|
|
814
973
|
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
974
|
+
# Clear timers in respawning biomes
|
|
975
|
+
new_respawn_timer = object_state.respawn_timer
|
|
976
|
+
new_respawn_object_id = object_state.respawn_object_id
|
|
977
|
+
for i in range(num_biomes):
|
|
978
|
+
biome_mask = self.biome_masks_array[i]
|
|
979
|
+
should_clear = biome_mask & should_respawn[i][..., None]
|
|
980
|
+
new_respawn_timer = jnp.where(
|
|
981
|
+
should_clear, jnp.array(0, dtype=TIMER_DTYPE), new_respawn_timer
|
|
982
|
+
)
|
|
983
|
+
new_respawn_object_id = jnp.where(
|
|
984
|
+
should_clear, jnp.array(0, dtype=ID_DTYPE), new_respawn_object_id
|
|
985
|
+
)
|
|
986
|
+
|
|
987
|
+
new_object_state = object_state.replace(
|
|
988
|
+
object_id=new_obj_id,
|
|
989
|
+
respawn_timer=new_respawn_timer,
|
|
990
|
+
respawn_object_id=new_respawn_object_id,
|
|
991
|
+
color=new_color,
|
|
992
|
+
state_params=new_params,
|
|
993
|
+
generation=new_gen,
|
|
994
|
+
spawn_time=new_spawn,
|
|
995
|
+
)
|
|
996
|
+
return new_object_state, new_biome_state, key
|
|
997
|
+
|
|
998
|
+
def no_respawn(args):
|
|
999
|
+
object_state, biome_state, _, key = args
|
|
1000
|
+
return object_state, biome_state, key
|
|
1001
|
+
|
|
1002
|
+
object_state, biome_state, key = jax.lax.cond(
|
|
1003
|
+
any_respawn,
|
|
1004
|
+
do_respawn,
|
|
1005
|
+
no_respawn,
|
|
1006
|
+
(object_state, biome_state, should_respawn, key),
|
|
823
1007
|
)
|
|
824
1008
|
|
|
825
|
-
return object_state,
|
|
1009
|
+
return object_state, biome_state, key
|
|
826
1010
|
|
|
827
1011
|
def reset_env(
|
|
828
1012
|
self, key: jax.Array, params: EnvParams
|
|
829
1013
|
) -> Tuple[jax.Array, EnvState]:
|
|
830
1014
|
"""Reset environment state."""
|
|
831
|
-
num_object_params =
|
|
1015
|
+
num_object_params = 3 + 2 * self.num_fourier_terms
|
|
832
1016
|
object_state = ObjectState.create_empty(self.size, num_object_params)
|
|
833
1017
|
|
|
834
1018
|
key, iter_key = jax.random.split(key)
|
|
@@ -840,7 +1024,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
840
1024
|
|
|
841
1025
|
# Set biome_id
|
|
842
1026
|
object_state = object_state.replace(
|
|
843
|
-
biome_id=jnp.where(
|
|
1027
|
+
biome_id=jnp.where(
|
|
1028
|
+
mask, jnp.array(i, dtype=BIOME_ID_DTYPE), object_state.biome_id
|
|
1029
|
+
)
|
|
844
1030
|
)
|
|
845
1031
|
|
|
846
1032
|
# Use unified spawn method
|
|
@@ -862,31 +1048,45 @@ class ForagaxEnv(environment.Environment):
|
|
|
862
1048
|
|
|
863
1049
|
# Initialize biome consumption tracking
|
|
864
1050
|
num_biomes = self.biome_object_frequencies.shape[0]
|
|
865
|
-
biome_consumption_count = jnp.zeros(num_biomes, dtype=
|
|
866
|
-
biome_total_objects = jnp.zeros(num_biomes, dtype=
|
|
1051
|
+
biome_consumption_count = jnp.zeros(num_biomes, dtype=ID_DTYPE)
|
|
1052
|
+
biome_total_objects = jnp.zeros(num_biomes, dtype=ID_DTYPE)
|
|
867
1053
|
|
|
868
1054
|
# Count objects in each biome
|
|
869
1055
|
for i in range(num_biomes):
|
|
870
1056
|
mask = self.biome_masks[i]
|
|
871
1057
|
# Count non-empty objects (object_id > 0)
|
|
872
|
-
total = jnp.sum((object_state.object_id > 0) & mask)
|
|
1058
|
+
total = jnp.sum((object_state.object_id > 0) & mask, dtype=ID_DTYPE)
|
|
873
1059
|
biome_total_objects = biome_total_objects.at[i].set(total)
|
|
874
1060
|
|
|
875
|
-
biome_generation = jnp.zeros(num_biomes, dtype=
|
|
1061
|
+
biome_generation = jnp.zeros(num_biomes, dtype=ID_DTYPE)
|
|
1062
|
+
|
|
1063
|
+
# Final state cleanup to ensure type consistency
|
|
1064
|
+
object_state = object_state.replace(
|
|
1065
|
+
object_id=object_state.object_id.astype(ID_DTYPE),
|
|
1066
|
+
respawn_timer=object_state.respawn_timer.astype(TIMER_DTYPE),
|
|
1067
|
+
respawn_object_id=object_state.respawn_object_id.astype(ID_DTYPE),
|
|
1068
|
+
spawn_time=object_state.spawn_time.astype(TIME_DTYPE),
|
|
1069
|
+
color=object_state.color.astype(COLOR_DTYPE),
|
|
1070
|
+
generation=object_state.generation.astype(ID_DTYPE),
|
|
1071
|
+
state_params=object_state.state_params.astype(PARAM_DTYPE),
|
|
1072
|
+
biome_id=object_state.biome_id.astype(BIOME_ID_DTYPE),
|
|
1073
|
+
)
|
|
876
1074
|
|
|
877
1075
|
state = EnvState(
|
|
878
|
-
pos=agent_pos,
|
|
879
|
-
time=0,
|
|
880
|
-
digestion_buffer=jnp.zeros((self.max_reward_delay,)),
|
|
1076
|
+
pos=agent_pos.astype(jnp.int32),
|
|
1077
|
+
time=jnp.array(0, dtype=jnp.int32),
|
|
1078
|
+
digestion_buffer=jnp.zeros((self.max_reward_delay,), dtype=REWARD_DTYPE),
|
|
881
1079
|
object_state=object_state,
|
|
882
1080
|
biome_state=BiomeState(
|
|
883
|
-
consumption_count=biome_consumption_count,
|
|
884
|
-
total_objects=biome_total_objects,
|
|
885
|
-
generation=biome_generation,
|
|
1081
|
+
consumption_count=biome_consumption_count.astype(ID_DTYPE),
|
|
1082
|
+
total_objects=biome_total_objects.astype(ID_DTYPE),
|
|
1083
|
+
generation=biome_generation.astype(ID_DTYPE),
|
|
886
1084
|
),
|
|
887
1085
|
)
|
|
888
1086
|
|
|
889
|
-
return
|
|
1087
|
+
return jax.lax.stop_gradient(
|
|
1088
|
+
self.get_obs(state, params)
|
|
1089
|
+
), jax.lax.stop_gradient(state)
|
|
890
1090
|
|
|
891
1091
|
def _spawn_biome_objects(
|
|
892
1092
|
self,
|
|
@@ -918,16 +1118,17 @@ class ForagaxEnv(environment.Environment):
|
|
|
918
1118
|
biome_size = int(self.biome_sizes[biome_idx])
|
|
919
1119
|
|
|
920
1120
|
grid = jnp.linspace(0, 1, biome_size, endpoint=False)
|
|
921
|
-
biome_objects_flat =
|
|
922
|
-
|
|
923
|
-
|
|
1121
|
+
biome_objects_flat = (
|
|
1122
|
+
len(biome_freqs)
|
|
1123
|
+
- jnp.searchsorted(jnp.cumsum(biome_freqs[::-1]), grid, side="right")
|
|
1124
|
+
).astype(ID_DTYPE)
|
|
924
1125
|
biome_objects_flat = jax.random.permutation(spawn_key, biome_objects_flat)
|
|
925
1126
|
|
|
926
1127
|
# Reshape to match biome dimensions (use concrete dimensions)
|
|
927
1128
|
biome_objects = biome_objects_flat.reshape(biome_height, biome_width)
|
|
928
1129
|
|
|
929
1130
|
# Place in full grid using slicing with static bounds
|
|
930
|
-
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=
|
|
1131
|
+
object_grid = jnp.zeros((self.size[1], self.size[0]), dtype=ID_DTYPE)
|
|
931
1132
|
object_grid = object_grid.at[
|
|
932
1133
|
biome_start[1] : biome_stop[1], biome_start[0] : biome_stop[0]
|
|
933
1134
|
].set(biome_objects)
|
|
@@ -941,10 +1142,10 @@ class ForagaxEnv(environment.Environment):
|
|
|
941
1142
|
)
|
|
942
1143
|
object_grid = (
|
|
943
1144
|
jnp.searchsorted(cumulative_freqs, grid_rand, side="right") - 1
|
|
944
|
-
)
|
|
1145
|
+
).astype(ID_DTYPE)
|
|
945
1146
|
|
|
946
1147
|
# Initialize color grid
|
|
947
|
-
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=
|
|
1148
|
+
color_grid = jnp.full((self.size[1], self.size[0], 3), 255, dtype=COLOR_DTYPE)
|
|
948
1149
|
|
|
949
1150
|
# Sample ONE color per object type in this biome (not per instance)
|
|
950
1151
|
# This gives objects of the same type the same color within a biome generation
|
|
@@ -956,9 +1157,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
956
1157
|
biome_object_colors = jax.random.randint(
|
|
957
1158
|
color_key,
|
|
958
1159
|
(num_actual_objects, 3),
|
|
959
|
-
minval=0,
|
|
960
|
-
maxval=256,
|
|
961
|
-
dtype=
|
|
1160
|
+
minval=jnp.array(0, dtype=COLOR_DTYPE),
|
|
1161
|
+
maxval=jnp.array(256, dtype=jnp.int32), # randint range is often int32
|
|
1162
|
+
dtype=COLOR_DTYPE,
|
|
962
1163
|
)
|
|
963
1164
|
|
|
964
1165
|
# Assign colors based on object type (starting from index 1)
|
|
@@ -970,9 +1171,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
970
1171
|
color_grid = jnp.where(obj_mask[..., None], obj_color, color_grid)
|
|
971
1172
|
|
|
972
1173
|
# Initialize parameters grid
|
|
973
|
-
num_object_params =
|
|
1174
|
+
num_object_params = 3 + 2 * self.num_fourier_terms
|
|
974
1175
|
params_grid = jnp.zeros(
|
|
975
|
-
(self.size[1], self.size[0], num_object_params), dtype=
|
|
1176
|
+
(self.size[1], self.size[0], num_object_params), dtype=PARAM_DTYPE
|
|
976
1177
|
)
|
|
977
1178
|
|
|
978
1179
|
# Generate per-object parameters for each object type
|
|
@@ -999,7 +1200,9 @@ class ForagaxEnv(environment.Environment):
|
|
|
999
1200
|
|
|
1000
1201
|
# Assign to all objects of this type in this biome
|
|
1001
1202
|
obj_mask = (object_grid == obj_idx) & biome_mask
|
|
1002
|
-
params_grid = jnp.where(
|
|
1203
|
+
params_grid = jnp.where(
|
|
1204
|
+
obj_mask[..., None], obj_params.astype(PARAM_DTYPE), params_grid
|
|
1205
|
+
)
|
|
1003
1206
|
|
|
1004
1207
|
return object_grid, color_grid, params_grid
|
|
1005
1208
|
|
|
@@ -1024,7 +1227,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1024
1227
|
|
|
1025
1228
|
def state_space(self, params: EnvParams) -> spaces.Dict:
|
|
1026
1229
|
"""State space of the environment."""
|
|
1027
|
-
num_object_params =
|
|
1230
|
+
num_object_params = 3 + 2 * self.num_fourier_terms
|
|
1028
1231
|
return spaces.Dict(
|
|
1029
1232
|
{
|
|
1030
1233
|
"pos": spaces.Box(0, max(self.size), (2,), int),
|
|
@@ -1147,7 +1350,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1147
1350
|
aperture = jnp.where(out_of_bounds[..., None], padding_value, values)
|
|
1148
1351
|
else:
|
|
1149
1352
|
# Object ID grid: use PADDING index
|
|
1150
|
-
padding_index = self.object_ids[-1]
|
|
1353
|
+
padding_index = self.object_ids[-1].astype(values.dtype)
|
|
1151
1354
|
aperture = jnp.where(out_of_bounds, padding_index, values)
|
|
1152
1355
|
else:
|
|
1153
1356
|
aperture = values
|
|
@@ -1196,9 +1399,11 @@ class ForagaxEnv(environment.Environment):
|
|
|
1196
1399
|
return jnp.zeros(obs_grid.shape + (0,), dtype=jnp.float32)
|
|
1197
1400
|
# Map object IDs to color channel indices
|
|
1198
1401
|
color_channels = jnp.where(
|
|
1199
|
-
obs_grid == 0,
|
|
1200
|
-
-1,
|
|
1201
|
-
jnp.take(
|
|
1402
|
+
obs_grid == jnp.array(0, dtype=obs_grid.dtype),
|
|
1403
|
+
jnp.array(-1, dtype=jnp.int32),
|
|
1404
|
+
jnp.take(
|
|
1405
|
+
self.object_to_color_map, obs_grid.astype(jnp.int32) - 1, axis=0
|
|
1406
|
+
),
|
|
1202
1407
|
)
|
|
1203
1408
|
obs = jax.nn.one_hot(color_channels, self.num_color_channels, axis=-1)
|
|
1204
1409
|
return obs
|
|
@@ -1218,9 +1423,8 @@ class ForagaxEnv(environment.Environment):
|
|
|
1218
1423
|
aperture_colors = color_aperture / 255.0
|
|
1219
1424
|
|
|
1220
1425
|
# Mask empty cells (object_id == 0) to white
|
|
1221
|
-
empty_mask = aperture == 0
|
|
1426
|
+
empty_mask = aperture == jnp.array(0, dtype=aperture.dtype)
|
|
1222
1427
|
white_color = jnp.ones(aperture_colors.shape, dtype=jnp.float32)
|
|
1223
|
-
|
|
1224
1428
|
obs = jnp.where(empty_mask[..., None], white_color, aperture_colors)
|
|
1225
1429
|
|
|
1226
1430
|
return obs
|
|
@@ -1260,12 +1464,15 @@ class ForagaxEnv(environment.Environment):
|
|
|
1260
1464
|
|
|
1261
1465
|
return spaces.Box(0, 1, obs_shape, float)
|
|
1262
1466
|
|
|
1263
|
-
def _compute_reward_grid(
|
|
1264
|
-
|
|
1467
|
+
def _compute_reward_grid(
|
|
1468
|
+
self, state: EnvState, object_id=None, state_params=None
|
|
1469
|
+
) -> jax.Array:
|
|
1470
|
+
"""Compute rewards for given positions. If no grid provided, uses full world."""
|
|
1471
|
+
if object_id is None:
|
|
1472
|
+
object_id = state.object_state.object_id
|
|
1473
|
+
if state_params is None:
|
|
1474
|
+
state_params = state.object_state.state_params
|
|
1265
1475
|
|
|
1266
|
-
Returns:
|
|
1267
|
-
Array of shape (H, W) with reward values for each cell
|
|
1268
|
-
"""
|
|
1269
1476
|
fixed_key = jax.random.key(0) # Fixed key for deterministic reward computation
|
|
1270
1477
|
|
|
1271
1478
|
def compute_reward(obj_id, params):
|
|
@@ -1277,9 +1484,7 @@ class ForagaxEnv(environment.Environment):
|
|
|
1277
1484
|
lambda: 0.0,
|
|
1278
1485
|
)
|
|
1279
1486
|
|
|
1280
|
-
reward_grid = jax.vmap(jax.vmap(compute_reward))(
|
|
1281
|
-
state.object_state.object_id, state.object_state.state_params
|
|
1282
|
-
)
|
|
1487
|
+
reward_grid = jax.vmap(jax.vmap(compute_reward))(object_id, state_params)
|
|
1283
1488
|
return reward_grid
|
|
1284
1489
|
|
|
1285
1490
|
def _reward_to_color(self, reward: jax.Array) -> jax.Array:
|
|
@@ -1369,172 +1574,87 @@ class ForagaxEnv(environment.Environment):
|
|
|
1369
1574
|
|
|
1370
1575
|
img = jax.lax.fori_loop(0, len(self.object_ids), update_image, img)
|
|
1371
1576
|
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
img = jax.image.resize(
|
|
1375
|
-
img,
|
|
1376
|
-
(self.size[1] * 3, self.size[0] * 3, 3),
|
|
1377
|
-
jax.image.ResizeMethod.NEAREST,
|
|
1378
|
-
)
|
|
1379
|
-
|
|
1380
|
-
# Compute rewards for all cells
|
|
1381
|
-
reward_grid = self._compute_reward_grid(state)
|
|
1382
|
-
|
|
1383
|
-
# Convert rewards to colors
|
|
1384
|
-
reward_colors = self._reward_to_color(reward_grid)
|
|
1385
|
-
|
|
1386
|
-
# Resize reward colors to match 3x scale and place in middle cells
|
|
1387
|
-
# We need to place reward colors at positions (i*3+1, j*3+1) for each (i,j)
|
|
1388
|
-
# Create index arrays for middle cells
|
|
1389
|
-
i_indices = jnp.arange(self.size[1])[:, None] * 3 + 1
|
|
1390
|
-
j_indices = jnp.arange(self.size[0])[None, :] * 3 + 1
|
|
1391
|
-
|
|
1392
|
-
# Broadcast and set middle cells
|
|
1393
|
-
img = img.at[i_indices, j_indices].set(reward_colors)
|
|
1394
|
-
|
|
1395
|
-
# Tint the agent's aperture
|
|
1577
|
+
# Define constants for all world modes
|
|
1578
|
+
alpha = 0.2
|
|
1396
1579
|
y_coords, x_coords, y_coords_adj, x_coords_adj = (
|
|
1397
1580
|
self._compute_aperture_coordinates(state.pos)
|
|
1398
1581
|
)
|
|
1399
1582
|
|
|
1400
|
-
alpha = 0.2
|
|
1401
|
-
agent_color = jnp.array(AGENT.color)
|
|
1402
|
-
|
|
1403
1583
|
if is_reward_mode:
|
|
1404
|
-
#
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
(self.size[1] * 3, self.size[0] * 3), dtype=bool
|
|
1409
|
-
)
|
|
1410
|
-
|
|
1411
|
-
# For each aperture cell, tint all 9 cells in its 3x3 block
|
|
1412
|
-
# Create meshgrid to get all aperture cell coordinates
|
|
1413
|
-
y_grid, x_grid = jnp.meshgrid(
|
|
1414
|
-
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1415
|
-
)
|
|
1416
|
-
y_flat = y_grid.flatten()
|
|
1417
|
-
x_flat = x_grid.flatten()
|
|
1418
|
-
|
|
1419
|
-
# Create offset arrays for 3x3 blocks
|
|
1420
|
-
offsets = jnp.array(
|
|
1421
|
-
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1422
|
-
)
|
|
1423
|
-
|
|
1424
|
-
# For each aperture cell, expand to 9 cells
|
|
1425
|
-
# We need to repeat each cell coordinate 9 times, then add offsets
|
|
1426
|
-
num_aperture_cells = y_flat.size
|
|
1427
|
-
y_base = jnp.repeat(
|
|
1428
|
-
y_flat * 3, 9
|
|
1429
|
-
) # Repeat each y coord 9 times and scale by 3
|
|
1430
|
-
x_base = jnp.repeat(
|
|
1431
|
-
x_flat * 3, 9
|
|
1432
|
-
) # Repeat each x coord 9 times and scale by 3
|
|
1433
|
-
y_offsets = jnp.tile(
|
|
1434
|
-
offsets[:, 0], num_aperture_cells
|
|
1435
|
-
) # Tile all 9 offsets
|
|
1436
|
-
x_offsets = jnp.tile(
|
|
1437
|
-
offsets[:, 1], num_aperture_cells
|
|
1438
|
-
) # Tile all 9 offsets
|
|
1439
|
-
y_expanded = y_base + y_offsets
|
|
1440
|
-
x_expanded = x_base + x_offsets
|
|
1441
|
-
|
|
1442
|
-
tint_mask = tint_mask.at[y_expanded, x_expanded].set(True)
|
|
1443
|
-
|
|
1444
|
-
original_colors = img
|
|
1445
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1446
|
-
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1447
|
-
else:
|
|
1448
|
-
# Tint all 9 cells in each 3x3 block for aperture cells
|
|
1449
|
-
# Create meshgrid to get all aperture cell coordinates
|
|
1450
|
-
y_grid, x_grid = jnp.meshgrid(
|
|
1451
|
-
y_coords_adj.flatten(), x_coords_adj.flatten(), indexing="ij"
|
|
1452
|
-
)
|
|
1453
|
-
y_flat = y_grid.flatten()
|
|
1454
|
-
x_flat = x_grid.flatten()
|
|
1584
|
+
# Construct 3x intermediate image
|
|
1585
|
+
# Each cell is 3x3, with reward color in center
|
|
1586
|
+
reward_grid = self._compute_reward_grid(state)
|
|
1587
|
+
reward_colors = self._reward_to_color(reward_grid)
|
|
1455
1588
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1589
|
+
# Each cell has its base color in 8 pixels and reward color in 1 (center)
|
|
1590
|
+
# Create a 3x3 pattern mask for center pixels
|
|
1591
|
+
cell_mask = jnp.array(
|
|
1592
|
+
[[False, False, False], [False, True, False], [False, False, False]]
|
|
1593
|
+
)
|
|
1594
|
+
grid_reward_mask = jnp.tile(cell_mask, (self.size[1], self.size[0]))
|
|
1460
1595
|
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
y_flat * 3, 9
|
|
1466
|
-
) # Repeat each y coord 9 times and scale by 3
|
|
1467
|
-
x_base = jnp.repeat(
|
|
1468
|
-
x_flat * 3, 9
|
|
1469
|
-
) # Repeat each x coord 9 times and scale by 3
|
|
1470
|
-
y_offsets = jnp.tile(
|
|
1471
|
-
offsets[:, 0], num_aperture_cells
|
|
1472
|
-
) # Tile all 9 offsets
|
|
1473
|
-
x_offsets = jnp.tile(
|
|
1474
|
-
offsets[:, 1], num_aperture_cells
|
|
1475
|
-
) # Tile all 9 offsets
|
|
1476
|
-
y_expanded = y_base + y_offsets
|
|
1477
|
-
x_expanded = x_base + x_offsets
|
|
1478
|
-
|
|
1479
|
-
# Get original colors and tint them
|
|
1480
|
-
original_colors = img[y_expanded, x_expanded]
|
|
1481
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1482
|
-
img = img.at[y_expanded, x_expanded].set(tinted_colors)
|
|
1483
|
-
|
|
1484
|
-
# Agent color - set all 9 cells of the agent's 3x3 block
|
|
1485
|
-
agent_y, agent_x = state.pos[1], state.pos[0]
|
|
1486
|
-
agent_offsets = jnp.array(
|
|
1487
|
-
[[dy, dx] for dy in range(3) for dx in range(3)]
|
|
1596
|
+
# Repeat base colors and rewards to 3x3
|
|
1597
|
+
base_img_x3 = jnp.repeat(jnp.repeat(img, 3, axis=0), 3, axis=1)
|
|
1598
|
+
reward_colors_x3 = jnp.repeat(
|
|
1599
|
+
jnp.repeat(reward_colors, 3, axis=0), 3, axis=1
|
|
1488
1600
|
)
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
img =
|
|
1492
|
-
|
|
1601
|
+
|
|
1602
|
+
# Composite base and reward colors
|
|
1603
|
+
img = jnp.where(
|
|
1604
|
+
grid_reward_mask[..., None], reward_colors_x3, base_img_x3
|
|
1493
1605
|
)
|
|
1494
1606
|
|
|
1495
|
-
#
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1607
|
+
# Tint the aperture region at 3x scale
|
|
1608
|
+
aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
|
|
1609
|
+
aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
|
|
1610
|
+
aperture_mask_x3 = jnp.repeat(
|
|
1611
|
+
jnp.repeat(aperture_mask, 3, axis=0), 3, axis=1
|
|
1500
1612
|
)
|
|
1613
|
+
|
|
1614
|
+
tinted_img = (
|
|
1615
|
+
(1.0 - alpha) * img.astype(jnp.float32)
|
|
1616
|
+
+ alpha * self.agent_color_jax.astype(jnp.float32)
|
|
1617
|
+
).astype(jnp.uint8)
|
|
1618
|
+
img = jnp.where(aperture_mask_x3[..., None], tinted_img, img)
|
|
1619
|
+
|
|
1620
|
+
# Set agent center block
|
|
1621
|
+
agent_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
|
|
1622
|
+
agent_mask = agent_mask.at[state.pos[1], state.pos[0]].set(True)
|
|
1623
|
+
agent_mask_x3 = jnp.repeat(jnp.repeat(agent_mask, 3, axis=0), 3, axis=1)
|
|
1624
|
+
img = jnp.where(agent_mask_x3[..., None], self.agent_color_jax, img)
|
|
1625
|
+
|
|
1626
|
+
# Final scale by 8 to get 24x
|
|
1627
|
+
img = jnp.repeat(jnp.repeat(img, 8, axis=0), 8, axis=1)
|
|
1501
1628
|
else:
|
|
1502
1629
|
# Standard rendering without reward visualization
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
tint_mask = jnp.zeros((self.size[1], self.size[0]), dtype=int)
|
|
1506
|
-
tint_mask = tint_mask.at[y_coords_adj, x_coords_adj].set(1)
|
|
1507
|
-
# Apply tint to masked positions
|
|
1508
|
-
original_colors = img
|
|
1509
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1510
|
-
img = jnp.where(tint_mask[..., None], tinted_colors, img)
|
|
1511
|
-
else:
|
|
1512
|
-
original_colors = img[y_coords_adj, x_coords_adj]
|
|
1513
|
-
tinted_colors = (1 - alpha) * original_colors + alpha * agent_color
|
|
1514
|
-
img = img.at[y_coords_adj, x_coords_adj].set(tinted_colors)
|
|
1630
|
+
aperture_mask = jnp.zeros((self.size[1], self.size[0]), dtype=bool)
|
|
1631
|
+
aperture_mask = aperture_mask.at[y_coords_adj, x_coords_adj].set(True)
|
|
1515
1632
|
|
|
1516
|
-
|
|
1517
|
-
|
|
1633
|
+
tinted_img = (
|
|
1634
|
+
(1.0 - alpha) * img.astype(jnp.float32)
|
|
1635
|
+
+ alpha * self.agent_color_jax.astype(jnp.float32)
|
|
1636
|
+
).astype(jnp.uint8)
|
|
1637
|
+
img = jnp.where(aperture_mask[..., None], tinted_img, img)
|
|
1518
1638
|
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
)
|
|
1639
|
+
# Set agent
|
|
1640
|
+
img = img.at[state.pos[1], state.pos[0]].set(self.agent_color_jax)
|
|
1641
|
+
# Scale by 24
|
|
1642
|
+
img = jnp.repeat(jnp.repeat(img, 24, axis=0), 24, axis=1)
|
|
1524
1643
|
|
|
1525
1644
|
if is_true_mode:
|
|
1526
|
-
# Apply true object borders
|
|
1527
|
-
render_grid = state.object_state.object_id
|
|
1645
|
+
# Apply true object borders
|
|
1528
1646
|
img = apply_true_borders(
|
|
1529
|
-
img,
|
|
1647
|
+
img, state.object_state.object_id, self.size, len(self.object_ids)
|
|
1530
1648
|
)
|
|
1531
1649
|
|
|
1532
|
-
# Add grid lines
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1650
|
+
# Add grid lines using masking instead of slice-setting
|
|
1651
|
+
row_grid = (jnp.arange(self.size[1] * 24) % 24) == 0
|
|
1652
|
+
col_grid = (jnp.arange(self.size[0] * 24) % 24) == 0
|
|
1653
|
+
# skip first rows/cols as they are borders or managed by caller
|
|
1654
|
+
row_grid = row_grid.at[0].set(False)
|
|
1655
|
+
col_grid = col_grid.at[0].set(False)
|
|
1656
|
+
grid_mask = row_grid[:, None] | col_grid[None, :]
|
|
1657
|
+
img = jnp.where(grid_mask[..., None], self.grid_color_jax, img)
|
|
1538
1658
|
|
|
1539
1659
|
elif is_aperture_mode:
|
|
1540
1660
|
obs_grid = state.object_state.object_id
|
|
@@ -1581,11 +1701,16 @@ class ForagaxEnv(environment.Environment):
|
|
|
1581
1701
|
self._compute_aperture_coordinates(state.pos)
|
|
1582
1702
|
)
|
|
1583
1703
|
|
|
1584
|
-
# Get reward grid for
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1704
|
+
# Get reward grid only for aperture region
|
|
1705
|
+
aperture_object_ids = state.object_state.object_id[
|
|
1706
|
+
y_coords_adj, x_coords_adj
|
|
1707
|
+
]
|
|
1708
|
+
aperture_params = state.object_state.state_params[
|
|
1709
|
+
y_coords_adj, x_coords_adj
|
|
1710
|
+
]
|
|
1711
|
+
aperture_rewards = self._compute_reward_grid(
|
|
1712
|
+
state, aperture_object_ids, aperture_params
|
|
1713
|
+
)
|
|
1589
1714
|
|
|
1590
1715
|
# Convert rewards to colors
|
|
1591
1716
|
reward_colors = self._reward_to_color(aperture_rewards)
|